In [1]:
import torch 
from tqdm import tqdm
from itertools import product as p
import json
import numpy as np
import h5py
from os.path import basename, dirname
#import dask.array as da
import pickle
from var import fname2mname

# setup

In [2]:
def pvec(t, ar_mask=False):
    if ar_mask:
        ar = np.tril(t.cpu().numpy())
        t = torch.FloatTensor(ar)
    return t/t.sum(dim=-1, keepdim=True)

In [3]:
class A:
    pass
self = A()

In [4]:
self.num_heads_d = {} # {fname, int}
self.attentions_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
attention_fname_l = [f1, f2, f3]

In [5]:
# initialize `num_heads_d`, `attentions_d` with fake data
n1, n2, n3 = 10, 12, 14
wlen_l = [27, 5, 9, 2]

self.num_heads_d[f1] = n1
self.num_heads_d[f2] = n2
self.num_heads_d[f3] = n3

for fname in attention_fname_l:
    attentions_l = [pvec(torch.abs(torch.randn(self.num_heads_d[fname], wlen, wlen)), True)
                        for wlen in wlen_l]
    self.attentions_d[fname] = attentions_l

In [6]:
self.device = torch.device('cpu')
self.op = min

In [7]:
self.num_sentences = len(wlen_l)
self.num_words = sum(wlen_l)

In [8]:
def compute_correlations(self):
    # convenient variables
    device = self.device
    self.num_sentences = len(next(iter(self.attentions_d.values())))
    self.num_words = sum(t.size()[-1] for t in next(iter(self.attentions_d.values())))

    # Set `self.corrs` : {network: {other: [corr]}}
    # Set `self.pairs` : {network: {other: [pair]}}
    # pair is index of head in other network
    # Set `self.similarities` : {network: {other: sim}}
    self.corrs = {network: {} for network in self.attentions_d}
    self.pairs = {network: {} for network in self.attentions_d}
    self.similarities = {network: {} for network in self.attentions_d}
    for network, other_network in tqdm(p(self.attentions_d,
                                         self.attentions_d),
                                         desc='correlate',
                                         total=len(self.attentions_d)**2):
        if network == other_network:
            continue

        if other_network in self.corrs[network]: 
            continue

        correlation = self.correlation_matrix(network, other_network)

        # Main update
        self.corrs[network][other_network] = correlation.max(axis=1)
        self.corrs[other_network][network] = correlation.max(axis=0)

        self.similarities[network][other_network] = self.corrs[network][other_network].mean()
        self.similarities[other_network][network] = self.corrs[other_network][network].mean()

        self.pairs[network][other_network] = correlation.argmax(axis=1)
        self.pairs[other_network][network] = correlation.argmax(axis=0)

    # Set `self.head_sort` : {network, sorted_list}
    # Set `self.head_notated_sort` : {network: [(head, {other: (corr, pair)})]}
    self.head_sort = {} 
    self.head_notated_sort = {}
    for network in tqdm(self.attentions_d, desc='annotation'):
        self.head_sort[network] = sorted(
            range(self.num_heads_d[network]), 
            key=lambda i: self.op(
                self.corrs[network][other][i] for other in self.corrs[network]
            ), 
            reverse=True,
        )
        self.head_notated_sort[network] = [
            (
                head,
                {
                    other : (
                        self.corrs[network][other][head], 
                        self.pairs[network][other][head],
                    ) 
                    for other in self.corrs[network]
                }
            ) 
            for head in self.head_sort[network]
        ]

# hnb

In [14]:
# arbitrary vals for arguments
network = f1
other_network = f2

In [15]:
device = self.device
num_sentences = self.num_sentences

In [16]:
# set `total_corrs`
total_corrs = np.zeros((num_sentences, self.num_heads_d[network],
                        self.num_heads_d[other_network]))

### for idx, (attns, o_attns) loop

In [17]:
# loop variables
idx = 2
attns = self.attentions_d[network][idx]
o_attns = self.attentions_d[other_network][idx]

