In [1]:
%reload_ext autoreload
%autoreload 2

import time
from tqdm import tqdm
import numpy as np
import scipy as sp
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.transforms import v2 as transforms

from matplotlib import pyplot as plt

from networkAlignmentAnalysis.models.registry import get_model
from networkAlignmentAnalysis.datasets import get_dataset
from networkAlignmentAnalysis.experiments.registry import get_experiment
from networkAlignmentAnalysis import utils
from networkAlignmentAnalysis import files
from networkAlignmentAnalysis import train

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using device: ', DEVICE)

using device:  cuda


In [3]:
# TODO
# 1.1. include additional AlignmentModel methods stored in extra class in base model

# 4. Rewrite existing analysis pipelines
# 5. SLURM!!!!

# i don't like how by_stride is buried in the layers. Rewrite 

# Figure out why convolutional alignment measurement is slow...
# still working on if it's possible to speed up measure_alignment for convolutional layers

# Basic alignment_comparison Analyses (or maybe for alignment_stats):
# - compare initial to final alignment...
# - compare initial alignment to delta weight norm...
# - observe alignment of delta weight
# - compare alignment to outgoing delta weight norm!

# Eigenfeature analyses:
# done: - start by just looking at amplitude of activity on each eigenvector within each layer
# - Determine contribution of each eigenfeature on performance with a eigenvector dropout experiment
# - Measure beta_adversarial (figure out how adversarial examples map onto eigenvectors)

# alignmentShaping.ipynb has an adversarial experiment worth looking at

# Consider Valentin's idea about measuring an error threshold given signal and noise for a given level of alignment
# e.g. plot a 2d heatmap comparing the noise amplitude and the average alignment
# and then think about how to apply this to network design...

In [2]:
model_name = 'CNN2P2'
dataset_name = 'MNIST'
by_stride = True

net = get_model(model_name, build=True, dataset=dataset_name).to(DEVICE)
dataset = get_dataset(dataset_name, build=True, transform_parameters=net, device=DEVICE)

# optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
# results = train.train([net], [optimizer], dataset, num_epochs=100, alignment=False)

# beta, eigenvalue, eigenvector = net.measure_eigenfeatures(dataset.test_loader, by_stride=by_stride)
# dropout_results = train.eigenvector_dropout([net], dataset, [eigenvalue], [eigenvector], train_set=False, by_stride=by_stride, by_layer=True)

# plt.close('all')
# plt.plot(results['accuracy'])
# plt.show()

In [None]:
# places that use 'unfold' or should use it
# -- done -- AN.get_alignment_weights()
# -- done --AN.forward_eigenvector.dropout() # --- doesn't use it but it should!!! --- 
# -- done -- AN.measure_eigenfeatures()
# AN.measure_class_eigenfeatures()

In [39]:
# ------------------ alignment functions ----------------------
def alignment(input, weight, method='alignment'):
    """
    measure alignment (proportion variance explained) between **input** and **weight**
    
    computes the rayleigh quotient between each weight vector in **weight** and the **input** fed 
    into **weight**. Typically, **input** is the output in Layer L-1 and **weight** is from Layer L

    the output is normalized by the total variance in output of layer L-1 to measure the proportion 
    of variance of in **input** is explained by a projection onto node's weights in **weight**

    args
    ----
        input: (batch, neurons) torch tensor 
            - represents input activity being fed in to network weight layer
        weight: (num_out, num_in) torch tensor 
            - represents weights multiplied by input layer
        method: string, default='alignment'
            - which method to use to measure structure in **input** 
            - if 'alignment', uses covariance matrix of **input**
            - if 'similarity', uses correlation matrix of **input**

    returns
    -------
        alignment: (num_out, ) torch tensor
            - proportion of variance explained by projection of **input** onto each **weight** vector
    """
    assert method=='alignment' or method=='similarity', "method must be set to either 'alignment' or 'similarity' (or None, default is alignment)"
    if method=='alignment':
        cc = torch.cov(input.T)
    elif method=='similarity':
        cc = utils.smartcorr(input.T)
    else: 
        raise ValueError(f"did not recognize method ({method}), must be 'alignment' or 'similarity'")
    # Compute rayleigh quotient
    rq = torch.sum(torch.matmul(weight, cc) * weight, axis=1) / torch.sum(weight * weight, axis=1)
    # proportion of variance explained by a projection of the input onto each weight
    prq = rq/torch.trace(cc)
    return prq

# B, D, S, C = 1024, 25, 784, 32
B, D, S, C = 1024, 800, 196, 64
input = torch.normal(0, 1, (B, D, S)).to(DEVICE)
weight = torch.normal(0, 1, (C, D)).to(DEVICE)

def get_align_usual(input, weight):
    var_stride = torch.mean(torch.var(input, dim=1), dim=0)
    align_stride = torch.stack([alignment(input[:, :, i], weight) for i in range(S)], dim=1)
    return utils.weighted_average(align_stride, var_stride.view(1, -1), 1, ignore_nan=True)
    
