In [13]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from utils import pattern_utils

In [14]:
class modern_hopfield():
  
  func_type = "Polyn"
  def __init__(self, N: int):
    """
    `N`: number of neurons of the network
    """
    self.N = N
    self.memory_matrix = np.zeros(shape=(1,N))
    return
  
  def add_pattern(self, p: np.ndarray):
    """
    `p`: pattern to add to memory
    """
    if self.memory_matrix[0,0]==0:
      self.memory_matrix = p
    else:
      self.memory_matrix = np.hstack((self.memory_matrix, p))
    return

  def retrieve_pattern(self, p: np.ndarray, epochs: int):
    """
    `p`: input pattern  \n
    `epochs`: maximum number of iterations \n
    performs the update rule up to `epochs` times, trying to match the input pattern `p` to a pattern stored in memory     
    """
    sigma = p.copy()
    for epoch in tqdm(range(epochs)):
      sigma = self._update_rule(sigma)
    return sigma

  def _update_rule(self, sigma: np.ndarray):
    """
    `sigma`: vector containing the values of the network's neurons at time t. \n 
    The update rule returns the values of the network's neurons at time t+1 
    """
    s = 0
    for xi in self.memory_matrix:
      s += self._F(xi + np.dot(xi, sigma) - np.multiply(xi, sigma)) - self._F(-xi + np.dot(xi, sigma) - np.multiply(xi, sigma))
    return np.sign(s)
  
  n=3
  def _F(self, x):
    """
    Energy function used in the update rule. \n
    Different types of function:
    - `Polyn`: polynomial function of order n
    - `RePn`: rectified polynomial function of order n
    """
    match self.func_type:
      case "Polyn":
        return x**self.n
      case "RePn":
        return np.maximum(0, x**self.n)
  


In [15]:
pattern_lenght = 100
n_patterns = 100

model = modern_hopfield(pattern_lenght)

In [16]:
patterns = pattern_utils.generate_patterns(pattern_lenght=pattern_lenght, N=n_patterns)
model.add_pattern(patterns)

In [22]:
corrupted = pattern_utils.corrupt_patterns(patterns=patterns, q=1, corruption_type="Erase")
np.array_equal(corrupted, patterns)
corrupted

array([[-1,  1, -1, ...,  1,  1, -1],
       [-1,  1,  1, ...,  1, -1,  1],
       [ 1, -1, -1, ..., -1,  1, -1],
       ...,
       [ 1,  1, -1, ...,  1, -1,  1],
       [-1, -1, -1, ...,  1,  1, -1],
       [-1, -1, -1, ...,  1,  1,  1]])

In [21]:
print(model.memory_matrix[0,:], "\n")
print(corrupted[0,:])

[-1  1 -1 -1  1 -1  1  1 -1  1 -1 -1  1 -1 -1  1 -1 -1  1  1 -1  1  1 -1
  1  1  1 -1 -1  1  1  1  1  1 -1 -1  1  1  1  1 -1 -1  1  1 -1 -1  1 -1
 -1  1 -1 -1  1 -1 -1 -1 -1 -1  1  1  1  1  1  1  1  1 -1  1 -1  1 -1 -1
 -1 -1 -1  1 -1  1  1 -1 -1 -1  1 -1  1 -1  1 -1  1 -1  1  1 -1 -1  1 -1
 -1  1  1 -1] 

[-1  1 -1 -1  1 -1  1  1 -1  1 -1 -1  1 -1 -1  1 -1 -1  1  1 -1  1  1 -1
  1  1  1 -1 -1  1  1  1  1  1 -1 -1  1  1  1  1 -1 -1  1  1 -1 -1  1 -1
 -1  1 -1 -1  1 -1 -1 -1 -1 -1  1  1  1  1  1  1  1  1 -1  1 -1  1  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0]


In [18]:
for i, c in enumerate(corrupted):
  r = model.retrieve_pattern(c, epochs=100)
  if np.array_equal(r, patterns[i]): print(f"Pattern #{i} recognized")
  else: print(f"Pattern #{i} not recognized")

100%|██████████| 100/100 [00:00<00:00, 679.14it/s]


Pattern #0 recognized


100%|██████████| 100/100 [00:00<00:00, 716.43it/s]


Pattern #1 recognized


100%|██████████| 100/100 [00:00<00:00, 504.37it/s]


Pattern #2 recognized


100%|██████████| 100/100 [00:00<00:00, 1420.79it/s]


Pattern #3 recognized


100%|██████████| 100/100 [00:00<00:00, 476.89it/s]


Pattern #4 recognized


100%|██████████| 100/100 [00:00<00:00, 412.86it/s]


Pattern #5 recognized


100%|██████████| 100/100 [00:00<00:00, 468.22it/s]


Pattern #6 recognized


100%|██████████| 100/100 [00:00<00:00, 1018.41it/s]


