In [None]:
#@title foo
#!pip install transformers==4.1.1 plotnine

## Setting stuff up

In [None]:
import re
import itertools

import numpy as np
import pandas as pd

from IPython.display import HTML
import seaborn
import matplotlib

from ahviz import create_indices, create_dataframe, filter_mask
import torch
from transformers import AutoModel, AutoTokenizer


In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
# uncomment to force CPU if you have a GPU but not enough memory to do what you want. it will be slow of course
#device = torch.device("cpu")

In [None]:
## transformer = "distilbert-base-cased"
#transformer = "bert-base-cased"
#transformer = "gpt2"
#transformer = "gpt2-medium"
transformer = "gpt2-large"
#transformer = "twmkn9/bert-base-uncased-squad2"
tokenizer = AutoTokenizer.from_pretrained(transformer)
# gpt2 doesn't do padding, so invent a padding token
# this one was suggested by the error you get when trying
# to do masking below, but it shouldn't matter as the actual
# tokens get ignored by the attention mask anyway
if transformer in ['gpt2', 'gpt2-medium', 'gpt2-large']:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModel.from_pretrained(transformer, output_attentions=True, output_hidden_states=True)
model.to(device)
model.eval()
model.zero_grad()


## data preparation

Read in the prepared data. Included in the repository is a copy of the penn treebank sample that is included in the `nltk` python package, converted into plain text and split into sentences. But you can replace this with any
text file. Since the first thing we do is join all the text, it isn't even neccessary to split it into sentences.

The script I used to create the file is `convert_corpus.py` in the repository

In [None]:
dataset_a = pd.read_csv("firsthalf.txt", sep="\t", header=None, names=["line"])

In [None]:
dataset_b = pd.read_csv("secondhalf.txt", sep="\t", header=None, names=["line"])

### To test things out, only use the first 100 lines of the datasets, so everything will go faster:

In [None]:
dataset_a = dataset_a.head(100)
dataset_b = dataset_b.head(100)

### window size and context

We move a sliding window over the complete dataset so we can always have context around the part we are looking at. This sets up how many tokens the model looks at each step, and with what step size to move through the corpus

- `window_size`:  
    the number of tokens that are in context

- `step`:  
    how many tokens we move ahead in each step through the corpus

- `future`:  
    how many tokens the model can look ahead
    
The mask printed below shows the effect of changing these values. The ones are the tokens we calculate things for, and the zeros are the extra context that the tokens of interest can pay attention to. For models like *GPT2*, `future` should be $0$, as the model only looks back 

In [None]:
window_size = 25
step = 12
future = 0

In [None]:
input_tensors = []
for half in [dataset_a, dataset_b]:
    tokenized_sents = tokenizer(half['line'].tolist(), add_special_tokens=False)['input_ids']
    if not "gpt" in transformer:
        separated = map(lambda s: s + [tokenizer.sep_token_id], tokenized_sents)
    else:
        separated = tokenized_sents
    chained = list(itertools.chain.from_iterable(separated))
    tokens = torch.tensor(chained)
    pad_len = window_size - len(tokens) % window_size
    padded = torch.cat((tokens, tokens.new_full((pad_len,), tokenizer.pad_token_id)))
    input_tensors.append(padded)

In [None]:
mask = torch.cat((torch.zeros(window_size - (step + future)), torch.ones(step), torch.zeros(future))).expand((100,-1))[0]
print(mask)

In [None]:
def get_batches(input_tensor:torch.Tensor, size: int, step: int, batch_size :int = 2):
    input_ids = input_tensor.unfold(0, size, step)
    tensor_dataset = torch.utils.data.TensorDataset(input_ids)
    tensor_dataloader = torch.utils.data.DataLoader(tensor_dataset, batch_size=batch_size)
    
    return tensor_dataloader

### distance functions

This defines a few distance functions and their name. To select a different one, change the distance variable to one of the keys in the map. The name is used below in the diffence plot title.

In [None]:
func_map = {
    'weighted': lambda d, w: d * w,
    'weighted absolute': lambda d, w: np.abs(d) * w,
    'weighted square': lambda d, w: np.square(d) * w,
}

distance = "weighted square"

In [None]:
%%time

result = None

