### This code illustrates the fast AI implementation of the unsupervised "biological" learning algorithm from [Unsupervised Learning by Competing Hidden Units](https://doi.org/10.1073/pnas.1820458116) on MNIST data set. 
If you want to learn more about this work you can also check out this [lecture](https://www.youtube.com/watch?v=4lY-oAY0aQU) from MIT's [6.S191 course](http://introtodeeplearning.com/). 

This cell loads the data and normalizes it to the [0,1] range

In [None]:
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
eps0=4e-2    # learning rate
Kx=10
Ky=10
hid_disp=Kx*Ky    # number of hidden units that are displayed in Ky by Kx array
N_hid = 1000  # number of hidden units
mu=0.0
sigma=1.0
Nep=1000      # number of epochs
N_batch=100      # size of the minibatch
prec=1e-30
delta=0.4    # Strength of the anti-hebbian learning
p=2.0        # Lebesgue norm of the weights
k=2          # ranking parameter, must be integer that is bigger or equal than 2

Nc = 10
N_in = 784
Nc = 10
val_split = 1/6

In [None]:
def shuffle(xt, yt, xv, yv):
    pt = np.random.permutation(len(xt))
    pv = np.random.permutation(len(xv))
    return(xt[pt], yt[pt], xv[pv], yv[pv])

## Data
Let's look at the `mat` matrix

In [None]:
# Cell derived from this page: https://keras.io/examples/vision/mnist_convnet/

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], x_train.shape[1]* x_train.shape[2])
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1]* x_test.shape[2])

x_train, y_train, x_test, y_test = shuffle(x_train, y_train, x_test, y_test)

x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

val_idx = np.random.choice(x_train.shape[0], int(val_split * x_train.shape[0]), replace=False)

x_val = x_train[val_idx]
y_val = y_train[val_idx]
x_train = np.delete(x_train, val_idx, axis=0)
y_train = np.delete(y_train, val_idx, axis=0)

# Make sure images have shape (28, 28, 1)
# x_train = np.expand_dims(x_train, -1)
# x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print("y_train shape:", y_train.shape)
print("x_val shape:", x_val.shape)
print("y_val shape:", y_val.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, Nc)
y_val = keras.utils.to_categorical(y_val, Nc)
y_test = keras.utils.to_categorical(y_test, Nc)

In [None]:
N_train = x_train.shape[0]
N_val = x_val.shape[0]
N_test = x_test.shape[0]

To draw a heatmap of the weights a helper function is created:

In [None]:
def draw_weights(synapses, Kx, Ky):
    yy=0
    HM=np.zeros((28*Ky,28*Kx))
    for y in range(Ky):
        for x in range(Kx):
            HM[y*28:(y+1)*28,x*28:(x+1)*28]=synapses[yy,:].reshape(28,28)
            yy += 1
    plt.clf()
    nc=np.amax(np.absolute(HM))
    im=plt.imshow(HM,cmap='bwr',vmin=-nc,vmax=nc)
    fig.colorbar(im,ticks=[np.amin(HM), 0, np.amax(HM)])
    plt.axis('off')
    fig.canvas.draw()   
    

This cell defines paramaters of the algorithm: `eps0` - initial learning rate that is linearly annealed during training; `hid` - number of hidden units that are displayed as an `Ky` by `Kx` array by the helper function defined above; `mu` - the mean of the gaussian distribution that initializes the weights; `sigma` - the standard deviation of that gaussian; `Nep` - number of epochs; `Num` - size of the minibatch; `prec` - parameter that controls numerical precision of the weight updates; `delta` - the strength of the anti-hebbian learning; `p` - Lebesgue norm of the weights; `k` - ranking parameter. 

