# Attention tokens

Attention heads appear to learn distinct sets of "attention tokens", which act as a compacted vocabulary set used to describe the input within that head. I define dttention tokens as the element-wise (Hadamard) product of each query against each key in the input up to and including the query's position. This is then projected onto the residual stream using the head's OV matrix and unembeded to decode a specific token. This token is the attention token.

Using this approach, different heads learn different attention token sets, but there are some common tokens shared between all heads (e.g. " the" "<NEWLINE>", " and"). Further, heads consistently convert inputs that are structurally or semantically similar, but use different language, to similar patterns of attention tokens. The patterns that heads use to represent inputs varies - some massively contract the input sequence into chunks while others appear linear or expand the input. Intuitevely, this tracks with some known behaviours - e.g. 10.7, known to supress copying, expands input, which might be necessary for it's task given the focus on tokens and not structure.

At the very least, this feels like a compelling way to categorise attention heads based on how they tend to restructure given inputs. But, I hope that it may also point to a direction for understanding attention mechanisms in a more rigourous way - see the final section for some highly speculative on how some ideas from group theory and physics could be used to help further this goal.

In [1]:
import torch
from transformer_lens import HookedTransformer, SVDInterpreter

from utils import *

torch.cuda.empty_cache()
torch.set_grad_enabled(False)

# NBVAL_IGNORE_OUTPUT
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [2]:
cache = run_prompts(
    model, 
    "The next sentence is false. The previous sentence is true.",
)

l, h = 0, 10
plot_grid(plot_attn(cache, l, h), title=f"Attention tokens: {l}.{h}",
          description="Axis clockwise: query, value, key, input<br>Each attention token is given a unique color starting from 0 and increasing in the order the tokens are discovered - i.e. brighter colors are newer tokens.")

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Attention tokens: 0.10</h2>"), HBox(…

## Random inputs

Testing with random token inputs provides some insights into how the heads behave in a "default" state. It also makes it easy to change the input length without worrying about being distracted by how that affects the semantic/grammatical structure of the input.

Here are the heads of the first layer using a single random token repeated 31 times.

In [3]:
token = random.randint(0, 50257)
prompt = torch.full((1, 31), token)
cache = run_prompts(model, *model.to_string(prompt))
plots = plot_attns(cache, range(12), show_grid_labels=False, show_axis=False, show_attn_overlay=False)

plot_grid(*plots, title="Attention tokens", description=f'Input: "{cache.prompts[0][:99]}..."')

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Attention tokens</h2>"), HBox(childr…

31 is important because beyond this length some kind of criticality is reached, which dramatically increases the number of attention tokens defined for each input token. You can see this below where the final two lines account for the majority of assigned tokens. It's important to note this doesn't change the existing patterns above, but just makes it much more difficult to view them clearly using this visualization technique.

After 32 tokens, things get a bit weird. The attention token structure doesn't remain in the more complex high token count state - instead it appears to alternate between the simple and complex patterns. However, this somewhat depends on the selected random token so it's difficult to highlight a precise sequence.

Below is the same head (0.0) plotted using input from 31 up to 48 tokens. You can see the simple and complex pattern evolve twice in this range.

In [5]:
plots = []
for i in range(31, 47, 1):
    prompt = torch.full((1, i), token)
    cache = run_prompts(model, *model.to_string(prompt))
    plots += plot_attns(cache, range(1), show_grid_labels=False, show_axis=False, show_attn_overlay=False)

plot_grid(*plots, title="Random inputs between 31 and 48 tokens (0.0)", description=f'Input: "{cache.prompts[0][:99]}"')

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Random inputs between 31 and 48 toke…

### Random sequences

Using a sequence of fully random tokens produces similar results, but they are less reliable. For example, patterns will _mostly_ breakdown around 32 tokens, but it often occurs before that point. Patterns, as expected, are not really interpretable.

In [6]:
prompt = torch.randint(0, 50257, (1, 28))
cache = run_prompts(model, *model.to_string(prompt))
plots = plot_attns(cache, range(12), show_grid_labels=False, show_axis=False, show_attn_overlay=False)

plot_grid(*plots, title="Attention tokens", description=f'Input: "{cache.prompts[0][:99]}..."')

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Attention tokens</h2>"), HBox(childr…

### Random repeating sequences

Using repeating random tokens creates more interesting patterns. Although the specific local pattern will vary depending on the selected token, global symmetries clearly emerge and identifiable motifs can be seen for each head. Again, the symmetries are stable and simple up to an input length of 32 and then the same pattern observed above occurs with iterating simple and complex patterns.

In [7]:
prompt = torch.randint(0, 50257, (1, 4)).repeat(1, 7)
cache = run_prompts(model, *model.to_string(prompt))
plots = plot_attns(cache, range(12), show_grid_labels=False, show_axis=False, show_attn_overlay=False)

plot_grid(*plots, title="Attention tokens", description=f'Input: "{cache.prompts[0][:128]}"')

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Attention tokens</h2>"), HBox(childr…

## Real language inputs

Indirect object identification (IOI) is a common task used to test models, which follows this structure

> John and Mary went to the store. John gave the bag to ..."

The actual task is irrelevant here, but it's useful as an example input because it's more realistic while still being structured and preductable.

In [14]:
def generate_prompts(templates, names):
    return [
        (prompt.format(S, IO))
        for prompt, (S, IO) in product(templates, permutations(names, 2))
    ]

names = (" Mary", " John", " Alice", " Bob")
prompts = generate_prompts(
    [
        "{0} and {1} went to the store.{0} gave the bag to{1}.",
        "{0} and {1} went to the zoo.{0} gave a book to{1}."
    ],
    names
)

prompts[:5]

[' Mary and  John went to the store. Mary gave the bag to John.',
 ' Mary and  Alice went to the store. Mary gave the bag to Alice.',
 ' Mary and  Bob went to the store. Mary gave the bag to Bob.',
 ' John and  Mary went to the store. John gave the bag to Mary.',
 ' John and  Alice went to the store. John gave the bag to Alice.']

In [16]:
cache = run_prompts(model, *prompts[:1])
plots = plot_attns(cache, range(12), show_grid_labels=False, show_axis=False, show_attn_overlay=False)

plot_grid(*plots, title="Attention tokens", description=f'Input: "{cache.prompts[0][:128]}"')

: 



So far, we've only tested using toy random inputs. Using the same approach on more realistic inputs displays similar, although much harder to interpret, patterns. I'm using the imdb dataset because it offers a reasonably varied colleciton of real language use and is relatively small.

In [8]:
from datasets import load_dataset

dataset = load_dataset("imdb")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


## Specific head analysis

### 0.0