# Attention Pattern Visualization with TransformerLens

This notebook demonstrates how to visualize attention patterns in transformer models using TransformerLens.

## What You'll Learn
- Loading models with TransformerLens
- Extracting attention patterns
- Visualizing attention heads
- Identifying induction heads
- Understanding model behavior

## Prerequisites
```bash
uv add transformer-lens plotly circuitsvis
```

## References
- [TransformerLens Docs](https://transformerlensorg.github.io/TransformerLens/)
- [Anthropic's Circuits Thread](https://transformer-circuits.pub/)

In [None]:
"""Setup and imports."""
import torch
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from transformer_lens import HookedTransformer
import circuitsvis as cv

# Set device
device = (
    "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using device: {device}")

## Step 1: Load a Small Transformer Model

We'll use GPT-2 Small for this demonstration. TransformerLens provides easy access to various pretrained models.

In [None]:
"""Load model with TransformerLens."""

# Load GPT-2 Small (124M parameters)
# TransformerLens adds hooks for easy activation access
model = HookedTransformer.from_pretrained(
    "gpt2-small", center_unembed=True, center_writing_weights=True, fold_ln=True, device=device
)

print(f"Model: {model.cfg.model_name}")
print(f"Layers: {model.cfg.n_layers}")
print(f"Attention heads: {model.cfg.n_heads}")
print(f"Hidden dimension: {model.cfg.d_model}")
print(f"Vocabulary size: {model.cfg.d_vocab}")

## Step 2: Generate Text and Capture Attention

We'll run a prompt through the model and capture all attention patterns.

In [None]:
"""Run inference and capture attention patterns."""

# Test prompt with repeated pattern to trigger induction heads
prompt = "When Mary and John went to the store, Mary gave a drink to John. Then Mary and John went to the park, Mary gave a ball to"

# Tokenize
tokens = model.to_tokens(prompt)
print(f"Tokens: {tokens.shape}")
print(f"Token IDs: {tokens[0][:20]}...")  # Show first 20

# Run with caching to capture all activations
logits, cache = model.run_with_cache(tokens)

# Get attention patterns: shape [batch, layer, head, query_pos, key_pos]
attention_patterns = cache["pattern"]
print(f"\nAttention patterns shape: {attention_patterns.shape}")
print(f"  Batch: {attention_patterns.shape[0]}")
print(f"  Layers: {attention_patterns.shape[1]}")
print(f"  Heads: {attention_patterns.shape[2]}")
print(f"  Sequence length: {attention_patterns.shape[3]}")

## Step 3: Visualize Attention Patterns for a Single Head

Let's examine how a specific attention head attends to different tokens.

In [None]:
"""Visualize attention for a specific layer and head."""


def visualize_attention_head(attention_pattern, tokens, layer_idx, head_idx, model):
    """
    Create an interactive heatmap of attention patterns.

    Args:
        attention_pattern: Attention weights [query_pos, key_pos]
        tokens: Token IDs
        layer_idx: Layer index
        head_idx: Head index
        model: HookedTransformer model for token decoding
    """
    # Get attention for this head
    attn = attention_pattern[0, layer_idx, head_idx].cpu().numpy()

    # Decode tokens to strings
    str_tokens = [model.to_string(t) for t in tokens[0]]

    # Create heatmap
    fig = go.Figure(
        data=go.Heatmap(
            z=attn,
            x=str_tokens,
            y=str_tokens,
            colorscale="Blues",
            hoverongaps=False,
            hovertemplate="Query: %{y}<br>Key: %{x}<br>Attention: %{z:.3f}<extra></extra>",
        )
    )

    fig.update_layout(
        title=f"Attention Pattern - Layer {layer_idx}, Head {head_idx}",
        xaxis_title="Key Position (attending to)",
        yaxis_title="Query Position (attending from)",
        width=800,
        height=800,
    )

    return fig


# Visualize Layer 5, Head 1 (often an induction head in GPT-2)
fig = visualize_attention_head(attention_patterns, tokens, layer_idx=5, head_idx=1, model=model)
fig.show()

## Step 4: Find Induction Heads

Induction heads are attention heads that complete repeated patterns. They attend to tokens that previously followed the current context.

### How Induction Heads Work
Given: "A B ... A" → Predict "B"
- The head attends back to the previous occurrence of the current token
- Then copies the token that came after it

We'll use a specialized test to identify them.

In [None]:
"""Detect induction heads using repeated sequence test."""


def test_induction_heads(model, seq_len=50):
    """
    Test for induction heads using random repeated sequences.

    Induction heads should show high attention to [A][B] when seeing [A] again.

    Args:
        model: HookedTransformer model
        seq_len: Length of test sequence

    Returns:
        Induction scores per head [layer, head]
    """
    # Create repeated random sequence: [a, b, c, ..., a, b, c, ...]
    half_len = seq_len // 2
    random_seq = torch.randint(100, 1000, (half_len,), device=device)
    repeated_seq = torch.cat([random_seq, random_seq]).unsqueeze(0)

    # Run through model
    logits, cache = model.run_with_cache(repeated_seq)
    attention = cache["pattern"]

    # For each position in second half, check if it attends to matching token in first half
    # Induction score: attention from pos i+half_len to pos i+1
    induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)

    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            # Get attention in second half
            attn = attention[0, layer, head, half_len:, :]

            # For each query position in second half,
            # check attention to the token that came after the match in first half
            induction_attn = []
            for i in range(half_len - 1):
                query_pos = i
                # Key position should be i+1 (token after the match)
                key_pos = i + 1
                induction_attn.append(attn[query_pos, key_pos].item())

            induction_scores[layer, head] = np.mean(induction_attn)

    return induction_scores


# Test for induction heads
induction_scores = test_induction_heads(model)

# Find top induction heads
top_k = 5
flat_scores = induction_scores.flatten()
top_indices = torch.topk(flat_scores, top_k).indices

print("Top Induction Heads:")
print("=" * 50)
for idx in top_indices:
    layer = idx // model.cfg.n_heads
    head = idx % model.cfg.n_heads
    score = induction_scores[layer, head]
    print(f"Layer {layer:2d}, Head {head:2d}: Score = {score:.4f}")

## Step 5: Visualize All Attention Heads

Create a grid view of all attention heads to compare patterns.

In [None]:
"""Create grid visualization of attention heads."""


def visualize_all_heads_in_layer(attention_pattern, tokens, layer_idx, model, max_seq_len=30):
    """
    Visualize all attention heads in a layer as a grid.

    Args:
        attention_pattern: Full attention patterns
        tokens: Token IDs
        layer_idx: Which layer to visualize
        model: HookedTransformer model
        max_seq_len: Limit sequence length for readability
    """
    n_heads = model.cfg.n_heads

    # Truncate sequence if too long
    seq_len = min(max_seq_len, tokens.shape[1])
    tokens[:, :seq_len]

    # Create subplots
    rows = 3
    cols = 4  # 12 heads in GPT-2 small
    fig = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=[f"Head {i}" for i in range(n_heads)],
        vertical_spacing=0.05,
        horizontal_spacing=0.05,
    )

    # Add heatmap for each head
    for head_idx in range(n_heads):
        row = head_idx // cols + 1
        col = head_idx % cols + 1

        attn = attention_pattern[0, layer_idx, head_idx, :seq_len, :seq_len].cpu().numpy()

        fig.add_trace(
            go.Heatmap(
                z=attn,
                colorscale="Blues",
                showscale=(head_idx == n_heads - 1),  # Only show scale on last plot
                hoverongaps=False,
            ),
            row=row,
            col=col,
        )

    fig.update_layout(
        title=f"All Attention Heads - Layer {layer_idx}", height=800, width=1200, showlegend=False
    )

    return fig


