In [2]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

# Load fake activations

In [3]:
class A:
    pass
self = A()

In [4]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [5]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

In [6]:
self.percent_variance = 0.99
self.normalize_dimensions = True
self.save_cca_transforms = False

# Function

In [7]:
# Normalize
# Set `self.nrepresentations_d`
self.nrepresentations_d = {}
if self.normalize_dimensions:
    for network in tqdm(self.representations_d, desc='mu, sigma'):
        t = self.representations_d[network]
        means = t.mean(0, keepdim=True)
        stdevs = t.std(0, keepdim=True)

        self.nrepresentations_d[network] = (t - means) / stdevs

mu, sigma: 100%|███████████████████████████████| 3/3 [00:00<00:00, 58.87it/s]


In [8]:
# Set `whitening_transforms`, `pca_directions`
# {network: whitening_tensor}
whitening_transforms = {} 
pca_directions = {} 
for network in tqdm(self.nrepresentations_d, desc='pca'):
    X = self.nrepresentations_d[network]
    U, S, V = torch.svd(X)

    var_sums = torch.cumsum(S.pow(2), 0)
    wanted_size = torch.sum(var_sums.lt(var_sums[-1] * self.percent_variance)).item()

    print('For network', network, 'wanted size is', wanted_size)

    whitening_transform = torch.mm(V, torch.diag(1/S))
    whitening_transforms[network] = whitening_transform[:, :wanted_size]
    pca_directions[network] = U[:, :wanted_size]

pca: 100%|█████████████████████████████████████| 3/3 [00:04<00:00,  1.39s/it]

For network foo wanted size is 97
For network bar wanted size is 78
For network baz wanted size is 68





In [9]:
# Set 
# `self.transforms`: {network: {other: svcca_transform}}
# `self.corrs`: {network: {other: canonical_corrs}}
# `self.pw_alignments`: {network: {other: unnormalized pw weights}}
# `self.pw_corrs`: {network: {other: pw_alignments*corrs}}
# `self.sv_similarities`: {network: {other: svcca_similarities}}
# `self.pw_similarities`: {network: {other: pwcca_similarities}}
self.transforms = {network: {} for network in self.nrepresentations_d}
self.corrs = {network: {} for network in self.nrepresentations_d}
self.pw_alignments = {network: {} for network in self.nrepresentations_d}
self.pw_corrs = {network: {} for network in self.nrepresentations_d}
self.sv_similarities = {network: {} for network in self.nrepresentations_d}
self.pw_similarities = {network: {} for network in self.nrepresentations_d}

# For network in ... loop

In [9]:
# arbitrary loop params
network = f1
other_network = f2

In [10]:
X = pca_directions[network]
Y = pca_directions[other_network]

In [11]:
# Perform SVD for CCA.
# u s vt = Xt Y
# s = ut Xt Y v
u, s, v = torch.svd(torch.mm(X.t(), Y))

# `self.transforms`, `self.corrs`, `self.sv_similarities`
self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u)
self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v)

self.corrs[network][other_network] = s
self.corrs[other_network][network] = s

self.sv_similarities[network][other_network] = s.mean().item()
self.sv_similarities[other_network][network] = s.mean().item()

In [12]:
# Compute `self.pw_alignments`, `self.pw_corrs`, `self.pw_similarities`. 
# This is not symmetric

# For X
H = torch.mm(X, u)
Z = self.representations_d[network]
align = torch.abs(torch.mm(H.t(), Z))
a = torch.sum(align, dim=1, keepdim=False)
self.pw_alignments[network][other_network] = a
self.pw_corrs[network][other_network] = s*a
self.pw_similarities[network][other_network] = (torch.sum(s*a)/torch.sum(a)).item()

# For Y
H = torch.mm(Y, v)
Z = self.representations_d[other_network]
align = torch.abs(torch.mm(H.t(), Z))
a = torch.sum(align, dim=1, keepdim=False)
self.pw_alignments[other_network][network] = a
self.pw_corrs[other_network][network] = s*a
self.pw_similarities[other_network][network] = (torch.sum(s*a)/torch.sum(a)).item()

