*Note*: this notebook is merely an example of usage of the `layer_sim` library and doesn't focus on the performance of the provided NNs nor on the accurate analysis of the resulting similarities.
The code provided is run on MNIST so that everyone may reproduce the results in a small enough amount of time on a medium-sized machine without a CUDA-capable GPU.

In [1]:
import torch
import sys
import os

sys.path.append("..") # so we can import the layer_sim library

from layer_sim import networks
from layer_sim import datasets
from layer_sim import nn_comparison
from layer_sim import preprocessing
from layer_sim.train import train_net, test_net
from layer_sim.pruning.IMP import imp_lrr

## Dataset and NN preparation

In [2]:
train_batch = 128
test_batch = 128
trainloader, testloader = datasets.MNIST("../data", train_batch, test_batch, num_workers=4)
net = networks.LeNet5(num_classes=10)

### TODO: insert image of LeNet5

## NN training and testing

In [3]:
lr_init = 0.1
weight_decay = 0.0001
momentum = 0.9
epochs = 15
device = "cpu"

In [4]:
lr_annealing_rate = 10
lr_annealing_schedule = [10, 12]

In [5]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr_init, momentum=momentum, weight_decay=weight_decay)

In [6]:
tr_loss, tr_perf = train_net(net, epochs, criterion, optimizer, trainloader, device=device, lr_annealing_factor=lr_annealing_rate, epochs_annealing=lr_annealing_schedule)

===> Epoch 1/15 ### Loss 0.3238158259967963 ### Performance 0.8953333333333333
===> Epoch 2/15 ### Loss 0.06931836195240418 ### Performance 0.9788333333333333
===> Epoch 3/15 ### Loss 0.04757271823982398 ### Performance 0.9855833333333334
===> Epoch 4/15 ### Loss 0.04082236483022571 ### Performance 0.9876833333333334
===> Epoch 5/15 ### Loss 0.032276718960702416 ### Performance 0.9903
===> Epoch 6/15 ### Loss 0.028784222680702805 ### Performance 0.9910666666666667
===> Epoch 7/15 ### Loss 0.02341511012427509 ### Performance 0.9927666666666667
===> Epoch 8/15 ### Loss 0.024698692759002248 ### Performance 0.9923833333333333
===> Epoch 9/15 ### Loss 0.023087299082769703 ### Performance 0.9929666666666667
===> Epoch 10/15 ### Loss 0.020946485105218987 ### Performance 0.9932166666666666
===> Epoch 11/15 ### Loss 0.016838115830874693 ### Performance 0.9945833333333334
LR annealed: previous 0.1, current 0.01
===> Epoch 12/15 ### Loss 0.007709789935398536 ### Performance 0.9976
===> Epoch 13/1

In [8]:
te_loss, te_perf = test_net(net, testloader, criterion, device=device)

===> TEST ### Loss 0.02399575710296631 ### Performance 0.9939


#### Optional: save the NN

In the next cell we save the state_dict along with a number of auxiliary data (train/test loss/performance) in a dictionary called `save_dict`. We save this dict in a `save_root` which we will use also as a base for the IMP checkpoints.

The `save_dict` mimics the structure of IMP's checkpoint (minus the pruning mask, which is absent in the case of the complete model).

In [9]:
save_dict = {
    "train_loss": tr_loss,
    "train_perf": tr_perf,
    "test_loss": te_loss,
    "test_perf": te_perf,
    "parameters": net.state_dict()
}

In [10]:
save_root = "../models/LeNet5"
save_name = "complete_net.pt"

In [11]:
torch.save(save_dict, os.path.join(save_root, save_name))

In [7]:
# load the NN
net.load_state_dict(torch.load(os.path.join(save_root, save_name))["parameters"])

<All keys matched successfully>

## Store models representation

In [12]:
# get dataloader for representation w/ Train set as False
reprloader, _ = datasets.MNIST("../data", 128, train=False, num_workers=4)
datapoints_repr = 500
layers_to_hook = (torch.nn.ReLU, torch.nn.AvgPool2d)
compl_repr = net.extract_network_representation(reprloader, limit_datapoints=datapoints_repr, layer_types_to_hook=layers_to_hook, device="cpu", retain_grad=True, loss_fn=criterion)

## Compare representations

In [35]:
# Load SVCCA
SVCCA_ROOT = "../../svcca"
sys.path.append(os.path.expanduser(SVCCA_ROOT))
from cca_core import get_cca_similarity

# prepare lambda fct to get scalar for mean_cca_similarity
mean_cca_sim = lambda x,y: get_cca_similarity(x,y)["mean"][0]

### Preprocess representations

#### Get kernels of representations (for CKA & NBS)

In [36]:
# define fct to get linear kernels
# linear kernel is just M M^T, where M is a matrix whose rows are datapoints and columns are the neurons
def get_linear_kernel(matrix, grad=False):
    # if matrix is more than two-dimensional, flatten the last dimensions into a single one
    if len(matrix.shape) == 4:
        matrix_2d = preprocessing.reshape_4d_tensor(matrix)
        ker = matrix_2d @ matrix_2d.T
        if grad:
            kgrad = matrix_2d.grad @ matrix_2d.grad.T
            return ker * kgrad
        return ker
    ker = matrix @ matrix.T
    if grad:
        kgrad = matrix.grad @ matrix.grad.T
        return ker * kgrad
    return ker

In [37]:
kernels_compl1 = [get_linear_kernel(r) for r in compl_repr1]
kernels_compl2 = [get_linear_kernel(r) for r in compl_repr2]

In [44]:
def preprocess_pipeline_cca(tensor, var_kept = .99):
    if len(tensor.shape) == 4:
        tensor = preprocessing.reshape_4d_tensor(tensor, True)
    tensor = preprocessing.svd_reduction(tensor, var_kept)
    return tensor.T

In [42]:
compl_repr1 = [preprocess_pipeline_cca(r) for r in compl_repr1]
compl_repr2 = [preprocess_pipeline_cca(r) for r in compl_repr2]

### Calculate similarities

In [51]:
# store measurements in tensor whose dimensions are: metric, iteration, layer
similarities = torch.zeros([2, len(compl_repr)])

for l, (layer1, layer2) in enumerate(zip(compl_repr1, compl_repr2)):
    similarities[0, l] = mean_cca_sim(compl.detach().numpy(), pruned.detach().numpy())
for l, (layer1, layer2) in enumerate(zip(kernels_compl1, kernels_compl2)):
    similarities[1, l] = nn_comparison.cka(layer1.detach(), layer2.detach())

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to di

In [53]:
similarities

tensor([[[0.9997, 0.9998, 0.9985, 0.9990, 0.9888, 0.9858, 0.9992],
         [0.9980, 0.9985, 0.9889, 0.9914, 0.9331, 0.9302, 0.9944]],

        [[0.4692, 0.3582, 0.7535, 0.6547, 0.8563, 0.7598, 0.9681],
         [0.4725, 0.3637, 0.7509, 0.6518, 0.8333, 0.7490, 0.9577]]])