In [11]:
%reload_ext autoreload
%autoreload 2

import time
from tqdm import tqdm
import numpy as np
import scipy as sp
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.transforms import v2 as transforms

from matplotlib import pyplot as plt

from networkAlignmentAnalysis.models.registry import get_model
from networkAlignmentAnalysis.datasets import get_dataset
from networkAlignmentAnalysis.experiments.registry import get_experiment
from networkAlignmentAnalysis import utils
from networkAlignmentAnalysis import files

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using device: ', DEVICE)

using device:  cuda


In [8]:
# TODO
# 1.1. include additional AlignmentModel methods stored in extra class in base model

# I should also break apart the measure_eigenfeatures method for better parameter handling
# and also to allow use of components of the method in other contexts

# 4. Rewrite existing analysis pipelines
# 5. SLURM!!!!

# Figure out why convolutional alignment measurement is slow...

# 2. Make CIFAR class -- prioritize

# Basic alignment_comparison Analyses (or maybe for alignment_stats):
# - compare initial to final alignment...
# - compare initial alignment to delta weight norm...
# - observe alignment of delta weight
# - compare alignment to outgoing delta weight norm!

# Eigenfeature analyses:
# - Measure beta_adversarial (figure out how adversarial examples map onto eigenvectors)
# - Determine contribution of each eigenfeature on performance with a eigenvector dropout experiment
# ------ start by just looking at amplitude of activity on each eigenvector within each layer

# alignmentShaping.ipynb has an adversarial experiment worth looking at

In [133]:
import cProfile

net = get_model('CNN2P2', build=True, each_stride=True).to(DEVICE)
dataset = get_dataset('MNIST', build=True, transform_parameters=net, device=DEVICE)

batch = next(iter(dataset.test_loader))
images, labels = dataset.unwrap_batch(batch)

# Perform forward pass
output = net(images, store_hidden=True)

In [134]:
def measure_alignment(self, x, swap_index=[0, 1, 2, 3], precomputed=False, method='alignment'):
    # Pre-layer activations start with input (x) and ignore output
    activations = [x, *self.get_activations(x=x, precomputed=precomputed)[:-1]]
    alignment = []
    duration = []
    activations = [activations[i] for i in swap_index]
    layers = [self.get_alignment_layers()[i] for i in swap_index]
    metaparams = [self.get_alignment_metaparameters()[i] for i in swap_index]
    zipped = zip(activations, layers, metaparams)
    for idx, (activation, layer, metaprms) in enumerate(zipped):
        t = time.time()
        alignment.append(metaprms['alignment_method'](activation, layer, method=method))
        duration.append(time.time() - t)
    return alignment, duration

_ = measure_alignment(net, images, precomputed=True)

In [139]:
swap_index = [0, 1, 2, 3]
duration = [0]* len(swap_index)
for batch in tqdm(dataset.test_loader):
    images, labels = dataset.unwrap_batch(batch)
    output = net(images, store_hidden=True)
    align, cduration = measure_alignment(net, images, swap_index=swap_index,precomputed=True)
    for i in range(len(duration)):
        duration[i] += cduration[i]

print('duration:', duration)

100%|██████████| 10/10 [00:04<00:00,  2.22it/s]

duration: [0.24064135551452637, 0.03815197944641113]





In [78]:
def test_measure_alignment(self, x, swap_index, precomputed=False, method='alignment', num_checks=10):
    # Pre-layer activations start with input (x) and ignore output
    activations = [x, *self.get_activations(x=x, precomputed=precomputed)[:-1]]
    alignment = []
    activations = [activations[i] for i in swap_index]
    layers = [self.get_alignment_layers()[i] for i in swap_index]
    metaparams = [self.get_alignment_metaparameters()[i] for i in swap_index]
    zipped = zip(activations, layers, metaparams)
    for idx, (activation, layer, metaprms) in enumerate(zipped):
        t = time.time()
        # _ = metaprms['alignment_method'](activation, layer, method=method)
        alignment.append(metaprms['alignment_method'](activation, layer, method=method))
        print("idx:", swap_index[idx], time.time() - t)
        
        # if idx==2:
        #     profiler = cProfile.Profile()
        #     profiler.enable()
        #     for _ in range(num_checks):
        #         _ = metaprms['alignment_method'](activation, layer, method=method)
        #     profiler.disable()
    return alignment #, profiler

def testidx_measure_alignment(self, x, swap_index, precomputed=False, method='alignment', num_checks=10):
    # Pre-layer activations start with input (x) and ignore output
    activations = [x, *self.get_activations(x=x, precomputed=precomputed)[:-1]]
    alignment = []
    activations = [activations[i] for i in swap_index]
    layers = [self.get_alignment_layers()[i] for i in swap_index]
    metaparams = [self.get_alignment_metaparameters()[i] for i in swap_index]
    for idx in range(len(activations)):
        t = time.time()
        # _ = metaprms['alignment_method'](activation, layer, method=method)
        alignment.append(metaparams[idx]['alignment_method'](activations[idx], layers[idx], method=method))
        print("idx:", swap_index[idx], time.time() - t)
        
        # if idx==2:
        #     profiler = cProfile.Profile()
        #     profiler.enable()
        #     for _ in range(num_checks):
        #         _ = metaparams[idx]['alignment_method'](activations[idx], layers[idx], method=method)
        #     profiler.disable()
    return alignment #, profiler

In [29]:
print("each_stride=", net.each_stride)
out, profiler = measure_alignment(net, images, [0, 1, 2, 3], precomputed=True, method='alignment', num_checks=1000)
profiler.print_stats(sort='cumulative')

each_stride= False
idx: 0 0.035366058349609375
idx: 1 0.06444048881530762
idx: 2 0.24110698699951172
idx: 3 0.0
         8001 function calls in 0.232 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1000    0.003    0.000    0.232    0.000 utils.py:92(alignment_linear)
     1000    0.027    0.000    0.228    0.000 utils.py:53(alignment)
     1000    0.174    0.000    0.174    0.000 {built-in method torch.cov}
     2000    0.012    0.000    0.012    0.000 {built-in method torch.sum}
     1000    0.009    0.000    0.009    0.000 {built-in method torch.matmul}
     1000    0.007    0.000    0.007    0.000 {built-in method torch.trace}
     1000    0.001    0.000    0.001    0.000 module.py:1682(__getattr__)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}




In [None]:
activations = [images, *net.get_activations(x=images, precomputed=True)]
layers = net.get_alignment_layers()
metaprms = net.get_alignment_metaparameters()

idx = 0
x = alignment_slow(activations[idx], layers[idx].weight.data)
y = alignment_fast(activations[idx], layers[idx].weight.data)

%timeit _ = alignment_slow(activations[idx], layers[idx].weight.data)
%timeit _ = alignment_fast(activations[idx], layers[idx].weight.data)

print(torch.allclose(x, y))

442 µs ± 1.87 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
443 µs ± 3.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
True
