*Note*: this notebook is merely an example of usage of the `layer_sim` library and doesn't focus on the performance of the provided NNs nor on the accurate analysis of the resulting similarities.
The code provided is run on MNIST so that everyone may reproduce the results in a small enough amount of time on a medium-sized machine without a CUDA-capable GPU.

In [2]:
import torch
import sys
import os

sys.path.append("..") # so we can import the layer_sim library

from layer_sim import networks
from layer_sim import datasets
from layer_sim import nn_comparison
from layer_sim import preprocessing
from layer_sim.train import train_net, test_net
from layer_sim.pruning.IMP import imp_lrr

## Dataset preparation

In [3]:
train_batch = 128
test_batch = 128
trainloader, testloader = datasets.MNIST("../data", train_batch, test_batch, num_workers=4)

## NN training and testing

In [3]:
lr_init = 0.1
weight_decay = 0.0001
momentum = 0.9
epochs = 15
device = "cpu"

reps = 2

In [5]:
criterion = torch.nn.CrossEntropyLoss()


In [6]:
save_root = "../models/LeNet5"
save_name = "trained_net_{}.pt"
for i in range(reps):
    net = networks.LeNet5(num_classes=10)
    optimizer = torch.optim.Adam(net.parameters())
    tr_loss, tr_perf = train_net(net, epochs, criterion, optimizer, trainloader, device=device)
    te_loss, te_perf = test_net(net, testloader, criterion, device)
    save_dict = {
        "train_loss": tr_loss,
        "train_perf": tr_perf,
        "test_loss": te_loss,
        "test_perf": te_perf,
        "parameters": net.state_dict()
    }
    torch.save(save_dict, os.path.join(save_root, save_name))

===> Epoch 1/15 ### Loss 0.3238158259967963 ### Performance 0.8953333333333333
===> Epoch 2/15 ### Loss 0.06931836195240418 ### Performance 0.9788333333333333
===> Epoch 3/15 ### Loss 0.04757271823982398 ### Performance 0.9855833333333334
===> Epoch 4/15 ### Loss 0.04082236483022571 ### Performance 0.9876833333333334
===> Epoch 5/15 ### Loss 0.032276718960702416 ### Performance 0.9903
===> Epoch 6/15 ### Loss 0.028784222680702805 ### Performance 0.9910666666666667
===> Epoch 7/15 ### Loss 0.02341511012427509 ### Performance 0.9927666666666667
===> Epoch 8/15 ### Loss 0.024698692759002248 ### Performance 0.9923833333333333
===> Epoch 9/15 ### Loss 0.023087299082769703 ### Performance 0.9929666666666667
===> Epoch 10/15 ### Loss 0.020946485105218987 ### Performance 0.9932166666666666
===> Epoch 11/15 ### Loss 0.016838115830874693 ### Performance 0.9945833333333334
LR annealed: previous 0.1, current 0.01
===> Epoch 12/15 ### Loss 0.007709789935398536 ### Performance 0.9976
===> Epoch 13/1

## Store complete model's representation

In [6]:
# get dataloader for representation w/ Train set as False
reprloader, _ = datasets.MNIST("../data", 128, train=False, num_workers=4)
datapoints_repr = 500
layers_to_hook = (torch.nn.ReLU, torch.nn.AvgPool2d)

net = networks.LeNet5(num_classes=10)
net.load_state_dict(torch.load("../models/LeNet5/trained_net_0.pt")["parameters"])
net_repr1 = net.extract_network_representation(reprloader, limit_datapoints=datapoints_repr, layer_types_to_hook=layers_to_hook, device="cpu")
net.load_state_dict(torch.load("../models/LeNet5/trained_net_1.pt")["parameters"])
net_repr2 = net.extract_network_representation(reprloader, limit_datapoints=datapoints_repr, layer_types_to_hook=layers_to_hook, device="cpu")

## Compare representations

In [10]:
# Load SVCCA
SVCCA_ROOT = "C:\\Users\\mzullich\\Documents\\svcca"
sys.path.append(os.path.expanduser(SVCCA_ROOT))
from cca_core import get_cca_similarity

# prepare lambda fct to get scalar for mean_cca_similarity
mean_cca_sim = lambda x,y: get_cca_similarity(x,y)["mean"][0]

### Preprocess representations

In [11]:
# define fct to get linear kernels
# linear kernel is just M M^T, where M is a matrix whose rows are datapoints and columns are the neurons
def get_linear_kernel(matrix):
    # if matrix is more than two-dimensional, flatten the last dimensions into a single one
    if len(matrix.shape) == 4:
        matrix_2d = preprocessing.reshape_4d_tensor(matrix)
        return matrix_2d @ matrix_2d.T
    return matrix @ matrix.T

In [13]:
kernels_1 = [get_linear_kernel(r) for r in net_repr1]
kernels_2 = [get_linear_kernel(r) for r in net_repr2]


In [15]:
kernels_1[0].shape

torch.Size([500, 500])

In [16]:
def preprocess_pipeline_cca(tensor, var_kept = .99):
    if len(tensor.shape) == 4:
        tensor = preprocessing.reshape_4d_tensor(tensor, True)
    tensor = preprocessing.svd_reduction(tensor, var_kept)
    return tensor.T

In [17]:
preprocessed_repr1 = [preprocess_pipeline_cca(r) for r in net_repr1]
preprocessed_repr2 = [preprocess_pipeline_cca(r) for r in net_repr2]


### Calculate similarities

In [18]:
# store measurements in tensor whose dimensions are: metric, iteration, layer
similarities = torch.zeros([2, len(preprocessed_repr1)])

for l, (r1, r2) in enumerate(zip(preprocessed_repr1, preprocessed_repr2)):
    similarities[0, l] = mean_cca_sim(r1.detach().numpy(), r2.detach().numpy())
for l, (r1, r2) in enumerate(zip(kernels_1, kernels_2)):
    similarities[1, l] = nn_comparison.cka(r1.detach(), r2.detach())

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!


In [19]:
similarities

tensor([[0.7036, 0.7124, 0.6040, 0.7061, 0.7852, 0.7711, 0.9486],
        [0.9557, 0.9576, 0.9430, 0.9324, 0.9384, 0.9489, 0.9330]])

### Homework

_NB: (§§§) indicates a hard exercise, (§§) a moderately hard exercise_

Reproduce the "Sanity Check for Similarity Indexes" from page 6 of [Similarity of Neural Network Representations Revisited, Kornblith et al.](https://arxiv.org/abs/1905.00414) for the case of Multilayer Perceptrons (MLPs).

1. Start from a MLP with an architecture of your choice. 

  a. _(extra 1) The architecture must be such that it reaches 98% of test-set accuracy on average_   
    * _Test with an appropriate statistic that this threshold is reached_
        
2. (§§) Build a function to extract representations from each layer *after* the application of its activation function
3. Operate a pairwise layer comparison **for each layer in the architecture** at least for 2 parameters sets
    a. Use both CKA and SVCCA
4. (§§§) _(extra 2) Fix the `layer_sim` library such that it is possible to retain the gradient of the representations (you'll need to call `backward` inside the routine for building representations. You can do it either on the loss or check [Similarity of Neural Networks with Gradients](https://arxiv.org/abs/2003.11498) for additional tricks) and implement CKA with the incorporation of gradient flow._

    a. _See how this metric compares to *vanilla* CKA_