In [1]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

# Load fake activations

In [2]:
class A:
    pass
self = A()

In [3]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [4]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

In [5]:
self.percent_variance = 0.99
self.normalize_dimensions = False
self.save_cca_transforms = False

# Function

In [6]:
# Set `self.nrepresentations_d`, "normalized representations". 
# Call it this regardless of if it's actually normalized
self.nrepresentations_d = {}
if self.normalize_dimensions:
    for network in tqdm(self.representations_d, desc='mu, sigma'):
        t = self.representations_d[network]
        means = t.mean(0, keepdim=True)
        stdevs = t.std(0, keepdim=True)

        self.nrepresentations_d[network] = (t - means) / stdevs
else:
    self.nrepresentations_d = self.representations_d

In [7]:
# Set `whitening_transforms`, `pca_directions`
# {network: whitening_tensor}
whitening_transforms = {} 
pca_directions = {} 
for network in tqdm(self.nrepresentations_d, desc='pca'):
    X = self.nrepresentations_d[network]
    U, S, V = torch.svd(X)

    var_sums = torch.cumsum(S.pow(2), 0)
    wanted_size = torch.sum(var_sums.lt(var_sums[-1] * self.percent_variance)).item()

    print('For network', network, 'wanted size is', wanted_size)

    if self.save_cca_transforms:
        whitening_transform = torch.mm(V, torch.diag(1/S))
        whitening_transforms[network] = whitening_transform[:, :wanted_size]
    
    pca_directions[network] = U[:, :wanted_size]

pca: 100%|█████████████████████████████████████| 3/3 [00:00<00:00,  9.21it/s]

For network foo wanted size is 97
For network bar wanted size is 78
For network baz wanted size is 68





In [8]:
# Set 
# `self.transforms`: {network: {other: svcca_transform}}
# `self.corrs`: {network: {other: canonical_corrs}}
# `self.pw_alignments`: {network: {other: unnormalized pw weights}}
# `self.pw_corrs`: {network: {other: pw_alignments*corrs}}
# `self.sv_similarities`: {network: {other: svcca_similarities}}
# `self.pw_similarities`: {network: {other: pwcca_similarities}}
self.transforms = {network: {} for network in self.nrepresentations_d}
self.corrs = {network: {} for network in self.nrepresentations_d}
self.pw_alignments = {network: {} for network in self.nrepresentations_d}
self.pw_corrs = {network: {} for network in self.nrepresentations_d}
self.sv_similarities = {network: {} for network in self.nrepresentations_d}
self.pw_similarities = {network: {} for network in self.nrepresentations_d}

# For network in ... loop

In [9]:
# arbitrary loop params
network = f1
other_network = f2

In [10]:
X = pca_directions[network]
Y = pca_directions[other_network]

In [11]:
# Perform SVD for CCA.
# u s vt = Xt Y
# s = ut Xt Y v
u, s, v = torch.svd(torch.mm(X.t(), Y))

# `self.transforms`, `self.corrs`, `self.sv_similarities`
if self.save_cca_transforms:
    self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u).cpu().numpy()
    self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v).cpu().numpy()

self.corrs[network][other_network] = s.cpu().numpy()
self.corrs[other_network][network] = s.cpu().numpy()

self.sv_similarities[network][other_network] = s.mean().item()
self.sv_similarities[other_network][network] = s.mean().item()

In [12]:
# Compute `self.pw_alignments`, `self.pw_corrs`, `self.pw_similarities`. 
# This is not symmetric

# For X
H = torch.mm(X, u)
Z = self.representations_d[network]
align = torch.abs(torch.mm(H.t(), Z))
a = torch.sum(align, dim=1, keepdim=False)
self.pw_alignments[network][other_network] = a.cpu().numpy()
self.pw_corrs[network][other_network] = (s*a).cpu().numpy()
self.pw_similarities[network][other_network] = (torch.sum(s*a)/torch.sum(a)).item()

