In [1]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

# Load fake activations

In [12]:
class A:
    pass
self = A()

In [13]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [14]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

In [15]:
self.device = torch.device('cpu')
self.op = min

# Build Function

### Normalize

In [6]:
# Set `means_d`, `stdevs_d`
# Set `self.nrepresentations_d` to be normalized. 
means_d = {}
stdevs_d = {}
self.nrepresentations_d = {}

for network in tqdm(self.representations_d, desc='mu, sigma'):
    t = self.representations_d[network]
    means = t.mean(0, keepdim=True)
    stdevs = (t - means).pow(2).mean(0, keepdim=True).pow(0.5)

    means_d[network] = means
    stdevs_d[network] = stdevs
    self.nrepresentations_d[network] = (t - means) / stdevs

mu, sigma: 100%|██████████████████████████████| 3/3 [00:00<00:00, 626.30it/s]


### Set `self.pred_power` loop

In [7]:
# full
self.pred_power = {network: {} for network in self.representations_d}
for network, other_network in tqdm(p(self.representations_d,
                                     self.representations_d),
                                   desc='correlate',
                                   total=len(self.representations_d)**2):

    if network == other_network:
        continue

    X = self.nrepresentations_d[other_network].to(self.device)
    Y = self.nrepresentations_d[network].to(self.device)

    # SVD method of linreg
    U, S, V = torch.svd(X) 
    UtY = torch.mm(U.t(), Y) # b for Ub = Y

    bnorms = torch.norm(UtY, dim=0)
    ynorms = torch.norm(Y, dim=0)

    self.pred_power[network][other_network] = (bnorms / ynorms).cpu()

correlate: 100%|██████████████████████████████| 9/9 [00:00<00:00, 334.77it/s]


### Set `self.neuron_sort`, `self.neuron_notated_sort`

In [8]:
# Set `self.neuron_sort` : {network: sorted_list}
# Set `self.neuron_notated_sort` : {network: [(neuron, {other_network: pred_power})]}
self.neuron_sort = {}
self.neuron_notated_sort = {}
# Sort neurons by correlation with another network
for network in tqdm(self.nrepresentations_d, desc='annotation'):
    self.neuron_sort[network] = sorted(
            range(self.num_neurons_d[network]),
            key = lambda i: self.op(
                self.pred_power[network][other][i] 
                for other in self.pred_power[network]),
            reverse=True
        )
    
    self.neuron_notated_sort[network] = [
        (
            neuron,
            {
                other: float(self.pred_power[network][other][neuron])
                for other in self.pred_power[network]
            }
        )
        for neuron in self.neuron_sort[network]
    ]

annotation: 100%|█████████████████████████████| 3/3 [00:00<00:00, 181.07it/s]


In [9]:
# Nice, seems to be correct
np.minimum(self.pred_power[f1][f2][self.neuron_sort[f1]],
           self.pred_power[f1][f3][self.neuron_sort[f1]])

tensor([0.2938, 0.2898, 0.2892, 0.2847, 0.2843, 0.2836, 0.2827, 0.2808, 0.2797,
        0.2796, 0.2787, 0.2784, 0.2766, 0.2765, 0.2749, 0.2746, 0.2721, 0.2717,
        0.2716, 0.2713, 0.2703, 0.2701, 0.2691, 0.2689, 0.2684, 0.2684, 0.2682,
        0.2679, 0.2672, 0.2668, 0.2668, 0.2660, 0.2658, 0.2655, 0.2643, 0.2633,
        0.2632, 0.2629, 0.2622, 0.2615, 0.2612, 0.2611, 0.2611, 0.2609, 0.2602,
        0.2597, 0.2593, 0.2589, 0.2579, 0.2571, 0.2571, 0.2563, 0.2558, 0.2558,
        0.2556, 0.2549, 0.2548, 0.2541, 0.2533, 0.2532, 0.2532, 0.2527, 0.2525,
        0.2524, 0.2521, 0.2521, 0.2519, 0.2513, 0.2510, 0.2500, 0.2499, 0.2499,
        0.2492, 0.2488, 0.2488, 0.2482, 0.2475, 0.2461, 0.2459, 0.2457, 0.2447,
        0.2442, 0.2438, 0.2432, 0.2426, 0.2396, 0.2390, 0.2383, 0.2376, 0.2355,
        0.2343, 0.2342, 0.2340, 0.2322, 0.2261, 0.2236, 0.2196, 0.2178, 0.2140,
        0.2080])

# Create final function

In [16]:
def compute_correlations(self):
    """
    Set `self.neuron_sort`. 
    """
    # Set `means_d`, `stdevs_d`
    # Set `self.nrepresentations_d` to be normalized. 
    means_d = {}
    stdevs_d = {}
    self.nrepresentations_d = {}

    for network in tqdm(self.representations_d, desc='mu, sigma'):
        t = self.representations_d[network]
        means = t.mean(0, keepdim=True)
        stdevs = (t - means).pow(2).mean(0, keepdim=True).pow(0.5)

        means_d[network] = means
        stdevs_d[network] = stdevs
        self.nrepresentations_d[network] = (t - means) / stdevs

    # Set `self.pred_power`
    # If the data is centered, it is the r value. 
    self.pred_power = {network: {} for network in self.representations_d}
    for network, other_network in tqdm(p(self.representations_d,
                                         self.representations_d),
                                       desc='correlate',
                                       total=len(self.representations_d)**2):

        if network == other_network:
            continue

        X = self.nrepresentations_d[other_network].to(self.device)
        Y = self.nrepresentations_d[network].to(self.device)

        # SVD method of linreg
        U, S, V = torch.svd(X) 
        UtY = torch.mm(U.t(), Y) # b for Ub = Y

        bnorms = torch.norm(UtY, dim=0)
        ynorms = torch.norm(Y, dim=0)

        self.pred_power[network][other_network] = (bnorms / ynorms).cpu()

    # Set `self.neuron_sort` : {network: sorted_list}
    # Set `self.neuron_notated_sort` : {network: [(neuron, {other_network: pred_power})]}
    self.neuron_sort = {}
    self.neuron_notated_sort = {}
    # Sort neurons by correlation with another network
    for network in tqdm(self.nrepresentations_d, desc='annotation'):
        self.neuron_sort[network] = sorted(
                range(self.num_neurons_d[network]),
                key = lambda i: self.op(
                    self.pred_power[network][other][i] 
                    for other in self.pred_power[network]),
                reverse=True
            )

        self.neuron_notated_sort[network] = [
            (
                neuron,
                {
                    other: float(self.pred_power[network][other][neuron])
                    for other in self.pred_power[network]
                }
            )
            for neuron in self.neuron_sort[network]
        ]

In [17]:
compute_correlations(self)

mu, sigma: 100%|█████████████████████████████| 3/3 [00:00<00:00, 1863.86it/s]
correlate: 100%|██████████████████████████████| 9/9 [00:00<00:00, 350.53it/s]
annotation: 100%|█████████████████████████████| 3/3 [00:00<00:00, 147.39it/s]
