In [2]:
import torch 
from tqdm import tqdm
from itertools import product as p
import json
import numpy as np
import h5py
from os.path import basename, dirname
#import dask.array as da
import pickle
from var import fname2mname

# setup

In [3]:
def pvec(t):
    return t/t.sum(dim=-1, keepdim=True)

In [4]:
class A:
    pass
self = A()

In [5]:
self.num_heads_d = {} # {fname, int}
self.attentions_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
attention_fname_l = [f1, f2, f3]

In [6]:
# initialize `num_heads_d`, `attentions_d` with fake data
n1, n2, n3 = 10, 12, 14
wlen_l = [12, 5, 9, 6]

self.num_heads_d[f1] = n1
self.num_heads_d[f2] = n2
self.num_heads_d[f3] = n3

for fname in attention_fname_l:
    attentions_l = [pvec(torch.randn(self.num_heads_d[fname], wlen, wlen))
                        for wlen in wlen_l]
    self.attentions_d[fname] = attentions_l

In [7]:
self.device = torch.device('cpu')
self.op = min

In [8]:
# FroMaxMinCorr, just our example
def correlation_matrix(network, other_network):
    device = self.device
    num_sentences = self.num_sentences
    
    distances = np.zeros((num_sentences, self.num_heads_d[network], self.num_heads_d[other_network]))
    for idx, (attns, o_attns) in enumerate(zip(self.attentions_d[network], self.attentions_d[other_network])):
        t1 = attns.to(device)
        t2 = o_attns.to(device)
        t11, t12, t13 = t1.size()
        t21, t22, t23 = t2.size()
        t1 = t1.reshape(t11, 1, t12, t13)
        t2 = t2.reshape(1, t21, t22, t23)

        distance = torch.norm(t1-t2, p='fro', dim=(2,3))
        distances[idx] = distance.cpu().numpy()
        
    # Set `correlation`
    distances = distances.mean(axis=0)
    mi, ma = distances.min(), distances.max()
    distances = (distances-mi)/(ma-mi)
    correlation = 1 - distances
    
    return correlation

In [9]:
self.correlation_matrix = correlation_matrix

# hnb

In [10]:
device = self.device

self.corrs = {network: {} for network in self.attentions_d}
self.pairs = {network: {} for network in self.attentions_d}
self.similarities = {network: {} for network in
                     self.attentions_d}
self.num_sentences = len(next(iter(self.attentions_d.values())))
self.num_words = sum(t.size()[-1] for t in next(iter(self.attentions_d.values())))

### `self.corrs`, `self.pairs`, `self.similarities` loop

In [28]:
# arbitrarily set loop variables
network = f1
other_network = f2

In [32]:
correlation = self.correlation_matrix(network, other_network)

In [35]:
# Main update
self.corrs[network][other_network] = correlation.max(axis=1)
self.corrs[other_network][network] = correlation.max(axis=0)

self.similarities[network][other_network] = self.corrs[network][other_network].mean()
self.similarities[other_network][network] = self.corrs[other_network][network].mean()

self.pairs[network][other_network] = correlation.argmax(axis=1)
self.pairs[other_network][network] = correlation.argmax(axis=0)

In [36]:
# full loop
for network, other_network in tqdm(p(self.attentions_d,
                                     self.attentions_d),
                                     desc='correlate',
                                     total=len(self.attentions_d)**2):
    if network == other_network:
        continue

    if other_network in self.corrs[network]: 
        continue
        
    correlation = self.correlation_matrix(network, other_network)
    
    # Main update
    self.corrs[network][other_network] = correlation.max(axis=1)
    self.corrs[other_network][network] = correlation.max(axis=0)

    self.similarities[network][other_network] = self.corrs[network][other_network].mean()
    self.similarities[other_network][network] = self.corrs[other_network][network].mean()

    self.pairs[network][other_network] = correlation.argmax(axis=1)
    self.pairs[other_network][network] = correlation.argmax(axis=0)