# For Y
H = torch.mm(Y, v)
Z = self.representations_d[other_network]
align = torch.abs(torch.mm(H.t(), Z))
a = torch.sum(align, dim=1, keepdim=False)
self.pw_alignments[other_network][network] = a.cpu().numpy()
self.pw_corrs[other_network][network] = (s*a).cpu().numpy()
self.pw_similarities[other_network][network] = (torch.sum(s*a)/torch.sum(a)).item()

In [15]:
# full loop
for network, other_network in tqdm(p(self.nrepresentations_d,
                                     self.nrepresentations_d),
                                   desc='cca',
                                   total=len(self.nrepresentations_d)**2):

    if network == other_network:
        continue

    if other_network in self.transforms[network]: 
        continue
    
    X = pca_directions[network]
    Y = pca_directions[other_network]

    # Perform SVD for CCA.
    # u s vt = Xt Y
    # s = ut Xt Y v
    u, s, v = torch.svd(torch.mm(X.t(), Y))

    # `self.transforms`, `self.corrs`, `self.sv_similarities`
    if self.save_cca_transforms:
        self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u).cpu().numpy()
        self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v).cpu().numpy()

    self.corrs[network][other_network] = s.cpu().numpy()
    self.corrs[other_network][network] = s.cpu().numpy()

    self.sv_similarities[network][other_network] = s.mean().item()
    self.sv_similarities[other_network][network] = s.mean().item()

    # Compute `self.pw_alignments`, `self.pw_corrs`, `self.pw_similarities`. 
    # This is not symmetric

    # For X
    H = torch.mm(X, u)
    Z = self.representations_d[network]
    align = torch.abs(torch.mm(H.t(), Z))
    a = torch.sum(align, dim=1, keepdim=False)
    self.pw_alignments[network][other_network] = a.cpu().numpy()
    self.pw_corrs[network][other_network] = (s*a).cpu().numpy()
    self.pw_similarities[network][other_network] = (torch.sum(s*a)/torch.sum(a)).item()

    # For Y
    H = torch.mm(Y, v)
    Z = self.representations_d[other_network]
    align = torch.abs(torch.mm(H.t(), Z))
    a = torch.sum(align, dim=1, keepdim=False)
    self.pw_alignments[other_network][network] = a.cpu().numpy()
    self.pw_corrs[other_network][network] = (s*a).cpu().numpy()
    self.pw_similarities[other_network][network] = (torch.sum(s*a)/torch.sum(a)).item()

cca: 100%|████████████████████████████████████| 9/9 [00:00<00:00, 669.35it/s]


# Full function