In [13]:
# full loop
for network, other_network in tqdm(p(self.nrepresentations_d,
                                     self.nrepresentations_d),
                                   desc='cca',
                                   total=len(self.nrepresentations_d)**2):

    if network == other_network:
        continue

    if other_network in self.transforms[network]: 
        continue
    
    X = pca_directions[network]
    Y = pca_directions[other_network]

    # Perform SVD for CCA.
    # u s vt = Xt Y
    # s = ut Xt Y v
    u, s, v = torch.svd(torch.mm(X.t(), Y))

    # `self.transforms`, `self.corrs`, `self.sv_similarities`
    self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u)
    self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v)

    self.corrs[network][other_network] = s
    self.corrs[other_network][network] = s

    self.sv_similarities[network][other_network] = s.mean().item()
    self.sv_similarities[other_network][network] = s.mean().item()

    # Compute `self.pw_alignments`, `self.pw_corrs`, `self.pw_similarities`. 
    # This is not symmetric

    # For X
    H = torch.mm(X, u)
    Z = self.representations_d[network]
    align = torch.abs(torch.mm(H.t(), Z))
    a = torch.sum(align, dim=1, keepdim=False)
    self.pw_alignments[network][other_network] = a
    self.pw_corrs[network][other_network] = s*a
    self.pw_similarities[network][other_network] = (torch.sum(s*a)/torch.sum(a)).item()

    # For Y
    H = torch.mm(Y, v)
    Z = self.representations_d[other_network]
    align = torch.abs(torch.mm(H.t(), Z))
    a = torch.sum(align, dim=1, keepdim=False)
    self.pw_alignments[other_network][network] = a
    self.pw_corrs[other_network][network] = s*a
    self.pw_similarities[other_network][network] = (torch.sum(s*a)/torch.sum(a)).item()

cca: 100%|███████████████████████████████████| 9/9 [00:00<00:00, 1571.36it/s]


In [17]:
self.pw_similarities

{'foo': {'bar': 0.28253406286239624, 'baz': 0.28071942925453186},
 'bar': {'foo': 0.28259754180908203, 'baz': 0.24323876202106476},
 'baz': {'foo': 0.2819087505340576, 'bar': 0.24311821162700653}}

# Full function