def get_align_new(input, weight):
    var_stride = torch.mean(torch.var(input, dim=1), dim=0)
    cc = utils.batch_cov(input.transpose(0, 2))
    rq = torch.sum(torch.matmul(weight, cc) * weight, axis=2) / torch.sum(weight * weight, axis=1)
    prq = rq / torch.diagonal(cc, dim1=1, dim2=2).sum(1, keepdim=True)
    return utils.weighted_average(prq, var_stride.view(-1, 1), 0, ignore_nan=True)

au = get_align_usual(input, weight)
an = get_align_new(input, weight)

%timeit _ = get_align_usual(input, weight)
%timeit _ = get_align_new(input, weight)

print(torch.allclose(au, an))


# convert batch_cov to allow for batch_corr too (and make it "smart")
# integrate batched alignment into pipeline


165 ms ± 738 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
189 ms ± 65.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
True


In [6]:
images, labels = dataset.unwrap_batch(next(iter(dataset.test_loader)))

inputs = net.get_layer_inputs(images, precomputed=False)
processed = net._preprocess_inputs(inputs)
processed = [p.cpu() for p in processed]

for i in processed:
    print(i.shape)

torch.Size([1024, 25, 784])
torch.Size([1024, 800, 196])
torch.Size([1024, 3136])
torch.Size([1024, 128])


In [16]:
input = processed[0].clone()
w, v = utils.smart_pca(input)
print(w.shape, v.shape)

torch.Size([1024, 25]) torch.Size([1024, 25, 25])


In [13]:
centered = True
input = processed[0].clone()

bvar = torch.mean(torch.var(input, dim=1), dim=0) 

# measuring across each stride independently (for conv layers)
bcov = utils.batch_cov(input.permute((2, 1, 0)), centered=centered)
cond_inf = torch.zeros((len(bvar)))
var_zero = torch.zeros((len(bvar)))

# eigh will fail if condition number is too high (which can happen
# in these strided input activities). If that's the case, we set the 
# covariance to the identity so eigh will work, and set the variance
# to 0 in that stride, so the weighted_average will ignore that stride.
for ii, bc in enumerate(bcov):
    if torch.isinf(torch.linalg.cond(bc)):
        bvar[ii] = 0 # set variance to 0 to ignore this stride
        bcov[ii] = torch.eye(bcov.size(1)) # set to identity for simple eigvec decomp.
        cond_inf[ii] = 1

idx_inf = torch.where(cond_inf)[0]

# measure eigenvalues and eigenvectors
we, ve = utils.named_transpose([utils.eigendecomposition(bc, use_rank=True) for bc in bcov])
# 
# stack 
we = torch.stack(we)
ve = torch.stack(ve)


KeyboardInterrupt: 

In [50]:
net.get_alignment_layers()[1].weight.data.shape

torch.Size([64, 32, 5, 5])

In [73]:
def batch_cov(input, centered=True, corr=False):
    """
    Performs batched covariance on input data of shape (batch, dim, samples) or (dim, samples)

    Where the resulting batch covariance matrix has shape (batch, dim, dim) or (dim, dim)
    and bcov[i] = torch.cov(input[i]) if input.ndim==3

    if centered=True (default=True) will subtract the means first
    if corr=True (default=False), will divide by standard deviations to get correlation matrices
          note that it's not really correlation if centered=False! 
    """


def smartcorr(input):
    """
    Performs torch corrcoef on the input data but sets each pair-wise correlation coefficent
    to 0 where the activity has no variance (var=0) for a particular dimension (replaces nans with zeros)sss
    """
    idx_zeros = torch.var(input, dim=1)==0
    cc = torch.corrcoef(input)
    cc[idx_zeros,:] = 0
    cc[:,idx_zeros] = 0
    return cc


B, D, S = 1024, 7, 1000
input = torch.normal(0, 1, (B, D, S))
input[0, 0, :] = 0
input[0, 3, :] = 0
centered=True
corr=True

assert (input.ndim == 2) or (input.ndim == 3), "input must be a 2D or 3D tensor"
# check if batch dimension was provided
no_batch = input.ndim == 2 

# add an empty batch dimension if not provided
if no_batch: 
    input = input.unsqueeze(0) 

# measure number of samples of each input matrix
S = input.size(2) 

# subtract mean if doing centered covariance
if centered:
    input = input - input.mean(dim=2, keepdim=True) 

if corr:
    # measure standard deviation
    input_dev = torch.std(input, dim=2)
    dev_correction = input_dev.unsqueeze(2) * input_dev.unsqueeze(1)

    # mask out any part of the correlation matrix with zeros where var=0
    idx_zeros = input_dev==0
    zero_mask = torch.logical_or(idx_zeros.unsqueeze(2), idx_zeros.unsqueeze(1))
    

# measure covariance of each input matrix
bcov = torch.bmm(input, input.transpose(1, 2))

if corr:
    bcov /= dev_correction
    bcov = bcov.masked_fill(zero_mask, 0)

print(bcov.shape)

# correct for number of samples
bcov /= (S-1)

# remove empty batch dimension if not provided
if no_batch: 
    bcov = bcov.squeeze(0)



torch.Size([1024, 7, 7])


In [34]:
# integrate the changes in utils (e.g. smart_pca into _measure_layer_eigenfeatures)