# Visualize all heads in layer 5
fig = visualize_all_heads_in_layer(attention_patterns, tokens, layer_idx=5, model=model)
fig.show()

## Step 6: Attention Pattern Statistics

Analyze quantitative properties of attention patterns.

In [None]:
"""Compute attention pattern statistics."""


def analyze_attention_statistics(attention_patterns, model):
    """
    Compute various statistics about attention patterns.

    Returns:
        Dictionary of statistics per head
    """
    n_layers = model.cfg.n_layers
    n_heads = model.cfg.n_heads

    stats = {
        "entropy": torch.zeros(n_layers, n_heads),
        "max_attention": torch.zeros(n_layers, n_heads),
        "mean_distance": torch.zeros(n_layers, n_heads),
    }

    for layer in range(n_layers):
        for head in range(n_heads):
            attn = attention_patterns[0, layer, head]

            # Entropy: how spread out is attention?
            # High entropy = diffuse attention, low = focused
            entropy = -(attn * torch.log(attn + 1e-10)).sum(dim=-1).mean()
            stats["entropy"][layer, head] = entropy

            # Max attention: what's the peak attention value?
            stats["max_attention"][layer, head] = attn.max()

            # Mean distance: how far back does this head look?
            seq_len = attn.shape[0]
            distances = torch.arange(seq_len, device=attn.device).unsqueeze(0) - torch.arange(
                seq_len, device=attn.device
            ).unsqueeze(1)
            distances = distances.float().abs()
            mean_dist = (attn * distances).sum(dim=-1).mean()
            stats["mean_distance"][layer, head] = mean_dist

    return stats


