### This code illustrates the learning algorithm for Dense Associative Memories from [Dense Associative Memory for Pattern Recognition](https://arxiv.org/abs/1606.01164) on MNIST data set.
If you want to learn more about Dense Associative Memories, check out a [NIPS 2016 talk](https://channel9.msdn.com/Events/Neural-Information-Processing-Systems-Conference/Neural-Information-Processing-Systems-Conference-NIPS-2016/Dense-Associative-Memory-for-Pattern-Recognition) or a [research seminar](https://www.youtube.com/watch?v=lvuAU_3t134). 

This cell loads the data and normalizes it to the [-1,1] range

In [1]:
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
!pip install emnist
from emnist import extract_training_samples, extract_test_samples



You should consider upgrading via the 'c:\python39\python.exe -m pip install --upgrade pip' command.


In [52]:
images, labels = extract_training_samples('letters')
images_test, labels_test = extract_test_samples('letters')
print(labels_test.shape)

included_chars = ['g','l','x']

def get_indices_per_char_train(included_chars_str):
    indices_per_char = []

    for char_str in included_chars_str:
        char_num = ord(char_str) - 96
        char_indices = [i for i, x in enumerate(labels) if x == char_num]
        indices_per_char.append(char_indices)
    return indices_per_char

def get_indices_per_char_test(included_chars_str):
    indices_per_char_test = []

    for char_str in included_chars_str:
        char_num = ord(char_str) - 96
        char_indices = [i for i, x in enumerate(labels_test) if x == char_num]
        indices_per_char_test.append(char_indices)
    return indices_per_char_test

indices_per_char = get_indices_per_char_train(included_chars)
indices_per_char_test = get_indices_per_char_test(included_chars)

N=784                                  # total neurons
Nc=len(included_chars)                 # classifier neurons
training_size=4800*len(included_chars)
test_size=800*len(included_chars)

M=np.zeros((0,N))
Lab=np.zeros((Nc,0))
for i in range(Nc):
    flat_images = images[indices_per_char[i]].reshape((4800, 784))
    M=np.concatenate((M, flat_images), axis=0)
    lab1=-np.ones((Nc,4800))
    lab1[i,:]=1.0
    Lab=np.concatenate((Lab,lab1), axis=1)
M=2*M/255.0-1
M=M.T

MT=np.zeros((0,N))
LabT=np.zeros((Nc,0))
for i in range(Nc):
    flat_images = images_test[indices_per_char_test[i]].reshape((800, 784))
    MT=np.concatenate((MT, flat_images), axis=0)
    lab1=-np.ones((Nc,800))
    lab1[i,:]=1.0
    LabT=np.concatenate((LabT,lab1), axis=1)
MT=2*MT/255.0-1
MT=MT.T

# print(indices_per_char_test[0][2])
# plt.imshow(np.reshape(MT[:,800], (28,28)))
# plt.show()

(20800,)


In [44]:
import scipy.io
mat = scipy.io.loadmat('mnist_all.mat')
print("hello")
print(mat['train'+str(i)].shape)

hello
(5958, 784)


To draw a heatmap of the weights together with the errors on the training set (blue) and the test set (red) a helper function is created:

In [53]:
def draw_weights(synapses, Kx, Ky, err_tr, err_test):
    fig.clf()
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)
    
    plt.sca(ax1)
    yy=0
    HM=np.zeros((28*Kx,28*Ky))
    for y in range(Ky):
        for x in range(Kx):
            HM[y*28:(y+1)*28,x*28:(x+1)*28]=synapses[yy,:].reshape(28,28)
            yy += 1
    nc=np.amax(np.absolute(HM))
    im=plt.imshow(HM,cmap='bwr',vmin=-nc,vmax=nc)
    cbar=fig.colorbar(im,ticks=[np.amin(HM), 0, np.amax(HM)])
    plt.axis('off')
    cbar.ax.tick_params(labelsize=30) 
    
    plt.sca(ax2)
    plt.ylim((0,100))
    plt.xlim((0,len(err_tr)+1))
    ax2.plot(np.arange(1, len(err_tr)+1, 1), err_tr, color='b', linewidth=4)
    ax2.plot(np.arange(1, len(err_test)+1, 1), err_test, color='r',linewidth=4)
    ax2.set_xlabel('Number of epochs', size=30)
    ax2.set_ylabel('Training and test error, %', size=30)
    ax2.tick_params(labelsize=30)

    plt.tight_layout()
    fig.canvas.draw()

