In [1]:
import torch
import json
from tqdm import tqdm
from itertools import product as p

# Load fake activations

In [2]:
class A:
    pass
self = A()

In [3]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [4]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

# Compute correlations

In [1]:
def compute_correlations(self, op):
    """
    Set `self.correlations`, `self.clusters`, `self.neuron_sort`. 
    """

    # Set `means_d`, `stdevs_d`
    means_d = {}
    stdevs_d = {}
    for network in tqdm(self.representations_d, desc='mu, sigma'):
        t = self.representations_d[network]

        means_d[network] = t.mean(0, keepdim=True)
        stdevs_d[network] = (t - means_d[network]).pow(2).mean(0, keepdim=True).pow(0.5)

    # Set `self.correlations`
    # {network: {other: tensor}}
    self.correlations = {network: {} for network in self.representations_d}
    num_words = list(self.representations_d.values())[0].size()[0] # TO DO: make more elegant

    for network, other_network in tqdm(p(self.representations_d,
                                         self.representations_d),
                                       desc='correlate',
                                       total=len(self.representations_d)**2):
        if network == other_network:
            continue

        if other_network in self.correlations[network].keys(): # TO DO: optimize?
            continue

        t1 = self.representations_d[network] # "tensor"
        t2 = self.representations_d[other_network] 
        m1 = means_d[network] # "means"
        m2 = means_d[other_network]
        s1 = stdevs_d[network] # "stdevs"
        s2 = stdevs_d[other_network]

        covariance = (torch.mm(t1.t(), t2) / num_words # E[ab]
                      - torch.mm(m1.t(), m2)) # E[a]E[b]
        correlation = covariance / torch.mm(s1.t(), s2)

        correlation = correlation.cpu().numpy()
        self.correlations[network][other_network] = correlation
        self.correlations[other_network][network] = correlation.T

    # Set `self.clusters`
    # {network: {neuron: {other: other_neuron}}}
    self.clusters = {network: {} for network in self.representations_d} 
    for network in tqdm(self.representations_d, desc='self.clusters',
                        total=len(self.representations_d)):
        for neuron in range(self.num_neurons_d[network]): 
            self.clusters[network][neuron] = {
                other : max(range(self.num_neurons_d[other]),
                            key = lambda i: abs(self.correlations[network][other][neuron][i])) 
                 for other in self.correlations[network]
            }

    # Set `self.neuron_sort`
    # {network, sorted_list}
    self.neuron_sort = {} 
    # Sort neurons by worst (or best) best correlation with another neuron
    # in another network.
    for network in tqdm(self.representations_d, desc='annotation'):
        self.neuron_sort[network] = sorted(
            range(self.num_neurons_d[network]),
            key = lambda i : op(
                abs(self.correlations[network][other][i][self.clusters[network][i][other]])
                for other in self.clusters[network][i]),
            reverse=True
        )

In [6]:
compute_correlations(self, max)

mu, sigma: 100%|█| 3/3 [00:00<00:00, 571.15it/s]
correlate: 100%|█| 9/9 [00:00<00:00, 1742.95it/s]
self.clusters: 100%|█| 3/3 [00:00<00:00, 59.69it/s]
annotation: 100%|█| 3/3 [00:00<00:00, 2648.48it/s]


In [7]:
self.neuron_sort