In [18]:
t1 = attns.to(device)
t2 = o_attns.to(device)
t11, t12, t13 = t1.size()
t21, t22, t23 = t2.size()
t1 = t1.reshape(t11, 1, t12, t13)
t2 = t2.reshape(1, t21, t22, t23)

In [19]:
# Set `corr`
nz = (t1!=0) & (t2!=0) # "nonzero"
n = nz.sum(dim=-1).float() # n = length to use
x1, x2 = t1.sum(dim=-1), t2.sum(dim=-1)
x11, x12, x22 = (t1*t1).sum(dim=-1), (t1*t2).sum(dim=-1), (t2*t2).sum(dim=-1)
num = x12-(x1*x2/n)
denom = torch.sqrt((x11-(x1*x1/n)) * (x22-(x2*x2/n)))
corr = num/denom
corr[n < 2] = 0

In [20]:
total_corrs[idx] = corr.sum(dim=-1).cpu().numpy()

In [21]:
# full loop
for idx, (attns, o_attns) in enumerate(
        zip(self.attentions_d[network],
            self.attentions_d[other_network])):
    t1 = attns.to(device)
    t2 = o_attns.to(device)
    t11, t12, t13 = t1.size()
    t21, t22, t23 = t2.size()
    t1 = t1.reshape(t11, 1, t12, t13)
    t2 = t2.reshape(1, t21, t22, t23)

    # Set `corr`
    nz = (t1!=0) & (t2!=0) # "nonzero"
    n = nz.sum(dim=-1).float() # n = length to use
    x1, x2 = t1.sum(dim=-1), t2.sum(dim=-1)
    x11, x12, x22 = (t1*t1).sum(dim=-1), (t1*t2).sum(dim=-1), (t2*t2).sum(dim=-1)
    num = x12-(x1*x2/n)
    denom = torch.sqrt((x11-(x1*x1/n)) * (x22-(x2*x2/n)))
    corr = num/denom

    corr = corr.cpu().numpy()
    corr[~np.isfinite(corr)] = 0
    total_corrs[idx] = corr.sum(axis=-1)

# full function

In [9]:
def correlation_matrix(self, network, other_network):
    device = self.device
    num_sentences = self.num_sentences

    # set `total_corrs`
    total_corrs = np.zeros((num_sentences, self.num_heads_d[network],
                            self.num_heads_d[other_network]))
    for idx, (attns, o_attns) in enumerate(
            zip(self.attentions_d[network],
                self.attentions_d[other_network])):
        t1 = attns.to(device)
        t2 = o_attns.to(device)
        t11, t12, t13 = t1.size()
        t21, t22, t23 = t2.size()
        t1 = t1.reshape(t11, 1, t12, t13)
        t2 = t2.reshape(1, t21, t22, t23)

        # Set `corr`
        nz = (t1!=0) & (t2!=0) # "nonzero"
        n = nz.sum(dim=-1).float() # n = length to use
        x1, x2 = t1.sum(dim=-1), t2.sum(dim=-1)
        x11, x12, x22 = (t1*t1).sum(dim=-1), (t1*t2).sum(dim=-1), (t2*t2).sum(dim=-1)
        num = x12-(x1*x2/n)
        denom = torch.sqrt((x11-(x1*x1/n)) * (x22-(x2*x2/n)))
        corr = num/denom

        corr = corr.cpu().numpy()
        corr[~np.isfinite(corr)] = 0
        total_corrs[idx] = corr.sum(axis=-1)

    # set `correlation`
    correlation = total_corrs.sum(axis=0)/self.num_words
    return correlation

# test

In [10]:
self.correlation_matrix = lambda n, o_n: correlation_matrix(self, n, o_n)

In [11]:
compute_correlations(self)

correlate: 100%|█████████████████████████████| 9/9 [00:00<00:00, 1016.45it/s]
annotation: 100%|████████████████████████████| 3/3 [00:00<00:00, 8548.17it/s]


In [12]:
self.similarities

{'foo': {'bar': 0.10310055208067562, 'baz': 0.12652370506940885},
 'bar': {'foo': 0.10782335860322613, 'baz': 0.12433415669370297},
 'baz': {'foo': 0.12067727221058454, 'bar': 0.11247482138416696}}