In [1]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

# Load fake activations

In [2]:
class A:
    pass
self = A()

In [3]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [4]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

# Build Function

### Normalize

In [5]:
# Set `means_d`, `stdevs_d`, normalize
means_d = {}
stdevs_d = {}

for network in tqdm(self.representations_d, desc='mu, sigma'):
    t = self.representations_d[network]
    means = t.mean(0, keepdim=True)
    stdevs = t.std(0, keepdim=True)

    means_d[network] = means
    stdevs_d[network] = stdevs
    self.representations_d[network] = (t - means) / stdevs

mu, sigma: 100%|███████████████████████| 3/3 [00:00<00:00, 1189.20it/s]


### Set `self.pred_power` loop

In [8]:
self.pred_power = {network: {} for network in self.representations_d} # "predictive power"

In [9]:
# loop variable
network = f1
other_network = f2

In [8]:
# if network == other_network:
#     continue

In [30]:
X = self.representations_d[other_network]
Y = self.representations_d[network]

In [31]:
U, S, V = torch.svd(X)
UtY = torch.mm(U.t(), Y)
UtY.shape

In [33]:
torch.norm(UtY, dim=0)

tensor([ 7.7806,  7.7896,  8.7355,  7.4009,  9.0961,  7.9801,  9.4549,  8.8950,
         7.6728,  7.4388,  9.1603,  8.0455,  8.9774,  7.9507,  7.9269,  8.9536,
         7.7259,  9.2478,  7.3973,  8.8610,  8.8618,  7.7693,  9.3108,  7.8104,
         8.6636,  9.0483,  9.5713,  8.5912,  8.5695,  8.9725,  7.9784,  7.9550,
         8.2758,  8.2214,  7.4071,  8.2919,  7.1655,  8.6621,  7.9327,  9.7438,
         9.2245,  9.2718,  7.1527,  9.1338,  9.3417,  7.8857,  7.8490,  7.8729,
         7.4691,  8.1648,  9.0727,  8.1953,  8.3313,  8.1932,  8.6896,  7.5120,
         7.6955,  8.6970,  8.4486,  8.4324,  9.1421,  9.2351,  7.9386,  9.6027,
         7.0667,  8.3023,  8.5285,  8.1558,  7.5736,  8.6171,  9.0789,  8.6213,
         6.8395,  6.8569,  9.3115,  8.2080,  9.0286,  8.3749,  7.7418,  8.5681,
         7.8323,  6.5963,  8.1163,  8.4784,  7.7956,  7.3548,  8.2077,  7.2818,
         8.3593,  8.0570,  8.9142,  9.2577,  8.4241,  8.3338, 10.2262,  8.1552,
         7.7836,  8.2849,  7.3795,  7.49

In [28]:
self.pred_power[network][other_network] = torch.norm(UtY, dim=0)

In [41]:
# full
# Set `self.pred_power`
for network, other_network in tqdm(p(self.representations_d, self.representations_d), desc='correlate', total=len(self.representations_d)**2):
    if network == other_network:
        continue

    X = self.representations_d[other_network]
    Y = self.representations_d[network]

    U, S, V = torch.svd(X)
    UtY = torch.mm(U.t(), Y) # x for Ux = Y

    self.pred_power[network][other_network] = torch.norm(UtY, dim=0) # cols of U are orthogonal

correlate: 100%|████████████████████████| 9/9 [00:00<00:00, 342.91it/s]


In [19]:
# old
# full
# Set `self.errors`
# {network: {other: error_tensor}}
for network, other_network in tqdm(p(self.representations_d, self.representations_d), desc='correlate', total=len(self.representations_d)**2):
    if network == other_network:
        continue

    # Try to predict this network given the other one
    X = self.representations_d[other_network].cpu().numpy()
    Y = self.representations_d[network].cpu().numpy()

    # solve with ordinary least squares 
    error = np.linalg.lstsq(X, Y, rcond=None)[1] # TO DO: don't use numpy, or at least use CUDA
    # Possibilities are use torch (torch.svd or smth), or use another library (cupy)
    error = torch.from_numpy(error)

    self.errors[network][other_network] = error

correlate: 100%|████████████████████████| 9/9 [00:00<00:00, 218.85it/s]


### Set `self.neuron_sort`

In [14]:
self.neuron_sort = {}

In [39]:
# loop variable
network = f1

In [43]:
sorted(
    range(self.num_neurons_d[network]),
    key = lambda i: min(
        self.pred_power[network][other][i] 
        for other in self.errors[network]
    ),
    reverse=True
)

[15,
 36,
 17,
 29,
 54,
 52,
 20,
 8,
 51,
 27,
 68,
 22,
 3,
 43,
 35,
 19,
 40,
 50,
 30,
 23,
 53,
 0,
 42,
 48,
 1,
 69,
 7,
 21,
 49,
 24,
 4,
 14,
 26,
 64,
 39,
 9,
 44,
 60,
 31,
 16,
 13,
 41,
 6,
 66,
 37,
 67,
 33,
 2,
 28,
 62,
 45,
 63,
 25,
 55,
 56,
 12,
 59,
 38,
 65,
 10,
 58,
 57,
 61,
 11,
 32,
 5,
 34,
 18,
 46,
 47]

In [42]:
# old. They are equal
sorted(
    range(self.num_neurons_d[network]),
    key = lambda i: max(
        self.errors[network][other][i] 
        for other in self.errors[network]
    )
)

[15,
 36,
 17,
 29,
 54,
 52,
 20,
 8,
 51,
 27,
 68,
 22,
 3,
 43,
 35,
 19,
 40,
 50,
 30,
 23,
 53,
 0,
 42,
 48,
 1,
 69,
 7,
 21,
 49,
 24,
 4,
 14,
 26,
 64,
 39,
 9,
 44,
 60,
 31,
 16,
 13,
 41,
 6,
 66,
 37,
 67,
 33,
 2,
 28,
 62,
 45,
 63,
 25,
 55,
 56,
 12,
 59,
 38,
 65,
 10,
 58,
 57,
 61,
 11,
 32,
 5,
 34,
 18,
 46,
 47]

In [44]:
# Set `self.neuron_sort`
# {network: sorted_list}
self.neuron_sort = {}
# Sort neurons by worst correlation (highest regression error) with another network
for network in tqdm(self.representations_d, desc='annotation'):
    self.neuron_sort[network] = sorted(
            range(self.num_neurons_d[network]),
            key = lambda i: min(
                self.pred_power[network][other][i] 
                for other in self.errors[network]),
            reverse=True
        )

annotation: 100%|████████████████████████| 3/3 [00:00<00:00, 95.03it/s]


In [46]:
# Nice, seems to be correct
np.minimum(self.pred_power[f1][f2][self.neuron_sort[f1]],
           self.pred_power[f1][f3][self.neuron_sort[f1]])

tensor([9.6027, 9.3417, 9.2718, 9.2577, 9.2245, 9.1421, 9.0789, 9.0727, 9.0699,
        9.0425, 9.0286, 9.0131, 8.9774, 8.9725, 8.9689, 8.9536, 8.9498, 8.8968,
        8.8610, 8.7355, 8.7266, 8.6970, 8.6896, 8.6802, 8.6636, 8.6621, 8.6472,
        8.6213, 8.6171, 8.5912, 8.5695, 8.5681, 8.4641, 8.4486, 8.4241, 8.4035,
        8.3749, 8.3593, 8.3441, 8.3338, 8.3313, 8.3023, 8.2849, 8.2758, 8.2303,
        8.2080, 8.2077, 8.1953, 8.1932, 8.1648, 8.1558, 8.1552, 8.1163, 8.0594,
        8.0455, 8.0426, 8.0406, 7.9801, 7.9784, 7.9550, 7.9507, 7.9386, 7.9269,
        7.8857, 7.8848, 7.8490, 7.8323, 7.8104, 7.8018, 7.7956, 7.7896, 7.7836,
        7.7806, 7.7693, 7.7418, 7.7259, 7.6955, 7.6728, 7.6017, 7.5736, 7.5120,
        7.4971, 7.4691, 7.4388, 7.4071, 7.4009, 7.3973, 7.3795, 7.3548, 7.3247,
        7.2818, 7.2402, 7.2235, 7.1655, 7.1527, 7.0667, 6.8569, 6.8395, 6.5963,
        6.5023])

# Create final function

In [5]:
def compute_correlations(self):
    # Set `means_d`, `stdevs_d`, normalize to mean 0 std 1
    means_d = {}
    stdevs_d = {}

    for network in tqdm(self.representations_d, desc='mu, sigma'):
        t = self.representations_d[network]
        means = t.mean(0, keepdim=True)
        stdevs = (t - means).pow(2).mean(0, keepdim=True).pow(0.5)

        means_d[network] = means
        stdevs_d[network] = stdevs
        self.representations_d[network] = (t - means) / stdevs

    # Set `self.pred_power`
    for network, other_network in tqdm(p(self.representations_d, self.representations_d), desc='correlate', total=len(self.representations_d)**2):
        if network == other_network:
            continue

        X = self.representations_d[other_network]
        Y = self.representations_d[network]

        U, S, V = torch.svd(X)
        UtY = torch.mm(U.t(), Y) # x for Ux = Y

        self.pred_power[network][other_network] = torch.norm(UtY, dim=0) / torch.norm(Y, dim=0) # cols of U are orthogonal

    # Set `self.neuron_sort`
    # {network: sorted_list}
    self.neuron_sort = {}
    # Sort neurons by worst correlation (highest regression error) with another network
    for network in tqdm(self.representations_d, desc='annotation'):
        self.neuron_sort[network] = sorted(
                range(self.num_neurons_d[network]),
                key = lambda i: min(
                    self.pred_power[network][other][i] 
                    for other in self.errors[network]),
                reverse=True
            )

mu, sigma: 100%|█| 3/3 [00:00<00:00, 503.20it/s]
