In [None]:
# @title foo
#!pip install transformers==4.1.1 plotnine

In [None]:
import re
import itertools

import numpy as np
import pandas as pd

from IPython.display import HTML
import seaborn
import matplotlib

from ahviz import create_indices, create_dataframe, filter_mask
import torch
from transformers import AutoModel, AutoTokenizer

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
# uncomment this if you run into memory issues on the gpu
#device = torch.device("cpu")

In [None]:
#transformer = "bert-base-cased"
#transformer = "gpt2"
#transformer = "gpt2-medium"
transformer = "gpt2-large"
#transformer = "twmkn9/bert-base-uncased-squad2"
tokenizer = AutoTokenizer.from_pretrained(transformer)

# gpt2 doesn't do padding, so invent a padding token
# this one was suggested by the error you get when trying
# to do masking below, but it shouldn't matter as the actual
# tokens get ignored by the attention mask anyway
if transformer in ['gpt2', 'gpt2-medium', 'gpt2-large']:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModel.from_pretrained(transformer, output_attentions=True, output_hidden_states=True)
model.to(device)
model.eval()
model.zero_grad()


In [None]:
dataset_a = pd.read_csv("firsthalf.txt", sep="\t", header=None, names=["line"])

In [None]:
dataset_b = pd.read_csv("secondhalf.txt", sep="\t", header=None, names=["line"])

### To test things out, only use the first 100 lines of the datasets, so everything will go faster:

In [None]:
dataset_a = dataset_a.head(100)
dataset_b = dataset_b.head(100)

In [None]:
# an example of a synthetic dataset
#dataset_b = pd.DataFrame(["one two three four five six seven eight nine ten"] * 100, columns=["line"])

In [None]:
window_size = 50
step = 25
future = 0

In [None]:
input_tensors = []
for half in [dataset_a, dataset_b]:
    tokenized_sents = tokenizer(half['line'].tolist(), add_special_tokens=False)['input_ids']
    if not "gpt" in transformer:
        separated = map(lambda s: s + [tokenizer.sep_token_id], tokenized_sents)
    else:
        separated = tokenized_sents
    chained = list(itertools.chain.from_iterable(separated))
    tokens = torch.tensor(chained)
    pad_len = window_size - len(tokens) % window_size
    padded = torch.cat((tokens, tokens.new_full((pad_len,), tokenizer.pad_token_id)))
    input_tensors.append(padded)

In [None]:
mask = torch.cat((torch.zeros(window_size - (step + future)), torch.ones(step), torch.zeros(future))).expand((100,-1))[0]

print(mask)

In [None]:
def get_batches(input_tensor:torch.Tensor, size: int, step: int, batch_size :int = 2):
    input_ids = input_tensor.unfold(0, size, step)
    tensor_dataset = torch.utils.data.TensorDataset(input_ids)
    tensor_dataloader = torch.utils.data.DataLoader(tensor_dataset, batch_size=batch_size)
    
    return tensor_dataloader

In [None]:
%%time

result = None

for n, dataset in enumerate(input_tensors):
    dl = get_batches(dataset, window_size, step, batch_size=3)

    data = None
    for batch, t in enumerate(dl):
        input_dict = {k: v.to(device) for k, v in zip(["input_ids"], t)}

        output = model(**input_dict)

        att = np.array([a.cpu().detach().numpy() for a in output['attentions']])

        # sort all the attention softmax vectors in descending order
        sorted = np.take_along_axis(att, (-att).argsort(), axis=-1)

        # add them up cumulatively
        cum = sorted.cumsum(axis=-1)

        # determine which ones are below 0.9
        limit = np.where(cum < 0.9, True, False)

        # count the ones below 0.9; k is that sum + 1
        k = limit.sum(axis=-1) + 1

        # swap the 'head' and 'sentence' axes so we can more easily apply the attention mask
        ks = np.swapaxes(k, 1, 2)

        if data is None:
            data = ks
        else:
            data = np.concatenate([data, ks], axis=2)
    ix = create_indices(data, names=['layer', 'head', 'sample', 'from_token'])
    df = create_dataframe(data, ix)
    df['dataset'] = n
    if result is None:
        result = df
    else:
        result = pd.concat([result, df])

