# Attention tokens

Exploring some behaviour in gpt2-small where attention heads appear to learn distinct sets of "attention tokens", which act as a compacted vocabulary used to describe the input for that head. Attention tokens are defined as the element-wise (Hadamard) product of each query against each key in the sequence up to and including the query position. This is then projected onto the residual stream using the head's OV matrix and unembeded to decode a specific token.

Using this approach, different heads learn different vocabulary sets, but there are some common tokens shared between all heads (e.g. " the" "<NEWLINE>", " and"). Further, heads consistently convert inputs that are structurally or semantically similar, but use different language, to similar patterns of attention tokens. The patterns that heads use to represent inputs varies - some massively contract the input sequence into chunks while others appear linear or expand the input. Intuitevely, this tracks with some known behaviours - e.g. 10.7, known to supress copying, expands input, which might be necessary for it's task given the focus on tokens and not structure.

<TODO: add image for 4.5 vs 10.7>

There may be a way to model attention tokens as non-abelian groups, where the elements are the possible continuous transformations between the discrete attention tokens. The operation would be adding a new attention token to the end of a sequence IN, which acts on IN set to create a new set of tokens OUT.

In [1]:
import torch
from transformer_lens import HookedTransformer, SVDInterpreter

from utils import *

torch.cuda.empty_cache()
torch.set_grad_enabled(False)

# NBVAL_IGNORE_OUTPUT
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [2]:
cache = run_prompts(
    model, 
    "The next sentence is false. The previous sentence is true.",
)

In [12]:
data = calculate_attns(cache, 0, 1)
feature = data[:, :, :, 2]
unique_token_pattern = unique_index_pattern(feature[0])
plot = plot_attns(unique_token_pattern)
plot = add_attn_overlay(cache, plot, data)
plot = add_axis_labels(cache, plot, data)
plot = add_token_labels(cache, plot, data, 2)
plot

In [15]:
cache = run_prompts(
    model, 
    "The next bus is late. The previous bus was early.",
)

l, h = 0, 2
data = calculate_attns(cache, l, h)
bus_plot = plot_attns(unique_index_pattern(data[0, :, :, 2].cpu()))
bus_plot = add_attn_overlay(cache, bus_plot, data)
bus_plot = add_axis_labels(cache, bus_plot, data)
bus_plot = add_token_labels(cache, bus_plot, data, 2)
bus_plot

RuntimeError: The expanded size of the tensor (13) must match the existing size (2) at non-singleton dimension 1.  Target sizes: [2, 13, 13, -1].  Tensor sizes: [2, 13, 64]

In [14]:
compare_plots(plot, bus_plot)

VBox(children=(HBox(children=(FigureWidget({
    'data': [{'coloraxis': 'coloraxis',
              'hovertempl…