In [1]:
%reload_ext autoreload
%autoreload 2

%matplotlib qt

import time
from tqdm import tqdm
import numpy as np
import scipy as sp
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.transforms import v2 as transforms

from matplotlib import pyplot as plt

from networkAlignmentAnalysis.models.registry import get_model
from networkAlignmentAnalysis.datasets import get_dataset
from networkAlignmentAnalysis.experiments.registry import get_experiment
from networkAlignmentAnalysis import utils
from networkAlignmentAnalysis import files
from networkAlignmentAnalysis import train

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using device: ', DEVICE)

using device:  cuda


In [3]:
# TODO
# 1.1. include additional AlignmentModel methods stored in extra class in base model

# 4. Rewrite existing analysis pipelines
# 5. SLURM!!!!

# Figure out why convolutional alignment measurement is slow...
# still working on if it's possible to speed up measure_alignment for convolutional layers

# Basic alignment_comparison Analyses (or maybe for alignment_stats):
# - compare initial to final alignment...
# - compare initial alignment to delta weight norm...
# - observe alignment of delta weight
# - compare alignment to outgoing delta weight norm!

# Eigenfeature analyses:
# done: - start by just looking at amplitude of activity on each eigenvector within each layer
# - Determine contribution of each eigenfeature on performance with a eigenvector dropout experiment
# - Measure beta_adversarial (figure out how adversarial examples map onto eigenvectors)

# alignmentShaping.ipynb has an adversarial experiment worth looking at

# Consider Valentin's idea about measuring an error threshold given signal and noise for a given level of alignment
# e.g. plot a 2d heatmap comparing the noise amplitude and the average alignment
# and then think about how to apply this to network design...

In [2]:
model_name = 'CNN2P2'
dataset_name = 'MNIST'

net = get_model(model_name, build=True, dataset=dataset_name).to(DEVICE)
dataset = get_dataset(dataset_name, build=True, transform_parameters=net, device=DEVICE)

# optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
results = train.train([net], [optimizer], dataset, num_epochs=100, alignment=False)

# beta, eigenvalue, eigenvector = net.measure_eigenfeatures(dataset.test_loader)
# dropout_results = train.eigenvector_dropout([net], dataset, [eigenvalue], [eigenvector], train_set=False, by_layer=True)

# plt.close('all')
# plt.plot(results['accuracy'])
# plt.show()

In [None]:
# places that use 'unfold' or should use it
# -- done -- AN.get_alignment_weights()
# -- done --AN.forward_eigenvector.dropout() # --- doesn't use it but it should!!! --- 
# -- done -- AN.measure_eigenfeatures()
# AN.measure_class_eigenfeatures()

In [47]:
# B, D, S, C = 1024, 25, 784, 32
B, D, S, C = 1024, 800, 196, 64
input = torch.normal(0, 1, (B, D, S)).to(DEVICE)
weight = torch.normal(0, 1, (C, D)).to(DEVICE)

def get_align_usual(input, weight):
    var_stride = torch.mean(torch.var(input, dim=1), dim=0)
    align_stride = torch.stack([utils.alignment(input[:, :, i], weight) for i in range(S)], dim=1)
    return utils.weighted_average(align_stride, var_stride.view(1, -1), 1, ignore_nan=True)
    
def get_align_new(input, weight):
    var_stride = torch.mean(torch.var(input, dim=1), dim=0)
    cc = utils.batch_cov(input.transpose(0, 2))
    rq = torch.sum(torch.matmul(weight, cc) * weight, axis=2) / torch.sum(weight * weight, axis=1)
    prq = rq / torch.diagonal(cc, dim1=1, dim2=2).sum(1, keepdim=True)
    return utils.weighted_average(prq, var_stride.view(-1, 1), 0, ignore_nan=True)

au = get_align_usual(input, weight)
an = get_align_new(input, weight)

%timeit _ = get_align_usual(input, weight)
%timeit _ = get_align_new(input, weight)

print(torch.allclose(au, an))


# convert batch_cov to allow for batch_corr too (and make it "smart" with zero-var handling)
# integrate batched alignment into pipeline
# check all convolutional code!!!


164 ms ± 3.65 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
188 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
True


In [6]:
images, labels = dataset.unwrap_batch(next(iter(dataset.test_loader)))

inputs = net.get_layer_inputs(images, precomputed=False)
processed = net._preprocess_inputs(inputs)
processed = [p.cpu() for p in processed]

for i in processed:
    print(i.shape)

torch.Size([1024, 25, 784])
torch.Size([1024, 800, 196])
torch.Size([1024, 3136])
torch.Size([1024, 128])


In [34]:
# integrate the changes in utils (e.g. smart_pca into _measure_layer_eigenfeatures)

In [15]:
model_name = 'CNN2P2'
dataset_name = 'MNIST'

net = get_model(model_name, build=True, dataset=dataset_name).to(DEVICE)
dataset = get_dataset(dataset_name, build=True, transform_parameters=net, loader_parameters=dict(batch_size=128), device=DEVICE)

# optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
# results = train.train([net], [optimizer], dataset, num_epochs=10, alignment=True)

# plt.close('all')
# plt.plot(results['accuracy'])
# plt.show()

In [11]:

layer_idx = 0
lalign = results['alignment'][layer_idx][0].T
x = torch.arange(0, lalign.size(1)).view(1, -1).expand(lalign.size(0), -1)
plt.close('all')
# plt.scatter(x.flatten(), lalign.flatten(), s=3, c=('k', 0.1))
plt.plot(x[0], lalign.T, linewidth=1, color=('k', 0.4))
plt.show()

In [16]:
beta, evals, evecs = net.measure_eigenfeatures(dataset.test_loader, with_updates=True)

100%|██████████| 79/79 [00:03<00:00, 23.06it/s]
1it [02:30, 150.37s/it]

In [None]:
images, labels = dataset.unwrap_batch(next(iter(dataset.test_loader)))

In [None]:
out = net.forward_eigenvector_dropout(images, evals, evecs, [[0] for _ in range(net.num_alignment_layers())], net.get_alignment_layer_indices())