In [1]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

# Load fake activations

In [2]:
class A:
    pass
self = A()

In [3]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [4]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

### Normalize

In [5]:
# Set `means_d`, `stdevs_d`, normalize
means_d = {}
stdevs_d = {}

for network in tqdm(self.representations_d, desc='mu, sigma'):
    t = self.representations_d[network]
    means = t.mean(0, keepdim=True)
    stdevs = (t - means).pow(2).mean(0, keepdim=True).pow(0.5)

    means_d[network] = means
    stdevs_d[network] = stdevs
    self.representations_d[network] = (t - means) / stdevs

mu, sigma: 100%|█| 3/3 [00:00<00:00, 503.20it/s]


### Set `self.errors` loop

In [6]:
self.errors = {network: {} for network in self.representations_d}

In [7]:
# loop variable
network = f1
other_network = f2

In [8]:
# if network == other_network:
#     continue

In [9]:
X = self.representations_d[other_network].cpu().numpy()
Y = self.representations_d[network].cpu().numpy()

In [10]:
# solve with ordinary least squares 
error = np.linalg.lstsq(X, Y, rcond=None)[1]
error = torch.from_numpy(error)

In [11]:
self.errors[network][other_network] = error

In [12]:
# Set `self.errors`
# {network: {other: error_tensor}}
for network, other_network in tqdm(p(self.representations_d, self.representations_d), desc='correlate', total=len(self.representations_d)**2):
    if network == other_network:
        continue

    # Try to predict this network given the other one
    X = self.representations_d[other_network].cpu().numpy()
    Y = self.representations_d[network].cpu().numpy()

    # solve with ordinary least squares 
    error = np.linalg.lstsq(X, Y, rcond=None)[1] # TO DO: don't use numpy, or at least use CUDA
    # Possibilities are use torch (torch.svd or smth), or use another library (cupy)
    error = torch.from_numpy(error)

    self.errors[network][other_network] = error

correlate: 100%|█| 9/9 [00:00<00:00, 267.99it/s]

(100,)
(100,)
(80,)
(80,)
(70,)
(70,)





In [13]:
self.errors

