In [1]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

# Load fake activations

In [2]:
class A:
    pass
self = A()

In [3]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [4]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

In [5]:
self.percent_variance = 0.99
self.normalize_dimensions = False
self.save_cca_transforms = False

# Function

In [6]:
# Normalize
# Set `self.nrepresentations_d`, "normalized representations". Call it this regardless of if it's actually normalized
self.nrepresentations_d = {}
if self.normalize_dimensions:
    for network in tqdm(self.representations_d, desc='mu, sigma'):
        t = self.representations_d[network]
        means = t.mean(0, keepdim=True)
        stdevs = t.std(0, keepdim=True)

        self.nrepresentations_d[network] = (t - means) / stdevs
else:
    self.nrepresentations_d = self.representations_d

In [7]:
# Set `whitening_transforms`, `pca_directions`
# {network: whitening_tensor}
whitening_transforms = {} 
pca_directions = {} 
for network in tqdm(self.nrepresentations_d, desc='pca'):
    X = self.nrepresentations_d[network]
    U, S, V = torch.svd(X)

    var_sums = torch.cumsum(S.pow(2), 0)
    wanted_size = torch.sum(var_sums.lt(var_sums[-1] * self.percent_variance)).item()

    print('For network', network, 'wanted size is', wanted_size)

    if self.save_cca_transforms:
        whitening_transform = torch.mm(V, torch.diag(1/S))
        whitening_transforms[network] = whitening_transform[:, :wanted_size]
    
    pca_directions[network] = U[:, :wanted_size]

pca: 100%|█████████████████████████████████████| 3/3 [00:00<00:00, 94.21it/s]

For network foo wanted size is 98
For network bar wanted size is 78
For network baz wanted size is 68





In [8]:
# Set 
# `self.transforms`: {network: {other: svcca_transform}}
# `self.corrs`: {network: {other: canonical_corrs}}
# `self.pw_alignments`: {network: {other: unnormalized pw weights}}
# `self.pw_corrs`: {network: {other: pw_alignments*corrs}}
# `self.sv_similarities`: {network: {other: svcca_similarities}}
# `self.pw_similarities`: {network: {other: pwcca_similarities}}
self.transforms = {network: {} for network in self.nrepresentations_d}
self.corrs = {network: {} for network in self.nrepresentations_d}
self.pw_alignments = {network: {} for network in self.nrepresentations_d}
self.pw_corrs = {network: {} for network in self.nrepresentations_d}
self.sv_similarities = {network: {} for network in self.nrepresentations_d}
self.pw_similarities = {network: {} for network in self.nrepresentations_d}

# For network in ... loop

In [9]:
# arbitrary loop params
network = f1
other_network = f2

In [10]:
X = pca_directions[network]
Y = pca_directions[other_network]

In [11]:
# Perform SVD for CCA.
# u s vt = Xt Y
# s = ut Xt Y v
u, s, v = torch.svd(torch.mm(X.t(), Y))

# `self.transforms`, `self.corrs`, `self.sv_similarities`
if self.save_cca_transforms:
    self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u)
    self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v)

self.corrs[network][other_network] = s
self.corrs[other_network][network] = s

self.sv_similarities[network][other_network] = s.mean().item()
self.sv_similarities[other_network][network] = s.mean().item()

In [12]:
# Compute `self.pw_alignments`, `self.pw_corrs`, `self.pw_similarities`. 
# This is not symmetric

# For X
H = torch.mm(X, u)
Z = self.representations_d[network]
align = torch.abs(torch.mm(H.t(), Z))
a = torch.sum(align, dim=1, keepdim=False)
self.pw_alignments[network][other_network] = a
self.pw_corrs[network][other_network] = s*a
self.pw_similarities[network][other_network] = (torch.sum(s*a)/torch.sum(a)).item()

# For Y
H = torch.mm(Y, v)
Z = self.representations_d[other_network]
align = torch.abs(torch.mm(H.t(), Z))
a = torch.sum(align, dim=1, keepdim=False)
self.pw_alignments[other_network][network] = a
self.pw_corrs[other_network][network] = s*a
self.pw_similarities[other_network][network] = (torch.sum(s*a)/torch.sum(a)).item()

In [15]:
# full loop
for network, other_network in tqdm(p(self.nrepresentations_d,
                                     self.nrepresentations_d),
                                   desc='cca',
                                   total=len(self.nrepresentations_d)**2):

    if network == other_network:
        continue

    if other_network in self.transforms[network]: 
        continue
    
    X = pca_directions[network]
    Y = pca_directions[other_network]

    # Perform SVD for CCA.
    # u s vt = Xt Y
    # s = ut Xt Y v
    u, s, v = torch.svd(torch.mm(X.t(), Y))

    # `self.transforms`, `self.corrs`, `self.sv_similarities`
    if self.save_cca_transforms:
        self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u)
        self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v)

    self.corrs[network][other_network] = s
    self.corrs[other_network][network] = s

    self.sv_similarities[network][other_network] = s.mean().item()
    self.sv_similarities[other_network][network] = s.mean().item()

    # Compute `self.pw_alignments`, `self.pw_corrs`, `self.pw_similarities`. 
    # This is not symmetric

    # For X
    H = torch.mm(X, u)
    Z = self.representations_d[network]
    align = torch.abs(torch.mm(H.t(), Z))
    a = torch.sum(align, dim=1, keepdim=False)
    self.pw_alignments[network][other_network] = a
    self.pw_corrs[network][other_network] = s*a
    self.pw_similarities[network][other_network] = (torch.sum(s*a)/torch.sum(a)).item()

    # For Y
    H = torch.mm(Y, v)
    Z = self.representations_d[other_network]
    align = torch.abs(torch.mm(H.t(), Z))
    a = torch.sum(align, dim=1, keepdim=False)
    self.pw_alignments[other_network][network] = a
    self.pw_corrs[other_network][network] = s*a
    self.pw_similarities[other_network][network] = (torch.sum(s*a)/torch.sum(a)).item()

