In [None]:
#@title foo
!pip install transformers==4.1.1 plotnine

In [None]:
import re
import itertools

import numpy as np
import pandas as pd

from IPython.display import HTML
import plotnine
from plotnine import *

import torch
from transformers import AutoModel, AutoTokenizer

plotnine.options.figure_size = (12, 12)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
#transformer = "distilbert-base-cased"
transformer = "bert-base-cased"
#transformer = "twmkn9/bert-base-uncased-squad2"
tokenizer = AutoTokenizer.from_pretrained(transformer)
model = AutoModel.from_pretrained(transformer, output_attentions=True, output_hidden_states=True)
model.to(device)
model.eval()
model.zero_grad()


In [None]:
sentences = [
    "You will either win or lose the game.",
    "Less is more.",
    "The quick brown fox jumped over the lazy dog.",
]
input_dict = tokenizer(sentences, padding=True, return_tensors="pt")
for k, v in input_dict.items():
    input_dict[k] = v.to(device)
print(input_dict)

In [None]:
output = model(**input_dict)

In [None]:
att = np.array([a.cpu().detach().numpy() for a in output['attentions']])
print(att.shape)

In [None]:
# sort all the attention softmax vectors in descending order
sorted = np.take_along_axis(att, (-att).argsort(), axis=-1)
print(sorted.shape)

In [None]:
# add them up cumulatively
cum = sorted.cumsum(axis=-1)
print(cum.shape)

In [None]:
# determine which ones are below 0.9
limit = np.where(cum < 0.9, True, False)
print(limit.shape)

In [None]:
# count the ones below 0.9; k is that sum + 1
k = limit.sum(axis=-1) + 1
print(k.shape)

In [None]:
# swap the 'head' and 'sentence' axes so we can more easily apply the attention mask
ks = np.swapaxes(k, 1, 2)
print(ks.shape)

In [None]:
# use the attention mask to flag the padding tokens
att_mask = input_dict['attention_mask'].cpu().detach()
mt = np.ma.MaskedArray(ks, mask = (att_mask == False).expand(ks.shape))

In [None]:
# flatten out the sentences so we're left with just a list of tokens
mr = mt.reshape(ks.shape[:2] + tuple([np.prod(ks.shape[2:])]))
print(mr.shape)

In [None]:
# find the indices of the token list we're interested in
unmasked = np.flatnonzero(att_mask)

In [None]:
# get the dimensions of the data we want
# layer X head X #tokens
l, h, v = mr[:, :, unmasked].shape
print(l, h, v)

In [None]:
# create a layer/head multiindex
ix = pd.MultiIndex.from_arrays(
    [
        np.repeat(np.arange(l) + 1,h),
        np.tile(np.arange(h) + 1, l)
    ], 
    names=['layer', 'head'])

In [None]:
# finally filter out the padding tokens, put the data in a dataframe,
# and transform it so we get one layer/head/token/k per row
data = (
        pd.DataFrame(mr[:,:,unmasked].reshape((l*h,len(unmasked))), index=ix)
            .reset_index()
            .melt(id_vars=['layer', 'head'])
    )
display(data)

In [None]:
# calculate the median k per head
avg_k = pd.DataFrame(np.median(mr[:,:,unmasked], axis=-1).flatten(), index=ix, columns=["value"]).reset_index()
display(avg_k)

In [None]:
# plot it!
(ggplot(data, aes(1, "value"))  + 
     geom_violin(fill="steelblue") + 
     geom_jitter(width=0.01, alpha=0.3) +
     geom_text(data=avg_k, mapping=aes(x=1, y=5, label="value"), color="red") +
     facet_grid("layer ~ head") + 
     coord_flip()
)