# Attention tokens

While exploring some other behaviour I noticed that different heads consistently produce tokens from a compacted subsets of the full vocabulary set when you project the components of a head onto the residual stream without going through the MLP layers. I think this is likely heavily influenced by positional encodings.

Each head appears to use it's learn vocabulary subset (the attention tokens) to chunk inputs in ways specific to that head. Some heads compress inputs into a small number attention tokens - as few as 3-4 for 16 token inputs. Others expand and others linearly track or double with input. My intuition is that this represents how heads structurally decompose inputs to represent meaning for their specific tasks and so provides a means for classifying heads.

I also look at the macro patterns that emerge across all the heads in the model and find some striking results. For example, the t-SNE plot below shows how the embeddings of the attention tokens are distributed. The plot on the right shows how the attention tokens are shared between heads. Both are quite remarkable - to the degree that I worry this is either an embarassing bug or a trivial artefact of positional encodings.

TODO add plot image

In [1]:
import torch
from transformer_lens import HookedTransformer 
import plotly.io as pio
import matplotlib.pyplot as plt

from utils import *

torch.cuda.empty_cache()
torch.set_grad_enabled(False)

pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)

# NBVAL_IGNORE_OUTPUT
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [2]:
pio.renderers.default = "png"
plt.ioff()

<contextlib.ExitStack at 0x118451510>

The plot below allows you to explore how attention tokens are generated. Note, the colourscale is based on the order in which the attention tokens are discovered during generation meaning brighter colours represent tokens generated later in the process. This means that colours don't represent absolute token values and instead more of a relative position within the generation.

In [4]:
cache = run_prompts(
    model, 
    "The next sentence is false. The previous sentence is false.",
    "The next sentence is false. The previous sentence is true.",
)

In [None]:
plot, axs = plt.subplots(1, 2, figsize=(16, 5))
attn_data = calculate_attns(cache, 0, 1)
plot_attn(cache, attn_data, feature_index=0, ax=axs[0])
plot_attn(cache, attn_data, feature_index=1, ax=axs[1])

figure(plot, title="Attention token plot for heads in the first layer",
       footer='The heads in the first layer for the prompt "The next sentence is false. The previous sentence is true."')

In [None]:
plot, axs = plt.subplots(3, 4, figsize=(16, 5))
plt.subplots_adjust(wspace=0, hspace=0.5)
for i in range(12):
    attn_data = calculate_attns(cache, 0, i)
    plot_attn(cache, attn_data, ax=axs[i // 4, i % 4], hide_labels=True, title=f"Head {i}")

figure(plot, title="Attention token plot for heads in the first layer",
    description='The heads in the first layer for the prompt "The next sentence is false. The previous sentence is true."')

To get a sense of how attention tokens are generated I want to test over a reasonably large and varied inputs. For the time being I'm using the imdb dataset because it contains a wide variety of natural language from multiple authors and is relatively small.

The code below extracts 3 random prompts from the dataset and uses them as the input into the model. We then use the cached activations to generate a dataset of attention tokens for each component across each head. The length of each prompt is capped at 16 tokens.

In [None]:
from datasets import load_dataset

dataset = load_dataset("imdb")
inputs = random.choices(dataset["train"]["text"], k=3)
inputs = model.to_string(model.to_tokens(inputs)[:, :16])
cache = run_prompts(model, *inputs)

data = generate(cache)
df = to_df(data)
df.to_csv("imdb_example.csv", index=False)
df

Generating the attention token dataset from activations can take a while depending on how many inputs you have and how long they are. To save time, I've pre-generated a dataset of 32 prompts with a token length of 32.

In [None]:
df = load('32x32_attn.csv')
df

We can plot each of these to get a sense of how the attention tokens are used differently for each head.

In [None]:
# TODO add gallery of all 32x32 imdb token plots

While it's interesting to look at the patterns that emerge in the token plots, and I think it can be useful for interpreting head behaviour (see below), it doesn't help to understand any macro patterns that may exist across the model heads.



In [None]:
token_counts = token_freq_data(model, df, 2, (144, 32, 32, 32))
token_counts

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(16, 5), )
plot_token_frequencies(model, token_counts, colorbar_label='Head', ax=axs[0])
plot_unique_tokens_by_head(model, token_counts, s=25, colorbar_label='Layer', ax=axs[1])
plot_unique_tokens_by_layer_head(model, token_counts, s=50, colorbar_label='Layer', ax=axs[2])

figure(fig, title="Token frequencies", description="Token frequencies by layer and head", footer="The next sentence is false. The previous sentence is true.")

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16, 5))
plot_token_embeddings(model, token_counts, 'PCA', ax=axs[0])
plot_token_embeddings(model, token_counts, 'UMAP', ax=axs[1])