cca: 100%|████████████████████████████████████| 9/9 [00:00<00:00, 550.41it/s]


In [16]:
self.pw_similarities

{'foo': {'bar': 0.27806973457336426, 'baz': 0.28410664200782776},
 'bar': {'foo': 0.2766464352607727, 'baz': 0.24689751863479614},
 'baz': {'foo': 0.2839832901954651, 'bar': 0.24695754051208496}}

# Full function

In [6]:
def compute_correlations(self):
    # Normalize
    # Set `self.nrepresentations_d`, "normalized representations". Call it this regardless of if it's actually "normalized"
    self.nrepresentations_d = {}
    if self.normalize_dimensions:
        for network in tqdm(self.representations_d, desc='mu, sigma'):
            t = self.representations_d[network]
            means = t.mean(0, keepdim=True)
            stdevs = t.std(0, keepdim=True)

            self.nrepresentations_d[network] = (t - means) / stdevs
    else:
        self.nrepresentations_d = self.representations_d

    # Set `whitening_transforms`, `pca_directions`
    # {network: whitening_tensor}
    whitening_transforms = {} 
    pca_directions = {} 
    for network in tqdm(self.nrepresentations_d, desc='pca'):
        X = self.nrepresentations_d[network]
        U, S, V = torch.svd(X)

        var_sums = torch.cumsum(S.pow(2), 0)
        wanted_size = torch.sum(var_sums.lt(var_sums[-1] * self.percent_variance)).item()

        print('For network', network, 'wanted size is', wanted_size)

        if self.save_cca_transforms:
            whitening_transform = torch.mm(V, torch.diag(1/S))
            whitening_transforms[network] = whitening_transform[:, :wanted_size]

        pca_directions[network] = U[:, :wanted_size]

    # Set 
    # `self.transforms`: {network: {other: svcca_transform}}
    # `self.corrs`: {network: {other: canonical_corrs}}
    # `self.pw_alignments`: {network: {other: unnormalized pw weights}}
    # `self.pw_corrs`: {network: {other: pw_alignments*corrs}}
    # `self.sv_similarities`: {network: {other: svcca_similarities}}
    # `self.pw_similarities`: {network: {other: pwcca_similarities}}
    self.transforms = {network: {} for network in self.nrepresentations_d}
    self.corrs = {network: {} for network in self.nrepresentations_d}
    self.pw_alignments = {network: {} for network in self.nrepresentations_d}
    self.pw_corrs = {network: {} for network in self.nrepresentations_d}
    self.sv_similarities = {network: {} for network in self.nrepresentations_d}
    self.pw_similarities = {network: {} for network in self.nrepresentations_d}
    for network, other_network in tqdm(p(self.nrepresentations_d,
                                         self.nrepresentations_d),
                                       desc='cca',
                                       total=len(self.nrepresentations_d)**2):

        if network == other_network:
            continue

        if other_network in self.transforms[network]: 
            continue

        X = pca_directions[network]
        Y = pca_directions[other_network]

        # Perform SVD for CCA.
        # u s vt = Xt Y
        # s = ut Xt Y v
        u, s, v = torch.svd(torch.mm(X.t(), Y))

        # `self.transforms`, `self.corrs`, `self.sv_similarities`
        if self.save_cca_transforms:
            self.transforms[network][other_network] = torch.mm(whitening_transforms[network], u)
            self.transforms[other_network][network] = torch.mm(whitening_transforms[other_network], v)

        self.corrs[network][other_network] = s
        self.corrs[other_network][network] = s

        self.sv_similarities[network][other_network] = s.mean().item()
        self.sv_similarities[other_network][network] = s.mean().item()

        # Compute `self.pw_alignments`, `self.pw_corrs`, `self.pw_similarities`. 
        # This is not symmetric

        # For X
        H = torch.mm(X, u)
        Z = self.representations_d[network]
        align = torch.abs(torch.mm(H.t(), Z))
        a = torch.sum(align, dim=1, keepdim=False)
        self.pw_alignments[network][other_network] = a
        self.pw_corrs[network][other_network] = s*a
        self.pw_similarities[network][other_network] = (torch.sum(s*a)/torch.sum(a)).item()

        # For Y
        H = torch.mm(Y, v)
        Z = self.representations_d[other_network]
        align = torch.abs(torch.mm(H.t(), Z))
        a = torch.sum(align, dim=1, keepdim=False)
        self.pw_alignments[other_network][network] = a
        self.pw_corrs[other_network][network] = s*a
        self.pw_similarities[other_network][network] = (torch.sum(s*a)/torch.sum(a)).item()

In [7]:
compute_correlations(self)

pca: 100%|█████████████████████████████████████| 3/3 [00:00<00:00, 90.16it/s]
cca: 100%|████████████████████████████████████| 9/9 [00:00<00:00, 521.14it/s]

For network foo wanted size is 97
For network bar wanted size is 78
For network baz wanted size is 68





In [8]:
self.pw_similarities

{'foo': {'bar': 0.27744582295417786, 'baz': 0.28650641441345215},
 'bar': {'foo': 0.27734071016311646, 'baz': 0.24014469981193542},
 'baz': {'foo': 0.2855011522769928, 'bar': 0.24075396358966827}}