# Attention Analysis: Identifying Factual Recall Heads

This notebook demonstrates how to analyze attention patterns in transformer models to identify specific "factual recall heads" - attention heads that specialize in retrieving factual information.

## Overview

We'll cover:
1. Extracting attention patterns from all heads and layers
2. Identifying subject tokens in factual statements
3. Computing attention scores from prediction → subject
4. Using statistical tests to find significant heads
5. Visualizing attention patterns interactively
6. Comparing attention for true vs false facts

## Key Question
**Which attention heads are responsible for retrieving factual knowledge, and do they behave differently for true vs false facts?**

In [None]:
# Setup and imports
import sys
sys.path.append('..')

import torch
import numpy as np
from src.utils import setup_logging, set_seed, load_model
from src.attention_analysis import (
    AttentionAnalyzer,
    compute_bonferroni_correction,
    compute_fdr_correction
)
from src.fact_dataset import create_sample_dataset
from src.visualization import (
    plot_factual_recall_heads,
    plot_attention_to_subject,
    plot_attention_comparison_interactive,
    plot_head_scores_distribution,
    plot_aggregated_attention_flow,
    plot_top_heads_comparison,
)

# Setup
setup_logging()
set_seed(42)

print("✓ Imports successful")

## Part 1: Load Model and Dataset

We'll use a smaller model (GPT-2 Small) for faster computation. The same analysis works for larger models.

In [None]:
# Load model
model = load_model('gpt2-small')

print(f"Model: {model.cfg.model_name}")
print(f"  Layers: {model.cfg.n_layers}")
print(f"  Heads per layer: {model.cfg.n_heads}")
print(f"  Total heads: {model.cfg.n_layers * model.cfg.n_heads}")

In [None]:
# Create fact dataset
dataset = create_sample_dataset()

# Split into true and false
true_dataset = dataset.filter(is_true=True)
false_dataset = dataset.filter(is_true=False)

print(f"Dataset: {len(true_dataset)} true facts, {len(false_dataset)} false facts")
print("\nExamples:")
for i in range(4):
    fact = dataset[i]
    print(f"  [{i}] {'✓' if fact.is_true else '✗'} {fact.to_prompt()}")

## Part 2: Initialize Attention Analyzer

The `AttentionAnalyzer` class extracts attention patterns from all heads and layers.

In [None]:
# Create analyzer
analyzer = AttentionAnalyzer(model)

print("✓ AttentionAnalyzer initialized")
print(f"  Analyzing {analyzer.n_layers} layers × {analyzer.n_heads} heads")
print(f"  = {analyzer.n_layers * analyzer.n_heads} total heads")

## Part 3: Extract Attention Patterns

For each prompt, we:
1. Run the model forward pass
2. Extract attention weights from all heads
3. Identify subject token positions
4. Store as an `AttentionPattern` object

In [None]:
# Helper function to extract subject tokens from facts
def get_subject_tokens(fact):
    """Extract subject entity as list of words."""
    return fact.subject.split()

# Get prompts and subjects
true_prompts = true_dataset.to_prompts()
false_prompts = false_dataset.to_prompts()

true_subjects = [get_subject_tokens(f) for f in true_dataset.facts]
false_subjects = [get_subject_tokens(f) for f in false_dataset.facts]

print(f"Example subject tokens: {true_subjects[0]}")

In [None]:
# Extract attention patterns for true facts
print("Extracting attention for TRUE facts...")
true_patterns = analyzer.extract_attention_patterns(true_prompts)

# Set subject positions for each pattern
for i, pattern in enumerate(true_patterns):
    pattern.subject_positions = analyzer._find_token_positions(
        pattern.tokens, true_subjects[i]
    )

print(f"✓ Extracted {len(true_patterns)} attention patterns")

In [None]:
# Extract attention patterns for false facts
print("Extracting attention for FALSE facts...")
false_patterns = analyzer.extract_attention_patterns(false_prompts)

# Set subject positions
for i, pattern in enumerate(false_patterns):
    pattern.subject_positions = analyzer._find_token_positions(
        pattern.tokens, false_subjects[i]
    )

print(f"✓ Extracted {len(false_patterns)} attention patterns")

## Part 4: Examine a Single Attention Pattern

Let's look at one example in detail to understand the data structure.

In [None]:
# Examine first true fact pattern
example = true_patterns[0]

print(f"Prompt: {example.prompt}")
print(f"\nTokens: {example.tokens}")
print(f"\nSubject positions: {example.subject_positions}")
print(f"Prediction position: {example.prediction_position} (last token)")
print(f"\nAttention shape: {example.attention.shape}")  # [n_layers, n_heads, seq_len, seq_len]

