In [1]:
import torch 
from tqdm import tqdm
from itertools import product as p
import json
import numpy as np
import h5py
from os.path import basename, dirname
#import dask.array as da
import pickle
from var import fname2mname

# setup

In [57]:
def pvec(t):
    return t/t.sum(dim=-1, keepdim=True)

In [58]:
class A:
    pass
self = A()

In [59]:
self.num_heads_d = {} # {fname, int}
self.attentions_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
attention_fname_l = [f1, f2, f3]

In [60]:
# initialize `num_heads_d`, `attentions_d` with fake data
n1, n2, n3 = 10, 12, 14
wlen_l = [12, 5, 9, 6]

self.num_heads_d[f1] = n1
self.num_heads_d[f2] = n2
self.num_heads_d[f3] = n3

for fname in attention_fname_l:
    attentions_l = [pvec(torch.randn(self.num_heads_d[fname], wlen, wlen))
                        for wlen in wlen_l]
    self.attentions_d[fname] = attentions_l

In [62]:
self.device = torch.device('cpu')
self.op = min

# hnb

In [63]:
device = self.device

self.corrs = {network: {} for network in self.attentions_d}
self.pairs = {network: {} for network in self.attentions_d}
self.similarities = {network: {} for network in
                     self.attentions_d}
num_sentences = len(next(iter(self.attentions_d.values())))

### `self.corrs`, `self.pairs`, `self.similarities` loop

In [64]:
# arbitrarily set loop variables
network = f1
other_network = f2

In [84]:
distances = np.zeros((num_sentences, self.num_heads_d[network], self.num_heads_d[other_network]))

##### Set `distances` loop

In [66]:
# arbitrary loop variables
idx = 0
attns = self.attentions_d[network][0]
o_attns = self.attentions_d[other_network][0]

In [67]:
t1 = attns.to(device)
t2 = o_attns.to(device)

In [68]:
t11, t12, t13 = t1.size()
t21, t22, t23 = t2.size()

In [69]:
t1 = t1.reshape(t11, 1, t12, t13)
t2 = t2.reshape(1, t21, t22, t23)

In [70]:
distance = torch.norm(t1-t2, p='fro', dim=(2,3))
distance = distance.cpu().numpy()

In [71]:
distances[idx] = distance

In [85]:
# full
for idx, (attns, o_attns) in enumerate(zip(self.attentions_d[network], self.attentions_d[other_network])):
    t1 = attns.to(device)
    t2 = o_attns.to(device)
    t11, t12, t13 = t1.size()
    t21, t22, t23 = t2.size()
    t1 = t1.reshape(t11, 1, t12, t13)
    t2 = t2.reshape(1, t21, t22, t23)

    distance = torch.norm(t1-t2, p='fro', dim=(2,3))
    distances[idx] = distance.cpu().numpy()

In [86]:
# Set `correlation`
distances = distances.mean(axis=0)
mi, ma = distances.min(), distances.max()
distances = (distances-mi)/(ma-mi)
correlation = 1 - distances

In [88]:
# Main update
self.corrs[network][other_network] = correlation.max(axis=1)
self.corrs[other_network][network] = correlation.max(axis=0)

self.similarities[network][other_network] = self.corrs[network][other_network].mean()
self.similarities[other_network][network] = self.corrs[other_network][network].mean()

self.pairs[network][other_network] = correlation.argmax(axis=1)
self.pairs[other_network][network] = correlation.argmax(axis=0)

In [89]:
# full loop
for network, other_network in tqdm(p(self.attentions_d,
                                     self.attentions_d),
                                     desc='correlate',
                                     total=len(self.attentions_d)**2):
    if network == other_network:
        continue

    if other_network in self.corrs[network]: 
        continue
        
    # Set `distances`
    distances = np.zeros((num_sentences, self.num_heads_d[network], self.num_heads_d[other_network]))
    for idx, (attns, o_attns) in enumerate(zip(self.attentions_d[network], self.attentions_d[other_network])):
        t1 = attns.to(device)
        t2 = o_attns.to(device)
        t11, t12, t13 = t1.size()
        t21, t22, t23 = t2.size()
        t1 = t1.reshape(t11, 1, t12, t13)
        t2 = t2.reshape(1, t21, t22, t23)

        distance = torch.norm(t1-t2, p='fro', dim=(2,3))
        distances[idx] = distance.cpu().numpy()
        
    # Set `correlation`
    distances = distances.mean(axis=0)
    mi, ma = distances.min(), distances.max()
    distances = (distances-mi)/(ma-mi)
    correlation = 1 - distances
    
    # Main update
    self.corrs[network][other_network] = correlation.max(axis=1)
    self.corrs[other_network][network] = correlation.max(axis=0)

    self.similarities[network][other_network] = self.corrs[network][other_network].mean()
    self.similarities[other_network][network] = self.corrs[other_network][network].mean()

    self.pairs[network][other_network] = correlation.argmax(axis=1)
    self.pairs[other_network][network] = correlation.argmax(axis=0)

