# Attention tokens

While exploring some other behaviour I noticed that different heads consistently produce tokens from a compacted subsets of the full vocabulary set when you project the components of a head onto the residual stream without going through the MLP layers. I think this is likely heavily influenced by positional encodings.

Each head appears to use it's learn vocabulary subset (the attention tokens) to chunk inputs in ways specific to that head. Some heads compress inputs into a small number attention tokens - as few as 3-4 for 16 token inputs. Others expand and others linearly track or double with input. My intuition is that this represents how heads structurally decompose inputs to represent meaning for their specific tasks and so provides a means for classifying heads.

Macro patterns across heads also emerge - particularly when considering the element-wise product of q and k. The plots below show the t-SNE for the embeddings of the element-wise tokens with colors representing layers and the shared element-wise tokens between heads.

<div style="display: flex; justify-content: center;">
    <img src="./imdb-tsne.png" alt="Image 1" style="width: 45%; margin-right: 10px;">
    <img src="./shared-token_imdb.png" alt="Image 2" style="width: 45%;">
</div>

In [1]:
import torch
from transformer_lens import HookedTransformer 
import plotly.io as pio
import matplotlib.pyplot as plt

from utils import *

torch.cuda.empty_cache()
torch.set_grad_enabled(False)

pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)

# NBVAL_IGNORE_OUTPUT
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [12]:
pio.renderers.default = "png"
plt.ioff()

<contextlib.ExitStack at 0x2b3112f20>

The plot below allows you to explore how attention tokens are generated. Note, the colourscale is based on the order in which the attention tokens are discovered during generation meaning brighter colours represent tokens generated later in the process. This means that colours don't represent absolute token values and instead more of a relative position within the generation.

In [3]:
prompt = "The next sentence is false. The previous sentence is true."
cache = run_prompts(model, prompt)

In [13]:
l, h = 0, 1
attn_data = calculate_attns(cache, l, h)
plot = plot_attn(model, attn_data)

figure(
    plot, 
    title=f"Attention tokens for {l}.{h}", 
    description="""
    Each cell represents the element-wise product of the Q and K components of the attention head. Colors increase in brightness based on the order in which the attention tokens were discovered in the raw input.
    Axis clockwise from top: query, key, value, and input tokens.
    Attention scores for the final position are overlaid as white borders.
    """,
    footer=f'Input: {prompt}',
)

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Attention tokens for 0.1</h2>"), HTM…

