In [6]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

# Load fake activations

In [7]:
class A:
    pass
self = A()

In [8]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [9]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

In [14]:
d = torch.device('cuda')
torch.cuda.is_available()

False

In [10]:
self.device = torch.device('cpu')
self.op = min

In [11]:
torch.cuda.memory_allocated()

AssertionError: 
Found no NVIDIA driver on your system. Please check that you
have an NVIDIA GPU and installed a driver from
http://www.nvidia.com/Download/index.aspx

# 

In [None]:
def use_gpu()

### Set `means_d`, `stdevs_d` loop

In [75]:
# full
# Set `means_d`, `stdevs_d`
means_d = {}
stdevs_d = {}
for network in tqdm(self.representations_d, desc='mu, sigma'):
    t = self.representations_d[network]

    means_d[network] = t.mean(0, keepdim=True)
    stdevs_d[network] = (t - means_d[network]).pow(2).mean(0, keepdim=True).pow(0.5)

mu, sigma: 100%|██████████████████████████████| 3/3 [00:00<00:00, 979.67it/s]


### Set `self.corrs`, `self.similarities`, `self.pairs`

In [76]:
self.corrs = {network: {} for network in
                     self.representations_d}
self.similarities = {network: {} for network in
                     self.representations_d}
self.pairs = {network: {} for network in
                     self.representations_d}
num_words = next(iter(self.representations_d.values())).size()[0]

In [77]:
network = f1
other_network = f2

In [78]:
device = self.device

t1 = self.representations_d[network].to(device) # "tensor"
t2 = self.representations_d[other_network].to(device)
m1 = means_d[network].to(device) # "means"
m2 = means_d[other_network].to(device)
s1 = stdevs_d[network].to(device) # "stdevs"
s2 = stdevs_d[other_network].to(device)

covariance = (torch.mm(t1.t(), t2) / num_words # E[ab]
              - torch.mm(m1.t(), m2)) # E[a]E[b]
correlation = covariance / torch.mm(s1.t(), s2)

In [79]:
correlation = correlation.cpu().numpy()
correlation = np.abs(correlation)

In [80]:
self.corrs[network][other_network] = correlation.max(axis=1)
self.corrs[other_network][network] = correlation.max(axis=0)

In [81]:
self.similarities[network][other_network] = self.corrs[network][other_network].mean()
self.similarities[other_network][network] = self.corrs[other_network][network].mean()

In [82]:
self.pairs[network][other_network] = correlation.argmax(axis=1)
self.pairs[other_network][network] = correlation.argmax(axis=0)

In [86]:
# Full
# Set `self.corrs` : {network: {other: [corr]}}
# Set `self.pairs` : {network: {other: [pair]}}
# pair is index of neuron in other network
self.corrs = {network: {} for network in
                     self.representations_d}
self.pairs = {network: {} for network in
                     self.representations_d}
num_words = next(iter(self.representations_d.values())).size()[0]
for network, other_network in tqdm(p(self.representations_d,
                                     self.representations_d),
                                     desc='correlate',
                                     total=len(self.representations_d)**2):
    if network == other_network:
        continue

    if other_network in self.corrs[network]: 
        continue

    device = self.device

    t1 = self.representations_d[network].to(device) # "tensor"
    t2 = self.representations_d[other_network].to(device)
    m1 = means_d[network].to(device) # "means"
    m2 = means_d[other_network].to(device)
    s1 = stdevs_d[network].to(device) # "stdevs"
    s2 = stdevs_d[other_network].to(device)

    covariance = (torch.mm(t1.t(), t2) / num_words # E[ab]
                  - torch.mm(m1.t(), m2)) # E[a]E[b]
    correlation = covariance / torch.mm(s1.t(), s2)
    correlation = correlation.cpu().numpy()
    correlation = np.abs(correlation)

    self.corrs[network][other_network] = correlation.max(axis=1)
    self.corrs[other_network][network] = correlation.max(axis=0)
    
    self.similarities[network][other_network] = self.corrs[network][other_network].mean()
    self.similarities[other_network][network] = self.corrs[other_network][network].mean()

    self.pairs[network][other_network] = correlation.argmax(axis=1)
    self.pairs[other_network][network] = correlation.argmax(axis=0)

correlate: 100%|█████████████████████████████| 9/9 [00:00<00:00, 3664.57it/s]


### Set `self.neuron_sort`, `self.neuron_notated_sort`