{'foo': [7,
  14,
  16,
  28,
  3,
  69,
  65,
  89,
  25,
  1,
  31,
  18,
  77,
  57,
  23,
  86,
  76,
  46,
  85,
  41,
  48,
  58,
  87,
  43,
  42,
  33,
  36,
  94,
  30,
  67,
  26,
  55,
  78,
  2,
  34,
  63,
  4,
  88,
  47,
  10,
  5,
  8,
  83,
  49,
  91,
  38,
  73,
  22,
  53,
  71,
  20,
  84,
  29,
  61,
  66,
  40,
  51,
  17,
  98,
  35,
  52,
  64,
  81,
  74,
  79,
  44,
  72,
  97,
  95,
  21,
  9,
  82,
  99,
  13,
  75,
  32,
  45,
  80,
  19,
  15,
  59,
  12,
  60,
  37,
  27,
  93,
  50,
  68,
  56,
  39,
  24,
  11,
  0,
  96,
  54,
  70,
  92,
  6,
  90,
  62],
 'bar': [69,
  64,
  36,
  44,
  38,
  56,
  1,
  19,
  12,
  16,
  78,
  27,
  2,
  22,
  20,
  60,
  8,
  34,
  33,
  5,
  66,
  17,
  35,
  74,
  31,
  41,
  47,
  0,
  50,
  70,
  63,
  77,
  62,
  79,
  59,
  37,
  26,
  55,
  39,
  45,
  61,
  75,
  13,
  40,
  72,
  71,
  10,
  53,
  6,
  46,
  15,
  24,
  4,
  28,
  7,
  43,
  51,
  58,
  65,
  18,
  49,
  30,
  54,
  21,
  67,
  73,
  76,
 

# Build `write_correlations`

In [8]:
# param
output_file = "temp"

In [11]:
# loop variable
network = f1

In [12]:
self.neuron_notated_sort = {}

In [13]:
self.neuron_notated_sort[network] = [
        (
            neuron, 
            {
                '%s:%d' % (other, self.clusters[network][neuron][other],):
                float(self.correlations[network][other][neuron][self.clusters[network][neuron][other]])
                for other in self.clusters[network][neuron]
            }
        )
        for neuron in self.neuron_sort[network]
    ]

In [14]:
self.neuron_notated_sort

{'foo': [(7, {'bar:69': 0.1475636512041092, 'baz:25': -0.07963211089372635}),
  (14, {'bar:64': 0.11425137519836426, 'baz:26': -0.10484716296195984}),
  (16, {'bar:27': -0.056227996945381165, 'baz:6': -0.11379553377628326}),
  (28, {'bar:36': 0.11259262263774872, 'baz:4': 0.08901067078113556}),
  (3, {'bar:51': 0.08520984649658203, 'baz:41': -0.11201736330986023}),
  (69, {'bar:38': 0.10861208289861679, 'baz:4': -0.0792410597205162}),
  (65, {'bar:24': -0.07739735394716263, 'baz:63': 0.10698217153549194}),
  (89, {'bar:28': -0.08586898446083069, 'baz:13': 0.10644762217998505}),
  (25, {'bar:10': 0.0888514369726181, 'baz:68': -0.10611695796251297}),
  (1, {'bar:64': 0.08616607636213303, 'baz:16': 0.10403265804052353}),
  (31, {'bar:30': -0.08269640058279037, 'baz:26': -0.1037655919790268}),
  (18, {'bar:9': 0.07295471429824829, 'baz:22': -0.10354381054639816}),
  (77, {'bar:23': 0.08005933463573456, 'baz:47': 0.10285221040248871}),
  (57, {'bar:1': -0.10206933319568634, 'baz:60': -0.080

In [16]:
json.dumps(self.neuron_notated_sort, indent=4)

'{\n    "foo": [\n        [\n            7,\n            {\n                "bar:69": 0.1475636512041092,\n                "baz:25": -0.07963211089372635\n            }\n        ],\n        [\n            14,\n            {\n                "bar:64": 0.11425137519836426,\n                "baz:26": -0.10484716296195984\n            }\n        ],\n        [\n            16,\n            {\n                "bar:27": -0.056227996945381165,\n                "baz:6": -0.11379553377628326\n            }\n        ],\n        [\n            28,\n            {\n                "bar:36": 0.11259262263774872,\n                "baz:4": 0.08901067078113556\n            }\n        ],\n        [\n            3,\n            {\n                "bar:51": 0.08520984649658203,\n                "baz:41": -0.11201736330986023\n            }\n        ],\n        [\n            69,\n            {\n                "bar:38": 0.10861208289861679,\n                "baz:4": -0.0792410597205162\n            }\n    

In [18]:
json.dump(self.neuron_notated_sort, open(output_file, "w"), indent=4)

In [19]:
# full
self.neuron_notated_sort = {}
for network in tqdm(self.representations_d, desc='write'):
    self.neuron_notated_sort[network] = [
            (
                neuron, 
                {
                    '%s:%d' % (other, self.clusters[network][neuron][other],):
                    float(self.correlations[network][other][neuron][self.clusters[network][neuron][other]])
                    for other in self.clusters[network][neuron]
                }
            )
            for neuron in self.neuron_sort[network]
        ]

json.dump(self.neuron_notated_sort, open(output_file, "w"), indent=4)

write: 100%|█| 3/3 [00:00<00:00, 1238.96it/s]


# Final function

In [20]:
def write_correlations(self, output_file):
    """
    Create `self.neuron_notated_sort`, and write it to output_file. 
    """
    
    self.neuron_notated_sort = {}
    for network in tqdm(self.representations_d, desc='write'):
        self.neuron_notated_sort[network] = [
                (
                    neuron, 
                    {
                        '%s:%d' % (other, self.clusters[network][neuron][other],):
                        float(self.correlations[network][other][neuron][self.clusters[network][neuron][other]])
                        for other in self.clusters[network][neuron]
                    }
                )
                for neuron in self.neuron_sort[network]
            ]

    json.dump(self.neuron_notated_sort, open(output_file, "w"), indent=4)

In [21]:
write_correlations(self, "temp")

write: 100%|█| 3/3 [00:00<00:00, 628.23it/s]