This cell defines the main code. The external loop runs over epochs `nep`, the internal loop runs over minibatches. For every minibatch the overlap with the data `tot_input` is calculated for each data point and each hidden unit. The sorted strengths of the activations are stored in `y`. The variable `yl` stores the activations of the post synaptic cells - it is denoted by g(Q) in Eq 3 of [Unsupervised Learning by Competing Hidden Units](https://doi.org/10.1073/pnas.1820458116), see also Eq 9 and Eq 10. The variable `ds` is the right hand side of Eq 3. The weights are updated after each minibatch in a way so that the largest update is equal to the learning rate `eps` at that epoch. The weights are displayed by the helper function after each epoch. 

In [None]:
def forward(inputs, synapses, p, N_hid, N_batch, training=True):
    inputs = np.transpose(inputs)
    sig=np.sign(synapses)
    tot_input=np.dot(sig*np.absolute(synapses)**(p-1),inputs) # with p=2, this is equal to <W.v> = I 
    y=np.argsort(tot_input,axis=0) # using tot_input (I) as proxy for h

    if training == False:
        return tot_input.T
    
    # TODO: compute h...
    
    yl=np.zeros((N_hid, N_batch)) # y1 = g(Q)
    yl[y[N_hid-1],np.arange(N_batch)]=1.0 # g(max_activation in I) = 1
    yl[y[N_hid-k],np.arange(N_batch)]=-delta # g(second max activation) = -0.4
#     if training == False:
#         return yl.T
    xx=np.sum(np.multiply(yl,tot_input),1) # g(Q) x <W, v>
    ds=np.dot(yl,np.transpose(inputs)) - np.multiply(np.tile(xx.reshape(xx.shape[0],1),(1,N_in)),synapses)
    # g(Q) (v_i - <W,v> W_i)
    nc=np.amax(np.absolute(ds))
    return ds, nc

In [None]:
%matplotlib inline
%matplotlib notebook
fig=plt.figure(figsize=(12.9,10))

synapses = np.random.normal(mu, sigma, (N_hid, N_in)) # W
print(synapses.shape)
for nep in range(Nep):
    eps=eps0*(1-nep/Nep)
#     M=M[np.random.permutation(N_train),:]
    for i in range(N_train//N_batch):
        inputs=x_train[i*N_batch:(i+1)*N_batch,:] # v_i 
        ds, nc = forward(inputs, synapses, p, N_hid, N_batch)
        if nc<prec:
            nc=prec
        synapses += eps*np.true_divide(ds,nc)
        
    if nep%20 == 0:    
        print('epoch ' + str(nep))
    draw_weights(synapses, Kx, Ky)
        

In [None]:
np.save('synapses_hid1000_epoch1000_eps4e-2_p2_k2_batch100.npy', synapses)
# synapses = np.load('synapses.npy')


In [None]:
%matplotlib inline
%matplotlib notebook
fig=plt.figure()
draw_weights(synapses, Kx, Ky)

In [None]:
x_train0 = x_train - x_train.mean(axis=1, keepdims=True) 
x_val0 = x_val - x_val.mean(axis=1, keepdims=True)
x_hid_train = forward(x_train0/np.linalg.norm(x_train0, ord=2, axis=1, keepdims=True), synapses, p, N_hid, N_train, training=False)
x_hid_val = forward(x_val0/np.linalg.norm(x_val0, axis=1, ord=2, keepdims=True), synapses, p, N_hid, N_val, training=False)

# x_hid_train /= np.linalg.norm(x_hid_train, ord=2, axis=1, keepdims=True)
# x_hid_val /= np.linalg.norm(x_hid_train, ord=2, axis=1, keepdims=True)


# x_hid_train = forward(x_train, synapses, p, N_hid, N_train, training=False)
# x_hid_val = forward(x_val, synapses, p, N_hid, N_val, training=False)
# x_hid_test = forward(x_test, synapses, p, N_hid, N_test, training=False)

print("x_hid_train shape:", x_hid_train.shape)
print("x_hid_val shape:", x_hid_val.shape)
# print("x_hid_test shape:", x_hid_test.shape)
print(x_hid_val.max())

n=4.5
x_hid_train = (x_hid_train * (x_hid_train>0)) ** n
x_hid_val = (x_hid_val * (x_hid_val>0)) ** n

In [None]:
n=4.5
x_hid_train = (x_hid_train * (x_hid_train>0)) ** n
x_hid_val = (x_hid_val * (x_hid_val>0)) ** n
# x_hid_test = (x_hid_test * (x_hid_test>0)) ** n

In [None]:
bio_model = keras.Sequential([
    layers.Input(shape=(N_hid,)),
#     layers.Activation('elu'),
#     layers.BatchNormalization(),
    layers.Dense(Nc),
#     layers.BatchNormalization(),
    layers.Activation('softmax')
])
# print(bio_model.summary())
lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=1e-1, decay_steps=Nep*20, decay_rate=0.7, staircase=True)
opt = keras.optimizers.Adam(learning_rate=lr_schedule)
bio_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
bio_logs = bio_model.fit(x_hid_train, y_train, validation_data=(x_hid_val, y_val), batch_size=100, epochs=300, verbose=1)

In [None]:
bp_model = keras.Sequential([
    layers.Input(shape=(N_in,)),
    layers.Dense(N_hid),
#     layers.BatchNormalization(),
    layers.Activation("relu"),
    layers.Dense(Nc, activation="softmax")
])
bp_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
bp_logs = bp_model.fit(x_train, y_train, validation_data=(x_val, y_val), batch_size=100, epochs=100)

In [None]:
bio_history = bio_logs.history
bp_history = bp_logs.history

plt.figure(figsize=(10, 5))
plt.plot(bio_history['loss'])
plt.plot(bio_history['val_loss'])

plt.plot(bp_history['loss'])
plt.plot(bp_history['val_loss'])

plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
# plt.ylim(0, 1)
plt.xlim(0, 50)
plt.legend(['Train_BIO', 'Validation_BIO', 'Train_BP', 'Validation_BP'], loc='upper left')
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(bio_history['accuracy'])
plt.plot(bio_history['val_accuracy'])

plt.plot(bp_history['accuracy'])
plt.plot(bp_history['val_accuracy'])

plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.ylim(.9, 1)
plt.xlim(0, 50)
plt.legend(['Train_BIO', 'Validation_BIO', 'Train_BP', 'Validation_BP'], loc='upper left')
plt.show()

### Control Test: A model with only one output layer
If a model with only one output layer performs on paar with BIO, this means the images themselves are as representative as the biological hidden neurons.

In [None]:
control_model = keras.Sequential([
    layers.Input(shape=(N_in)),
    layers.BatchNormalization(),
    layers.Dense(Nc, activation='softmax')
])
control_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
control_logs = control_model.fit(x_train0, y_train, validation_data=(x_val0, y_val), batch_size=100, epochs=100)

In [None]:
print("Final Training Accuracy:", control_logs.history['accuracy'][-1])
print("Final Validation Accuracy:", control_logs.history['val_accuracy'][-1])

### short-term 
- load data from keras (MNIST)
- add a softmax layer to train the netwok (SGD, ...)
- get the network run on a simple validation set
- get the accuracy ...

### long-term
- get the result from the bio network
- transfer to gpu?
- train a "usual" net
- compare the results 
- visualize the weights in the backprop-based network
- do the bio computation on gpu?
- ...