In [84]:
# full
# Set `self.neuron_sort` : {network, sorted_list}
# Set `self.neuron_notated_sort` : {network: [(neuron, {other: (corr, pair)})]}
self.neuron_sort = {} 
self.neuron_notated_sort = {}
for network in tqdm(self.representations_d, desc='annotation'):
    self.neuron_sort[network] = sorted(
        range(self.num_neurons_d[network]), 
        key=lambda i: self.op(
            self.corrs[network][other][i] for other in self.corrs[network]
        ), 
        reverse=True,
    )
    self.neuron_notated_sort[network] = [
        (
            neuron,
            {
                other : (
                    self.corrs[network][other][neuron], 
                    self.pairs[network][other][neuron],
                ) 
                for other in self.corrs[network]
            }
        ) 
        for neuron in self.neuron_sort[network]
    ]

annotation: 100%|████████████████████████████| 3/3 [00:00<00:00, 3538.50it/s]


# Full function

In [6]:
def compute_correlations(self):
    # Set `means_d`, `stdevs_d`
    means_d = {}
    stdevs_d = {}
    for network in tqdm(self.representations_d, desc='mu, sigma'):
        t = self.representations_d[network]

        means_d[network] = t.mean(0, keepdim=True)
        stdevs_d[network] = (t - means_d[network]).pow(2).mean(0, keepdim=True).pow(0.5)

    # Set `self.corrs` : {network: {other: [corr]}}
    # Set `self.pairs` : {network: {other: [pair]}}
    # pair is index of neuron in other network
    # Set `self.similarities` : {network: {other: sim}}
    self.corrs = {network: {} for network in
                         self.representations_d}
    self.pairs = {network: {} for network in
                         self.representations_d}
    self.similarities = {network: {} for network in
                     self.representations_d}
    num_words = next(iter(self.representations_d.values())).size()[0]
    for network, other_network in tqdm(p(self.representations_d,
                                         self.representations_d),
                                         desc='correlate',
                                         total=len(self.representations_d)**2):
        if network == other_network:
            continue

        if other_network in self.corrs[network]: 
            continue

        device = self.device

        t1 = self.representations_d[network].to(device) # "tensor"
        t2 = self.representations_d[other_network].to(device)
        m1 = means_d[network].to(device) # "means"
        m2 = means_d[other_network].to(device)
        s1 = stdevs_d[network].to(device) # "stdevs"
        s2 = stdevs_d[other_network].to(device)

        covariance = (torch.mm(t1.t(), t2) / num_words # E[ab]
                      - torch.mm(m1.t(), m2)) # E[a]E[b]
        correlation = covariance / torch.mm(s1.t(), s2)
        correlation = correlation.cpu().numpy()
        correlation = np.abs(correlation)

        self.corrs[network][other_network] = correlation.max(axis=1)
        self.corrs[other_network][network] = correlation.max(axis=0)
        
        self.similarities[network][other_network] = self.corrs[network][other_network].mean()
        self.similarities[other_network][network] = self.corrs[other_network][network].mean()
        
        self.pairs[network][other_network] = correlation.argmax(axis=1)
        self.pairs[other_network][network] = correlation.argmax(axis=0)

    # Set `self.neuron_sort` : {network, sorted_list}
    # Set `self.neuron_notated_sort` : {network: [(neuron, {other: (corr, pair)})]}
    self.neuron_sort = {} 
    self.neuron_notated_sort = {}
    for network in tqdm(self.representations_d, desc='annotation'):
        self.neuron_sort[network] = sorted(
            range(self.num_neurons_d[network]), 
            key=lambda i: self.op(
                self.corrs[network][other][i] for other in self.corrs[network]
            ), 
            reverse=True,
        )
        self.neuron_notated_sort[network] = [
            (
                neuron,
                {
                    other : (
                        self.corrs[network][other][neuron], 
                        self.pairs[network][other][neuron],
                    ) 
                    for other in self.corrs[network]
                }
            ) 
            for neuron in self.neuron_sort[network]
        ]

In [7]:
compute_correlations(self)

mu, sigma: 100%|██████████████████████████████| 3/3 [00:00<00:00, 133.37it/s]
correlate: 100%|█████████████████████████████| 9/9 [00:00<00:00, 3259.54it/s]
annotation: 100%|████████████████████████████| 3/3 [00:00<00:00, 2172.46it/s]


In [8]:
self.similarities

{'foo': {'bar': 0.08205225, 'baz': 0.08246259},
 'bar': {'foo': 0.085321024, 'baz': 0.08407295},
 'baz': {'foo': 0.08618133, 'bar': 0.08577694}}