Note: Much code taken from Neuromatch NeuroAI 2024 Microlearning.

## Dependencies

In [None]:
# dependencies
from IPython.display import Image, SVG, display
import os
from pathlib import Path

import random
from tqdm import tqdm
import warnings
import numpy as np
import matplotlib.pyplot as plt
import scipy
import torch
import torchvision
import contextlib
import io
from scipy.ndimage import uniform_filter1d

## Plotting and metrics imports
from metrics import get_plotting_color, plot_examples, plot_class_distribution, plot_results, plot_scores_per_class, plot_weights

## Other functions imports
from helpers import sigmoid, ReLU, add_bias, create_batches, calculate_accuracy, calculate_cosine_similarity, calculate_grad_snr

## MLP imports
from MLP import MLP, NodePerturbMLP, KolenPollackMLP

## Data

In [None]:
# Download MNIST function
def download_mnist(train_prop=0.8, keep_prop=0.5):

  valid_prop = 1 - train_prop

  discard_prop = 1 - keep_prop

  transform = torchvision.transforms.Compose(
      [torchvision.transforms.ToTensor(),
      torchvision.transforms.Normalize((0.1307,), (0.3081,))]
      )


  with contextlib.redirect_stdout(io.StringIO()): #to suppress output
    
    rng_data = np.random.default_rng(seed=42)
    train_num = 50000
    shuffled_train_idx = rng_data.permutation(train_num)

    full_train_set = torchvision.datasets.MNIST(
          root="./data/", train=True, download=True, transform=transform)
    full_test_set = torchvision.datasets.MNIST(
          root="./data/", train=False, download=True, transform=transform)
    
    full_train_images = full_train_set.data.numpy().astype(float) / 255
    train_images = full_train_images[shuffled_train_idx[:train_num]].reshape((-1, 784)).T.copy()
    valid_images = full_train_images[shuffled_train_idx[train_num:]].reshape((-1, 784)).T.copy()
    test_images = (full_test_set.data.numpy().astype(float) / 255).reshape((-1, 784)).T

    full_train_labels = torch.nn.functional.one_hot(full_train_set.targets, num_classes=10).numpy()
    train_labels = full_train_labels[shuffled_train_idx[:train_num]].T.copy()
    valid_labels = full_train_labels[shuffled_train_idx[train_num:]].T.copy()
    test_labels = torch.nn.functional.one_hot(full_test_set.targets, num_classes=10).numpy().T

    train_set, valid_set, _ = torch.utils.data.random_split(
      full_train_set, [train_prop * keep_prop, valid_prop * keep_prop, discard_prop])
    test_set, _ = torch.utils.data.random_split(
      full_test_set, [keep_prop, discard_prop])

  print("Number of examples retained:")
  print(f"  {len(train_set)} (training)")
  print(f"  {len(valid_set)} (validation)")
  print(f"  {len(test_set)} (test)")

  return train_set, valid_set, test_set, train_images, valid_images, test_images, train_labels, valid_labels, test_labels

In [None]:
train_set, valid_set, test_set, train_images, valid_images, test_images, train_labels, valid_labels, test_labels = download_mnist()

## Hyperparams

In [None]:
#HYPERPARAMETERS
NUM_INPUTS = 784
NUM_OUTPUTS = 10
numhidden = 500
batchsize = 128
initweight = 0.1
learnrate = 0.001
noise = 0.1
numepochs = 25
numrepeats = 1
numbatches = int(train_images.shape[1] / batchsize)
numupdates = numepochs * numbatches
activation = 'sigmoid'
report = True
rep_rate = 1
seed = 12345

In [None]:
# set random seed
rng_bp2 = np.random.default_rng(seed=seed)

# select 1000 random images to test the accuracy on
indices = rng_bp2.choice(range(test_images.shape[1]), size=(1000,), replace=False)

## Backprop