correlate: 100%|█████████████████████████████| 9/9 [00:00<00:00, 6202.55it/s]


# Full function

In [95]:
def compute_correlations(self):
    # convenient variables
    device = self.device
    num_sentences = len(next(iter(self.attentions_d.values())))

    # Set `self.corrs` : {network: {other: [corr]}}
    # Set `self.pairs` : {network: {other: [pair]}}
    # pair is index of head in other network
    # Set `self.similarities` : {network: {other: sim}}
    self.corrs = {network: {} for network in self.attentions_d}
    self.pairs = {network: {} for network in self.attentions_d}
    self.similarities = {network: {} for network in self.attentions_d}
    for network, other_network in tqdm(p(self.attentions_d,
                                         self.attentions_d),
                                         desc='correlate',
                                         total=len(self.attentions_d)**2):
        if network == other_network:
            continue

        if other_network in self.corrs[network]: 
            continue

        # Set `distances`
        distances = np.zeros((num_sentences, self.num_heads_d[network], self.num_heads_d[other_network]))
        for idx, (attns, o_attns) in enumerate(zip(self.attentions_d[network], self.attentions_d[other_network])):
            t1 = attns.to(device)
            t2 = o_attns.to(device)
            t11, t12, t13 = t1.size()
            t21, t22, t23 = t2.size()
            t1 = t1.reshape(t11, 1, t12, t13)
            t2 = t2.reshape(1, t21, t22, t23)

            distance = torch.norm(t1-t2, p='fro', dim=(2,3))
            distances[idx] = distance.cpu().numpy()

        # Set `correlation`
        distances = distances.mean(axis=0)
        mi, ma = distances.min(), distances.max()
        distances = (distances-mi)/(ma-mi)
        correlation = 1 - distances

        # Main update
        self.corrs[network][other_network] = correlation.max(axis=1)
        self.corrs[other_network][network] = correlation.max(axis=0)

        self.similarities[network][other_network] = self.corrs[network][other_network].mean()
        self.similarities[other_network][network] = self.corrs[other_network][network].mean()

        self.pairs[network][other_network] = correlation.argmax(axis=1)
        self.pairs[other_network][network] = correlation.argmax(axis=0)

    # Set `self.head_sort` : {network, sorted_list}
    # Set `self.head_notated_sort` : {network: [(head, {other: (corr, pair)})]}
    self.head_sort = {} 
    self.head_notated_sort = {}
    for network in tqdm(self.attentions_d, desc='annotation'):
        self.head_sort[network] = sorted(
            range(self.num_heads_d[network]), 
            key=lambda i: self.op(
                self.corrs[network][other][i] for other in self.corrs[network]
            ), 
            reverse=True,
        )
        self.head_notated_sort[network] = [
            (
                head,
                {
                    other : (
                        self.corrs[network][other][head], 
                        self.pairs[network][other][head],
                    ) 
                    for other in self.corrs[network]
                }
            ) 
            for head in self.head_sort[network]
        ]

In [96]:
compute_correlations(self)

correlate: 100%|█████████████████████████████| 9/9 [00:00<00:00, 4496.04it/s]
annotation: 100%|████████████████████████████| 3/3 [00:00<00:00, 9939.11it/s]


In [98]:
self.corrs

{'foo': {'bar': array([0.80571031, 0.89711243, 0.96516111, 0.96388111, 0.79245618,
         0.98107638, 0.95147972, 0.90318424, 0.66177127, 1.        ]),
  'baz': array([0.70828754, 0.8381147 , 0.9657241 , 0.97126409, 0.69442393,
         0.98877343, 0.94589928, 0.85756298, 0.49796528, 1.        ])},
 'bar': {'foo': array([0.51749251, 0.25547081, 0.93043385, 0.57637179, 0.75780165,
         0.87522676, 0.29529897, 0.96944789, 0.82012099, 0.63578879,
         0.90963748, 1.        ]),
  'baz': array([0.4970413 , 0.25899628, 0.95184827, 0.5703824 , 0.75182914,
         0.88623117, 0.28212633, 0.97411617, 0.83753652, 0.70851246,
         0.91955914, 1.        ])},
 'baz': {'foo': array([0.90670855, 0.98533472, 0.76424015, 0.85075476, 0.96583704,
         0.90766268, 0.92603776, 0.88066662, 0.87295673, 1.        ,
         0.99580053, 0.53103789, 0.91161257, 0.82553783]),
  'bar': array([0.94825046, 0.98203163, 0.85541828, 0.90084108, 0.98183017,
         0.93970805, 0.95202526, 0.91037505