# Compute statistics
stats = analyze_attention_statistics(attention_patterns, model)

# Visualize entropy across all heads
fig = go.Figure(
    data=go.Heatmap(
        z=stats["entropy"].numpy(),
        x=[f"Head {i}" for i in range(model.cfg.n_heads)],
        y=[f"Layer {i}" for i in range(model.cfg.n_layers)],
        colorscale="Viridis",
        colorbar_title="Entropy",
    )
)

fig.update_layout(
    title="Attention Entropy by Head (Lower = More Focused)",
    xaxis_title="Head",
    yaxis_title="Layer",
    height=600,
)
fig.show()

# Print summary
print("\nAttention Pattern Summary:")
print("=" * 50)
print(
    f"Most focused head: L{stats['entropy'].argmin() // model.cfg.n_heads}, "
    f"H{stats['entropy'].argmin() % model.cfg.n_heads} "
    f"(entropy: {stats['entropy'].min():.4f})"
)
print(
    f"Most diffuse head: L{stats['entropy'].argmax() // model.cfg.n_heads}, "
    f"H{stats['entropy'].argmax() % model.cfg.n_heads} "
    f"(entropy: {stats['entropy'].max():.4f})"
)
print(
    f"\nLongest-range head: L{stats['mean_distance'].argmax() // model.cfg.n_heads}, "
    f"H{stats['mean_distance'].argmax() % model.cfg.n_heads} "
    f"(mean distance: {stats['mean_distance'].max():.2f} tokens)"
)

## Step 7: Interactive Attention Visualization with CircuitsVis

Use CircuitsVis for a more polished, interactive visualization.

In [None]:
"""Use CircuitsVis for interactive attention visualization."""

# Get string tokens
str_tokens = [model.to_string(t) for t in tokens[0]]

# CircuitsVis expects attention in shape [n_heads, query_pos, key_pos]
# Let's visualize layer 5
layer_attn = attention_patterns[0, 5]  # [n_heads, seq_len, seq_len]

# Create interactive visualization
cv.attention.attention_patterns(
    tokens=str_tokens,
    attention=layer_attn.cpu().numpy(),
    attention_head_names=[f"Head {i}" for i in range(model.cfg.n_heads)],
)

## Key Findings

### What to Look For:

1. **Induction Heads** (typically in middle layers)
   - Attend to previous occurrences of current token
   - Enable in-context learning
   - Show diagonal stripe patterns in repeated sequences

2. **Previous Token Heads** (early layers)
   - Attend primarily to the previous token
   - Help with local syntax and n-gram patterns
   - Show strong diagonal patterns

3. **Positional Heads**
   - Attend based on absolute or relative position
   - Often attend to special tokens (BOS, separator)
   - May show structured geometric patterns

4. **Syntax Heads**
   - Attend to syntactically related tokens
   - Subject-verb agreement, noun-modifier relations
   - More complex, context-dependent patterns

### Next Steps:
- Try different prompts and analyze how attention changes
- Examine how induction heads behave with different repetition patterns
- Investigate which heads are important for specific tasks
- Move on to activation patching to test causal relationships

## Troubleshooting

### Model Download Issues
If model download fails:
```python
# Try with different cache directory
import os
os.environ['TRANSFORMERS_CACHE'] = '/path/to/cache'
model = HookedTransformer.from_pretrained('gpt2-small')
```

### Memory Issues
If running out of memory:
- Use shorter sequences (truncate tokens)
- Clear cache between runs: `torch.cuda.empty_cache()`
- Use smaller model: 'gpt2-small' instead of 'gpt2-medium'

### Visualization Not Showing
If plots don't appear:
- Ensure Jupyter is running in notebook mode
- Try `fig.show(renderer='browser')`
- Check plotly installation: `uv add plotly`