### Install required dependencies

In [1]:
# !pip install numpy matplotlib

### Load dataset

In [2]:
import numpy as np

# I took this function from tutorials
def prepare_data(filename):
    patterns = []
    with open(filename) as f:
        count, width, height = [int(x) for x in f.readline().split()] # header
        dim = width*height

        for _ in range(count):
            f.readline() # skip empty line
            x = np.empty((height, width))
            for r in range(height):
                x[r,:] = np.array(list(f.readline().strip())) == '#'
            patterns.append(2*x.flatten()-1) # flatten to 1D vector, rescale {0,1} -> {-1,+1}
    
    return patterns, dim

patterns, dim = prepare_data('data.in')
labels = ['X', 'H', 'O', 'Z']

for p in patterns:
    print(p)

[-1.  1. -1. -1. -1.  1. -1. -1. -1.  1. -1.  1. -1. -1. -1. -1. -1.  1.
 -1. -1. -1. -1. -1.  1. -1.  1. -1. -1. -1.  1. -1. -1. -1.  1. -1.]
[-1.  1. -1. -1. -1.  1. -1. -1.  1. -1. -1. -1.  1. -1. -1.  1.  1.  1.
  1.  1. -1. -1.  1. -1. -1. -1.  1. -1. -1.  1. -1. -1. -1.  1. -1.]
[-1. -1.  1.  1.  1. -1. -1. -1.  1. -1. -1. -1.  1. -1. -1.  1. -1. -1.
 -1.  1. -1. -1.  1. -1. -1. -1.  1. -1. -1. -1.  1.  1.  1. -1. -1.]
[-1.  1.  1.  1.  1.  1. -1. -1. -1. -1. -1.  1. -1. -1. -1. -1. -1.  1.
 -1. -1. -1. -1. -1.  1. -1. -1. -1. -1. -1.  1.  1.  1.  1.  1. -1.]


## Corrupt dataset

In [3]:
# setting the seed globally for reproducibility
np.random.seed(42)  

def corrupt_patterns(patterns, k):
    corrupted_patterns = []
    for p in patterns:
        cp = np.copy(p)
        indices_to_flip = np.random.choice(len(p), size=k, replace=False)
        cp[indices_to_flip] *= -1
        corrupted_patterns.append(cp)
    return corrupted_patterns


# for p, cp in zip(patterns, corrupt_patterns(patterns, 23)):
#     print(f'Original: {p}')
#     print(f'Corrupted: {cp}')
    
#     print()
#     # print(np.sum(p == cp))  # Integer between 0 and len(p)


### Train hopfield

In [None]:
from hopfield import HopfieldNetwork

network = HopfieldNetwork(dim)
network.train(patterns)

results = {}
overlaps = {}
k_noises = [0, 7, 14, 21]
for k in k_noises:
    corrupted_patterns = corrupt_patterns(patterns, k) if k > 0 else list(patterns)
    results[k] = []
    overlaps[k] = [[[] for _ in range (len(patterns))] for _ in range(len(patterns))]
    
    for cpi, cp in enumerate(corrupted_patterns):
        states, energies = network.run_sync(cp, eps=6)
        for state in states:
            for pi in range(len(patterns)):
                overlaps[k][cpi][pi].append(np.sum(state == patterns[pi]) / len(state))
        results[k].append((states, energies))

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

for letter_index in range(len(patterns)): 

    fig = plt.figure(figsize=(10, 5))
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1.2])  # Left and right half

    ax0 = plt.subplot(gs[0])
    ax0.set_title(f'Letter {labels[letter_index]}')
    ax0.set_xlabel('epochs')
    ax0.set_ylabel('energy')

    data = [results[k][letter_index] for k in k_noises]
    for i, (states, energies) in enumerate(data):
        ax0.plot(energies, label=f'k={k_noises[i]}')
    ax0.legend()
    gs_right = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[1], wspace=0.3, hspace=0.4)

    # Fill in subplots (you can customize them)
    # for i in range(4):
    #     ax = plt.subplot(gs_right[i])
    #     ax.set_title(f'Pattern {i+1}')
    #     ax.imshow(np.random.choice([-1, 1], size=(10, 10)), cmap='gray', vmin=-1, vmax=1)  # Example content
    #     ax.axis('off')

    for i in range(len(patterns)):
        ax = plt.subplot(gs_right[i])
        ax.set_title(f'Pattern {labels[i]}')
        ax.set_xlabel('epochs')
        ax.set_ylabel('overlap')
        ax.set_ylim(-0.1,1.1)
        for ki, k in enumerate(k_noises):
            # overlaps[k][letter_index][i] is a list of overlaps (over epochs) for corrupted pattern letter_index vs stored pattern i
            overlap_curve = overlaps[k][letter_index][i]
            ax.plot(overlap_curve, label=f'k={k}')
        ax.axhline(0, color='gray', linestyle='--', linewidth=0.5)
        # ax.legend(fontsize='small')

    plt.tight_layout()
    plt.show()


## 10000 random

In [None]:
ta_count = 0
fa_count = 0

for i in range(100):
    print(f'Iteration {i}')
    rp = np.random.choice([-1, 1], size=35)
    print(rp)
    states, energies = network.run_sync(rp, eps=10)
    final_state = states[-1]
    true_attractor = False
    for p in patterns:
        if np.all(final_state == p):
            true_attractor = True
            break
    if true_attractor:
        ta_count += 1
    else:
        fa_count += 1

    

print(f'True attractors: {ta_count}')
print(f'False attractors: {fa_count}')
    


Iteration 0
[ 1  1 -1 -1 -1 -1 -1  1  1 -1  1  1  1 -1  1 -1 -1  1 -1 -1  1  1  1 -1
 -1 -1  1 -1  1 -1  1 -1 -1  1  1]
Iteration 1
[ 1  1 -1 -1  1 -1  1 -1  1  1 -1  1  1  1  1  1  1  1 -1  1  1 -1  1 -1
  1  1 -1 -1 -1 -1 -1  1 -1  1  1]
Iteration 2
[ 1  1  1 -1  1  1 -1 -1 -1 -1 -1 -1 -1  1 -1 -1  1  1  1  1 -1  1 -1  1
  1 -1 -1 -1  1  1  1 -1 -1 -1  1]
Iteration 3
[-1 -1 -1 -1  1  1  1 -1  1 -1  1 -1 -1 -1  1  1  1  1  1  1  1 -1  1 -1
  1 -1 -1  1  1 -1  1  1  1  1 -1]
Iteration 4
[ 1  1  1  1 -1 -1 -1  1 -1  1  1  1 -1 -1 -1  1 -1 -1  1  1 -1 -1  1 -1
 -1 -1 -1  1 -1 -1  1 -1 -1 -1 -1]
Iteration 5
[ 1  1  1 -1  1  1  1  1 -1  1 -1 -1 -1  1  1  1 -1 -1 -1 -1 -1  1  1 -1
  1  1 -1  1  1  1  1  1 -1  1 -1]
Iteration 6
[ 1 -1 -1  1 -1 -1 -1 -1 -1  1 -1 -1 -1  1  1  1 -1  1  1 -1 -1  1 -1  1
  1  1  1 -1  1  1 -1 -1 -1  1 -1]
Iteration 7
[ 1  1  1  1  1  1 -1  1 -1  1  1  1 -1 -1 -1 -1  1  1 -1 -1 -1  1  1  1
 -1  1 -1 -1  1 -1  1  1  1  1 -1]
Iteration 8
[-1 -1  1  1 -1 -1 -1 -1 -1 