# Show which tokens are subjects
print("\nSubject tokens:")
for pos in example.subject_positions:
    print(f"  Position {pos}: '{example.tokens[pos]}'")

## Part 5: Compute Subject Attention Scores

For each head, we measure: **How much does the final token attend to the subject?**

This gives us a score for each head: high score = strong attention to subject.

In [None]:
# Compute scores for the example pattern
example_scores = analyzer.compute_subject_attention_scores(example)

print(f"Subject attention scores shape: {example_scores.shape}")  # [n_layers, n_heads]
print(f"\nTop 5 heads by attention to subject:")

# Flatten and get top-5
flat_scores = example_scores.flatten()
top_5_indices = flat_scores.topk(5).indices

for idx in top_5_indices:
    layer = idx // model.cfg.n_heads
    head = idx % model.cfg.n_heads
    score = example_scores[layer, head].item()
    print(f"  Layer {layer} Head {head}: {score:.4f}")

## Part 6: Statistical Testing - Find Factual Recall Heads

Now we identify heads that attend **significantly more** to subjects for true facts than false facts.

For each head:
1. Compute mean attention for true facts
2. Compute mean attention for false facts
3. Run independent t-test
4. Keep heads with p < 0.05 and effect size > threshold

In [None]:
# Identify factual recall heads
results = analyzer.identify_factual_recall_heads(
    true_patterns,
    false_patterns,
    threshold=0.05,        # p-value threshold
    min_effect_size=0.01,  # minimum difference in attention
)

significant_heads = results['significant_heads']
n_significant = len(significant_heads)
total_heads = model.cfg.n_layers * model.cfg.n_heads

print(f"✓ Found {n_significant} significant heads out of {total_heads}")
print(f"  ({n_significant / total_heads * 100:.1f}% of all heads)")

In [None]:
# Show top 10 factual recall heads
print("\nTop 10 Factual Recall Heads:")
print("(Heads with strongest preference for true facts)\n")

for i, head in enumerate(significant_heads[:10]):
    print(f"{i+1:2d}. Layer {head.layer:2d} Head {head.head:2d}: "
          f"effect={head.score:.4f}, p={head.p_value:.2e}")

## Part 7: Visualize All Heads with Statistical Significance

This creates a 2-panel heatmap:
- **Top**: Effect size (true - false attention)
- **Bottom**: Statistical significance (-log10 p-value)

In [None]:
# Plot overview of all heads
fig = plot_factual_recall_heads(
    results,
    top_k=10,  # Highlight top 10 with stars
    title=f"Factual Recall Heads - {model.cfg.model_name}"
)

fig.show()

**Interpretation**:
- Red = True facts have higher attention (factual recall heads)
- Blue = False facts have higher attention
- Stars = Top significant heads
- Brighter colors in bottom panel = more significant

## Part 8: Compare Top Heads

Bar chart showing mean attention scores for top heads.

In [None]:
if significant_heads:
    fig = plot_top_heads_comparison(
        significant_heads,
        results['true_scores'],
        results['false_scores'],
        top_k=10
    )
    fig.show()
else:
    print("No significant heads found")

## Part 9: Examine Attention Pattern of Top Head

Let's visualize the attention matrix for the most important head.

In [None]:
if significant_heads:
    # Get top head
    top_head = significant_heads[0]
    layer, head = top_head.layer, top_head.head
    
    print(f"Examining Layer {layer} Head {head}")
    print(f"  Effect size: {top_head.score:.4f}")
    print(f"  P-value: {top_head.p_value:.2e}")
    
    # Visualize attention for first true fact
    fig = plot_attention_to_subject(
        true_patterns[0],
        layer,
        head,
        title=f"Top Head: L{layer}H{head} - True Fact"
    )
    fig.show()
    
    print("\nRed dashed line = subject tokens")
    print("Green dashed line = prediction token (last)")

**Interpretation**: Look at the last row (prediction token). Does it attend strongly to the subject?

## Part 10: Compare True vs False Attention

Side-by-side comparison of the same head on true vs false facts.

In [None]:
if significant_heads and len(false_patterns) > 0:
    fig = plot_attention_comparison_interactive(
        true_patterns[0],
        false_patterns[0],
        layer,
        head
    )
    fig.show()
    
    print(f"\nLeft: {true_patterns[0].prompt} (TRUE)")
    print(f"Right: {false_patterns[0].prompt} (FALSE)")

**Question**: Is the attention to subject visibly different between true and false?

## Part 11: Distribution of Attention Scores

Histogram showing the distribution of attention scores for the top head.

