In this demo, we will train a fully-connected network on MNIST for a 10 epochs using the Burstprop learning rule with updating feedback weights, and observe the weight alignment before and after training.

## Import dependencies

In [1]:
import os
import math

import torch
import torch.nn.functional as F

from plotly.subplots import make_subplots
import plotly.graph_objects as go

## Setting the hyperparameters

In [34]:
directory = "./Demo"

n_epochs = 10             # Let's train for 10 epochs
batch_size = 32           # Set the batch size
n_hidden_layers = 1       # We'll use one hidden layer
weight_decay = 1e-6       # Set weight decay
momentum = 0.9            # Set momentum
hidden_lr = 1.0           # Set the learning rate for the hidden layer
output_lr = 0.01          # Set the learning rate for the output layer
weight_fa_learning = True # We'll use the Kolen-Pollack learning rule for feedback weights
weight_fa_std = 0.1       # Set standard deviation for initialization of feedback weights

## Running the training script

We'll use Jupyter's magic %run command to run the main script:

In [None]:
%run train_mnist.py {directory} -momentum {momentum} -n_epochs {n_epochs} -weight_decay {weight_decay} -n_hidden_layers {n_hidden_layers} -hidden_lr {hidden_lr} -weight_fa_std {weight_fa_std} -batch_size {batch_size} -weight_fa_learning {weight_fa_learning} -validation True

Test Loss: 0.488 | Acc: 14.210% (1421/10000): 100%|██████████| 100/100 [00:02<00:00, 33.98it/s]
  0%|          | 0/1563 [00:00<?, ?it/s]


Epoch 1.


Train Loss: 0.133 | Acc: 12.884% (1645/12768):  25%|██▌       | 397/1563 [00:16<00:50, 23.05it/s]

## Loading the model

In [69]:
net_pre = torch.load(os.path.join(directory, "initial_model.pth")) # Before training
net = torch.load(os.path.join(directory, "model.pth"))             # After 10 epochs

## Compare weight alignment before and after training

In [59]:
weight_fa_pre = net_pre.classification_layers[1].weight_fa
weight_pre = net_pre.classification_layers[1].weight

weight_fa = net.classification_layers[1].weight_fa
weight = net.classification_layers[1].weight

In [68]:
# Plot sample weights before and after training
fig = make_subplots(rows=1, cols=2)

fig.add_trace(go.Scatter(x=weight_pre.flatten()[:500], y=weight_fa_pre.flatten()[:500], mode='markers', name='Before training'), row=1, col=1)
fig.add_trace(go.Scatter(x=weight.flatten()[:500], y=weight_fa.flatten()[:500], mode='markers', name=f'After {n_epochs} epochs'), row=1, col=2)
fig.update_layout(xaxis_title="Feedforward weight", yaxis_title="Feedback weight", title=f"Hidden layer feedforward vs. feedback weights, before training and after {n_epochs} epochs")
fig.show()

In [63]:
def weight_alignment(weight, weight_fa):
    """
    Given feedforward and feedback weights, compute the
    angle between the weight vectors.
    
    Arguments:
    weight: Feedforward weight tensor.
    weight_fa: Feedback weight tensor.
    
    Returns:
    Angle between the two weigh
    """
    
    return (180/math.pi)*torch.acos(F.cosine_similarity(weight.flatten(), weight_fa.flatten(), dim=0))

In [66]:
# Compare angles between weights pre- and post-training

print(f"Before training: {weight_alignment(weight_pre, weight_fa_pre)}")
print(f"After {n_epochs} epochs: {weight_alignment(weight, weight_fa)}")

Before training: 90.62793731689453
After 10 epochs: 67.65705108642578