In [None]:
# Normal learning
netbackprop = MLP(rng_bp2, numhidden, sigma=initweight, activation=activation)
(losses_backprop, accuracy_backprop, test_loss_backprop, snr_backprop, cosine_similarity_backprop) = \
    netbackprop.train(rng_bp2, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=learnrate, batch_size=batchsize, algorithm='backprop', noise=noise, \
                      report=report, report_rate=rep_rate)

In [None]:
# Online learning
net_bp_online = MLP(rng_bp2, numhidden, sigma=initweight, activation=activation)
(losses_bp_online, accuracy_bp_online, test_loss_bp_online, snr_bp_online, cosine_similarity_bp_online) = \
    net_bp_online.train_online(rng_bp2, train_images, train_labels, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=0.01, max_it=numupdates*batchsize, conv_loss = 1e-1, algorithm='backprop', noise=noise, \
                      report=report, report_rate=batchsize)

In [None]:
# Noisy Input
# create a network and train it using backprop
netbackprop_noisy = MLP(rng_bp2, numhidden, sigma=initweight, activation=activation)
(losses_bp_noisy, accuracy_bp_noisy, test_loss_bp_noisy, snr_bp_noisy, cosine_similarity_bp_noisy) = \
    netbackprop_noisy.train(rng_bp2, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=learnrate, batch_size=batchsize, algorithm='backprop', noise=noise, \
                      noise_type='gauss',report=report, report_rate=rep_rate)

In [None]:
# Non-Stationary Data
net_bp_nonstat = MLP(rng_bp2, numhidden, sigma=initweight, activation=activation)
(losses_bp_nonstat, accuracy_bp_nonstat, test_loss_bp_nonstat, _) = \
    net_bp_nonstat.train_nonstat_data(rng_bp2, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=learnrate, batch_size=batchsize, algorithm='backprop', noise=noise, \
                      report=report, report_rate=1)

## FFA

In [None]:
# Normal learning
# create a network and train it using ffa
netffa = MLP(rng_bp2, numhidden, sigma=initweight, activation=activation)
(losses_ffa, accuracy_ffa, test_loss_ffa, _) = \
    netffa.train(rng_bp2, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=learnrate, batch_size=batchsize, algorithm='ffa', noise=noise, \
                      report=report, report_rate=rep_rate)

In [None]:
# Online Learning
net_ffa_online = MLP(rng_bp2, numhidden, sigma=initweight, activation=activation)
(losses_ffa_online, accuracy_ffa_online, test_loss_ffa_online, _) = \
    net_ffa_online.train_online(rng_bp2, train_images, train_labels, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=0.01, max_it=numupdates*batchsize, conv_loss = 1e-4, algorithm='ffa', noise=noise, \
                      report=report, report_rate=batchsize)

In [None]:
# create a network and train it using ffa
net_ffa_noisy = MLP(rng_bp2, numhidden, sigma=initweight, activation=activation)
(losses_ffa_noisy, accuracy_ffa_noisy, test_loss_ffa_noisy, _) = \
    net_ffa_noisy.train(rng_bp2, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=learnrate, batch_size=batchsize, algorithm='ffa', noise=noise, \
                      noise_type='gauss',report=report, report_rate=rep_rate)

In [None]:
# Non-Stationary Data
net_ffa_nonstat = MLP(rng_bp2, numhidden, sigma=initweight, activation=activation)
(losses_ffa_nonstat, accuracy_ffa_nonstat, test_loss_ffa_nonstat, _) = \
    net_ffa_nonstat.train_nonstat_data(rng_bp2, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=learnrate, batch_size=batchsize, algorithm='ffa', noise=noise, \
                      report=report, report_rate=1)

## Node Perturbation

