In [None]:
#@title foo
!pip install transformers==4.1.1 plotnine

In [None]:
import re
import itertools

import numpy as np
import pandas as pd

from IPython.display import HTML
import plotnine
from plotnine import *

import torch
from transformers import AutoModel, AutoTokenizer

plotnine.options.figure_size = (12, 12)

In [None]:
plotnine.options.figure_size = (20, 20)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
#device = torch.device("cpu")

In [None]:
transformer = "bert-base-cased"
#transformer = "gpt2"
#transformer = "gpt2-medium"
#transformer = "gpt2-large"
#transformer = "twmkn9/bert-base-uncased-squad2"
tokenizer = AutoTokenizer.from_pretrained(transformer)

# gpt2 doesn't do padding, so invent a padding token
# this one was suggested by the error you get when trying
# to do masking below, but it shouldn't matter as the actual
# tokens get ignored by the attention mask anyway
if transformer in ['gpt2', 'gpt2-medium', 'gpt2-large']:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModel.from_pretrained(transformer, output_attentions=True, output_hidden_states=True)
model.to(device)
model.eval()
model.zero_grad()


In [None]:
# read the prepared data.
# based on the penn treebank sample in nltk; prepared with the convert_corpus.py script
sentences = pd.read_csv("lines.csv")
sentences['length'] = sentences.line.str.split().apply(len)
display(sentences[sentences['length'] <100].length.describe())

In [None]:
sentences[sentences['length'] > 100]

In [None]:
tokenizer.convert_ids_to_tokens(tokenizer(sentences.values[1854,0])['input_ids'])

In [None]:
def get_batches(df:pd.DataFrame, tokenizer, lengths: tuple = None, batch_size :int = 2):
    
    if lengths is not None:
        subset = sentences[(sentences['length'] > lengths[0]) & (sentences['length'] < lengths[1])]
    else:
        subset = sentences
    input_dict = tokenizer(subset['line'].values.tolist(), padding=True, return_tensors="pt")
    input_ids, token_type_ids, attention_mask = input_dict.values()
    tensor_dataset = torch.utils.data.TensorDataset(input_ids, token_type_ids, attention_mask)
    tensor_dataloader = torch.utils.data.DataLoader(tensor_dataset, batch_size=batch_size)
    
    return tensor_dataloader

In [None]:
dl = get_batches(sentences, tokenizer, lengths=(25, 35), batch_size=3)

In [None]:
data = None
for batch, t in enumerate(dl):
    input_dict = {k: v.to(device) for k, v in zip(["input_ids", "token_type_ids", "attention_mask"], t)}
    
    output = model(**input_dict)

    att = np.array([a.cpu().detach().numpy() for a in output['attentions']])
    #print(att.shape)

    # sort all the attention softmax vectors in descending order
    sorted = np.take_along_axis(att, (-att).argsort(), axis=-1)

    # add them up cumulatively
    cum = sorted.cumsum(axis=-1)

    # determine which ones are below 0.9
    limit = np.where(cum < 0.9, True, False)

    # count the ones below 0.9; k is that sum + 1
    k = limit.sum(axis=-1) + 1

    # swap the 'head' and 'sentence' axes so we can more easily apply the attention mask
    ks = np.swapaxes(k, 1, 2)

    # use the attention mask to flag the padding tokens
    att_mask = input_dict['attention_mask'].cpu().detach()
    mt = np.ma.MaskedArray(ks, mask = (att_mask == False).expand(ks.shape))

    # flatten out the sentences so we're left with just a list of tokens
    mr = mt.reshape(ks.shape[:2] + tuple([np.prod(ks.shape[2:])]))

    # find the indices of the token list we're interested in
    unmasked = np.flatnonzero(att_mask)

    # get the dimensions of the data we want
    # layer × head × #tokens
    l, h, v = mr[:, :, unmasked].shape

    # create a layer/head multiindex
    ix = pd.MultiIndex.from_arrays(
        [
            np.repeat(np.arange(l) + 1,h),
            np.tile(np.arange(h) + 1, l)
        ], 
        names=['layer', 'head'])

    # finally filter out the padding tokens, put the data in a dataframe,
    # and transform it so we get one layer/head/token/k per row
    batch_data = (
            pd.DataFrame(mr[:,:,unmasked].reshape((l*h,len(unmasked))), index=ix)
                .reset_index()
                .melt(id_vars=['layer', 'head'])
        )
    batch_data['batch'] = batch
    if data is None:
        data = batch_data
    else:
        data = pd.concat([data, batch_data])


In [None]:
data.shape

In [None]:
avg_k = data.groupby(['layer', 'head']).agg(avg_k = pd.NamedAgg('value', np.median)).reset_index()

To replicate the plot in the hopfield network paper better, add a `sorted_head` column just so we can plot the attention heads per layer sorted from small to large k

In [None]:
sorted_avg_k = avg_k.sort_values(["layer", "avg_k"]) 
sorted_avg_k['sorted_head'] = np.tile(np.arange(h) + 1, l)


In [None]:
# merge this sorted_head column into the original data too
data_sh = data.merge(sorted_avg_k[['layer', 'head', 'sorted_head']], on=["layer", "head"])

In [None]:
print(data_sh)

In [None]:
# geom_violin can't deal with too much data, so instead of giving
# it the raw data, count how often each value occurs and use
# the count as weight
gdata = data_sh.groupby(['layer', 'head', 'value']).agg({'variable': 'count'}).reset_index()

# also for the sorted heads
sgdata = data_sh.groupby(['layer', 'sorted_head', 'value']).agg({'variable': 'count'}).reset_index()

In [None]:
# position of the avg_k value in the plot
ypos = gdata['value'].max() * .5

## Plots

## first in the natural order of the layers/heads

In [None]:
# plot it!
(ggplot(gdata, aes(1, "value"))  + 
     geom_violin(mapping=aes(weight="variable"), fill="pink") + 
     geom_jitter(mapping=aes(colour="variable"), alpha=0.5, size=1) +
     geom_label(data=sorted_avg_k, mapping=aes(x=1, y=ypos, label="avg_k")) +
     scale_color_continuous(cmap_name="viridis_r") +
     facet_grid("layer ~ head", labeller="label_both") + 
     theme_dark() +
     coord_flip() +
     labs(
             x = "",
             y = "k",
             title = "Distribution and median k for each attention head"
         ) +
     theme(
             axis_text_y = element_blank(),
             axis_ticks_major_y = element_blank()
         )

)

### And the heads per layer sorted by the median k, like in the hopfield networks paper

In [None]:
# plot it!
(ggplot(sgdata, aes(1, "value"))  + 
     geom_violin(mapping=aes(weight="variable"), fill="pink") + 
     geom_jitter(mapping=aes(colour="variable"), alpha=0.5, size=1) +
     geom_label(data=sorted_avg_k, mapping=aes(x=1, y=ypos, label="avg_k")) +
     scale_color_continuous(cmap_name="viridis_r") +
     facet_grid("layer ~ sorted_head", labeller="label_both") + 
     theme_dark() +
     coord_flip() +
     labs(
             x = "",
             y = "attention",
             title = "Distribution and median k for each attention head"
         ) +
     theme(
             axis_text_y = element_blank(),
             axis_ticks_major_y = element_blank()
         )

)