## This code illustrates the fast AI implementation of the unsupervised "biological" learning algorithm from [Unsupervised Learning by Competing Hidden Units](https://doi.org/10.1073/pnas.1820458116) on CIFAR-10 data set. 
If you want to learn more about this work you can also check out this [lecture](https://www.youtube.com/watch?v=4lY-oAY0aQU) from MIT's [6.S191 course](http://introtodeeplearning.com/). 

## Imports and Parameters

In [None]:
import time
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

This cell defines paramaters of the algorithm: 
- `eps0`: initial learning rate that is linearly annealed during training
- `N_hid`: number of hidden units that are displayed as an `Ky` by `Kx` array by the `draw_weights` function defined below
- `mu`: the mean of the gaussian distribution that initializes the weights
- `sigma`: the standard deviation of that gaussian 
- `Nep`: number of epochs
- `N_batch`: size of the minibatch
- `prec`: parameter that controls numerical precision of the weight updates
- `delta`: the strength of the anti-hebbian learning
- `p`: Lebesgue norm of the weights; `k` - ranking parameter. 
- `N_in`: number of input units
- `val_split`: the proportion of validation split

In [None]:
eps0 = 2e-2    # initial learning rate for the unsupervised part
Kx = 5
Ky = 5
hid_disp = Kx*Ky    # number of hidden units that are displayed in Ky by Kx array
N_hid = 1000    # number of hidden units
mu = 0.0
sigma = 1.0
Nep = 1000     # number of epochs
N_batch = 1000  # size of the minibatch
prec = 1e-30
delta = 0.1    # Strength of the anti-hebbian learning
p = 2.0        # Lebesgue norm of the weights
k = 2          # ranking parameter, must be integer that is bigger or equal than 2

N_in = 3072     # input units
Nc = 10        # number of classes
val_split = 1/10

In [None]:
# a utility function to shuffle training and testing samples
def shuffle(xt, yt, xv, yv):
    
    # xt, yt: training samples and labels
    # xv, yv: validation samples and labels
    
    pt = np.random.permutation(len(xt))
    pv = np.random.permutation(len(xv))
    return(xt[pt], yt[pt], xv[pv], yv[pv])

## Data
Let's load the data from `keras`:
- Use `keras.datasets.cifar10.load_data()` to load the 45000 train and 5000 test samples.
- Reshape images to (#samples, 3x32x32) and labels to (#samples, 10)
- Split the train part to 45000 and 5000 validation samples

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

# reshape to (# samples, 784)
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1] * x_train.shape[2] * x_train.shape[3])
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1] * x_test.shape[2] * x_test.shape[3])

x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# shuffle
x_train, y_train, x_test, y_test = shuffle(x_train, y_train, x_test, y_test)

# indices of validation splilt
val_idx = np.random.choice(x_train.shape[0], int(val_split * x_train.shape[0]), replace=False)

#split validation
x_val = x_train[val_idx]
y_val = y_train[val_idx]
x_train = np.delete(x_train, val_idx, axis=0)
y_train = np.delete(y_train, val_idx, axis=0)

# some logging
print("x_train shape:", x_train.shape)
print("y_train shape:", y_train.shape)
print("x_val shape:", x_val.shape)
print("y_val shape:", y_val.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, Nc)
y_val = keras.utils.to_categorical(y_val, Nc)
y_test = keras.utils.to_categorical(y_test, Nc)

In [None]:
N_train = x_train.shape[0]
N_val = x_val.shape[0]
N_test = x_test.shape[0]

## Weights visualization
To draw a heatmap of the weights a helper function is created:

In [None]:
def draw_weights(synapses, Kx, Ky, ax=None):
    yy = 0
    HM = np.zeros((32*Ky, 32*Kx, 3))
    for y in range(Ky):
        for x in range(Kx):
            synapse_tmp = np.zeros((32, 32, 3))
            synapse_tmp[:, :, 0] = synapses[yy, 0:1024].reshape(32, 32)
            synapse_tmp[:, :, 1] = synapses[yy, 1024:2048].reshape(32, 32)
            synapse_tmp[:, :, 2] = synapses[yy, 2048:3072].reshape(32, 32)
            min_synapse_tmp = np.amin(synapse_tmp)
            if min_synapse_tmp < 0:
                synapse_tmp -= min_synapse_tmp
            ratio = 255/np.amax(synapse_tmp)
            synapse_tmp *= ratio
            HM[y * 32:(y + 1) * 32, x * 32:(x + 1) * 32, :] = synapse_tmp
            yy += 1
    
    nc=np.amax(np.absolute(HM))
    HM = HM.astype(np.uint8)
    
    if ax is not None:
        im = ax.imshow(HM, cmap='bwr', vmin=-nc, vmax = nc)
        fig.colorbar(im, ticks=[np.amin(HM), 0, np.amax(HM)], ax=ax)
        ax.axis('off')
    else:
        print(HM.shape)
        plt.clf()
        im=plt.imshow(HM,cmap='bwr',vmin=-nc,vmax=nc)
        fig.colorbar(im,ticks=[np.amin(HM), 0, np.amax(HM)])
        plt.axis('off')
    fig.canvas.draw() 

## Forward pass of the unsupervised part
This is one step of the training and returns the changes that should be applied to the weights in one iteration. 

**Note:** In the testing phase, we don't need to compute `h` or `g` functions, so we just return the input currents.
- Given the inputs and the synapses (weights), compute the input currents `tot_input` = <W.v>
    - This is referred to `I` in the paper.
- If we're in the testing phase, return `tot_input`
- If we're in the training phase, go on to use the rankings in the input currents as proxies for the final `h` and return the changes that should be application 

In [None]:
def forward(inputs, synapses, p, N_hid, N_batch, training=True):
    inputs = np.transpose(inputs)
    sig=np.sign(synapses)
    tot_input=np.dot(sig*np.absolute(synapses)**(p-1),inputs) # with p=2, this is equal to <W.v> = I 
    

    if training == False:
        return tot_input.T
    
    # TODO: compute h...
    
    y=np.argsort(tot_input,axis=0) # using tot_input (I) as proxy for h
    yl=np.zeros((N_hid, N_batch)) # y1 = g(Q)
    yl[y[N_hid-1],np.arange(N_batch)]=1.0 # g(max_activation in I) = 1
    yl[y[N_hid-k],np.arange(N_batch)]=-delta # g(second max activation) = -0.4
#     if training == False:
#         return yl.T
    xx=np.sum(np.multiply(yl,tot_input),1) # g(Q) x <W, v>
    ds=np.dot(yl,np.transpose(inputs)) - np.multiply(np.tile(xx.reshape(xx.shape[0],1),(1,N_in)),synapses)
    # g(Q) (v_i - <W,v> W_i)
    nc=np.amax(np.absolute(ds))
    return ds, nc