In [18]:
def compute_correlations(self):
    # Normalize
    # Set `self.nrepresentations_d`
    self.nrepresentations_d = {}
    if self.normalize_dimensions:
        for network in tqdm(self.representations_d, desc='mu, sigma'):
            t = self.representations_d[network]
            means = t.mean(0, keepdim=True)
            stdevs = t.std(0, keepdim=True)

            self.nrepresentations_d[network] = (t - means) / stdevs

    # Set `whitening_transforms`, `pca_directions`
    # {network: whitening_tensor}
    whitening_transforms = {} 
    pca_directions = {} 
    for network in tqdm(self.nrepresentations_d, desc='pca'):
        X = self.nrepresentations_d[network]
        U, S, V = torch.svd(X)

        var_sums = torch.cumsum(S.pow(2), 0)
        wanted_size = torch.sum(var_sums.lt(var_sums[-1] * self.percent_variance)).item()

        print('For network', network, 'wanted size is', wanted_size)

        whitening_transform = torch.mm(V, torch.diag(1/S))
        whitening_transforms[network] = whitening_transform[:, :wanted_size]
        pca_directions[network] = U[:, :wanted_size]
    
    # Set 
    # `self.transforms`: {network: {other: svcca_transform}}
    # `self.corrs`: {network: {other: canonical_corrs}}
    # `self.pw_alignments`: {network: {other: unnormalized pw weights}}
    # `self.pw_corrs`: {network: {other: pw_alignments*corrs}}
    # `self.sv_similarities`: {network: {other: svcca_similarities}}
    # `self.pw_similarities`: {network: {other: pwcca_similarities}}
    self.transforms = {network: {} for network in self.nrepresentations_d}
    self.corrs = {network: {} for network in self.nrepresentations_d}
    self.pw_alignments = {network: {} for network in self.nrepresentations_d}
    self.pw_corrs = {network: {} for network in self.nrepresentations_d}
    self.sv_similarities = {network: {} for network in self.nrepresentations_d}
    self.pw_similarities = {network: {} for network in self.nrepresentations_d}
    for network, other_network in tqdm(p(self.nrepresentations_d,
                                         self.nrepresentations_d),
                                       desc='cca',
                                       total=len(self.nrepresentations_d)**2):

        if other_network in self.transforms[network]: 
            continue

        X = pca_directions[network]
        Y = pca_directions[other_network]

        # Perform SVD for CCA.
        # u s vt = Xt Y
        # s = ut Xt Y v
        u, s, v = torch.svd(torch.mm(X.t(), Y))

        # `self.transforms`, `self.corrs`, `self.sv_similarities`
        self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u)
        self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v)

        self.corrs[network][other_network] = s
        self.corrs[other_network][network] = s

        self.sv_similarities[network][other_network] = s.mean().item()
        self.sv_similarities[other_network][network] = s.mean().item()

        # Compute `self.pw_alignments`, `self.pw_corrs`, `self.pw_similarities`. 
        # This is not symmetric

        # For X
        H = torch.mm(X, u)
        Z = self.representations_d[network]
        align = torch.abs(torch.mm(H.t(), Z))
        a = torch.sum(align, dim=1, keepdim=False)
        self.pw_alignments[network][other_network] = a
        self.pw_corrs[network][other_network] = s*a
        self.pw_similarities[network][other_network] = (torch.sum(s*a)/torch.sum(a)).item()

        # For Y
        H = torch.mm(Y, v)
        Z = self.representations_d[other_network]
        align = torch.abs(torch.mm(H.t(), Z))
        a = torch.sum(align, dim=1, keepdim=False)
        self.pw_alignments[other_network][network] = a
        self.pw_corrs[other_network][network] = s*a
        self.pw_similarities[other_network][network] = (torch.sum(s*a)/torch.sum(a)).item()

In [19]:
compute_correlations(self)

mu, sigma: 100%|█████████████████████████████| 3/3 [00:00<00:00, 1903.62it/s]
pca: 100%|████████████████████████████████████| 3/3 [00:00<00:00, 240.29it/s]
cca: 100%|████████████████████████████████████| 9/9 [00:00<00:00, 873.07it/s]

For network foo wanted size is 97
For network bar wanted size is 78
For network baz wanted size is 68





In [21]:
self.pw_corrs

{'foo': {'bar': tensor([139.9517, 140.7342, 140.8350, 128.5976, 134.9371, 129.0182, 139.3720,
          119.7300, 122.1407, 124.5879, 117.4110, 120.1342, 109.9069, 116.9453,
          112.3361, 110.1559, 104.3443, 106.3454,  96.7515, 107.5248,  98.3512,
           99.8768,  96.0159,  94.7809,  98.3115,  88.1688,  91.3338,  82.3340,
           86.3369,  84.8769,  81.3848,  81.4239,  79.3815,  77.9744,  73.5487,
           80.3439,  76.5836,  76.9173,  69.9647,  67.2655,  69.3710,  64.4150,
           60.7970,  61.8085,  64.7805,  57.1003,  58.7008,  62.3385,  56.0259,
           51.0174,  58.0377,  53.8280,  52.1599,  50.3478,  49.2290,  47.3323,
           44.2947,  43.9474,  38.9249,  37.9358,  38.0542,  36.7693,  35.7900,
           33.2676,  32.5004,  31.8194,  28.2865,  28.0051,  29.1439,  22.9423,
           20.1307,  20.0695,  19.4042,  17.9061,  14.5332,  12.6117,  11.7166,
            8.4762]),
  'baz': tensor([133.5720, 141.4685, 132.1705, 129.7212, 119.4776, 119.0930, 123.758