correlate: 100%|█████████████████████████████| 9/9 [00:00<00:00, 5287.68it/s]


# Full function

In [14]:
def compute_correlations(self):
    # convenient variables
    device = self.device
    self.num_sentences = len(next(iter(self.attentions_d.values())))
    self.num_words = sum(t.size()[-1] for t in next(iter(self.attentions_d.values())))

    # Set `self.corrs` : {network: {other: [corr]}}
    # Set `self.pairs` : {network: {other: [pair]}}
    # pair is index of head in other network
    # Set `self.similarities` : {network: {other: sim}}
    self.corrs = {network: {} for network in self.attentions_d}
    self.pairs = {network: {} for network in self.attentions_d}
    self.similarities = {network: {} for network in self.attentions_d}
    for network, other_network in tqdm(p(self.attentions_d,
                                         self.attentions_d),
                                         desc='correlate',
                                         total=len(self.attentions_d)**2):
        if network == other_network:
            continue

        if other_network in self.corrs[network]: 
            continue

        correlation = self.correlation_matrix(network, other_network)

        # Main update
        self.corrs[network][other_network] = correlation.max(axis=1)
        self.corrs[other_network][network] = correlation.max(axis=0)

        self.similarities[network][other_network] = self.corrs[network][other_network].mean()
        self.similarities[other_network][network] = self.corrs[other_network][network].mean()

        self.pairs[network][other_network] = correlation.argmax(axis=1)
        self.pairs[other_network][network] = correlation.argmax(axis=0)

    # Set `self.head_sort` : {network, sorted_list}
    # Set `self.head_notated_sort` : {network: [(head, {other: (corr, pair)})]}
    self.head_sort = {} 
    self.head_notated_sort = {}
    for network in tqdm(self.attentions_d, desc='annotation'):
        self.head_sort[network] = sorted(
            range(self.num_heads_d[network]), 
            key=lambda i: self.op(
                self.corrs[network][other][i] for other in self.corrs[network]
            ), 
            reverse=True,
        )
        self.head_notated_sort[network] = [
            (
                head,
                {
                    other : (
                        self.corrs[network][other][head], 
                        self.pairs[network][other][head],
                    ) 
                    for other in self.corrs[network]
                }
            ) 
            for head in self.head_sort[network]
        ]

In [15]:
compute_correlations(self)

correlate: 100%|█████████████████████████████| 9/9 [00:00<00:00, 4123.74it/s]
annotation: 100%|████████████████████████████| 3/3 [00:00<00:00, 5043.25it/s]


In [16]:
self.corrs

{'foo': {'bar': array([0.98086308, 0.90935966, 0.95885367, 0.84668847, 1.        ,
         0.6859649 , 0.99684803, 0.82426914, 0.87466759, 0.96812032]),
  'baz': array([0.99653077, 0.9896387 , 0.99414953, 0.98114855, 1.        ,
         0.96047981, 0.99906089, 0.97768111, 0.98509738, 0.99635   ])},
 'bar': {'foo': array([0.91008778, 0.31300565, 0.92151437, 0.83687101, 0.72204893,
         0.97862211, 0.89750213, 0.88318838, 0.48930261, 0.93181968,
         0.88392007, 1.        ]),
  'baz': array([0.98914902, 0.91514913, 0.99059196, 0.98066529, 0.96687033,
         0.99733577, 0.98778902, 0.98619307, 0.93834528, 0.99155758,
         0.98664106, 1.        ])},
 'baz': {'foo': array([0.99940343, 0.99812358, 0.99586541, 0.99183936, 0.03294723,
         0.95344159, 0.98109539, 1.        , 0.98771851, 0.99633164,
         0.97622873, 0.99829991, 0.98175752, 0.97234356]),
  'bar': array([1.        , 0.99716084, 0.99547619, 0.99160533, 0.0855523 ,
         0.95667934, 0.98149787, 0.99952981