for n, dataset in enumerate(input_tensors):
    dl = get_batches(dataset, window_size, step, batch_size=3)

    data = None
    for batch, t in enumerate(dl):
        input_dict = {k: v.to(device) for k, v in zip(["input_ids"], t)}

        output = model(**input_dict)

        att = np.array([a.cpu().detach().numpy() for a in output['attentions']])

        # swap the 'head' and 'sample' axes so they're in a more natural order
        att = np.swapaxes(att, 1, 2)
        ix = create_indices(att, sample=batch*att.shape[2])
        df = create_dataframe(att, ix)
        filtered = df[(df['from_token']>(window_size-(step+future))) & (df['from_token']<=(window_size-future)) ].copy()
        filtered['distance'] = (filtered['to_token'] - filtered['from_token'])
        filtered['sign'] = filtered['distance'] > 0
        filtered['weighted'] = func_map[distance](filtered['distance'], filtered['attention_fraction'])
        g = filtered.groupby(['layer', 'head', 'sample'])
        grouped = (g['weighted'].agg([np.mean, 'count'])).reset_index()

        if data is None:
            data = grouped
        else:
            data = pd.concat([data, grouped])

    data['dataset'] = n
    
    if result is None:
        result = data
    else:
        result = pd.concat([result, data])

df = result.reset_index(drop=True)

## getting the data ready to plot

### calculate the weighted distances and their mean per head

In [None]:
g = df.groupby(['dataset', 'layer', 'head'])
avg_dist = (g['mean'].mean()).reset_index().round(3)

In [None]:
pivoted = avg_dist.pivot(index=['layer', 'head'], columns='dataset', values="mean").reset_index()
pivoted['diff'] = pivoted[0] - pivoted[1]

In [None]:
d, l, h = avg_dist['dataset'].max() + 1, avg_dist['layer'].max(), avg_dist['head'].max()
print(d,l,h)

In [None]:
sorted_avg_dist = avg_dist.sort_values(["dataset", "layer", "mean"]) 
sorted_avg_dist['sorted_head'] = np.tile(np.tile(np.arange(h) + 1, l), d)


In [None]:
# merge this sorted_head column into the original data too
data_sh = df.merge(sorted_avg_dist[['dataset', 'layer', 'head', 'sorted_head']], on=["dataset", "layer", "head"])

## plot

### Average distances and the difference between the two datasets

In [None]:
fig, axes = matplotlib.pyplot.subplots(1, 3, figsize=(16, 6), sharey=True)
fig.suptitle(f'average {distance} distance per head for the two datasets, and the difference per head')

seaborn.heatmap(
        ax=axes[0],
        data=avg_dist[avg_dist['dataset'] == 0].pivot('layer', 'head', "mean"),
        cmap=seaborn.light_palette("seagreen", as_cmap=True)
    )
axes[0].set_title("dataset A")

seaborn.heatmap(
        ax=axes[1],
        data=pivoted.pivot(['layer'], 'head', 'diff'),
        cmap=seaborn.color_palette("coolwarm", as_cmap=True)
    )
axes[1].set_title("difference")

seaborn.heatmap(
        ax=axes[2],
        data=avg_dist[avg_dist['dataset'] == 1].pivot('layer', 'head', "mean"),
        cmap=seaborn.light_palette("seagreen", as_cmap=True)
    )
axes[2].set_title("dataset B")

matplotlib.pyplot.show()

### The distances again, but with the heads sorted by the distance

In [None]:
fig, axes = matplotlib.pyplot.subplots(1, 2, figsize=(20, 12), sharey=True)
fig.suptitle('average {distance} per head for two datasets, with the heads sorted per layer')

seaborn.heatmap(
        ax=axes[0],
        data=sorted_avg_dist[sorted_avg_dist['dataset'] == 0].pivot('layer', 'sorted_head', "mean"),
        cmap=seaborn.light_palette("seagreen", as_cmap=True)
    )
axes[0].set_title("dataset A")

seaborn.heatmap(
        ax=axes[1],
        data=sorted_avg_dist[sorted_avg_dist['dataset'] == 1].pivot('layer', 'sorted_head', "mean"),
        cmap=seaborn.light_palette("seagreen", as_cmap=True)
    )
axes[1].set_title("dataset B")

matplotlib.pyplot.show()


### and plots showing the distribution of the distance values

First with the heads in model-order, and then again with the heads sorted by average distance.

These are **really** slow, so you may want to skip running them (each cell took half an hour on my laptop)

In [None]:
%%time
def make_violin(y, **kwargs):
    v = seaborn.violinplot(y=y,x="layer", hue="dataset", split=True, **kwargs)
    data = kwargs['data']
    for dataset in range(2):
        mean = np.round(np.mean(data[data['dataset'] == dataset][y]), 2)
        v.text(-0.4 + (dataset * 0.5), 10, str(mean), fontdict=dict(color="red", fontsize=30))
    return v
g = seaborn.FacetGrid(data_sh, col="head",  row="layer", col_order=(np.arange(h) + 1), row_order=np.flip(np.arange(l) + 1))
g.map_dataframe(make_violin, "mean")

In [None]:
g = seaborn.FacetGrid(data_sh, col="sorted_head",  row="layer", col_order=(np.arange(h) + 1), row_order=np.flip(np.arange(l) + 1))
g.map_dataframe(make_violin, "value")