## The unsupervised training of the hidden layer
This cell defines the main code. The external loop runs over epochs `nep`, the internal loop runs over minibatches. For every minibatch the overlap with the data `tot_input` is calculated for each data point and each hidden unit. The sorted strengths of the activations are stored in `y`. The variable `yl` stores the activations of the post synaptic cells - it is denoted by g(Q) in Eq 3 of [Unsupervised Learning by Competing Hidden Units](https://doi.org/10.1073/pnas.1820458116), see also Eq 9 and Eq 10. The variable `ds` is the right hand side of Eq 3. The weights are updated after each minibatch in a way so that the largest update is equal to the learning rate `eps` at that epoch. The weights are displayed by the helper function after each epoch. 

### Note
This will take some time, especially if the hidden layer is big (e.g. N_hid>500). If you want to see a demo, skip to the next cell to load the pre-trained weights.

If you do train the model yourself, make sure to comment the next cell so you don't replace the your trained weights with the pre-trained ones!

In [None]:
%matplotlib inline
%matplotlib notebook
fig=plt.figure(figsize=(12.9,10))

start  = time.time()

synapses = np.random.normal(mu, sigma, (N_hid, N_in)) # W
print(synapses.shape)
for nep in range(Nep):
    eps=eps0*(1-nep/Nep)
#     M=M[np.random.permutation(N_train),:]
    for i in range(N_train//N_batch):
        inputs=x_train[i*N_batch:(i+1)*N_batch,:] # v_i 
        ds, nc = forward(inputs, synapses, p, N_hid, N_batch)
        if nc<prec:
            nc=prec
        synapses += eps*np.true_divide(ds,nc)
        
    if nep%2 == 0:    
        print('epoch ' nep)
    draw_weights(synapses, Kx, Ky)
    
print('took {:.4f}s to run {} epochs'.format(time.time() - start, Nep))

## Saving or loading synapses
- If you want to save your own synapses from previous cell, uncomment and edit the first line
- If you want to load a pre-trained model, use the `np.load` function in the second line

A sample `synapses.npy` file can be found [here](https://drive.google.com/file/d/1b-2FNqVQ1fN-eixeTO9WZin9eNl2gqou/view?usp=sharing)


In [None]:
# save your own model
np.save('CIFAR10_synapses_hid100_epoch500_eps2e-2_p2_k2_batch10.npy', synapses)

In [13]:
# load a pre-trained model
synapses = np.load('CIFAR10_synapses_hid1000_epoch1000_eps2e-2_p2_k2_batch1000.npy')

In [14]:
%matplotlib inline
%matplotlib notebook
fig=plt.figure(figsize=(10, 7))
draw_weights(synapses, Kx, Ky)

<IPython.core.display.Javascript object>

(160, 160, 3)


In [None]:
def normalize(x):
    x_norm = x - x.mean(axis=1, keepdims=True)
    x_norm = x_norm / np.linalg.norm(x_norm, ord=2, axis=1, keepdims=True)
    return x_norm

## Testing: compute hidden neurons for train, val, and test data
This will be the input to the second layer, which will be trained with SGD to do digit recognition.

In [None]:
# Center and normalize input data to unit vectors  
x_train0 = normalize(x_train)
x_val0 = normalize(x_val)
x_test0 = normalize(x_test)

x_hid_train = forward(x_train0, synapses, p, N_hid, N_train, training=False)
x_hid_val = forward(x_val0, synapses, p, N_hid, N_val, training=False)
x_hid_test = forward(x_test0, synapses, p, N_hid, N_val, training=False)

n=1.5
x_hid_train = (x_hid_train * (x_hid_train>0)) ** n
x_hid_val = (x_hid_val * (x_hid_val>0)) ** n

## Define the BIO model
- Input is the pre-computed hidden activations.
- Output layer is a fully connected layer with `softmax activation`
- Loss function is the `categorical cross-entropy`
- Optimizer is `Adam` with initial learning rate of 0.1 which decays exponentially every `20` epochs with a decay rate of `0.5`
- Batch size is `10` and number of epochs is `500`

Set `verbose=1` in the `model.fit()` function to see the loss and accuracy values after each epoch.

In [None]:
bio_model = keras.Sequential([
    layers.Input(shape=(N_hid,)),
#     layers.Activation('elu'),
#     layers.BatchNormalization(),
    layers.Dense(Nc),
#     layers.BatchNormalization(),
    layers.Activation('softmax')
])

lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=4e-3, decay_steps=Nep*50, decay_rate=0.5, staircase=True)
opt = keras.optimizers.Adam(learning_rate=lr_schedule)
bio_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
bio_logs = bio_model.fit(x_hid_train, y_train, validation_data=(x_hid_val, y_val), batch_size=10, epochs=100, verbose=1)

## Define the backprop model
- One fully-connected hidden layer with the same number of neurons as the BIO model, i.e. `N_hid`
    - Activation is 'relu'
- Another fully-connected output layer with softmax activation
- Same loss and optimizer as BIO

In [None]:
bp_model = keras.Sequential([
    layers.Input(shape=(N_in,)),
    layers.Dense(N_hid),
#     layers.BatchNormalization(),
    layers.Activation("relu"),
    layers.Dense(Nc, activation="softmax")
])

bp_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
bp_logs = bp_model.fit(x_train, y_train, validation_data=(x_val, y_val), batch_size=10, epochs=100, verbose=1)

## Plot losses and accuracies

In [None]:
bio_history = bio_logs.history
bp_history = bp_logs.history

plt.figure(figsize=(10, 5))
plt.plot(bio_history['loss'])
plt.plot(bio_history['val_loss'])

plt.plot(bp_history['loss'])
plt.plot(bp_history['val_loss'])

plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
# plt.ylim(0, 1)
plt.xlim(0, 100)
plt.legend(['Train_BIO', 'Validation_BIO', 'Train_BP', 'Validation_BP'], loc='upper right')
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(bio_history['acc'])
plt.plot(bio_history['val_acc'])

plt.plot(bp_history['acc'])
plt.plot(bp_history['val_acc'])

plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.ylim(0, 1)
plt.xlim(0, 100)
plt.legend(['Train_BIO', 'Validation_BIO', 'Train_BP', 'Validation_BP'], loc='lower right')
plt.show()

## Plot weights of the BP model

In [None]:
# %matplotlib inline
# %matplotlib notebook

fig, ax = plt.subplots(1, 2, figsize=(13, 5))
ax[0].set_title('BP')
ax[1].set_title('BIO')
bp_weights = bp_model.layers[0].get_weights()[0].T
draw_weights(bp_weights, Kx, Ky, ax[0])
draw_weights(synapses, Kx, Ky, ax[1])
plt.tight_layout()
plt.show()

### Control Test: A model with only one output layer
If a model with only one output layer performs on paar with BIO, this means the images themselves are as representative as the biological hidden neurons.

In [None]:
control_model = keras.Sequential([
    layers.Input(shape=(N_in)),
    layers.BatchNormalization(),
    layers.Dense(Nc, activation='softmax')
])
control_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
control_logs = control_model.fit(x_train0, y_train, validation_data=(x_val0, y_val), batch_size=100, epochs=100, verbose=0)

In [None]:
print("Final Training Accuracy:", control_logs.history['acc'][-1])
print("Final Validation Accuracy:", control_logs.history['val_acc'][-1])

## GPU implementation for the unsupervised training of the hidden layer

In [None]:
import torch
import time

use_cuda = torch.cuda.is_available()
torch.cuda.empty_cache()
device = torch.device("cuda" if use_cuda else "cpu")

# move train data to cuda
train_data = torch.from_numpy(x_train).float().to(device)

start = time.time()
synapses = (torch.randn(N_hid, N_in) * sigma + mu).to(device)
for nep in range(Nep):
    eps=eps0*(1-nep/Nep)
    train_data = train_data[np.random.permutation(N_train), :]
    for i in range(N_train//N_batch):
        inputs = torch.transpose(train_data[i*N_batch:(i+1)*N_batch, :], 0, 1).to(device)
        sig = torch.sign(synapses).to(device)
        tot_input = torch.matmul(sig*torch.abs(synapses).pow_(p-1), inputs).to(device)

        y = torch.argsort(tot_input, dim=0).to(device)
        y1 = torch.zeros((N_hid, N_batch)).to(device)
        tmp = y[N_hid - 1, :]
        y1[y[N_hid-1, :], np.arange(N_batch)] = 1.0
        y1[y[N_hid-k], np.arange(N_batch)] = -delta
        xx = torch.sum(torch.mul(y1, tot_input), 1).to(device)

        ds = torch.matmul(y1, torch.transpose(inputs, 0, 1)) - torch.mul(xx.reshape(xx.shape[0],1).repeat(1, N_in), synapses).to(device)
        nc = torch.max(torch.abs(ds))
        if nc < prec:
            nc = prec
        synapses += eps * torch.div(ds, nc)

    if (nep % 2) == 0:
        print('epoch ' + nep)
    

print('took {:.4f}s to run {} epochs'.format(time.time() - start, Nep))

In [None]:
synapses_gpu = synapses.to("cpu")
fig=plt.figure(figsize=(10, 7))
draw_weights(synapses_gpu, Kx, Ky)