This cell defines parameters of the algorithm: `n` - power of the rectified polynomial in [Eq 3](https://arxiv.org/abs/1606.01164); `m` - power of the loss function in [Eq 14](https://arxiv.org/abs/1606.01164); `K` - number of memories that are displayed as an `Ky` by `Kx` array by the helper function defined above; `eps0` - initial learning rate that is exponentially annealed during training with the damping parameter `f`, as explained in [Eq 12](https://arxiv.org/abs/1606.01164); `p` - momentum as defined in [Eq 13](https://arxiv.org/abs/1606.01164); `mu` - the mean of the gaussian distribution that initializes the weights; `sigma` - the standard deviation of that gaussian; `Nep` - number of epochs; `Num` - size of the training minibatch; `NumT` - size of the test minibatch; `prec` - parameter that controls numerical precision of the weight updates. Parameter `beta` that is used in [Eq 9](https://arxiv.org/abs/1606.01164) is defined as `beta=1/Temp**n`. The choice of temperatures `Temp` as well as the duration of the annealing `thresh_pret` is discussed in [Appendix A](https://arxiv.org/abs/1606.01164). 

In [54]:
Kx=10              # Number of memories per row on the weights plot
Ky=10              # Number of memories per column on the weigths plot
K=Kx*Ky            # Number of memories
n=20               # Power of the interaction vertex in the DAM energy function
m=30               # Power of the loss function
eps0=4.0e-2        # Initial learning rate  
f=0.998            # Damping parameter for the learning rate
p=0.6                         # Momentum
epochs=300                    # Number of epochs
Temp_in=540.                  # Initial temperature
Temp_f=540.                   # Final temperature
thresh_pret=200               # Length of the temperature ramp
training_batch_size=600      # Size of training minibatch     
test_batch_size=1200          # Size of test minibatch 
mu=-0.3            # Weights initialization mean
sigma=0.3          # Weights initialization std
prec=1.0e-30       # Precision of weight update

This cell defines the main code. The external loop runs over epochs `nep`, the internal loop runs over minibatches.  The weights are updated after each minibatch in a way so that the largest update is equal to the learning rate `eps` at that epoch, see [Eq 13](https://arxiv.org/abs/1606.01164). The weights are displayed by the helper function after each epoch. 

In [None]:
%matplotlib inline
%matplotlib notebook
fig=plt.figure(figsize=(12,10))

memories=np.random.normal(mu, sigma, (K, N+Nc))
VKS=np.zeros((K, N+Nc))

aux=-np.ones((Nc,training_batch_size*Nc))
for d in range(Nc):
    aux[d,d*training_batch_size:(d+1)*training_batch_size]=1.

auxT=-np.ones((Nc,test_batch_size*Nc))
for d in range(Nc):
    auxT[d,d*test_batch_size:(d+1)*test_batch_size]=1.
    
err_tr=[]
err_test=[]
for epoch in range(epochs):
    learning_rate=eps0*f**epoch
    # Temperature ramp
    if epoch<=thresh_pret:
        Temp=Temp_in+(Temp_f-Temp_in)*epoch/thresh_pret
    else:
        Temp=Temp_f
    beta=1./Temp**n

    # Training
    perm=np.random.permutation(training_size) # random order
    M=M[:,perm]                               # change memory order
    Lab=Lab[:,perm]                           # change label order
    num_correct = 0
    # for every batch
    for k in range(training_size//training_batch_size):       # floor division
        batch_memories = M[:,k*training_batch_size:(k+1)*training_batch_size] 
        batch_labels = Lab[:,k*training_batch_size:(k+1)*training_batch_size]
        t=np.reshape(batch_labels,(1,Nc*training_batch_size))
                
        # u = memories in column form with classifier neurons all -1
        u=np.concatenate((batch_memories, -np.ones((Nc,training_batch_size))),axis=0)
        # uu = Nc * every memory
        uu=np.tile(u,(1,Nc))
        
        # vv = memories in column form with classifier neurons with one +1
        vv=np.concatenate((uu[:N,:],aux),axis=0) 
                
        KSvv=np.maximum(np.dot(memories,vv),0)    # memories with positive classifier
        KSuu=np.maximum(np.dot(memories,uu),0)    # memories with negative classifier
        
        # Diff F(postive classifier) and F(negative classifier)
        Y=np.tanh(beta*np.sum(KSvv**n-KSuu**n, axis=0))  # Forward path, Eq 9
        pred_labels=np.reshape(Y,(Nc,training_batch_size))
        
        # Gradients of the loss function
        d_KS=np.dot(np.tile((t-Y)**(2*m-1)*(1-Y)*(1+Y), (K,1))*KSvv**(n-1),vv.T) - np.dot(np.tile((t-Y)**(2*m-1)*(1-Y)*(1+Y), (K,1))*KSuu**(n-1),uu.T)
        
        VKS=p*VKS+d_KS
        nc=np.amax(np.absolute(VKS),axis=1).reshape(K,1)
        nc[nc<prec]=prec
        ncc=np.tile(nc,(1,N+Nc))
        memories += learning_rate*VKS/ncc
        memories=np.clip(memories, a_min=-1., a_max=1.)
        
        correct=np.argmax(pred_labels,axis=0)==np.argmax(batch_labels,axis=0)
        num_correct += np.sum(correct)
        
    err_tr.append(100.*(1.0-num_correct/training_size))
    
    # Testing
    num_correct = 0
    for k in range(test_size//test_batch_size):
        v=MT[:,k*test_batch_size:(k+1)*test_batch_size]
        t_R=LabT[:,k*test_batch_size:(k+1)*test_batch_size]
        u=np.concatenate((v, -np.ones((Nc,test_batch_size))),axis=0)
        uu=np.tile(u,(1,Nc))
        vv=np.concatenate((uu[:N,:],auxT),axis=0)
        KSvv=np.maximum(np.dot(memories,vv),0)
        KSuu=np.maximum(np.dot(memories,uu),0)
        Y=np.tanh(beta*np.sum(KSvv**n-KSuu**n, axis=0))  # Forward path, Eq 9
        Y_R=np.reshape(Y,(Nc,test_batch_size))
        correct=np.argmax(Y_R,axis=0)==np.argmax(t_R,axis=0)
        num_correct += np.sum(correct)
        print(np.argmax(Y_R,axis=0))
        print(np.argmax(t_R,axis=0))
        print(num_correct)
    errr=100.*(1.0-num_correct/test_size)
    err_test.append(errr)
    draw_weights(memories[:,:N], Kx, Ky, err_tr, err_test)


<IPython.core.display.Javascript object>

[1 1 1 ... 1 1 1]
[0 0 0 ... 1 1 1]
400
[1 1 1 ... 1 1 1]
[1 1 1 ... 2 2 2]
800
[1 1 1 ... 1 1 1]
[0 0 0 ... 1 1 1]
400
[1 1 1 ... 1 1 1]
[1 1 1 ... 2 2 2]
800
[1 1 1 ... 1 1 1]
[0 0 0 ... 1 1 1]
400
[1 1 1 ... 1 1 1]
[1 1 1 ... 2 2 2]
800
[1 1 1 ... 1 1 1]
[0 0 0 ... 1 1 1]
400
[1 1 1 ... 1 1 1]
[1 1 1 ... 2 2 2]
800
[1 1 1 ... 1 1 1]
[0 0 0 ... 1 1 1]
400
[1 1 1 ... 1 1 1]
[1 1 1 ... 2 2 2]
800
[1 1 1 ... 1 1 1]
[0 0 0 ... 1 1 1]
400
[1 1 1 ... 1 1 1]
[1 1 1 ... 2 2 2]
821
[1 1 1 ... 1 1 1]
[0 0 0 ... 1 1 1]
394
[1 1 1 ... 2 2 1]
[1 1 1 ... 2 2 2]
1468
[1 1 2 ... 1 1 1]
[0 0 0 ... 1 1 1]
391
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
1514
[1 2 2 ... 1 1 1]
[0 0 0 ... 1 1 1]
388
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
1525
[1 1 2 ... 1 1 1]
[0 0 0 ... 1 1 1]
409
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
1542
[0 2 2 ... 1 1 1]
[0 0 0 ... 1 1 1]
486
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
1637
[0 2 2 ... 1 1 1]
[0 0 0 ... 1 1 1]
609
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
1757
[0 2 2 ... 1 1 1]
[0 0 0 ... 1 1 1

[0 0 0 ... 1 1 1]
[0 0 0 ... 1 1 1]
1189
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
2355
[0 0 0 ... 1 1 1]
[0 0 0 ... 1 1 1]
1183
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
2362
[0 0 0 ... 1 1 1]
[0 0 0 ... 1 1 1]
1187
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
2369
[0 0 0 ... 1 1 1]
[0 0 0 ... 1 1 1]
1188
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
2369
[0 0 0 ... 1 1 1]
[0 0 0 ... 1 1 1]
1184
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
2367
[0 0 0 ... 1 1 1]
[0 0 0 ... 1 1 1]
1186
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
2367
[0 0 0 ... 1 1 1]
[0 0 0 ... 1 1 1]
1186
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
2369
[0 0 0 ... 1 1 1]
[0 0 0 ... 1 1 1]
1188
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
2370
[0 0 0 ... 1 1 1]
[0 0 0 ... 1 1 1]
1185
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
2370
[0 0 0 ... 1 1 1]
[0 0 0 ... 1 1 1]
1187
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
2368
[0 0 0 ... 1 1 1]
[0 0 0 ... 1 1 1]
1183
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
2367
[0 0 0 ... 1 1 1]
[0 0 0 ... 1 1 1]
1188
[1 1 1 ... 2 2 2]
[1 1 1 ... 2 2 2]
2367
[0 0 0 ... 1 1 1

In [None]:
print("test")