{'foo': {'bar': tensor([909.0477, 931.8926, 925.4075, 911.4684, 898.3582, 883.4199, 910.2272,
          917.4545, 908.2744, 916.7164, 907.1227, 917.1426, 913.2227, 925.6381,
          911.5452, 933.3781, 920.3215, 921.1205, 909.6570, 910.8798, 932.5266,
          930.8191, 925.1227, 928.0283, 933.5466, 926.8748, 917.5969, 918.5745,
          918.3296, 903.1541, 911.8093, 924.6309, 927.8109, 933.6752, 898.6085,
          933.2340, 904.3768, 914.6415, 922.4907, 924.1975, 909.6744, 946.7028,
          892.7957, 910.7985, 904.0025, 931.7461, 905.6818, 915.7996, 934.2261,
          916.9269, 916.3373, 922.0054, 877.2310, 902.3813, 930.7200, 878.0267,
          934.4856, 899.2604, 875.0659, 915.6265, 926.8925, 915.1988, 922.2410,
          915.5558, 938.5184, 938.1669, 903.6459, 893.1216, 930.3840, 904.7031,
          907.9333, 924.9750, 938.7953, 930.3270, 939.5958, 914.6122, 923.0024,
          928.2825, 908.7150, 919.4897, 936.3714, 941.4680, 916.0426, 911.4482,
          921.9346, 918.79

### Set `self.neuron_sort`

In [14]:
self.neuron_sort = {}

In [15]:
# loop variable
network = f1

In [16]:
sorted(
    range(self.num_neurons_d[network]),
    key = lambda i: max(
        self.errors[network][other][i] 
        for other in self.errors[network]
    )
)

[83,
 53,
 95,
 98,
 82,
 8,
 26,
 27,
 49,
 30,
 51,
 66,
 96,
 42,
 61,
 4,
 71,
 22,
 3,
 59,
 6,
 13,
 67,
 40,
 39,
 25,
 14,
 84,
 32,
 23,
 50,
 31,
 52,
 7,
 34,
 60,
 73,
 93,
 91,
 54,
 19,
 75,
 63,
 87,
 45,
 1,
 20,
 37,
 68,
 5,
 79,
 15,
 33,
 56,
 28,
 10,
 69,
 62,
 80,
 17,
 86,
 18,
 58,
 29,
 47,
 88,
 65,
 90,
 64,
 72,
 57,
 21,
 74,
 99,
 35,
 36,
 0,
 44,
 89,
 77,
 12,
 48,
 85,
 55,
 97,
 24,
 9,
 2,
 41,
 16,
 11,
 94,
 78,
 43,
 46,
 92,
 38,
 81,
 76,
 70]

In [17]:
# Set `self.neuron_sort`
# {network: sorted_list}
self.neuron_sort = {}
# Sort neurons by worst correlation (highest regression error) with another network
for network in tqdm(self.representations_d, desc='annotation'):
    self.neuron_sort[network] = sorted(
        range(self.num_neurons_d[network]),
        key = lambda i: max(
            self.errors[network][other][i] 
            for other in self.errors[network]
        )
    )

annotation: 100%|█| 3/3 [00:00<00:00, 227.92it/s]


In [19]:
self.neuron_sort

{'foo': [83,
  53,
  95,
  98,
  82,
  8,
  26,
  27,
  49,
  30,
  51,
  66,
  96,
  42,
  61,
  4,
  71,
  22,
  3,
  59,
  6,
  13,
  67,
  40,
  39,
  25,
  14,
  84,
  32,
  23,
  50,
  31,
  52,
  7,
  34,
  60,
  73,
  93,
  91,
  54,
  19,
  75,
  63,
  87,
  45,
  1,
  20,
  37,
  68,
  5,
  79,
  15,
  33,
  56,
  28,
  10,
  69,
  62,
  80,
  17,
  86,
  18,
  58,
  29,
  47,
  88,
  65,
  90,
  64,
  72,
  57,
  21,
  74,
  99,
  35,
  36,
  0,
  44,
  89,
  77,
  12,
  48,
  85,
  55,
  97,
  24,
  9,
  2,
  41,
  16,
  11,
  94,
  78,
  43,
  46,
  92,
  38,
  81,
  76,
  70],
 'bar': [59,
  8,
  9,
  31,
  47,
  25,
  45,
  27,
  23,
  74,
  16,
  43,
  0,
  49,
  55,
  22,
  44,
  69,
  46,
  15,
  1,
  10,
  53,
  60,
  37,
  62,
  50,
  66,
  39,
  65,
  71,
  52,
  78,
  35,
  56,
  28,
  73,
  67,
  19,
  26,
  13,
  17,
  57,
  61,
  11,
  4,
  40,
  5,
  48,
  29,
  64,
  72,
  20,
  21,
  38,
  58,
  3,
  6,
  32,
  42,
  30,
  34,
  68,
  7,
  76,
  79,
  51,
  

In [24]:
# Nice, seems to be correct
np.maximum(self.errors[f1][f2][self.neuron_sort[f1]],
           self.errors[f1][f3][self.neuron_sort[f1]])

tensor([911.4482, 912.8023, 915.1296, 915.5312, 916.0426, 917.3156, 918.0559,
        918.5745, 920.3403, 920.4220, 922.0054, 922.8049, 922.9943, 923.2622,
        923.3635, 924.7756, 924.9750, 925.1227, 925.1437, 925.2809, 925.5129,
        925.6381, 926.0693, 926.4814, 926.6767, 926.8748, 926.9313, 927.1315,
        927.8109, 928.0283, 928.4019, 928.7456, 929.2178, 929.3124, 929.4144,
        929.4376, 930.3270, 930.4148, 930.6184, 930.7200, 931.0573, 931.0923,
        931.1027, 931.5375, 931.7461, 931.8926, 932.5266, 932.5573, 932.5692,
        932.5891, 933.1299, 933.3781, 933.6752, 934.4856, 934.5899, 935.4150,
        935.8964, 936.0991, 936.3714, 936.5215, 937.1719, 937.1955, 937.2303,
        937.4166, 937.5488, 937.6421, 938.1669, 938.1930, 938.5184, 938.7953,
        938.8829, 939.0637, 939.5958, 939.9243, 940.0334, 940.4062, 940.9043,
        942.8909, 942.9031, 943.3944, 943.7775, 944.0373, 944.0889, 944.2590,
        944.6002, 944.7104, 944.7645, 945.0562, 946.7028, 947.21

# Create final function

In [5]:
def compute_correlations(self):
    # Set `means_d`, `stdevs_d`, normalize to mean 0 std 1
    means_d = {}
    stdevs_d = {}

    for network in tqdm(self.representations_d, desc='mu, sigma'):
        t = self.representations_d[network]
        means = t.mean(0, keepdim=True)
        stdevs = (t - means).pow(2).mean(0, keepdim=True).pow(0.5)

        means_d[network] = means
        stdevs_d[network] = stdevs
        self.representations_d[network] = (t - means) / stdevs

    # Set `self.errors`
    # {network: {other: error_tensor}}
    for network, other_network in tqdm(p(self.representations_d, self.representations_d), desc='correlate', total=len(self.representations_d)**2):
        if network == other_network:
            continue

        # Try to predict this network given the other one
        X = self.representations_d[other_network].cpu().numpy()
        Y = self.representations_d[network].cpu().numpy()

        # solve with ordinary least squares 
        error = np.linalg.lstsq(X, Y, rcond=None)[1] # TO DO: don't use numpy, or at least use CUDA
        # Possibilities are use torch (torch.svd or smth), or use another library (cupy)
        error = torch.from_numpy(error)

        self.errors[network][other_network] = error

    # Set `self.neuron_sort`
    # {network: sorted_list}
    self.neuron_sort = {}
    # Sort neurons by worst correlation (highest regression error) with another network
    for network in tqdm(self.representations_d, desc='annotation'):
        self.neuron_sort[network] = sorted(
            range(self.num_neurons_d[network]),
            key = lambda i: max(
                self.errors[network][other][i] 
                for other in self.errors[network]
            )
        )

mu, sigma: 100%|█| 3/3 [00:00<00:00, 503.20it/s]
