# Discovering Induction Heads in Small Language Models

**Goal**: Replicate Anthropic's discovery of induction heads using our interpretability toolkit.

**What are induction heads?**

Induction heads are circuits in transformers that enable in-context learning. They implement a pattern-matching algorithm:

```
Given sequence: [A][B]...[A] → Predict [B]
```

When the model sees a repeated token [A], an induction head attends back to what came after the previous occurrence of [A], allowing it to predict [B].

**References**:
- [In-context Learning and Induction Heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html) (Anthropic, 2022)
- [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html) (Anthropic, 2021)

## Setup

In [None]:
import sys
import os

# Add parent directory to path
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Import nanoGPT
from model import GPT, GPTConfig

# Import our interpretability toolkit
from interpretability import activation_patching, attention_analysis, logit_lens
from interpretability.utils import ActivationCache, HookManager

# Notebook settings
%load_ext autoreload
%autoreload 2

plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 1. Load or Train a Model

For this experiment, we'll use a small character-level GPT trained on Shakespeare.
This model is small enough to analyze thoroughly but large enough to exhibit interesting behaviors.

In [None]:
# Option A: Load a pretrained checkpoint
checkpoint_path = '../trained_models/shakespeare_char_model.pt'

if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Create model from checkpoint
    model_args = checkpoint['model_args']
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
    model.load_state_dict(checkpoint['model'])
    model.to(device)
    model.eval()
    
    print(f"Model loaded: {model.get_num_params()/1e6:.2f}M parameters")
    
else:
    print("No checkpoint found. You'll need to train a model first.")
    print("Run: python train.py config/train_shakespeare_char.py --max_iters=5000")
    print("Then copy the checkpoint to trained_models/shakespeare_char_model.pt")
    
    # For demonstration, create a small model (won't have learned induction yet)
    print("\nCreating untrained model for demonstration...")
    config = GPTConfig(
        block_size=256,
        vocab_size=65,  # Shakespeare character vocab
        n_layer=6,
        n_head=6,
        n_embd=384,
        dropout=0.0,
        bias=False
    )
    model = GPT(config)
    model.to(device)
    model.eval()
    print("Note: Untrained model won't show induction head behavior!")

## 2. Create Test Sequences for Induction

We'll create sequences with repeated patterns to test for induction behavior.

In [None]:
# Create simple induction test sequences
# Pattern: [A][B][C][A][?] -> Should predict [B]

def create_induction_sequence(vocab_size=65, seq_len=20, repeat_at=10):
    """
    Create a sequence with a repeated pattern for testing induction.
    
    Args:
        vocab_size: Size of vocabulary
        seq_len: Total sequence length
        repeat_at: Where to start the repeat pattern
    
    Returns:
        sequence: Tensor of token IDs
        repeat_token: The token that gets repeated
        target_token: The token that should be predicted after repeat
    """
    # Create random initial sequence
    sequence = torch.randint(0, vocab_size, (seq_len,))
    
    # Pick a position before repeat_at to copy
    copy_from = repeat_at // 2
    
    # Copy pattern: sequence[copy_from] appears at repeat_at
    # We want to test if model predicts sequence[copy_from+1]
    sequence[repeat_at] = sequence[copy_from]
    
    repeat_token = sequence[copy_from].item()
    target_token = sequence[copy_from + 1].item()
    
    return sequence.unsqueeze(0), repeat_token, target_token

# Create test sequence
test_seq, repeat_tok, target_tok = create_induction_sequence(
    vocab_size=model.config.vocab_size,
    seq_len=30,
    repeat_at=20
)

print(f"Test sequence: {test_seq.squeeze().tolist()}")
print(f"Repeat token: {repeat_tok}")
print(f"Target token (should be predicted): {target_tok}")
print(f"\nSequence visualization:")
print(f"Position 10: {test_seq[0, 10].item()} (A)")
print(f"Position 11: {test_seq[0, 11].item()} (B) <- This is what we want to predict")
print(f"Position 20: {test_seq[0, 20].item()} (A again)")
print(f"Position 21: ??? <- Model should predict (B)")

## 3. Test Model's Induction Capability

First, let's see if the model actually performs induction.

In [None]:
# Get model predictions
test_seq_device = test_seq.to(device)

with torch.no_grad():
    logits, _ = model(test_seq_device)
    
# Get prediction at position after the repeat
next_token_logits = logits[0, -1, :]  # Last position
probs = torch.softmax(next_token_logits, dim=-1)

# Get top 5 predictions
top_probs, top_tokens = torch.topk(probs, 5)

print("Top 5 predictions:")
for i, (token, prob) in enumerate(zip(top_tokens, top_probs)):
    marker = " ← TARGET!" if token.item() == target_tok else ""
    print(f"{i+1}. Token {token.item()}: {prob.item():.2%}{marker}")

# Check if target is in top 5
if target_tok in top_tokens:
    rank = (top_tokens == target_tok).nonzero(as_tuple=True)[0].item() + 1
    prob = probs[target_tok].item()
    print(f"\n✓ Model shows induction! Target token ranked #{rank} with {prob:.2%} probability")
else:
    print(f"\n✗ Model did not predict target token in top 5")
    print(f"  Target probability: {probs[target_tok].item():.2%}")

