# Attention tokens

Attention heads appear to learn distinct sets of "attention tokens", which act as a compacted vocabulary set used to describe the input within that head. I define dttention tokens as the element-wise (Hadamard) product of each query against each key in the input up to and including the query's position. This is then projected onto the residual stream using the head's OV matrix and unembeded to decode a specific token. This token is the attention token.

Using this approach, different heads learn different attention token sets, but there are some common tokens shared between all heads (e.g. " the" "<NEWLINE>", " and"). Further, heads consistently convert inputs that are structurally or semantically similar, but use different language, to similar patterns of attention tokens. The patterns that heads use to represent inputs varies - some massively contract the input sequence into chunks while others appear linear or expand the input. Intuitevely, this tracks with some known behaviours - e.g. 10.7, known to supress copying, expands input, which might be necessary for it's task given the focus on tokens and not structure.

At the very least, this feels like a compelling way to categorise attention heads based on how they tend to restructure given inputs. But, I hope that it may also point to a direction for understanding attention mechanisms in a more rigourous way - see the final section for some highly speculative on how some ideas from group theory and physics could be used to help further this goal.

In [1]:
import torch
from transformer_lens import HookedTransformer, SVDInterpreter

from utils import *

torch.cuda.empty_cache()
torch.set_grad_enabled(False)

# NBVAL_IGNORE_OUTPUT
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Loaded pretrained model gpt2-small into HookedTransformer


The plot below allows you to explore how attention tokens are generated. Note, the colourscale is based on the order in which the attention tokens are discovered during generation meaning brighter colours represent tokens generated later in the process. This means that colours don't represent absolute token values and instead more of a relative position within the generation.

In [2]:
cache = run_prompts(model, "The next sentence is false. The previous sentence is true.")

l, h = 0, 10
plot_grid(
    plot_attn(cache, l, h), 
    title=f"Attention tokens: {l}.{h}",
    footer=f'<i>{cache.prompts[0]}</i>'
)

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Attention tokens: 0.10</h2>"), HTML(…

In [15]:
q = cache['q', l][:, :, :, :]
k = cache['k', l][:, :, h, :]
v = cache['v', l][:, :, h, :]

q.shape, k.shape, v.shape

(torch.Size([1, 46, 12, 64]), torch.Size([1, 46, 64]), torch.Size([1, 46, 64]))

To get a sense of how attention tokens are generated I want to test over a reasonably large and varied inputs. For the time being I'm using the imdb dataset because it contains a wide variety of natural language from multiple authors and is relatively small.

In [3]:
from datasets import load_dataset

dataset = load_dataset("imdb")
len(dataset["train"])

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


25000

In [4]:
inputs = random.choices(dataset["train"]["text"], k=4)

plots = []
l1, h1 = 0, 10
for text in inputs:
    prompt = ' '.join(text.split(' ')[:16])
    cache = run_prompts(model, prompt)
    plots.append(plot_attn(cache, l1, h1, hide_labels=True))

l2, h2 = 0, 11
for text in inputs:
    prompt = ' '.join(text.split(' ')[:16])
    cache = run_prompts(model, prompt)
    plots.append(plot_attn(cache, l2, h2, hide_labels=True))

plot_grid(
    *plots,
    title=f"Heads {l1}.{h1} and {l2}.{h2} for 4 random IMDB reviews<br>",
    description=f"""
    While specific patterns are hard to pin down, common motifs across the reviews are clearly visible.
    """,
    footer='<br>'.join([f'"<i>{text[:99]}...</i>"' for text in inputs]),
)

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>\n    Heads 0.10 and 0.11 for 4 rand…

In [10]:
a = [
    ('l', 'h', 'at', 'q', 'k', 'v')
]

data = calculate_attns(cache, l, h)
data.shape

torch.Size([1, 46, 46, 6])

In [5]:
inputs = random.choices(dataset["train"]["text"], k=4)

data = []
l, h = 0, 0
for text in inputs:
    prompt = ' '.join(text.split(' ')[:16])
    cache = run_prompts(model, prompt)
    attn = calculate_attns(cache, l, h)
    data.append({
        'prompt': prompt,
        **attn,
    })

