In [None]:
#@title foo
!pip install transformers==4.1.1 plotnine

## setting stuff up

In [None]:
import re
import itertools

import numpy as np
import pandas as pd

from IPython.display import HTML
import plotnine
from plotnine import *

import torch
from transformers import AutoModel, AutoTokenizer

plotnine.options.figure_size = (20, 20)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
# uncomment to force CPU if you have a GPU but not enough memory to do what you want. it will be slow of course

#device = torch.device("cpu")

In [None]:
#transformer = "distilbert-base-cased"
transformer = "bert-base-cased"
#transformer = "gpt2"
#transformer = "gpt2-medium"
#transformer = "gpt2-large"
#transformer = "twmkn9/bert-base-uncased-squad2"
tokenizer = AutoTokenizer.from_pretrained(transformer)
model = AutoModel.from_pretrained(transformer, output_attentions=True, output_hidden_states=True)
model.to(device)
model.eval()
model.zero_grad()


## data preparation

In [None]:
sentences = [
    "This is a really long sentence that doesn't make much sense, but let's see what happens at the end",
    "There are five subspecies of the pigeon guillemot; all subspecies, when in breeding plumage, are dark brown with a black iridescent sheen and a distinctive wing patch broken by a brown-black wedge.",
    "Buchanan, working through federal patronage appointees in Illinois, ran candidates for the legislature in competition with both the Republicans and the Douglas Democrats.",
#    "Less is more.",
]

# gpt2 doesn't do padding, so invent a padding token
# this one was suggested by the error you get when trying
# to do masking below, but it shouldn't matter as the actual
# tokens get ignored by the attention mask anyway
if transformer in ['gpt2', 'gpt2-medium', 'gpt2-large']:
    tokenizer.pad_token = tokenizer.eos_token

input_dict = tokenizer(sentences, padding=True, return_tensors="pt")
for k, v in input_dict.items():
    input_dict[k] = v.to(device)
print(input_dict)

In [None]:
output = model(**input_dict)

In [None]:
att = np.array([a.cpu().detach().numpy() for a in output['attentions']])
print(att.shape)

## getting the data ready to plot

In [None]:
# use the attention mask to flag the padding tokens
att_mask = input_dict['attention_mask'].cpu().detach()
print(att_mask.shape)
print(att_mask.sum())

In [None]:
att = att.swapaxes(2,1)
print(att.shape)

In [None]:
# make the dimension indices of the array explicit as
# a pandas dataframe MultiIndex
spec = att.shape
dims = {}
for dim, size in reversed(list(enumerate(spec))):
    if dim == len(spec) - 1:
        dims[dim] = np.arange(size) + 1
    else:
        for d in range(dim + 1, len(spec)):
            dims[d] = np.tile(dims[d], size)
        dims[dim] = np.repeat(np.arange(size) + 1, np.prod(spec[dim+1:]))

ix = pd.MultiIndex.from_arrays(list(dims.values()), names=reversed(['layer', 'head', 'sentence', 'from_token', 'to_token']))

In [None]:
print(ix)

In [None]:
df = pd.DataFrame(
        att.flatten(), # turn the array into one long list of numbers
        columns=["attention_fraction"], 
        index=ix, # indexed by its dimensions
    ).reset_index() # and then turn the dimensions into columns
display(df)

In [None]:
# filter out the masked tokens
for sentence, toklist in enumerate(att_mask.tolist()):
    # the next two lines filter out the first and last unmasked token which are [CLS] and [SEP] (for bert)
    # comment them out to see the results with them included
    final = max(np.nonzero(toklist)[0])
    modified = [0] + toklist[1:final] + [0] + toklist[final+1:]
    for token in [i for i, v in enumerate(modified) if v == 0]:
        df = df.query(f"~(sentence == {sentence + 1} & (to_token == {token + 1} | from_token == {token + 1}))")


### calculate the weighted distances and their median per head

In [None]:
df['distance'] = (df['from_token'] - df['to_token']).abs()

In [None]:
display(df)

In [None]:
# show the data for the second token of the first sentence for the first layer and the first head
# which is really the first one in the data as I filter out [CLS] above
with pd.option_context("max_rows", None):
    display(df.query("layer == 1 & head == 1 & sentence == 1 & from_token == 2").sort_values("attention_fraction", ascending=False))

In [None]:
df['weighted'] = df['distance'] * df['attention_fraction']

In [None]:
g = df.groupby(['layer', 'head'])
median_dist = (g['weighted'].median()).reset_index().round(3)

In [None]:
display(median_dist)

## plot

Cheat a bit by limiting the plots to show only the interval between 0 and 2 for the weighted distance. 

This unsquishes the violins to reveal a pattern a bit similar to the plot of `k` in the Hopfields networks
paper, but it does hide values (especially for the dot plot), which may give a false impression

you can change the interval by adjusting the limits variable

In [None]:
# plot it!
limits = (0, 2)
plotnine.options.figure_size = (20, 20)
(ggplot(df, aes(1, "weighted"))  + 
     geom_jitter(height=0, size=0.1, alpha=0.1, color="magenta") +
     geom_violin(fill="lightblue") + 
     geom_label(data=median_dist, mapping=aes(x=1.2, y=limits[1] * .75, label="weighted")) +
     scale_y_continuous(breaks=np.linspace(*limits, num=3)) +
     facet_grid("layer ~ head", labeller="label_both") + 
     coord_flip(ylim=limits) +
     labs(
             x = "",
             y = "weighted distance",
             title = "Distribution and median of distances between attending and attended tokens"
         ) +
     theme(
             axis_text_y = element_blank(),
             axis_ticks_major_y = element_blank()
         )
)

In [None]:
# from me trying to figure out why the plot looked weird when I used mean instead of median
pd.concat([
    df[(df['head'] == 1) & (df['layer'] == 5)]['weighted'].describe(),
    df[(df['head'] == 1) & (df['layer'] == 6)]['weighted'].describe(),
    df[(df['head'] == 1) & (df['layer'] == 7)]['weighted'].describe(),
], axis=1)

### subset of heads

plot only a few heads, so each facet can be bigger and it's not as neccesary to limit what is shown

In [None]:
subset = "head == 1 & layer >= 5 & layer <= 7"
plotnine.options.figure_size = (20, 6)
(ggplot(df.query(subset), aes(1, "weighted"))  + 
     geom_jitter(height=0, alpha=0.1, color="magenta") +
     geom_violin(fill="lightblue") + 
     geom_label(data=median_dist.query(subset), mapping=aes(x=1.2, y=df.query(subset)['weighted'].max() * 0.75, label="weighted")) +
#     facet_grid("layer ~ head", labeller="label_both") + 
     facet_wrap("~ layer + head", labeller="label_both") + 
     coord_flip(
#             ylim=(0,10)
         ) +
     labs(
             y = "",
             title = "Distribution and median of distances between attending and attended tokens"
         ) +
     theme(
             axis_text_y = element_blank(),
             axis_ticks_major_x = element_blank()
         )
)