In [2]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

# Load fake activations

In [3]:
class A:
    pass
self = A()

In [4]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [5]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

In [9]:
self.percent_variance = 0.99
self.normalize_dimensions = True

# Build function

### Normalize

In [10]:
# Normalize
if self.normalize_dimensions:
    for network in tqdm(self.representations_d, desc='mu, sigma'):
        t = self.representations_d[network]
        means = t.mean(0, keepdim=True)
        stdevs = t.std(0, keepdim=True)

        self.representations_d[network] = (t - means) / stdevs

mu, sigma: 100%|█| 3/3 [00:00<00:00, 853.25it/s]


### Set `whitening_transforms`

In [12]:
# loop variable
network = f1

In [29]:
whitening_transforms = {} # {network: whitening_tensor}

In [18]:
X = self.representations_d[network]
U, S, V = torch.svd(X)
U.shape, S.shape, V.shape

(torch.Size([1000, 100]), torch.Size([100]), torch.Size([100, 100]))

In [20]:
var_sums = torch.cumsum(S.pow(2), 0)
var_sums

tensor([ 1710.9971,  3373.7769,  5013.2178,  6627.2939,  8212.7188,  9770.2158,
        11315.7871, 12826.0156, 14323.4873, 15794.5332, 17255.8984, 18709.4844,
        20149.7383, 21563.4004, 22964.4492, 24348.6895, 25713.0527, 27054.7852,
        28389.6602, 29696.4648, 30998.3535, 32284.4219, 33557.5664, 34821.3711,
        36069.5586, 37301.2461, 38527.1406, 39734.8242, 40934.7812, 42115.1719,
        43293.6445, 44461.2578, 45624.0352, 46769.0742, 47904.1562, 49035.1484,
        50133.7227, 51228.7070, 52311.4922, 53389.5742, 54460.8555, 55517.8711,
        56562.3164, 57596.4766, 58618.0703, 59632.5859, 60630.8672, 61617.5469,
        62585.3242, 63540.2852, 64486.4531, 65428.7031, 66359.9062, 67278.2891,
        68183.9219, 69082.7109, 69977.0625, 70867.1016, 71753.4922, 72624.2500,
        73485.2188, 74337.9844, 75179.5234, 76013.7188, 76843.7422, 77664.4297,
        78469.3906, 79267.1172, 80051.6797, 80835.2109, 81606.8203, 82369.3359,
        83128.5781, 83873.6328, 84613.28

In [23]:
wanted_size = torch.sum(var_sums.lt(var_sums[-1] * self.percent_variance)).item()
wanted_size

97

In [24]:
print('For network', network, 'wanted size is', wanted_size)

For network foo wanted size is 97


In [27]:
whitening_transform = torch.mm(V, torch.diag(1/S))
whitening_transform.shape

torch.Size([100, 100])

In [30]:
whitening_transforms[network] = whitening_transform[:, :wanted_size]

In [37]:
# full
# Set `whitening_transforms`, `pca_directions`
whitening_transforms = {} # {network: whitening_tensor}
pca_directions = {} 
for network in tqdm(self.representations_d, desc='pca'):
    X = self.representations_d[network]
    U, S, V = torch.svd(X)

    var_sums = torch.cumsum(S.pow(2), 0)
    wanted_size = torch.sum(var_sums.lt(var_sums[-1] * self.percent_variance)).item()

    print('For network', network, 'wanted size is', wanted_size)

    whitening_transform = torch.mm(V, torch.diag(1/S))
    whitening_transforms[network] = whitening_transform[:, :wanted_size]
    pca_directions[network] = U[:, :wanted_size]

pca: 100%|█| 3/3 [00:00<00:00, 158.11it/s]

For network foo wanted size is 97
For network bar wanted size is 78
For network baz wanted size is 68





### Set `self.transforms`

In [35]:
# {network: {other: svcca_transform}}
self.transforms = {network: {} for network in self.representations_d}

In [38]:
# loop variables
network = f1
other_network = f2

In [None]:
# if network == other_network:
#     continue

# if other_network in self.transforms[network].keys(): # TO DO: optimize?
#     continue

In [40]:
X = pca_directions[network]
Y = pca_directions[other_network]

In [41]:
# Perform SVD for CCA.
# u s vt = Xt Y
# s = ut Xt Y v
u, s, v = torch.svd(torch.mm(X.t(), Y))

In [47]:
self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u)
self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v)