Pattern #7 recognized


100%|██████████| 100/100 [00:00<00:00, 633.69it/s]


Pattern #8 recognized


100%|██████████| 100/100 [00:00<00:00, 1362.67it/s]


Pattern #9 recognized


100%|██████████| 100/100 [00:00<00:00, 1603.38it/s]


Pattern #10 recognized


100%|██████████| 100/100 [00:00<00:00, 1529.47it/s]


Pattern #11 recognized


100%|██████████| 100/100 [00:00<00:00, 1231.47it/s]


Pattern #12 recognized


100%|██████████| 100/100 [00:00<00:00, 1607.92it/s]


Pattern #13 recognized


100%|██████████| 100/100 [00:00<00:00, 1637.48it/s]


Pattern #14 recognized


100%|██████████| 100/100 [00:00<00:00, 1608.96it/s]


Pattern #15 recognized


100%|██████████| 100/100 [00:00<00:00, 1586.54it/s]


Pattern #16 recognized


100%|██████████| 100/100 [00:00<00:00, 1137.06it/s]


Pattern #17 recognized


100%|██████████| 100/100 [00:00<00:00, 696.26it/s]


Pattern #18 recognized


100%|██████████| 100/100 [00:00<00:00, 509.22it/s]


Pattern #19 recognized


100%|██████████| 100/100 [00:00<00:00, 511.74it/s]


Pattern #20 recognized


100%|██████████| 100/100 [00:00<00:00, 1288.79it/s]


Pattern #21 recognized


100%|██████████| 100/100 [00:00<00:00, 1647.77it/s]


Pattern #22 recognized


100%|██████████| 100/100 [00:00<00:00, 711.10it/s]


Pattern #23 recognized


100%|██████████| 100/100 [00:00<00:00, 306.58it/s]


Pattern #24 recognized


100%|██████████| 100/100 [00:00<00:00, 920.55it/s]


Pattern #25 recognized


100%|██████████| 100/100 [00:00<00:00, 654.58it/s]


Pattern #26 recognized


100%|██████████| 100/100 [00:00<00:00, 841.64it/s]


Pattern #27 recognized


100%|██████████| 100/100 [00:00<00:00, 578.56it/s]


Pattern #28 recognized


100%|██████████| 100/100 [00:00<00:00, 537.73it/s]


Pattern #29 recognized


100%|██████████| 100/100 [00:00<00:00, 532.46it/s]


Pattern #30 recognized


100%|██████████| 100/100 [00:00<00:00, 879.47it/s]


Pattern #31 recognized


100%|██████████| 100/100 [00:00<00:00, 1601.88it/s]


Pattern #32 recognized


100%|██████████| 100/100 [00:00<00:00, 1584.59it/s]


Pattern #33 recognized


100%|██████████| 100/100 [00:00<00:00, 1579.25it/s]


Pattern #34 recognized


100%|██████████| 100/100 [00:00<00:00, 1647.51it/s]


Pattern #35 recognized


100%|██████████| 100/100 [00:00<00:00, 1195.99it/s]


Pattern #36 recognized


100%|██████████| 100/100 [00:00<00:00, 1619.35it/s]


Pattern #37 recognized


100%|██████████| 100/100 [00:00<00:00, 1605.93it/s]


Pattern #38 recognized


100%|██████████| 100/100 [00:00<00:00, 1648.70it/s]


Pattern #39 recognized


100%|██████████| 100/100 [00:00<00:00, 1590.58it/s]


Pattern #40 recognized


100%|██████████| 100/100 [00:00<00:00, 1607.99it/s]


Pattern #41 recognized


100%|██████████| 100/100 [00:00<00:00, 1564.22it/s]


Pattern #42 recognized


100%|██████████| 100/100 [00:00<00:00, 1597.75it/s]


Pattern #43 recognized


100%|██████████| 100/100 [00:00<00:00, 1614.21it/s]


Pattern #44 recognized


100%|██████████| 100/100 [00:00<00:00, 1553.91it/s]


Pattern #45 recognized


100%|██████████| 100/100 [00:00<00:00, 1577.39it/s]


Pattern #46 recognized


100%|██████████| 100/100 [00:00<00:00, 1603.73it/s]


Pattern #47 recognized


100%|██████████| 100/100 [00:00<00:00, 1602.47it/s]


Pattern #48 recognized


100%|██████████| 100/100 [00:00<00:00, 906.09it/s]


Pattern #49 recognized


100%|██████████| 100/100 [00:00<00:00, 1558.03it/s]


Pattern #50 recognized


100%|██████████| 100/100 [00:00<00:00, 1575.40it/s]


Pattern #51 recognized


100%|██████████| 100/100 [00:00<00:00, 1559.44it/s]


Pattern #52 recognized


100%|██████████| 100/100 [00:00<00:00, 1542.45it/s]