In [None]:
filtered = result[(result['from_token']>(window_size-(step+future))) & (result['from_token']<=(window_size-future)) ].rename(columns={'attention_fraction': "value"})

In [None]:
avg_k = filtered.groupby(['dataset', 'layer', 'head']).agg(avg_k = pd.NamedAgg('value', np.median)).reset_index()

In [None]:
pivoted = avg_k.pivot(index=['layer', 'head'], columns='dataset', values="avg_k").reset_index()
pivoted['diff'] = pivoted[0] - pivoted[1]

In [None]:
fig, axes = matplotlib.pyplot.subplots(1, 3, figsize=(16, 6), sharey=True)
fig.suptitle('average K per head for the two datasets, and the difference per head')

seaborn.heatmap(ax=axes[0], data=avg_k[avg_k['dataset'] == 0].pivot('layer', 'head', "avg_k"), cmap=seaborn.light_palette("seagreen", as_cmap=True))
axes[0].set_title("dataset A")

seaborn.heatmap(ax=axes[1], data=pivoted.pivot(['layer'], 'head', 'diff'), cmap=seaborn.color_palette("coolwarm", as_cmap=True))
axes[1].set_title("difference")

seaborn.heatmap(ax=axes[2], data=avg_k[avg_k['dataset'] == 1].pivot('layer', 'head', "avg_k"), cmap=seaborn.light_palette("seagreen", as_cmap=True))
axes[2].set_title("dataset B")

matplotlib.pyplot.show()

In [None]:
avg_k.pivot(['dataset','layer'], 'head', "avg_k")

### To replicate the plot in the hopfield network paper better, add a `sorted_head` column just so we can plot the attention heads per layer sorted from small to large k

In [None]:
d, l, h = avg_k['dataset'].max() + 1, avg_k['layer'].max(), avg_k['head'].max()
print(d,l,h)

In [None]:
sorted_avg_k = avg_k.sort_values(["dataset", "layer", "avg_k"]) 
sorted_avg_k['sorted_head'] = np.tile(np.tile(np.arange(h) + 1, l), d)


In [None]:
fig, axes = matplotlib.pyplot.subplots(1, 2, figsize=(20, 12), sharey=True)
fig.suptitle('average K per head for two datasets, with the heads sorted per layer')

seaborn.heatmap(ax=axes[0], data=sorted_avg_k[sorted_avg_k['dataset'] == 0].pivot('layer', 'sorted_head', "avg_k"), cmap=seaborn.light_palette("seagreen", as_cmap=True))
axes[0].set_title("dataset A")

seaborn.heatmap(ax=axes[1], data=sorted_avg_k[sorted_avg_k['dataset'] == 1].pivot('layer', 'sorted_head', "avg_k"), cmap=seaborn.light_palette("seagreen", as_cmap=True))
axes[1].set_title("dataset B")

matplotlib.pyplot.show()


In [None]:
# merge this sorted_head column into the original data too
data_sh = filtered.merge(sorted_avg_k[['dataset', 'layer', 'head', 'sorted_head']], on=["dataset", "layer", "head"])

## More Plots

## first in the natural order of the layers/heads

In [None]:
%%time
def make_violin(y, **kwargs):
    v = seaborn.violinplot(y=y,x="layer", hue="dataset", split=True, **kwargs)
    data = kwargs['data']
    for dataset in range(2):
        median_k = np.median(data[data['dataset'] == dataset][y])
        mean = np.mean(data[data['dataset'] == dataset][y])
        v.text(-0.4 + (dataset * 0.5), 10, str(median_k), fontdict=dict(color="red", fontsize=30))
    return v
g = seaborn.FacetGrid(data_sh, col="head",  row="layer", col_order=(np.arange(h) + 1), row_order=np.flip(np.arange(l) + 1))
g.map_dataframe(make_violin, "value")

### And the heads per layer sorted by the median k, like in the hopfield networks paper

In [None]:
g = seaborn.FacetGrid(data_sh, col="sorted_head",  row="layer", col_order=(np.arange(h) + 1), row_order=np.flip(np.arange(l) + 1))
g.map_dataframe(make_violin, "value")