In [None]:
if significant_heads:
    # Compute scores for all patterns
    true_scores_list = []
    for pattern in true_patterns:
        scores = analyzer.compute_subject_attention_scores(pattern)
        true_scores_list.append(scores)
    true_scores_tensor = torch.stack(true_scores_list, dim=0)
    
    false_scores_list = []
    for pattern in false_patterns:
        scores = analyzer.compute_subject_attention_scores(pattern)
        false_scores_list.append(scores)
    false_scores_tensor = torch.stack(false_scores_list, dim=0)
    
    # Plot distribution for top head
    fig = plot_head_scores_distribution(
        true_scores_tensor,
        false_scores_tensor,
        layer,
        head
    )
    fig.show()

**Interpretation**: Overlaid histograms show whether true facts (green) consistently have higher attention than false facts (red).

## Part 12: Aggregated Attention Flow

Heatmap showing mean attention across all prompts and all positions.

In [None]:
# Average attention flow across all true facts
fig = plot_aggregated_attention_flow(
    true_patterns,
    aggregation='mean',
    title="Mean Attention Flow (True Facts)"
)
fig.show()

**Observation**: Which layers and heads show highest average attention?

## Part 13: Multiple Comparison Correction

When testing many hypotheses (one per head), we need to correct for multiple comparisons.

**Bonferroni correction**: Very conservative, divides α by number of tests.

In [None]:
# Apply Bonferroni correction
sig_mask, corrected_alpha = compute_bonferroni_correction(
    results['p_values'],
    alpha=0.05
)

n_bonferroni = sig_mask.sum().item()

print(f"Bonferroni Correction:")
print(f"  Original α: 0.05")
print(f"  Corrected α: {corrected_alpha:.2e}")
print(f"  Significant heads: {n_bonferroni} (was {n_significant} before correction)")

In [None]:
# Apply FDR correction (less conservative)
fdr_mask, adjusted_p = compute_fdr_correction(
    results['p_values'],
    alpha=0.05
)

n_fdr = fdr_mask.sum().item()

print(f"FDR Correction (Benjamini-Hochberg):")
print(f"  Significant heads: {n_fdr}")
print(f"  (More lenient than Bonferroni)")

## Part 14: Analyze Specific Fact Examples

Let's examine how specific facts are processed.

In [None]:
# Pick an interesting fact pair
fact_idx = 0
true_fact = true_dataset[fact_idx]
false_fact = false_dataset[fact_idx]

print(f"Analyzing fact pair {fact_idx}:")
print(f"  TRUE:  {true_fact.to_prompt()}")
print(f"  FALSE: {false_fact.to_prompt()}")

true_pattern = true_patterns[fact_idx]
false_pattern = false_patterns[fact_idx]

# Compute scores for each head
true_scores = analyzer.compute_subject_attention_scores(true_pattern)
false_scores = analyzer.compute_subject_attention_scores(false_pattern)

# Find heads with biggest difference
diff = (true_scores - false_scores).abs()
top_diff_heads = diff.flatten().topk(5).indices

print("\nHeads with biggest true/false difference:")
for idx in top_diff_heads:
    layer = idx // model.cfg.n_heads
    head = idx % model.cfg.n_heads
    true_score = true_scores[layer, head].item()
    false_score = false_scores[layer, head].item()
    print(f"  L{layer}H{head}: true={true_score:.4f}, false={false_score:.4f}, "
          f"diff={abs(true_score - false_score):.4f}")

## Part 15: Get Top-K Heads by Average Attention

Alternative approach: find heads that attend most to subjects (regardless of true/false).

In [None]:
# Get top heads by average subject attention
top_heads = analyzer.get_top_heads(true_patterns, top_k=10)

print("Top 10 heads by subject attention (true facts):")
for i, head in enumerate(top_heads):
    print(f"{i+1:2d}. Layer {head.layer:2d} Head {head.head:2d}: "
          f"mean={head.mean_score:.4f} ± {head.std_score:.4f}")

---

## Summary

We've demonstrated:
1. ✓ Extracting attention patterns from all heads
2. ✓ Identifying subject tokens in factual statements
3. ✓ Computing attention from prediction → subject
4. ✓ Statistical testing to find factual recall heads
5. ✓ Visualizing attention patterns interactively
6. ✓ Comparing true vs false fact attention
7. ✓ Multiple comparison corrections (Bonferroni, FDR)

---

## Suggested Experiments

### 1. Attention Pattern Evolution
**Question**: Do factual recall heads emerge during training, or are they present from initialization?
- Load checkpoints from different training steps
- Run same analysis on each checkpoint
- Plot: # of factual recall heads vs training step
- Hypothesis: Factual heads emerge gradually

