In [8]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

# Load fake activations

In [2]:
class A:
    pass
self = A()

In [3]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [4]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

# Compute correlations

In [11]:
def compute_correlations(self):
    """
    Set `self.neuron_sort`. 
    """

    # Set `means_d`, `stdevs_d`, normalize to mean 0 std 1
    # Not exactly sure why we compute `means_d`
    means_d = {}
    stdevs_d = {}

    for network in tqdm(self.representations_d, desc='mu, sigma'):
        t = self.representations_d[network]
        means = t.mean(0, keepdim=True)
        stdevs = t.std(0, keepdim=True)

        means_d[network] = means
        stdevs_d[network] = stdevs
        self.representations_d[network] = (t - means) / stdevs

    # Set `self.errors`
    # {network: {other: error_tensor}}
    self.errors = {network: {} for network in self.representations_d}
    for network, other_network in tqdm(p(self.representations_d,
                                         self.representations_d), desc='correlate',
                                       total=len(self.representations_d)**2):
        if network == other_network:
            continue

        # Try to predict this network given the other one
        X = self.representations_d[other_network].cpu().numpy()
        Y = self.representations_d[network].cpu().numpy()

        # solve with ordinary least squares 
        error = np.linalg.lstsq(X, Y, rcond=None)[1] # TO DO: don't use numpy, or at least use CUDA
        # Possibilities are use torch (torch.svd or smth), or use another library (cupy)

        # Note: what was here previously is very numerically
        # unstable. Linear regression should be performed using either QR or
        # the SVD (which are numerically stable computations). 
        if len(error) == 0:
            raise ValueError
        error = torch.from_numpy(error)

        self.errors[network][other_network] = error

    # Set `self.neuron_sort`
    # {network: sorted_list}
    self.neuron_sort = {}
    # Sort neurons by worst correlation (highest regression error) with another network
    for network in tqdm(self.representations_d, desc='annotation'):
        self.neuron_sort[network] = sorted(
            range(self.num_neurons_d[network]),
            key = lambda i: max(
                self.errors[network][other][i] 
                for other in self.errors[network]
            )
        )

In [12]:
compute_correlations(self)

mu, sigma: 100%|█| 3/3 [00:00<00:00, 451.96it/s]
correlate: 100%|█| 9/9 [00:00<00:00, 256.85it/s]
annotation: 100%|█| 3/3 [00:00<00:00, 240.32it/s]


In [21]:
self.neuron_sort

