In [None]:
import matplotlib

import torch

import matplotlib.pyplot as plt
import numpy as np

from transformers import LlamaTokenizer
%matplotlib inline

In [None]:
plt.rcParams.update({
        "pgf.texsystem": "pdflatex",
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": "Helvetica",
        'pgf.rcfonts': False,
    })

plt.style.use('seaborn-v0_8-ticks')

matplotlib.rcParams.update({'font.size': 18})

plt.locator_params(nbins=4)

tokenizer = LlamaTokenizer.from_pretrained("./")

In [8]:
### Loading data. Can use the unablated or ablated data. In the latter case we zero ablated heads 6:23 and 13:15. 

input = torch.load('addition/test/inputs.pt', map_location='cpu').int()
gen = torch.load('addition/test/generations.pt', map_location='cpu')
data = torch.load('addition/test/data.pt', map_location='cpu') ## input_text, correctness, acc, counter, av_att, std_att

input_text, correctness, acc, counter, av_att, std_att = data

In [None]:
### Plotting the count of heads with (part of) a staircase type attention pattern.

plt.imshow(counter)
plt.colorbar()

In [None]:
### Relevancy of heads and percentage the first leading heads take up and how many of 1024 get turned on during inference.

sorted_ = torch.sort(counter.view(-1), descending=True)

sum(sorted_[0][:2]) / sum(sorted_[0]), sum(sorted_[0] != 0.0)

In [None]:
## Accuracy of the model on the addition task. 

acc

# Figure 6

In [None]:
### Generating ticks for plotting

example = 0
seq_len = input[example].shape[0]

sorted_ = torch.sort(counter.view(-1), descending=True)

l0 = [min(torch.argwhere(input[i, :] == 29901)) for i in range(input.shape[0])] # To get tokens of first integer. 29901 decodes to ":".
l2 = [torch.argwhere(input[i, ] == 353) for i in range(input.shape[0])] # Final token before addition sum starts. 353 decodes to "=".

ll0 = l0[example].item() + 2
ll2 = l2[example].item() 

ticks = [('$*$' if (tokenizer.decode(input[example][i]).isdigit()) else tokenizer.decode(input[example][i])) for i in range(ll0, ll2+1)] + ['$\cdot$'] * 7 + ['$*$'] * 5

## Plotting the attention patters, mean and variance.

fig, ax = plt.subplots(2, 2, figsize=(8, 8))
fig.tight_layout()

for i in range(2):
    coord = torch.argwhere(counter == sorted_[0][i])[0]
    ax[i, 0].imshow(av_att[i], cmap='Blues')

    ax[i, 0].set_xticks(range(len(ticks)), ticks, rotation='vertical')
    ax[i, 0].set_yticks(range(len(ticks)), ticks)
    ax[i, 0].set_title('\\rm Mean attention')


    ax[i, 1].imshow(std_att[i]**2, cmap='Blues')

    ax[i, 1].set_xticks(range(len(ticks)), ticks, rotation='vertical')
    ax[i, 1].set_yticks(range(len(ticks)), ticks)
    ax[i, 1].set_title('\\rm Variance attention')


    row = f'$\\rm Head\;{{a}}:{{b}}$'.format(a = coord[0], b = coord[1])

    ax[i, 0].annotate(row, xy=(0, 0), xytext=(-ax[i, 0].yaxis.labelpad - 5, 0),
                xycoords=ax[i, 0].yaxis.label, textcoords='offset points',
                size='large', ha='right', va='center')