In [1]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

# Load fake activations

In [2]:
class A:
    pass
self = A()

In [3]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [4]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

In [5]:
self.device = torch.device('cpu')
self.op = min

# Compute correlations

In [6]:
def compute_correlations(self):
    """
    Set `self.neuron_sort`. 
    """
    # Set `means_d`, `stdevs_d`
    # Set `self.nrepresentations_d` to be normalized. 
    means_d = {}
    stdevs_d = {}
    self.nrepresentations_d = {}

    for network in tqdm(self.representations_d, desc='mu, sigma'):
        t = self.representations_d[network]
        means = t.mean(0, keepdim=True)
        stdevs = (t - means).pow(2).mean(0, keepdim=True).pow(0.5)

        means_d[network] = means
        stdevs_d[network] = stdevs
        self.nrepresentations_d[network] = (t - means) / stdevs

    # Set `self.pred_power`
    # If the data is centered, it is the r value. 
    self.pred_power = {network: {} for network in self.representations_d}
    for network, other_network in tqdm(p(self.representations_d,
                                         self.representations_d),
                                       desc='correlate',
                                       total=len(self.representations_d)**2):

        if network == other_network:
            continue

        X = self.nrepresentations_d[other_network].to(self.device)
        Y = self.nrepresentations_d[network].to(self.device)

        # SVD method of linreg
        U, S, V = torch.svd(X) 
        UtY = torch.mm(U.t(), Y) # b for Ub = Y

        bnorms = torch.norm(UtY, dim=0)
        ynorms = torch.norm(Y, dim=0)

        self.pred_power[network][other_network] = (bnorms / ynorms).cpu()

    # Set `self.neuron_sort` : {network: sorted_list}
    # Set `self.neuron_notated_sort` : {network: [(neuron, {other_network: pred_power})]}
    self.neuron_sort = {}
    self.neuron_notated_sort = {}
    # Sort neurons by correlation with another network
    for network in tqdm(self.nrepresentations_d, desc='annotation'):
        self.neuron_sort[network] = sorted(
                range(self.num_neurons_d[network]),
                key = lambda i: self.op(
                    self.pred_power[network][other][i] 
                    for other in self.pred_power[network]),
                reverse=True
            )

        self.neuron_notated_sort[network] = [
            (
                neuron,
                {
                    other: float(self.pred_power[network][other][neuron])
                    for other in self.pred_power[network]
                }
            )
            for neuron in self.neuron_sort[network]
        ]

In [7]:
compute_correlations(self)

mu, sigma: 100%|██████████████████████████████| 3/3 [00:00<00:00, 438.37it/s]
correlate: 100%|██████████████████████████████| 9/9 [00:00<00:00, 338.20it/s]
annotation: 100%|█████████████████████████████| 3/3 [00:00<00:00, 130.57it/s]


# Build `write_correlations`

In [8]:
# param
output_file = "temp"

In [9]:
output = {
    "pred_power" : self.pred_power,
    "neuron_sort" : self.neuron_sort,
    "neuron_notated_sort" : self.neuron_notated_sort,    
}
torch.save(output, output_file)

# Final Function

In [10]:
def write_correlations(self, output_file):
    output = {
        "pred_power" : self.pred_power,
        "neuron_sort" : self.neuron_sort,
        "neuron_notated_sort" : self.neuron_notated_sort,    
    }
    torch.save(output, output_file)

In [11]:
output_file = "temp"
write_correlations(self, output_file)

In [12]:
d = torch.load("temp")
d['pred_power'][f1]

{'bar': tensor([0.2906, 0.2562, 0.2893, 0.2935, 0.2618, 0.2636, 0.2666, 0.2958, 0.2787,
         0.2671, 0.2544, 0.2575, 0.2863, 0.2532, 0.2858, 0.2770, 0.3161, 0.2492,
         0.2764, 0.2803, 0.3063, 0.2913, 0.2903, 0.2743, 0.2809, 0.2906, 0.2877,
         0.2635, 0.3639, 0.2494, 0.2934, 0.2965, 0.2484, 0.2793, 0.3088, 0.2750,
         0.2506, 0.2390, 0.2848, 0.2891, 0.2960, 0.3027, 0.2939, 0.3122, 0.2834,
         0.3005, 0.2919, 0.2968, 0.3116, 0.2886, 0.3077, 0.2433, 0.2395, 0.2909,
         0.2594, 0.2953, 0.2982, 0.3060, 0.2606, 0.2619, 0.2695, 0.2809, 0.2750,
         0.2828, 0.2467, 0.2468, 0.2788, 0.2751, 0.2897, 0.2813, 0.2809, 0.2990,
         0.2976, 0.2952, 0.2980, 0.2878, 0.2654, 0.2377, 0.2822, 0.2633, 0.2936,
         0.2611, 0.2867, 0.2833, 0.2490, 0.2594, 0.2843, 0.3101, 0.3282, 0.2949,
         0.2791, 0.2994, 0.2723, 0.2845, 0.3110, 0.2752, 0.2855, 0.2707, 0.2634,
         0.2539]),
 'baz': tensor([0.2742, 0.2408, 0.2626, 0.2831, 0.2630, 0.2438, 0.2583, 0.2686, 0.2