In [6]:
def compute_correlations(self):
    # Set `self.nrepresentations_d`, "normalized representations". 
    # Call it this regardless of if it's actually "normalized"
    self.nrepresentations_d = {}
    if self.normalize_dimensions:
        for network in tqdm(self.representations_d, desc='mu, sigma'):
            t = self.representations_d[network]
            means = t.mean(0, keepdim=True)
            stdevs = t.std(0, keepdim=True)

            self.nrepresentations_d[network] = (t - means) / stdevs
    else:
        self.nrepresentations_d = self.representations_d

    # Set `whitening_transforms`, `pca_directions`
    # {network: whitening_tensor}
    whitening_transforms = {} 
    pca_directions = {} 
    for network in tqdm(self.nrepresentations_d, desc='pca'):
        X = self.nrepresentations_d[network]
        U, S, V = torch.svd(X)

        var_sums = torch.cumsum(S.pow(2), 0)
        wanted_size = torch.sum(var_sums.lt(var_sums[-1] * self.percent_variance)).item()

        print('For network', network, 'wanted size is', wanted_size)

        if self.save_cca_transforms:
            whitening_transform = torch.mm(V, torch.diag(1/S))
            whitening_transforms[network] = whitening_transform[:, :wanted_size]

        pca_directions[network] = U[:, :wanted_size]

    # Set 
    # `self.transforms`: {network: {other: svcca_transform}}
    # `self.corrs`: {network: {other: canonical_corrs}}
    # `self.pw_alignments`: {network: {other: unnormalized pw weights}}
    # `self.pw_corrs`: {network: {other: pw_alignments*corrs}}
    # `self.sv_similarities`: {network: {other: svcca_similarities}}
    # `self.pw_similarities`: {network: {other: pwcca_similarities}}
    self.transforms = {network: {} for network in self.nrepresentations_d}
    self.corrs = {network: {} for network in self.nrepresentations_d}
    self.pw_alignments = {network: {} for network in self.nrepresentations_d}
    self.pw_corrs = {network: {} for network in self.nrepresentations_d}
    self.sv_similarities = {network: {} for network in self.nrepresentations_d}
    self.pw_similarities = {network: {} for network in self.nrepresentations_d}
    for network, other_network in tqdm(p(self.nrepresentations_d,
                                         self.nrepresentations_d),
                                       desc='cca',
                                       total=len(self.nrepresentations_d)**2):

        if network == other_network:
            continue

        if other_network in self.transforms[network]: 
            continue

        X = pca_directions[network]
        Y = pca_directions[other_network]

        # Perform SVD for CCA.
        # u s vt = Xt Y
        # s = ut Xt Y v
        u, s, v = torch.svd(torch.mm(X.t(), Y))

        # `self.transforms`, `self.corrs`, `self.sv_similarities`
        if self.save_cca_transforms:
            self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u).cpu().numpy()
            self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v).cpu().numpy()

        self.corrs[network][other_network] = s.cpu().numpy()
        self.corrs[other_network][network] = s.cpu().numpy()

        self.sv_similarities[network][other_network] = s.mean().item()
        self.sv_similarities[other_network][network] = s.mean().item()

        # Compute `self.pw_alignments`, `self.pw_corrs`, `self.pw_similarities`. 
        # This is not symmetric

        # For X
        H = torch.mm(X, u)
        Z = self.representations_d[network]
        align = torch.abs(torch.mm(H.t(), Z))
        a = torch.sum(align, dim=1, keepdim=False)
        self.pw_alignments[network][other_network] = a.cpu().numpy()
        self.pw_corrs[network][other_network] = (s*a).cpu().numpy()
        self.pw_similarities[network][other_network] = (torch.sum(s*a)/torch.sum(a)).item()

        # For Y
        H = torch.mm(Y, v)
        Z = self.representations_d[other_network]
        align = torch.abs(torch.mm(H.t(), Z))
        a = torch.sum(align, dim=1, keepdim=False)
        self.pw_alignments[other_network][network] = a.cpu().numpy()
        self.pw_corrs[other_network][network] = (s*a).cpu().numpy()
        self.pw_similarities[other_network][network] = (torch.sum(s*a)/torch.sum(a)).item()

In [7]:
compute_correlations(self)

pca: 100%|█████████████████████████████████████| 3/3 [00:00<00:00, 87.65it/s]
cca: 100%|████████████████████████████████████| 9/9 [00:00<00:00, 539.38it/s]

For network foo wanted size is 97
For network bar wanted size is 78
For network baz wanted size is 68





In [8]:
self.pw_alignments

{'foo': {'bar': array([251.97348, 260.81625, 255.077  , 275.73474, 256.24982, 257.5581 ,
         259.6351 , 257.06528, 262.96936, 254.49544, 250.53667, 236.20905,
         252.74142, 262.22662, 249.7028 , 256.1499 , 256.25308, 251.71135,
         255.15184, 251.54121, 267.07938, 259.43402, 258.43118, 251.3825 ,
         242.11292, 266.68982, 245.26556, 260.9492 , 259.05383, 255.66913,
         254.6817 , 255.97926, 250.11707, 270.2134 , 243.4038 , 271.04242,
         260.8858 , 263.7016 , 259.79315, 253.11787, 268.34558, 275.30338,
         260.15234, 255.72946, 257.3257 , 259.2988 , 254.91328, 261.27917,
         254.07199, 258.81763, 261.5435 , 252.48909, 267.74344, 275.60687,
         264.9771 , 256.6436 , 272.45782, 259.81717, 263.23416, 259.08066,
         259.57907, 262.90054, 246.1847 , 248.96051, 247.92561, 259.30035,
         241.10927, 258.95767, 262.53665, 256.79718, 250.05742, 256.76492,
         243.58519, 269.09012, 260.75836, 255.06744, 259.09225, 258.03223],
        dt