In [None]:
# Normal Learning
# create a network and train it using node perturbation
with contextlib.redirect_stdout(io.StringIO()):
    netnodeperturb = NodePerturbMLP(rng_bp2, numhidden, num_inputs = 784, sigma=initweight, activation=activation)
    (losses_node_perturb, accuracy_node_perturb, test_loss_node_perturb, snr_node_perturb) = \
        netnodeperturb.train(rng_bp2, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \
                             learning_rate=learnrate, batch_size=batchsize, algorithm='node_perturb', noise=noise, report=report, report_rate=rep_rate)

In [None]:
# Online Learning
with contextlib.redirect_stdout(io.StringIO()):
    netnodeperturb_online = NodePerturbMLP(rng_bp2, numhidden, num_inputs = 784, sigma=initweight, activation=activation)
    (losses_np_online, accuracy_np_online, test_loss_np_online, _) = \
        netnodeperturb_online.train_online(rng_bp2, train_images, train_labels, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=0.01, max_it=numupdates*batchsize, conv_loss = 1e-4, algorithm='node_perturb', noise=noise, \
                      report=report, report_rate=batchsize)

In [None]:
# Noisy Input
with contextlib.redirect_stdout(io.StringIO()):
    nodeperturb_noisy = NodePerturbMLP(rng_bp2, numhidden, num_inputs = 784, sigma=initweight, activation=activation)
    (losses_np_noisy, accuracy_np_noisy, test_loss_np_noisy, _) = \
        nodeperturb_noisy.train(rng_bp2, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \
                        learning_rate=learnrate, batch_size=batchsize, algorithm='node_perturb', noise=noise, \
                        noise_type='gauss',report=report, report_rate=rep_rate)

In [None]:
# Non-Stationary Data
with contextlib.redirect_stdout(io.StringIO()):
    nodeperturb_nonstat = NodePerturbMLP(rng_bp2, numhidden, num_inputs = 784, sigma=initweight, activation=activation)
    (losses_np_nonstat, accuracy_np_nonstat, test_loss_np_nonstat, _) = \
        nodeperturb_nonstat.train_nonstat_data(rng_bp2, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \
                        learning_rate=learnrate, batch_size=batchsize, algorithm='node_perturb', noise=noise, \
                        report=report, report_rate=1)

## Kolen-Pollack

In [None]:
# Normal Learning
netkolepoll = KolenPollackMLP(rng_kp, numhidden, sigma=initweight, activation=activation)
(losses_kolepoll, accuracy_kolepoll, test_loss_kolepoll, _) = \
    netkolepoll.train(rng_kp, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=learnrate, batch_size=batchsize, algorithm='kolepoll', noise=noise, \
                      report=report, report_rate=rep_rate)

rng_bp2 = np.random.default_rng(seed=seed)

In [None]:
# Online Learning
net_kp_online = KolenPollackMLP(rng_kp, numhidden, sigma=initweight, activation=activation)
(losses_kp_online, accuracy_kp_online, test_loss_kp_online, _) = \
    net_kp_online.train_online(rng_bp2, train_images, train_labels, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=0.01, max_it=numupdates*batchsize, conv_loss = 1e-1, algorithm='kolepoll', noise=noise, \
                      report=report, report_rate=batchsize)

In [None]:
# Noisy Input
net_kp_noisy = KolenPollackMLP(rng_kp, numhidden, sigma=initweight, activation=activation)
(losses_kp_noisy, accuracy_kp_noisy, test_loss_kp_noisy, _) = \
    net_kp_noisy.train(rng_bp2, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=learnrate, batch_size=batchsize, algorithm='kolepoll', noise=noise, \
                      noise_type='gauss',report=report, report_rate=rep_rate)

In [None]:
# Non-Stationary Data
net_kp_nonstat = KolenPollackMLP(rng_kp, numhidden, sigma=initweight, activation=activation)
(losses_kp_nonstat, accuracy_kp_nonstat, test_loss_kp_nonstat, _) = \
    net_kp_nonstat.train_nonstat_data(rng_bp2, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \
                      learning_rate=learnrate, batch_size=batchsize, algorithm='kolepoll', noise=noise, \
                      report=report, report_rate=1)

