In [1034]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm



In [1035]:
class modern_hopfield():
  
  memory_matrix = []
  func_type = "Polyn"
  def __init__(self, N: int):
    """
    `N`: number of neurons of the network
    """
    self.N = N
    return
  
  def add_pattern(self, p: np.ndarray):
    """
    `p`: pattern to add to memory
    """
    self.memory_matrix.append(p)
    return

  def retrieve_pattern(self, p: np.ndarray, epochs: int):
    """
    `p`: input pattern  \n
    `epochs`: maximum number of iterations \n
    performs the update rule up to `epochs` times, trying to match the input pattern `p` to a pattern stored in memory     
    """
    sigma = p.copy()
    for epoch in tqdm(range(epochs)):
      sigma = self._update_rule(sigma)
    return sigma

  def _update_rule(self, sigma: np.ndarray):
    """
    `sigma`: vector containing the values of the network's neurons at time t. \n 
    The update rule returns the values of the network's neurons at time t+1 
    """
    s = 0
    for xi in self.memory_matrix:
      s += self._F(xi + np.dot(xi, sigma) - np.multiply(xi, sigma)) - self._F(-xi + np.dot(xi, sigma) - np.multiply(xi, sigma))
    return np.sign(s)
  
  n=3
  def _F(self, x):
    """
    Energy function used in the update rule. \n
    Different types of function:
    - `Polyn`: polynomial function of order n
    - `RePn`: rectified polynomial function of order n
    """
    match self.func_type:
      case "Polyn":
        return x**self.n
      case "RePn":
        return np.maximum(0, x**self.n)
  


In [1036]:
N = 100
n_patterns = 100
epochs = 100

model = modern_hopfield(N)

In [1037]:
patterns = []
for i in range(n_patterns):
  patterns.append(np.random.randint(0,2,N) * 2 - 1)

for p in patterns:
  model.add_pattern(p)

In [1038]:
p0 = patterns[0].copy()

p0[3:45] = - p0[3:45]
p0

array([ 1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,
        1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1,
       -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,  1,
        1, -1, -1,  1,  1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1,  1, -1,
       -1,  1, -1, -1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,
       -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1])

In [1039]:
retrieved_p = model.retrieve_pattern(p0, epochs)

print(patterns[0], "\n")
print(retrieved_p, "\n")

if np.array_equal(retrieved_p, patterns[0]): print("they are equal")

for p in patterns:
  if np.array_equal(retrieved_p, p): print("different match found")

100%|██████████| 100/100 [00:00<00:00, 361.76it/s]

[ 1  1  1 -1 -1  1  1 -1 -1 -1 -1 -1  1  1 -1 -1 -1 -1 -1  1  1 -1  1  1
  1 -1  1  1 -1  1  1  1  1  1  1 -1  1 -1  1  1 -1 -1  1  1  1  1 -1 -1
  1 -1  1  1 -1 -1  1  1 -1  1 -1 -1  1 -1 -1 -1 -1  1  1 -1 -1  1 -1 -1
 -1 -1 -1 -1  1 -1  1  1 -1  1  1 -1 -1 -1  1 -1 -1  1  1  1 -1  1 -1 -1
  1  1  1  1] 

[-1 -1 -1  1 -1  1 -1  1  1  1 -1 -1  1  1 -1 -1  1 -1 -1  1  1  1 -1  1
  1 -1 -1 -1  1 -1  1  1  1  1  1  1  1 -1 -1 -1 -1 -1 -1  1 -1  1 -1 -1
 -1 -1  1  1  1  1  1 -1 -1  1  1  1 -1 -1 -1  1  1 -1 -1  1  1  1 -1  1
  1  1 -1  1  1  1 -1 -1 -1  1 -1 -1 -1 -1 -1  1  1 -1 -1 -1 -1 -1  1  1
 -1 -1 -1 -1] 

different match found



