In [3]:
%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [50]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

# Load fake activations

In [77]:
class A:
    pass
self = A()

In [78]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [79]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

In [80]:
self.device = torch.device('cpu')
self.op = min

# Build Function

### Normalize

In [68]:
# Set `means_d`, `stdevs_d`
# Set `self.nrepresentations_d` to be normalized. 
means_d = {}
stdevs_d = {}
self.nrepresentations_d = {}
self.lsingularv_d = {}

for network in tqdm(self.representations_d, desc='mu, sigma'):
    t = self.representations_d[network].to(self.device)
    means = t.mean(0, keepdim=True)
    stdevs = (t - means).pow(2).mean(0, keepdim=True).pow(0.5)

    means_d[network] = means.cpu()
    stdevs_d[network] = stdevs.cpu()
    self.nrepresentations_d[network] = ((t - means) / stdevs).cpu()
    self.lsingularv_d[network], _, _ = torch.svd(self.nrepresentations_d[network])
    
del self.representations_d

mu, sigma: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 212.81it/s]


### Set `self.pred_power` loop

In [69]:
# Set `self.pred_power`
# If the data is centered, it is the r value.
# Set `self.similarities`
self.pred_power = {network: {} for network in self.nrepresentations_d}
self.similarities = {network: {} for network in self.nrepresentations_d}        
for network, other_network in tqdm(p(self.nrepresentations_d,
                                     self.nrepresentations_d),
                                   desc='correlate',
                                   total=len(self.nrepresentations_d)**2):

    if network == other_network:
        continue

    U = self.lsingularv_d[other_network].to(self.device)
    Y = self.nrepresentations_d[network].to(self.device)

    # SVD method of linreg
    UtY = torch.mm(U.t(), Y) # b for Ub = Y

    bnorms = torch.norm(UtY, dim=0)
    ynorms = torch.norm(Y, dim=0)

    self.pred_power[network][other_network] = (bnorms / ynorms).cpu().numpy()
    self.similarities[network][other_network] = self.pred_power[network][other_network].mean()

correlate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 1218.25it/s]


In [70]:
self.similarities

{'foo': {'bar': 0.2804998, 'baz': 0.26484612},
 'bar': {'foo': 0.3134351, 'baz': 0.26393247},
 'baz': {'foo': 0.31717753, 'bar': 0.2807093}}

### Set `self.neuron_sort`, `self.neuron_notated_sort`

In [71]:
# Set `self.neuron_sort` : {network: sorted_list}
# Set `self.neuron_notated_sort` : {network: [(neuron, {other_network: pred_power})]}
self.neuron_sort = {}
self.neuron_notated_sort = {}
# Sort neurons by correlation with another network
for network in tqdm(self.nrepresentations_d, desc='annotation'):
    self.neuron_sort[network] = sorted(
            range(self.num_neurons_d[network]),
            key = lambda i: self.op(
                self.pred_power[network][other][i] 
                for other in self.pred_power[network]),
            reverse=True
        )
    
    self.neuron_notated_sort[network] = [
        (
            neuron,
            {
                other: float(self.pred_power[network][other][neuron])
                for other in self.pred_power[network]
            }
        )
        for neuron in self.neuron_sort[network]
    ]

annotation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3387.97it/s]


# Create final function

In [82]:
%lprun -f compute_correlations compute_correlations(self)

mu, sigma: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 258.52it/s]
correlate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 1097.00it/s]
annotation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 663.45it/s]


In [81]:
def compute_correlations(self):
    """
    Set `self.neuron_sort`. 
    """

    # Set `means_d`, `stdevs_d`
    # Set `self.nrepresentations_d` to be normalized. 
    means_d = {}
    stdevs_d = {}
    self.nrepresentations_d = {}
    self.lsingularv_d = {}

    for network in tqdm(self.representations_d, desc='mu, sigma'):
        t = self.representations_d[network].to(self.device)
        means = t.mean(0, keepdim=True)
        stdevs = (t - means).pow(2).mean(0, keepdim=True).pow(0.5)

        means_d[network] = means.cpu()
        stdevs_d[network] = stdevs.cpu()
        self.nrepresentations_d[network] = ((t - means) / stdevs).cpu()
        self.lsingularv_d[network], _, _ = torch.svd(self.nrepresentations_d[network])

    del self.representations_d
    
    # Set `self.pred_power`
    # If the data is centered, it is the r value.
    # Set `self.similarities`
    self.pred_power = {network: {} for network in self.nrepresentations_d}
    self.similarities = {network: {} for network in self.nrepresentations_d}        
    for network, other_network in tqdm(p(self.nrepresentations_d,
                                         self.nrepresentations_d),
                                       desc='correlate',
                                       total=len(self.nrepresentations_d)**2):

        if network == other_network:
            continue

        U = self.lsingularv_d[other_network].to(self.device)
        Y = self.nrepresentations_d[network].to(self.device)

        # SVD method of linreg
        UtY = torch.mm(U.t(), Y) # b for Ub = Y

        bnorms = torch.norm(UtY, dim=0)
        ynorms = torch.norm(Y, dim=0)

        self.pred_power[network][other_network] = (bnorms / ynorms).cpu().numpy()
        self.similarities[network][other_network] = self.pred_power[network][other_network].mean()
    

    # Set `self.neuron_sort` : {network: sorted_list}
    # Set `self.neuron_notated_sort` : {network: [(neuron, {other_network: pred_power})]}
    self.neuron_sort = {}
    self.neuron_notated_sort = {}
    # Sort neurons by correlation with another network
    for network in tqdm(self.nrepresentations_d, desc='annotation'):
        self.neuron_sort[network] = sorted(
                range(self.num_neurons_d[network]),
                key = lambda i: self.op(
                    self.pred_power[network][other][i] 
                    for other in self.pred_power[network]),
                reverse=True
            )

        self.neuron_notated_sort[network] = [
            (
                neuron,
                {
                    other: float(self.pred_power[network][other][neuron])
                    for other in self.pred_power[network]
                }
            )
            for neuron in self.neuron_sort[network]
        ]

In [55]:
compute_correlations(self)

mu, sigma: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 270.72it/s]


AttributeError: 'A' object has no attribute 'representations_d'