In [14]:
plot, axs = plt.subplots(3, 4, figsize=(12, 8))
plt.subplots_adjust(wspace=0, hspace=0.5)
for i in range(12):
    attn_data = calculate_attns(cache, l, i)
    plot_attn(cache, attn_data, ax=axs[i // 4, i % 4], hide_labels=True, title=f"Head {i}")

figure(
    plot, 
    title=f"Attention token plots for each head in layer {l}",
    description="""
    Note that the colors are relative and don't represent the same atttention token across heads. 
    Instead the patterns show how each head uses the tokens within its vocabulary set.
    """,
    footer=f'Input: {prompt}',
)

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Attention token plots for each head …

To get a sense of how attention tokens are generated I want to test over a reasonably large and varied inputs. For the time being I'm using the imdb dataset because it contains a wide variety of natural language from multiple authors and is relatively small.

The code below extracts 3 random prompts from the dataset and uses them as the input into the model. We then use the cached activations to generate a dataset of attention tokens for each component across each head. The length of each prompt is capped at 16 tokens.

In [None]:
from datasets import load_dataset

dataset = load_dataset("imdb")
inputs = random.choices(dataset["train"]["text"], k=3)
inputs = model.to_string(model.to_tokens(inputs)[:, :16])
cache = run_prompts(model, *inputs)

data = generate(cache)
df = to_df(data)
df.to_csv("imdb_example.csv", index=False)
df

Generating the attention token dataset from activations can take a while depending on how many inputs you have and how long they are. To save time, I've pre-generated a dataset of 32 prompts with a token length of 32 from the imdb dataset.

In [38]:
df = load('../data/32x32_attn.csv')
df

Unnamed: 0,layer,head,Input token,attn,hp,q,k,v
0,0.0,0.0,50256.0,1.000000,357.0,11.0,11.0,262.0
1,-1.0,-1.0,-1.0,-1.000000,-1.0,-1.0,-1.0,-1.0
2,-1.0,-1.0,-1.0,-1.000000,-1.0,-1.0,-1.0,-1.0
3,-1.0,-1.0,-1.0,-1.000000,-1.0,-1.0,-1.0,-1.0
4,-1.0,-1.0,-1.0,-1.000000,-1.0,-1.0,-1.0,-1.0
...,...,...,...,...,...,...,...,...
2509051,11.0,11.0,379.0,0.002832,47992.0,10473.0,47992.0,2887.0
2509052,11.0,11.0,379.0,0.002169,47992.0,10473.0,47992.0,2887.0
2509053,11.0,11.0,379.0,0.004394,47992.0,10473.0,47992.0,2887.0
2509054,11.0,11.0,379.0,0.016218,23712.0,10473.0,47992.0,2887.0


While it's interesting to look at the patterns that emerge in the token plots, and I think it can be useful for interpreting head behaviour (see below), it doesn't help to understand any macro patterns that may exist across the model heads.

Calculating fequencies of each token shows a dramatic compaction down to ~1200 tokens for the element-wise product and ~10000 for the Q and K components and ~25000 for V. These numbers seem relatively invariant to actual token input and length - see random token experiments below for more details.

In [39]:
token_counts = token_freq_data(model, df, 4, (144, 32, 32, 32))
q_token_counts = token_freq_data(model, df, 5, (144, 32, 32, 32))
k_token_counts = token_freq_data(model, df, 6, (144, 32, 32, 32))
v_token_counts = token_freq_data(model, df, 7, (144, 32, 32, 32))
token_counts

Unnamed: 0,Head,Token,Frequency,Token str,Rank
0,0,11,1325,",",4.0
1,0,12,461,-,6.0
2,0,198,82,\n,12.0
3,0,287,94,in,11.0
4,0,290,328,and,7.0
...,...,...,...,...,...
1207,76,9228,223,burgh,5.0
1208,76,22150,292,ultz,4.0
1209,76,23712,3923,eday,2.0
1210,76,30044,407,Jude,3.0


In [43]:
fig, axs = plt.subplots(4, 3, figsize=(16, 16))
plot_token_frequencies(model, q_token_counts, ax=axs[0][0])
plot_unique_tokens_by_head(model, q_token_counts, ax=axs[0][1])
plot_unique_tokens_by_layer_head(model, q_token_counts, ax=axs[0][2])
plot_token_frequencies(model, k_token_counts, ax=axs[1][0])
plot_unique_tokens_by_head(model, k_token_counts, ax=axs[1][1])
plot_unique_tokens_by_layer_head(model, k_token_counts, ax=axs[1][2])
plot_token_frequencies(model, v_token_counts, ax=axs[2][0])
plot_unique_tokens_by_head(model, v_token_counts, ax=axs[2][1])
plot_unique_tokens_by_layer_head(model, v_token_counts, ax=axs[2][2])
plot_token_frequencies(model, token_counts, ax=axs[3][0])
plot_unique_tokens_by_head(model, token_counts, ax=axs[3][1])
plot_unique_tokens_by_layer_head(model, token_counts, ax=axs[3][2])
figure(
    fig,
    title="Token frequencies",
    description="""
    From left to right: token frequencies, attention token count by head, and attention token count by layer.
    From top to bottom: query, key, value, and element-wise attention tokens.
    Colors represent layer index.
    """,
    footer="IMDB dataset",
)

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Token frequencies</h2>"), HTML(value…

I've tried various dimensionality reduction techniques, but t-SNE offers by far the most striking visualizations. The plots are unusualy for t-SNE and suggest a highly stucutured dataset, but this is generated from randomly selected snippets natural language.

In [40]:
fig, axs = plt.subplots(2, 2, figsize=(12, 8))
plot_token_embeddings(model, token_counts, 'TSNE', ax=axs[0][0])
plot_token_embeddings(model, q_token_counts, 'TSNE', ax=axs[0][1])
plot_token_embeddings(model, k_token_counts, 'TSNE', ax=axs[1][0])
plot_token_embeddings(model, v_token_counts, 'TSNE', ax=axs[1][1])

figure(
    fig,
    title="t-SNE for attention token embeddings",
    description="""
    t-SNE representation of the embeddings of the element-wise, query, key, and value attention tokens from top clockwise.
    Colours represent the layer index.
    """,
    footer="IMDB dataset",
)

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>t-SNE for attention token embeddings…

Calculating token frequencies shows that patterns exist in how attention tokens are shared between heads - at least one token is shared between every head pair and at most 4 are shared. This forms a pattern that could point to how heads collaborate under different tasks

In [42]:
hp = shared_tokens(df)
hp.sort_values('Shared Tokens', ascending=False)

Unnamed: 0,Subgroup 1,Subgroup 2,Shared Tokens,Tokens
7928,,²,4,"290.0,11.0,198.0,262.0"
469,D,,4,"257.0,290.0,11.0,262.0"
2555,V,,4,"649.0,290.0,11.0,262.0"
7312,,²,4,"290.0,11.0,198.0,262.0"
3027,Z,,4,"15961.0,257.0,11.0,262.0"
...,...,...,...,...
7364,,©,1,11.0
7365,,ª,1,11.0
7366,,«,1,11.0
7368,,­,1,11.0


In [24]:
q = shared_tokens(df, 'q')
k = shared_tokens(df, 'k')
v = shared_tokens(df, 'v')

fig, axs = plt.subplots(2, 2, figsize=(12, 10))
visualize_shared_tokens(hp, ax=axs[0][0])
visualize_shared_tokens(q, ax=axs[0][1])
visualize_shared_tokens(k, ax=axs[1][0])
visualize_shared_tokens(v, ax=axs[1][1])

figure(
    fig,
    title="Shared tokens",
    description="""
    Heatmap of how tokens are shared between pairs of heads for element-wise, query, key, and value tokens from top clockwise.
    """,
    footer="IMDB dataset",
)


Glyph 128 (\x80) missing from current font.


Glyph 134 (\x86) missing from current font.


Glyph 140 (\x8c) missing from current font.


Glyph 146 (\x92) missing from current font.


Glyph 152 (\x98) missing from current font.


Glyph 158 (\x9e) missing from current font.


Glyph 130 (\x82) missing from current font.


Glyph 136 (\x88) missing from current font.


Glyph 142 (\x8e) missing from current font.


Glyph 148 (\x94) missing from current font.


Glyph 154 (\x9a) missing from current font.


Glyph 135 (\x87) missing from current font.


Glyph 145 (\x91) missing from current font.


Glyph 150 (\x96) missing from current font.


Glyph 155 (\x9b) missing from current font.


Glyph 130 (\x82) missing from current font.


Glyph 140 (\x8c) missing from current font.


Glyph 131 (\x83) missing from current font.


Glyph 141 (\x8d) missing from current font.


Glyph 151 (\x97) missing from current font.


Glyph 156 (\x9c) missing from current font.


Glyph 136 (\x88) missing from cur

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Shared tokens</h2>"), HTML(value="<p…

### Random inputs

Testing with random token inputs provides some insights into how the heads behave in a "default" state. It also makes it easy to change the input length without worrying about being distracted by how that affects the semantic/grammatical structure of the input.

Here are the heads of the first layer using a single random token repeated 62 times.

In [19]:
token = random.randint(0, 50257)
prompt = torch.full((1, 31), token)
cache = run_prompts(model, *model.to_string(prompt))

plot, axs = plt.subplots(3, 4, figsize=(16, 12))
plt.subplots_adjust(wspace=0, hspace=0.5)
for i in range(12):
    attn_data = calculate_attns(cache, 0, i)
    plot_attn(cache, attn_data, ax=axs[i // 4, i % 4], hide_labels=True, title=f"Head {i}")

figure(plot, title="Random single repeating token")

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Random single repeating token</h2>")…

In [20]:
prompt = torch.randint(0, 50257, (1, 31))
cache = run_prompts(model, *model.to_string(prompt))

plot, axs = plt.subplots(3, 4, figsize=(16, 12))
plt.subplots_adjust(wspace=0, hspace=0.5)
for i in range(12):
    attn_data = calculate_attns(cache, 0, i)
    plot_attn(cache, attn_data, ax=axs[i // 4, i % 4], hide_labels=True, title=f"Head {i}")

figure(plot, title="Random tokens")

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Random tokens</h2>"), Output()), lay…

In [21]:
prompt = torch.randint(0, 50257, (1, 3)).repeat(1, 10)
cache = run_prompts(model, *model.to_string(prompt))

plot, axs = plt.subplots(3, 4, figsize=(16, 12))
plt.subplots_adjust(wspace=0, hspace=0.5)
for i in range(12):
    attn_data = calculate_attns(cache, 0, i)
    plot_attn(cache, attn_data, ax=axs[i // 4, i % 4], hide_labels=True, title=f"Head {i}")

figure(plot, title="Random repeating 3-seq token")

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Random repeating 3-seq token</h2>"),…

While the specific patterns change, if you run this multiple times you can see clear motifs emerge that identify specific heads. Again, to get a better idea of how the attention tokens a distributed across heads we can look at the t-SNE plots and shared tokens. The t-SNE plots show similar patterns for the element-wise tokens, but much simpler more structured outputs for q, k and v tokens. This makes sense given the constrained input/output of the random sequence.

In [48]:
df = load('../data/32x32-random_attn.csv')
token_counts = token_freq_data(model, df, 4, (144, 32, 32, 32))
q_token_counts = token_freq_data(model, df, 5, (144, 32, 32, 32))
k_token_counts = token_freq_data(model, df, 6, (144, 32, 32, 32))
v_token_counts = token_freq_data(model, df, 7, (144, 32, 32, 32))
v_token_counts

Unnamed: 0,Head,Token,Frequency,Token str,Rank
0,0,11,62,",",2.0
1,0,262,31,the,3.0
2,0,287,16743,in,1.0
3,1,11,31,",",3.0
4,1,262,15701,the,1.0
...,...,...,...,...,...
1252,152,41744,10,antid,5.0
1253,152,43617,15456,Whit,1.0
1254,152,46800,855,TMZ,2.0
1255,153,16100,4,retty,2.0


In [47]:
fig, axs = plt.subplots(2, 2, figsize=(12, 8))
plot_token_embeddings(model, token_counts, 'TSNE', ax=axs[0][0])
plot_token_embeddings(model, q_token_counts, 'TSNE', ax=axs[0][1])
plot_token_embeddings(model, k_token_counts, 'TSNE', ax=axs[1][0])
plot_token_embeddings(model, v_token_counts, 'TSNE', ax=axs[1][1])

figure(
    fig,
    title="t-SNE for attention token embeddings",
    description="""
    t-SNE representation of the embeddings of the element-wise, query, key, and value attention tokens from top clockwise.
    Colours represent the layer index.
    """,
    footer="Random sequence dataset",
)

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>t-SNE for attention token embeddings…

The shared token plots are similar to the natural language examples, but there's less structure and a slight hyperbolic nature to some regions.

In [49]:
hp = shared_tokens(df)
q = shared_tokens(df, 'q')
k = shared_tokens(df, 'k')
v = shared_tokens(df, 'v')

fig, axs = plt.subplots(2, 2, figsize=(12, 10))
visualize_shared_tokens(hp, ax=axs[0][0])
visualize_shared_tokens(q, ax=axs[0][1])
visualize_shared_tokens(k, ax=axs[1][0])
visualize_shared_tokens(v, ax=axs[1][1])

figure(
    fig,
    title="Shared tokens",
    description="""
    Heatmap of how tokens are shared between pairs of heads for element-wise, query, key, and value tokens from top clockwise.
    """,
    footer="Random sequence dataset",
)


Glyph 130 (\x82) missing from current font.


Glyph 135 (\x87) missing from current font.


Glyph 141 (\x8d) missing from current font.


Glyph 145 (\x91) missing from current font.


Glyph 153 (\x99) missing from current font.


Glyph 159 (\x9f) missing from current font.


Glyph 129 (\x81) missing from current font.


Glyph 133 (\x85) missing from current font.


Glyph 139 (\x8b) missing from current font.


Glyph 144 (\x90) missing from current font.


Glyph 149 (\x95) missing from current font.


Glyph 157 (\x9d) missing from current font.


Glyph 140 (\x8c) missing from current font.


Glyph 150 (\x96) missing from current font.


Glyph 156 (\x9c) missing from current font.


Glyph 130 (\x82) missing from current font.


Glyph 135 (\x87) missing from current font.


Glyph 145 (\x91) missing from current font.


Glyph 131 (\x83) missing from current font.


Glyph 136 (\x88) missing from current font.


Glyph 146 (\x92) missing from current font.


Glyph 151 (\x97) missing from cur

VBox(children=(HTML(value="<h2 style='font-size: 14; text-align: center;'>Shared tokens</h2>"), HTML(value="<p…

## Specific head analysis

Random repeating sequences provides the clearest interpretability without being devoid of structure so I'm starting there. Understanding how heads evolve based on changing predictable sequences is likely to provide more general insights, but it's worth noting again the risk of these toy inputs not translating to more realistic language.

To provide some structure to the work I plan to evaluate the patterns using this approach:
- Take an input sequence (S) and convert it into a sequence of attention tokens (A)
- Given the next token (t) generate a new sequence of attention tokens using S + t
- Analyse the invariances and symmetries implied by how different sequences produce different attention tokens

### 0.0