Pattern #53 recognized


100%|██████████| 100/100 [00:00<00:00, 1605.83it/s]


Pattern #54 recognized


100%|██████████| 100/100 [00:00<00:00, 1483.72it/s]


Pattern #55 recognized


100%|██████████| 100/100 [00:00<00:00, 1544.81it/s]


Pattern #56 recognized


100%|██████████| 100/100 [00:00<00:00, 1577.55it/s]


Pattern #57 recognized


100%|██████████| 100/100 [00:00<00:00, 1551.29it/s]


Pattern #58 recognized


100%|██████████| 100/100 [00:00<00:00, 1522.53it/s]


Pattern #59 recognized


100%|██████████| 100/100 [00:00<00:00, 1538.36it/s]


Pattern #60 recognized


100%|██████████| 100/100 [00:00<00:00, 410.63it/s]


Pattern #61 recognized


100%|██████████| 100/100 [00:00<00:00, 955.76it/s]


Pattern #62 recognized


100%|██████████| 100/100 [00:00<00:00, 1491.74it/s]


Pattern #63 recognized


100%|██████████| 100/100 [00:00<00:00, 1575.21it/s]


Pattern #64 recognized


100%|██████████| 100/100 [00:00<00:00, 1676.19it/s]


Pattern #65 recognized


100%|██████████| 100/100 [00:00<00:00, 1575.37it/s]


Pattern #66 recognized


100%|██████████| 100/100 [00:00<00:00, 671.52it/s]


Pattern #67 recognized


100%|██████████| 100/100 [00:00<00:00, 546.00it/s]


Pattern #68 recognized


100%|██████████| 100/100 [00:00<00:00, 589.00it/s]


Pattern #69 recognized


100%|██████████| 100/100 [00:00<00:00, 677.88it/s]


Pattern #70 recognized


100%|██████████| 100/100 [00:00<00:00, 1449.32it/s]


Pattern #71 recognized


100%|██████████| 100/100 [00:00<00:00, 569.39it/s]


Pattern #72 recognized


100%|██████████| 100/100 [00:00<00:00, 582.65it/s]


Pattern #73 recognized


100%|██████████| 100/100 [00:00<00:00, 627.75it/s]


Pattern #74 recognized


100%|██████████| 100/100 [00:00<00:00, 752.39it/s]


Pattern #75 recognized


100%|██████████| 100/100 [00:00<00:00, 1656.62it/s]


Pattern #76 recognized


100%|██████████| 100/100 [00:00<00:00, 1643.83it/s]


Pattern #77 recognized


100%|██████████| 100/100 [00:00<00:00, 1592.70it/s]


Pattern #78 recognized


100%|██████████| 100/100 [00:00<00:00, 1647.69it/s]


Pattern #79 recognized


100%|██████████| 100/100 [00:00<00:00, 773.81it/s]


Pattern #80 recognized


100%|██████████| 100/100 [00:00<00:00, 790.85it/s]


Pattern #81 recognized


100%|██████████| 100/100 [00:00<00:00, 1563.03it/s]


Pattern #82 recognized


100%|██████████| 100/100 [00:00<00:00, 732.48it/s]


Pattern #83 recognized


100%|██████████| 100/100 [00:00<00:00, 662.96it/s]


Pattern #84 recognized


100%|██████████| 100/100 [00:00<00:00, 1560.16it/s]


Pattern #85 recognized


100%|██████████| 100/100 [00:00<00:00, 1574.56it/s]


Pattern #86 recognized


100%|██████████| 100/100 [00:00<00:00, 1572.58it/s]


Pattern #87 recognized


100%|██████████| 100/100 [00:00<00:00, 760.80it/s]


Pattern #88 recognized


100%|██████████| 100/100 [00:00<00:00, 564.38it/s]


Pattern #89 recognized


100%|██████████| 100/100 [00:00<00:00, 719.30it/s]


Pattern #90 recognized


100%|██████████| 100/100 [00:00<00:00, 569.69it/s]


Pattern #91 recognized


100%|██████████| 100/100 [00:00<00:00, 428.29it/s]


Pattern #92 recognized


100%|██████████| 100/100 [00:00<00:00, 852.25it/s]


Pattern #93 recognized


100%|██████████| 100/100 [00:00<00:00, 1185.79it/s]


Pattern #94 recognized


100%|██████████| 100/100 [00:00<00:00, 1604.84it/s]


Pattern #95 recognized


100%|██████████| 100/100 [00:00<00:00, 1570.55it/s]


Pattern #96 recognized


100%|██████████| 100/100 [00:00<00:00, 1574.37it/s]


Pattern #97 recognized


100%|██████████| 100/100 [00:00<00:00, 1606.73it/s]


Pattern #98 recognized


100%|██████████| 100/100 [00:00<00:00, 1539.89it/s]

Pattern #99 recognized