## 4. Find Induction Heads Using Activation Patching

Now we'll use activation patching to systematically identify which layers and heads are responsible for induction behavior.

In [None]:
# Create clean and corrupted inputs
# Clean: normal induction sequence
# Corrupted: sequence where repeat pattern is broken

clean_seq, repeat_tok, target_tok = create_induction_sequence()
corrupted_seq = clean_seq.clone()
# Break the pattern by changing the repeated token
corrupted_seq[0, 20] = (corrupted_seq[0, 20] + 5) % model.config.vocab_size

clean_seq = clean_seq.to(device)
corrupted_seq = corrupted_seq.to(device)

print("Running activation patching scan...")
print("This identifies which layers are important for induction.\n")

# Scan all transformer layers
results = activation_patching.patch_layer_scan(
    model,
    clean_input=clean_seq,
    corrupted_input=corrupted_seq,
    show_progress=True
)

# Find most important layers
important = activation_patching.find_important_components(
    results,
    threshold=0.3,
    top_k=5
)

print("\nMost important layers for induction:")
for layer_name, effect in important:
    print(f"  {layer_name}: {effect:.3f}")

## 5. Visualize Activation Patching Results

In [None]:
# Plot patching effects by layer
layer_effects = [(int(name.split('.')[2]), result.effect) 
                 for name, result in results.items()]
layer_effects.sort(key=lambda x: x[0])

layers = [x[0] for x in layer_effects]
effects = [x[1] for x in layer_effects]

plt.figure(figsize=(12, 6))
plt.bar(layers, effects, color='steelblue', alpha=0.7)
plt.axhline(y=0.5, color='red', linestyle='--', label='High importance threshold')
plt.xlabel('Layer', fontsize=12)
plt.ylabel('Patching Effect', fontsize=12)
plt.title('Layer Importance for Induction (Activation Patching)', fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('../results/induction_patching_effects.png', dpi=300, bbox_inches='tight')
plt.show()

print("Figure saved to results/induction_patching_effects.png")

## 6. Analyze Attention Patterns

Let's look at attention patterns to visually identify induction heads.

In [None]:
# Note: This is a placeholder showing the intended analysis
# The actual attention extraction needs to be implemented to work with nanoGPT's architecture

print("Attention pattern analysis (to be implemented):")
print("- Extract attention weights from forward pass")
print("- Look for heads that attend to positions matching current token")
print("- Visualize attention patterns for suspected induction heads")
print("\nThis requires modifying nanoGPT to expose attention weights,")
print("or using hooks to capture them during the forward pass.")

## 7. Logit Lens Analysis

Use logit lens to see when the model "decides" on the induction prediction.

In [None]:
print("Running logit lens analysis...\n")

# Analyze how prediction forms across layers
lens_result = logit_lens.logit_lens(
    model,
    clean_seq,
    target_position=-1,
    top_k=5
)

# Show prediction evolution
print("Prediction evolution across layers:")
for layer_idx, predictions in enumerate(lens_result.predictions_by_layer):
    top_pred, top_prob = predictions[0]
    print(f"Layer {layer_idx}: {top_pred} ({top_prob:.2%})")

# Measure convergence
convergence = logit_lens.measure_convergence(lens_result, threshold=0.5)
print(f"\nModel becomes confident at layer {convergence['convergence_layer']}")
print(f"Prediction stability: {convergence['stability']:.2%}")

## 8. Visualize Prediction Formation

In [None]:
# Plot how predictions evolve
fig = logit_lens.plot_prediction_evolution(
    lens_result,
    num_predictions=5,
    save_path='../results/induction_logit_lens.png'
)
plt.show()

print("Figure saved to results/induction_logit_lens.png")

## 9. Summary and Key Findings

### Expected Results (on a trained model):

1. **Induction Capability**: Model successfully predicts repeated patterns
2. **Layer Localization**: Induction behavior emerges in middle-to-late layers (typically layers 3-5 in a 6-layer model)
3. **Head Specialization**: Specific attention heads specialize in induction
4. **Prediction Formation**: Logit lens shows gradual convergence to correct prediction

### Comparison to Anthropic's Findings:

Our results should align with:
- Induction heads typically form after ~2000-5000 training iterations
- Usually found in layers L/2 to 3L/4 (where L is total layers)
- Responsible for significant portion of in-context learning capability
- Can be identified by characteristic attention pattern

### Next Steps:

1. Test on multiple random seeds to verify robustness
2. Measure how induction capability develops during training
3. Ablate suspected induction heads to measure their importance
4. Compare to larger models to see how behavior scales

## 10. Save Results

In [None]:
# Save analysis results
results_dict = {
    'model_config': model.config.__dict__,
    'induction_test': {
        'repeat_token': repeat_tok,
        'target_token': target_tok,
        'predicted_correctly': target_tok in top_tokens,
    },
    'important_layers': important,
    'convergence': convergence,
}

# Save to file
import json
with open('../results/induction_heads_analysis.json', 'w') as f:
    # Convert non-serializable types
    results_serializable = {
        k: v if not isinstance(v, (torch.Tensor, np.ndarray)) else str(v)
        for k, v in results_dict.items()
    }
    json.dump(results_serializable, f, indent=2)

print("Results saved to results/induction_heads_analysis.json")