TypeError: 'Tensor' object is not a mapping

### Attention token embedding dimensionality reduction

### Graph number of attention tokens by input length per head

### Graph number of attention tokens by head index per layer (and reverse)

### Analysis of attention token frequencies

## Random inputs

Testing with random token inputs provides some insights into how the heads behave in a "default" state. It also makes it easy to change the input length without worrying about being distracted by how that affects the semantic/grammatical structure of the input.

Here are the heads of the first layer using a single random token repeated 62 times.

In [None]:
token = random.randint(0, 50257)
prompt = torch.full((1, 62), token)
cache = run_prompts(model, *model.to_string(prompt))
plots = plot_attns(cache, range(12), show_grid_labels=False, show_axis=False, show_attn_overlay=False)

plot_grid(*plots, title="Attention tokens", description=f'Input: "{cache.prompts[0][:99]}..."')

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Attention tokens</h2>"), HBox(childr…

In [None]:
prompt = torch.randint(0, 50257, (1, 62))
cache = run_prompts(model, *model.to_string(prompt))
plots = plot_attns(cache, range(12), show_grid_labels=False, show_axis=False, show_attn_overlay=False)

plot_grid(*plots, title="Attention tokens", description=f'Input: "{cache.prompts[0][:99]}..."')

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Attention tokens</h2>"), HBox(childr…

In [None]:
prompt = torch.randint(0, 50257, (1, 3)).repeat(1, 21)
cache = run_prompts(model, *model.to_string(prompt))
plots = plot_attns(cache, range(12), show_grid_labels=False, show_axis=False, show_attn_overlay=False)

plot_grid(*plots, title="Attention tokens", description=f'Input: "{cache.prompts[0][:128]}"')

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Attention tokens</h2>"), HBox(childr…

Finally, let's try plotting some more complex repeating sequences using the same head.

Compare the same head across random inputs

### A note in input length

Input length is important. Up to 31 tokens the attention tokens form surprisingly compact representations of the input across each position. At 32 input tokens some kind of criticality is reached, which dramatically increases the number of attention tokens defined for each input token. You can see this below where the final two lines account for the majority of assigned tokens. It's important to note this doesn't change the existing patterns above, but just makes it much more difficult to view them clearly using this visualization technique.

After 32 tokens, things get a bit weird. The attention token structure doesn't remain in the more complex high token count state - instead it appears to alternate between the simple and complex patterns. However, this somewhat depends on the selected random token so it's difficult to highlight a precise sequence.

Below is the same head (0.0) plotted using input from 31 up to 48 tokens. You can see the simple and complex pattern evolve twice in this range.

In [8]:
token = random.randint(0, 50257)
plots = []
for i in range(31, 47, 2):
    prompt = torch.full((1, i), token)
    cache = run_prompts(model, *model.to_string(prompt))
    plots += plot_attns(cache, range(1), show_grid_labels=False, show_axis=False, show_attn_overlay=False)

plot_grid(*plots, title="Random inputs between 31 and 48 tokens (0.0)", description=f'Input: "{cache.prompts[0][:99]}"')

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Random inputs between 31 and 48 toke…

The same behaviour occurs with natural language input, but the threshold is less predictable. The point of criticality appears to align with the start of the most recent "block" of text. E.g. if it's a repeating sequence of length 4 the threshold is ~28. This isn't precise, but is relisably closer than 32 as seen with repeated and fully random inputs.

Interactive random input plot

## Specific head analysis

Random repeating sequences provides the clearest interpretability without being devoid of structure so I'm starting there. Understanding how heads evolve based on changing predictable sequences is likely to provide more general insights, but it's worth noting again the risk of these toy inputs not translating to more realistic language.

To provide some structure to the work I plan to evaluate the patterns using this approach:
- Take an input sequence (S) and convert it into a sequence of attention tokens (A)
- Given the next token (t) generate a new sequence of attention tokens using S + t
- Analyse the invariances and symmetries implied by how different sequences produce different attention tokens

### 0.0