## Final Plots

In [None]:
# create arrays
# normal 
tr_loss_normal = [losses_backprop, losses_ffa, losses_node_perturb, losses_kolepoll]
te_acc_normal = [accuracy_backprop, accuracy_ffa, accuracy_node_perturb, accuracy_kolepoll]
te_loss_normal = [test_loss_backprop, test_loss_ffa, test_loss_node_perturb, test_loss_kolepoll]

# online
# Calculate the moving average
# window_size = 100
# losses_online_mean = uniform_filter1d(losses_bp_online, size=window_size)
# smoothed version not yet implemented

tr_loss_online = [losses_bp_online, losses_ffa_online, losses_np_online, losses_kp_online]
te_acc_online = [accuracy_bp_online, accuracy_ffa_online, accuracy_np_online, accuracy_kp_online]
te_loss_online = [test_loss_bp_online, test_loss_ffa_online, test_loss_np_online, test_loss_kp_online]

# noisy
tr_loss_noisy = [losses_bp_noisy, losses_ffa_noisy, losses_np_noisy, losses_kp_noisy]
te_acc_noisy = [accuracy_bp_noisy, accuracy_ffa_noisy, accuracy_np_noisy, accuracy_kp_noisy]
te_loss_noisy = [test_loss_bp_noisy, test_loss_ffa_noisy, test_loss_np_noisy, test_loss_kp_noisy]

# nonstat
tr_loss_nonstat = [losses_bp_nonstat, losses_ffa_nonstat, losses_np_nonstat, losses_kp_nonstat]
te_acc_nonstat = [accuracy_bp_nonstat, accuracy_ffa_nonstat, accuracy_np_nonstat, accuracy_kp_nonstat]
te_loss_nonstat = [test_loss_bp_nonstat, test_loss_ffa_nonstat, test_loss_np_nonstat, test_loss_kp_nonstat]

# algorithms
algos = ['normal', 'online', 'noisy', 'nonstat']

In [None]:
# plotting
for i, algo in enumerate(algos):
    if algo == 'normal':
        tr_loss, te_acc, te_loss = tr_loss_normal, te_acc_normal, te_loss_normal
    elif algo == 'online':
        tr_loss, te_acc, te_loss = tr_loss_online, te_acc_online, te_loss_online
    elif algo == 'noisy':
        tr_loss, te_acc, te_loss = tr_loss_noisy, te_acc_noisy, te_loss_noisy
    elif algo == 'nonstat':
        tr_loss, te_acc, te_loss = tr_loss_nonstat, te_acc_nonstat, te_loss_nonstat
        
    # plot
    plt.figure(figsize=(18, 5))
    plt.subplot(131)

    plt.plot(tr_loss[0], label="Backprop", color='r')
    plt.plot(tr_loss[1], label="FFA", color='gold')
    plt.plot(tr_loss[2], label="Node Perturbation", color='c')
    plt.plot(tr_loss[3], label="Kolen-Pollack", color='k')
    
    plt.xlabel("Updates (every batch)")
    plt.ylabel("MSE")
    plt.legend()
    plt.title("Training loss")

    plt.subplot(132)
    plt.plot(te_acc[0], label="Backprop", color='r')
    plt.plot(te_acc[1], label="FFA", color='gold')
    plt.plot(te_acc[2], label="Node Perturbation", color='c')
    plt.plot(te_acc[3], label="Kolen-Pollack", color='k')
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.title("Performance over time")

    plt.subplot(133)
    plt.plot(te_loss[0], label="Backprop", color='r')
    plt.plot(te_loss[1], label="FFA", color='gold')
    plt.plot(te_loss[2], label="Node Perturbation", color='c')
    plt.plot(te_loss[3], label="Kolen-Pollack", color='k')
    plt.xlabel("Epochs")
    plt.ylabel("MSE")
    plt.legend()
    plt.title("Test loss")
plt.show()