{'foo': [91,
  28,
  12,
  88,
  49,
  17,
  66,
  60,
  43,
  94,
  40,
  8,
  9,
  71,
  10,
  7,
  72,
  95,
  31,
  37,
  47,
  86,
  24,
  6,
  32,
  20,
  67,
  78,
  15,
  50,
  5,
  96,
  93,
  74,
  1,
  23,
  48,
  76,
  75,
  53,
  59,
  45,
  26,
  30,
  98,
  52,
  36,
  55,
  63,
  83,
  46,
  0,
  29,
  80,
  61,
  58,
  62,
  2,
  77,
  97,
  89,
  85,
  57,
  81,
  22,
  4,
  84,
  79,
  11,
  54,
  41,
  51,
  21,
  35,
  18,
  90,
  92,
  38,
  16,
  42,
  87,
  82,
  56,
  70,
  64,
  13,
  25,
  73,
  99,
  44,
  69,
  65,
  39,
  19,
  14,
  3,
  34,
  33,
  68,
  27],
 'bar': [16,
  15,
  19,
  24,
  55,
  20,
  2,
  7,
  78,
  43,
  40,
  65,
  26,
  67,
  68,
  29,
  70,
  11,
  61,
  1,
  27,
  17,
  21,
  63,
  13,
  3,
  48,
  74,
  50,
  35,
  25,
  34,
  47,
  79,
  57,
  23,
  38,
  64,
  5,
  49,
  66,
  28,
  54,
  42,
  4,
  45,
  10,
  36,
  9,
  53,
  0,
  71,
  8,
  58,
  44,
  33,
  51,
  41,
  73,
  32,
  22,
  62,
  37,
  75,
  76,
  31,
  12,
  

# Build `write_correlations`

In [14]:
# param
output_file = "temp"

In [16]:
# loop variable
network = f1

In [17]:
self.neuron_notated_sort = {}

In [24]:
self.neuron_notated_sort[network] = [
    (
        neuron,
        {
            other: float(self.errors[network][other][neuron])
            for other in self.errors[network]
        }
    )
    for neuron in self.neuron_sort[network]
]

In [25]:
self.neuron_notated_sort

{'foo': [(91, {'bar': 900.4144897460938, 'baz': 906.55517578125}),
  (28, {'bar': 913.35693359375, 'baz': 909.3932495117188}),
  (12, {'bar': 909.6669311523438, 'baz': 915.0709228515625}),
  (88, {'bar': 911.9093627929688, 'baz': 915.1096801757812}),
  (49, {'bar': 897.9951171875, 'baz': 917.8684692382812}),
  (17, {'bar': 910.4146118164062, 'baz': 918.1256103515625}),
  (66, {'bar': 910.5084228515625, 'baz': 919.3346557617188}),
  (60, {'bar': 918.93017578125, 'baz': 919.8286743164062}),
  (43, {'bar': 889.5558471679688, 'baz': 920.1090087890625}),
  (94, {'bar': 916.6554565429688, 'baz': 920.4462890625}),
  (40, {'bar': 921.0355834960938, 'baz': 921.3095703125}),
  (8, {'bar': 915.6966552734375, 'baz': 921.4940185546875}),
  (9, {'bar': 921.8228149414062, 'baz': 913.03271484375}),
  (71, {'bar': 905.7970581054688, 'baz': 922.1661987304688}),
  (10, {'bar': 922.1859741210938, 'baz': 914.6178588867188}),
  (7, {'bar': 922.2753295898438, 'baz': 915.5311279296875}),
  (72, {'bar': 913.70

In [27]:
json.dumps(self.neuron_notated_sort, indent=4)

'{\n    "foo": [\n        [\n            91,\n            {\n                "bar": 900.4144897460938,\n                "baz": 906.55517578125\n            }\n        ],\n        [\n            28,\n            {\n                "bar": 913.35693359375,\n                "baz": 909.3932495117188\n            }\n        ],\n        [\n            12,\n            {\n                "bar": 909.6669311523438,\n                "baz": 915.0709228515625\n            }\n        ],\n        [\n            88,\n            {\n                "bar": 911.9093627929688,\n                "baz": 915.1096801757812\n            }\n        ],\n        [\n            49,\n            {\n                "bar": 897.9951171875,\n                "baz": 917.8684692382812\n            }\n        ],\n        [\n            17,\n            {\n                "bar": 910.4146118164062,\n                "baz": 918.1256103515625\n            }\n        ],\n        [\n            66,\n            {\n                

In [28]:
# full
self.neuron_notated_sort = {}
for network in tqdm(self.representations_d, desc='write'):
    self.neuron_notated_sort[network] = [
        (
            neuron,
            {
                other: float(self.errors[network][other][neuron])
                for other in self.errors[network]
            }
        )
        for neuron in self.neuron_sort[network]
    ]

json.dump(self.neuron_notated_sort, open(output_file, 'w'), indent=4)

write: 100%|█| 3/3 [00:00<00:00, 231.26it/s]


# Final Function

In [29]:
def write_correlations(self, output_file):
    self.neuron_notated_sort = {}
    for network in tqdm(self.representations_d, desc='write'):
        self.neuron_notated_sort[network] = [
            (
                neuron,
                {
                    other: float(self.errors[network][other][neuron])
                    for other in self.errors[network]
                }
            )
            for neuron in self.neuron_sort[network]
        ]

    json.dump(self.neuron_notated_sort, open(output_file, 'w'), indent=4)

In [30]:
write_correlations(self, output_file)

write: 100%|█| 3/3 [00:00<00:00, 243.72it/s]
