<a href="https://colab.research.google.com/github/dvoils/neural-network-experiments/blob/main/associative_memory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
from scipy.integrate import ode

# Define a distance metric
def euclidean_distance(x, y):
    return np.linalg.norm(x - y)

# Stored patterns in R^2
patterns = np.array([
    [1.0, 1.0],
    [-1.0, -1.0],
    [1.0, -1.0],
    [-1.0, 1.0]
])

sigma = 0.5

# Softmax weights
def weights(x):
    dists = np.array([np.linalg.norm(x - p)**2 for p in patterns])
    exps = np.exp(-0.5 * dists / sigma**2)
    return exps / np.sum(exps)

# ODE dynamics
def memory_dynamics(t, x):
    x = np.array(x)
    w = weights(x)
    dxdt = sum(w[i] * (patterns[i] - x) for i in range(len(patterns)))
    return dxdt

# Perturb a pattern
true_pattern_index = 0
true_pattern = patterns[true_pattern_index]
perturbation = np.random.normal(scale=0.4, size=true_pattern.shape)
x0 = true_pattern + perturbation

# Distance from input to all patterns
input_distances = [euclidean_distance(x0, p) for p in patterns]

# Integrate ODE
solver = ode(memory_dynamics)
solver.set_integrator('dopri5')
solver.set_initial_value(x0, 0)

dt = 0.05
t_max = 10
while solver.successful() and solver.t < t_max:
    solver.integrate(solver.t + dt)

final_state = solver.y
output_distances = [euclidean_distance(final_state, p) for p in patterns]

# Print results
print("Distances from perturbed input to stored patterns:")
for i, d in enumerate(input_distances):
    print(f"  Pattern {i}: {d:.4f}")

print("\nDistances from recalled output to stored patterns:")
for i, d in enumerate(output_distances):
    print(f"  Pattern {i}: {d:.4f}")

print(f"\nRecalled state: {final_state}")
print(f"Original pattern index: {true_pattern_index}, Pattern: {true_pattern}")


Distances from perturbed input to stored patterns:
  Pattern 0: 0.3467
  Pattern 1: 3.1058
  Pattern 2: 2.3450
  Pattern 3: 2.0656

Distances from recalled output to stored patterns:
  Pattern 0: 0.0009
  Pattern 1: 2.8275
  Pattern 2: 1.9993
  Pattern 3: 1.9993

Recalled state: [0.99932746 0.99934214]
Original pattern index: 0, Pattern: [1. 1.]
