In [3]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

# Load fake activations

In [4]:
class A:
    pass
self = A()

In [5]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [6]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

In [7]:
self.percent_variance = 0.99
self.normalize_dimensions = True

# Compute Correlations

In [8]:
def compute_correlations(self):
    """
    Set `self.transforms` to be the svcca transform matrix M. 

    If X is the activation tensor, then X M is the svcca tensor. 
    """ 
    # Normalize
    if self.normalize_dimensions:
        for network in tqdm(self.representations_d, desc='mu, sigma'):
            t = self.representations_d[network]
            means = t.mean(0, keepdim=True)
            stdevs = t.std(0, keepdim=True)

            self.representations_d[network] = (t - means) / stdevs

    # Set `whitening_transforms`, `pca_directions`
    whitening_transforms = {} # {network: whitening_tensor}
    pca_directions = {} 
    for network in tqdm(self.representations_d, desc='pca'):
        X = self.representations_d[network]
        U, S, V = torch.svd(X)

        var_sums = torch.cumsum(S.pow(2), 0)
        wanted_size = torch.sum(var_sums.lt(var_sums[-1] * self.percent_variance)).item()

        print('For network', network, 'wanted size is', wanted_size)

        whitening_transform = torch.mm(V, torch.diag(1/S))
        whitening_transforms[network] = whitening_transform[:, :wanted_size]
        pca_directions[network] = U[:, :wanted_size]

    # Set `self.transforms` to be {network: {other: svcca_transform}}
    self.transforms = {network: {} for network in self.representations_d}
    for network, other_network in tqdm(p(self.representations_d,
                                         self.representations_d), desc='cca',
                                       total=len(self.representations_d)**2):

        if network == other_network:
            continue

        if other_network in self.transforms[network].keys(): # TO DO: optimize?
            continue

        X = pca_directions[network]
        Y = pca_directions[other_network]

        # Perform SVD for CCA.
        # u s vt = Xt Y
        # s = ut Xt Y v
        u, s, v = torch.svd(torch.mm(X.t(), Y))

        self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u)
        self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v)

In [9]:
compute_correlations(self)

mu, sigma: 100%|█| 3/3 [00:00<00:00, 1081.19it/s]
pca: 100%|█| 3/3 [00:00<00:00, 256.29it/s]
cca: 100%|█| 9/9 [00:00<00:00, 1364.74it/s]

For network foo wanted size is 97
For network bar wanted size is 78
For network baz wanted size is 68





# Build function

In [10]:
output_file = "temp"

In [11]:
torch.save(self.transforms, output_file)

In [13]:
torch.load(output_file)

{'foo': {'bar': tensor([[ 4.3389e-03, -4.0581e-03, -4.2364e-03,  ...,  3.1170e-03,
            2.2676e-04,  5.1054e-03],
          [-2.8659e-03, -3.0557e-03,  5.2872e-03,  ..., -3.2413e-03,
           -4.4404e-03,  6.3419e-06],
          [ 4.9767e-03,  1.1773e-05, -2.2116e-03,  ..., -1.4815e-03,
           -1.6110e-03,  2.2029e-05],
          ...,
          [-4.1768e-03, -1.7934e-03, -4.4769e-03,  ...,  3.9952e-03,
            1.2988e-03,  4.4865e-03],
          [-8.5371e-03, -2.1010e-03,  4.0864e-03,  ..., -4.9519e-04,
            1.4108e-03,  1.8515e-03],
          [-1.3414e-03, -9.9751e-04,  6.0142e-03,  ..., -1.7713e-03,
           -2.3972e-03, -4.3518e-03]]),
  'baz': tensor([[ 0.0015, -0.0018,  0.0027,  ...,  0.0008, -0.0025,  0.0014],
          [ 0.0057,  0.0086, -0.0023,  ...,  0.0048, -0.0022,  0.0031],
          [ 0.0081,  0.0026, -0.0021,  ...,  0.0001, -0.0022, -0.0004],
          ...,
          [ 0.0046,  0.0046,  0.0002,  ..., -0.0054, -0.0055, -0.0027],
          [-0.000