figure(
    fig, 
    title="Token embeddings", 
    description="Token embeddings by layer and head", 
    footer="The next sentence is false. The previous sentence is true.",
)

In [None]:
plot_token_embeddings(model, token_counts, 'TSNE', s=1)

# TODO

In [None]:
at = df[['Input token', 'Input token', 'attn']]
at.columns = ['input_token', 'attention_token', 'attention_score']
print(at)

In [None]:
import pandas as pd
import numpy as np

def get_head_layer(index, seq_length, batch_size, num_layers, num_heads):
    total_tokens = seq_length * batch_size
    layer = (index // total_tokens) % num_layers
    head = (index // (total_tokens * num_layers)) % num_heads
    return layer, head

# Assuming at is your DataFrame containing the data
seq_length = 32
batch_size = 32
num_layers = 12
num_heads = 12

# Add layer and head columns to the DataFrame
at['layer'], at['head'] = zip(*at.index.map(lambda x: get_head_layer(x, seq_length, batch_size, num_layers, num_heads)))

# Analyze subgroups for each attention head
subgroup_data = []
for layer in range(num_layers):
    for head in range(num_heads):
        subgroup_mask = (at['layer'] == layer) & (at['head'] == head)
        unique_tokens = at.loc[subgroup_mask, 'attention_token'].unique()
        subgroup_data.append({
            'Layer': layer,
            'Head': head,
            'Unique Tokens': len(unique_tokens),
            'Tokens': ','.join(map(str, unique_tokens))
        })

subgroups_df = pd.DataFrame(subgroup_data)

# Analyze shared tokens between subgroups
shared_token_data = []
for i, (layer1, head1) in enumerate(subgroups_df[['Layer', 'Head']].itertuples(index=False)):
    for layer2, head2 in subgroups_df[['Layer', 'Head']].iloc[i+1:].itertuples(index=False):
        tokens1 = set(map(float, subgroups_df[(subgroups_df['Layer'] == layer1) & (subgroups_df['Head'] == head1)]['Tokens'].iloc[0].split(',')))
        tokens2 = set(map(float, subgroups_df[(subgroups_df['Layer'] == layer2) & (subgroups_df['Head'] == head2)]['Tokens'].iloc[0].split(',')))
        shared_tokens = tokens1.intersection(tokens2)
        if len(shared_tokens) > 0:
            shared_token_data.append({
                'Subgroup 1': f"G_({layer1}, {head1})",
                'Subgroup 2': f"G_({layer2}, {head2})",
                'Shared Tokens': len(shared_tokens),
                'Tokens': ','.join(map(str, shared_tokens))
            })

shared_tokens_df = pd.DataFrame(shared_token_data)

In [None]:
shared_tokens_df_sorted = shared_tokens_df.sort_values(by='Shared Tokens', ascending=True)
shared_tokens_df_sorted


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Create a pivot table from the shared_tokens_df DataFrame
pivot_df = shared_tokens_df.pivot(index='Subgroup 1', columns='Subgroup 2', values='Shared Tokens')

# Fill NaN values with 0
pivot_df.fillna(0, inplace=True)

# Create a heatmap using seaborn
plt.figure(figsize=(12, 10))
sns.heatmap(pivot_df, annot=False, cmap='inferno', fmt='d', cbar_kws={'label': 'Number of Shared Tokens'})
plt.title('Shared Token Overlaps between Subgroups')

# Remove the axis tick labels

plt.tight_layout()
plt.show()

## Random inputs

Testing with random token inputs provides some insights into how the heads behave in a "default" state. It also makes it easy to change the input length without worrying about being distracted by how that affects the semantic/grammatical structure of the input.

Here are the heads of the first layer using a single random token repeated 62 times.

In [None]:
token = random.randint(0, 50257)
prompt = torch.full((1, 62), token)
cache = run_prompts(model, *model.to_string(prompt))
plots = plot_attns(cache, range(12), hide_labels=True)

plot_grid(*plots, title="Attention tokens", description=f'Input: "{cache.prompts[0][:99]}..."')

In [None]:
prompt = torch.randint(0, 50257, (1, 62))
cache = run_prompts(model, *model.to_string(prompt))
plots = plot_attns(cache, range(12), hide_labels=True)

plot_grid(*plots, title="Attention tokens", description=f'Input: "{cache.prompts[0][:99]}..."')

In [None]:
prompt = torch.randint(0, 50257, (1, 3)).repeat(1, 21)
cache = run_prompts(model, *model.to_string(prompt))
plots = plot_attns(cache, range(12), show_grid_labels=False, show_axis=False, show_attn_overlay=False)

plot_grid(*plots, title="Attention tokens", description=f'Input: "{cache.prompts[0][:128]}"')

### A note in input length

Input length is important. Up to 31 tokens the attention tokens form surprisingly compact representations of the input across each position. At 32 input tokens some kind of criticality is reached, which dramatically increases the number of attention tokens defined for each input token. You can see this below where the final two lines account for the majority of assigned tokens. It's important to note this doesn't change the existing patterns above, but just makes it much more difficult to view them clearly using this visualization technique.

After 32 tokens, things get a bit weird. The attention token structure doesn't remain in the more complex high token count state - instead it appears to alternate between the simple and complex patterns. However, this somewhat depends on the selected random token so it's difficult to highlight a precise sequence.

Below is the same head (0.0) plotted using input from 31 up to 48 tokens. You can see the simple and complex pattern evolve twice in this range.

In [None]:
token = random.randint(0, 50257)
plots = []
for i in range(31, 47, 2):
    prompt = torch.full((1, i), token)
    cache = run_prompts(model, *model.to_string(prompt), prepend_bos=False)
    plots += plot_attns(cache, range(1), hide_labels=True, prepend_bos=False)

plot_grid(*plots, title="Random inputs between 31 and 48 tokens (0.0)", description=f'Input: "{cache.prompts[0][:99]}"')

The same behaviour occurs with natural language input, but the threshold is less predictable. The point of criticality appears to align with the start of the most recent "block" of text. E.g. if it's a repeating sequence of length 4 the threshold is ~28. This isn't precise, but is relisably closer than 32 as seen with repeated and fully random inputs.

In [None]:
plots = []
tokens = model.to_tokens(random.choice(dataset["train"]["text"]))
for i in range(31, 47, 2):
    prompt = tokens[:, :i]
    cache = run_prompts(model, *model.to_string(prompt))
    plots += plot_attns(cache, range(1), hide_labels=True)

plot_grid(*plots)

## Specific head analysis

Random repeating sequences provides the clearest interpretability without being devoid of structure so I'm starting there. Understanding how heads evolve based on changing predictable sequences is likely to provide more general insights, but it's worth noting again the risk of these toy inputs not translating to more realistic language.

To provide some structure to the work I plan to evaluate the patterns using this approach:
- Take an input sequence (S) and convert it into a sequence of attention tokens (A)
- Given the next token (t) generate a new sequence of attention tokens using S + t
- Analyse the invariances and symmetries implied by how different sequences produce different attention tokens

### 0.0