### 2. Cross-Model Comparison
**Question**: Are the same heads important across different model sizes?
- Run analysis on GPT-2 Small, Medium, Large
- Check if factual recall heads are in similar layers
- E.g., is Layer 6 Head 3 in Small analogous to Layer 12 Head 6 in Medium?
- Use layer/total_layers as normalized position

### 3. Attention to Object Tokens
**Question**: Do heads attend to the object (answer) differently for true vs false?
- Modify `_find_token_positions` to find object positions
- Compute attention: prediction → object
- Compare with attention to subject
- Hypothesis: False facts might have weaker subject→object attention

### 4. Multi-Hop Reasoning
**Question**: Can we trace information flow through multiple heads?
- Create facts requiring multi-hop reasoning: "Paris is in France. France is in Europe. Therefore, Paris is in Europe."
- Use attention flow analysis to trace: subject → intermediate → final
- Visualize the reasoning chain

### 5. Intervention Experiments
**Question**: What happens if we knock out factual recall heads?
- Identify top 5 factual recall heads
- Run model with these heads' attention zeroed out
- Measure: Does accuracy on factual questions drop?
- Control: Zero out random heads and compare

### 6. Attention Entropy Analysis
**Question**: Do factual recall heads have more focused attention?
- Compute attention entropy for each head
- Formula: `H = -Σ p(i) log p(i)` where p(i) is attention weight
- Compare entropy of factual recall heads vs others
- Hypothesis: Lower entropy = more focused = factual recall

### 7. Relation-Specific Heads
**Question**: Do different relations use different heads?
- Group facts by relation type: location, invention, people, etc.
- Run separate analysis for each relation
- Find relation-specific factual recall heads
- Visualization: Venn diagram of head sets

### 8. Attention Gradient Analysis
**Question**: How sensitive are predictions to attention weights?
- Compute gradients: ∂loss/∂attention for each head
- Compare gradient magnitudes for factual recall heads
- Are factual heads more influential for correct predictions?

### 9. Few-Shot Attention Patterns
**Question**: Do in-context examples change which heads are important?
- Prompt format: "France's capital is Paris. Germany's capital is Berlin. Italy's capital is [MASK]"
- Compare attention patterns with vs without examples
- Do different heads activate for in-context learning?

### 10. Adversarial Attention
**Question**: Can we craft inputs that fool factual recall heads?
- Start with true fact: "Paris is the capital of France"
- Add misleading context: "Many people think Paris is in Germany, but Paris is the capital of France"
- Measure: Does attention shift from subject?
- Find prompts where factual heads fail

### 11. Temporal Attention Dynamics
**Question**: How does attention evolve across the sequence?
- Don't just look at final token
- Compute attention to subject at each position
- Plot: attention strength vs token position
- When does the model "lock on" to the subject?

### 12. Attention Head Clustering
**Question**: Do factual recall heads have similar attention patterns?
- Represent each head by its attention pattern vector
- Cluster heads using k-means or hierarchical clustering
- Are factual recall heads in the same cluster?
- What characterizes different clusters?

---

## Code Template for Experiments

In [None]:
# Template: Analyze a specific head in detail

def analyze_head(layer, head, patterns, name=""):
    """Deep dive into a specific attention head."""
    
    print(f"\nAnalyzing Layer {layer} Head {head} ({name})")
    print("=" * 60)
    
    # Compute scores
    scores = []
    for pattern in patterns:
        score = analyzer.compute_subject_attention_scores(pattern)[layer, head]
        scores.append(score.item())
    
    # Statistics
    mean_score = np.mean(scores)
    std_score = np.std(scores)
    print(f"Mean attention to subject: {mean_score:.4f} ± {std_score:.4f}")
    
    # Visualize for first pattern
    fig = plot_attention_to_subject(
        patterns[0],
        layer,
        head,
        title=f"L{layer}H{head} - {name}"
    )
    fig.show()
    
    return scores

# Example usage:
# if significant_heads:
#     top_head = significant_heads[0]
#     analyze_head(top_head.layer, top_head.head, true_patterns, "Top Factual Recall Head")

In [None]:
# Template: Compare two sets of patterns

def compare_pattern_sets(patterns_a, patterns_b, name_a="A", name_b="B"):
    """Compare attention patterns between two conditions."""
    
    # Identify significant heads
    results = analyzer.identify_factual_recall_heads(
        patterns_a,
        patterns_b,
        threshold=0.05,
        min_effect_size=0.01,
    )
    
    print(f"\nComparing {name_a} vs {name_b}")
    print(f"Significant heads: {len(results['significant_heads'])}")
    
    # Visualize
    fig = plot_factual_recall_heads(
        results,
        title=f"{name_a} vs {name_b}"
    )
    fig.show()
    
    return results

# Example usage:
# compare_pattern_sets(true_patterns, false_patterns, "True", "False")