In [50]:
# full 
# Set `self.transforms` to be {network: {other: svcca_transform}}
self.transforms = {network: {} for network in self.representations_d}
for network, other_network in tqdm(p(self.representations_d,
                                     self.representations_d), desc='cca',
                                   total=len(self.representations_d)**2):

    if network == other_network:
        continue

    if other_network in self.transforms[network].keys(): # TO DO: optimize?
        continue

    X = pca_directions[network]
    Y = pca_directions[other_network]

    # Perform SVD for CCA.
    # u s vt = Xt Y
    # s = ut Xt Y v
    u, s, v = torch.svd(torch.mm(X.t(), Y))

    self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u)
    self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v)

self.transforms

cca: 100%|█| 9/9 [00:00<00:00, 962.83it/s]


{'foo': {'bar': tensor([[ 2.8799e-04,  2.6425e-03, -5.9800e-03,  ...,  1.5790e-03,
            1.0699e-03,  7.6494e-04],
          [ 3.0322e-03,  3.9201e-04, -1.4865e-03,  ..., -2.4446e-03,
           -2.3550e-04, -3.5492e-03],
          [-1.0548e-03, -1.3683e-03,  2.7049e-04,  ...,  2.5806e-03,
            5.1039e-03,  3.2589e-03],
          ...,
          [-9.4123e-05,  8.6863e-04, -4.1224e-03,  ..., -4.7492e-03,
            2.5494e-03, -7.9301e-05],
          [-3.0804e-03,  1.5640e-03,  2.7552e-03,  ...,  5.2576e-04,
            4.0630e-03,  7.6578e-04],
          [-3.8764e-04,  4.6201e-03, -1.4311e-03,  ..., -9.4200e-04,
           -4.7955e-04, -3.8509e-04]]),
  'baz': tensor([[ 0.0051, -0.0008, -0.0013,  ..., -0.0081,  0.0017, -0.0008],
          [-0.0052, -0.0016, -0.0005,  ..., -0.0012,  0.0043, -0.0056],
          [-0.0048, -0.0056,  0.0001,  ...,  0.0030, -0.0002, -0.0026],
          ...,
          [-0.0004,  0.0013,  0.0010,  ..., -0.0019,  0.0014,  0.0066],
          [ 0.000

# Build full function

In [51]:
def compute_correlations(self):
    """
    Set `self.transforms` to be the svcca transform matrix M. 
    
    If X is the activation tensor, then X M is the svcca tensor. 
    """
    # Normalize
    if self.normalize_dimensions:
        for network in tqdm(self.representations_d, desc='mu, sigma'):
            t = self.representations_d[network]
            means = t.mean(0, keepdim=True)
            stdevs = t.std(0, keepdim=True)

            self.representations_d[network] = (t - means) / stdevs

    # Set `whitening_transforms`, `pca_directions`
    whitening_transforms = {} # {network: whitening_tensor}
    pca_directions = {} 
    for network in tqdm(self.representations_d, desc='pca'):
        X = self.representations_d[network]
        U, S, V = torch.svd(X)

        var_sums = torch.cumsum(S.pow(2), 0)
        wanted_size = torch.sum(var_sums.lt(var_sums[-1] * self.percent_variance)).item()

        print('For network', network, 'wanted size is', wanted_size)

        whitening_transform = torch.mm(V, torch.diag(1/S))
        whitening_transforms[network] = whitening_transform[:, :wanted_size]
        pca_directions[network] = U[:, :wanted_size]

    # Set `self.transforms` to be {network: {other: svcca_transform}}
    self.transforms = {network: {} for network in self.representations_d}
    for network, other_network in tqdm(p(self.representations_d,
                                         self.representations_d), desc='cca',
                                       total=len(self.representations_d)**2):

        if network == other_network:
            continue

        if other_network in self.transforms[network].keys(): # TO DO: optimize?
            continue

        X = pca_directions[network]
        Y = pca_directions[other_network]

        # Perform SVD for CCA.
        # u s vt = Xt Y
        # s = ut Xt Y v
        u, s, v = torch.svd(torch.mm(X.t(), Y))

        self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u)
        self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v)