# Experiment 8: Attention Pattern Analysis

**Goal:** Understand where the model "looks" under different prompts.

**Key Questions:**
- Does "think step by step" change attention to intermediate tokens?
- Do few-shot examples create attention shortcuts?
- Which layers show the most prompt-sensitivity?

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import torch

from src.model_utils import load_model
from src.visualization import set_style

set_style()

In [None]:
model = load_model("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

## 1. Attention Extraction Utilities

In [None]:
def get_attention_patterns(model, prompt, return_tokens=True):
    """
    Extract attention patterns from all layers and heads.
    
    Returns:
        attentions: [n_layers, n_heads, seq_len, seq_len]
        tokens: list of token strings
    """
    inputs = model.tokenizer(prompt, return_tensors="pt").to(model.config.device)
    
    with torch.no_grad():
        outputs = model.model(
            **inputs,
            output_attentions=True
        )
    
    # Stack attention from all layers: [n_layers, batch, n_heads, seq_len, seq_len]
    attentions = torch.stack(outputs.attentions).squeeze(1)  # Remove batch dim
    attentions = attentions.cpu().numpy()
    
    if return_tokens:
        tokens = [model.tokenizer.decode([t]) for t in inputs.input_ids[0]]
        return attentions, tokens
    
    return attentions


def aggregate_attention(attentions, method='mean'):
    """
    Aggregate attention across layers and heads.
    
    Args:
        attentions: [n_layers, n_heads, seq_len, seq_len]
        method: 'mean', 'max', or 'last_layer'
    """
    if method == 'mean':
        return attentions.mean(axis=(0, 1))  # [seq_len, seq_len]
    elif method == 'max':
        return attentions.max(axis=(0, 1))
    elif method == 'last_layer':
        return attentions[-1].mean(axis=0)  # Last layer, mean across heads
    else:
        raise ValueError(f"Unknown method: {method}")

## 2. Compare Attention Patterns Across Prompts

In [None]:
# Test prompts with same question but different instructions
BASE_QUESTION = "What is 15 plus 27?"

PROMPT_VARIANTS = {
    "plain": BASE_QUESTION,
    
    "cot": f"""Let's think step by step.

{BASE_QUESTION}""",
    
    "expert": f"""You are an expert mathematician.

{BASE_QUESTION}""",
    
    "structured": f"""Question: {BASE_QUESTION}
Answer:"""
}

In [None]:
# Extract attention for each variant
attention_data = {}

for name, prompt in PROMPT_VARIANTS.items():
    print(f"\nExtracting attention for: {name}")
    attentions, tokens = get_attention_patterns(model, prompt)
    
    attention_data[name] = {
        "attentions": attentions,
        "tokens": tokens,
        "prompt": prompt
    }
    
    print(f"  Shape: {attentions.shape}")
    print(f"  Tokens: {tokens[:10]}...")

In [None]:
# Visualize attention patterns
import os
os.makedirs('../results', exist_ok=True)

fig, axes = plt.subplots(2, 2, figsize=(16, 14))
axes = axes.flatten()

for idx, (name, data) in enumerate(attention_data.items()):
    ax = axes[idx]
    
    # Aggregate attention
    agg_attn = aggregate_attention(data["attentions"], method='last_layer')
    
    # Plot heatmap
    tokens = data["tokens"]
    # Truncate for visibility
    max_tokens = min(20, len(tokens))
    
    sns.heatmap(
        agg_attn[:max_tokens, :max_tokens],
        xticklabels=[t[:8] for t in tokens[:max_tokens]],
        yticklabels=[t[:8] for t in tokens[:max_tokens]],
        ax=ax,
        cmap='Blues',
        cbar_kws={'shrink': 0.5}
    )
    ax.set_title(f'{name}\n(Last Layer, Mean Heads)')
    ax.tick_params(axis='x', rotation=45)
    ax.tick_params(axis='y', rotation=0)

plt.tight_layout()
plt.savefig('../results/exp8_attention_patterns.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Attention to Key Tokens

In [None]:
def compute_attention_to_keywords(attentions, tokens, keywords):
    """
    Compute how much attention the last token pays to specific keywords.
    """
    # Find keyword positions
    keyword_positions = []
    for i, token in enumerate(tokens):
        for kw in keywords:
            if kw.lower() in token.lower():
                keyword_positions.append((i, token, kw))
    
    if not keyword_positions:
        return {}
    
    # Get attention from last token to these positions
    last_layer_attn = attentions[-1]  # [n_heads, seq_len, seq_len]
    last_token_attn = last_layer_attn[:, -1, :]  # [n_heads, seq_len]
    mean_attn = last_token_attn.mean(axis=0)  # [seq_len]
    
    results = {}
    for pos, token, kw in keyword_positions:
        results[f"{kw} ('{token}')"] = mean_attn[pos]
    
    return results

In [None]:
# Analyze attention to numbers and operators
KEYWORDS = ["15", "27", "plus", "step", "expert", "math", "Question", "Answer"]

print("=== Attention to Key Tokens (from last position) ===")

for name, data in attention_data.items():
    print(f"\n{name}:")
    
    attn_to_keywords = compute_attention_to_keywords(
        data["attentions"], data["tokens"], KEYWORDS
    )
    
    if attn_to_keywords:
        for kw, attn in sorted(attn_to_keywords.items(), key=lambda x: x[1], reverse=True):
            print(f"  {kw:30s}: {attn:.4f}")
    else:
        print("  No matching keywords found")

## 4. Layer-wise Analysis

In [None]:
def compute_layer_entropy(attentions):
    """
    Compute attention entropy per layer.
    Higher entropy = more distributed attention.
    """
    n_layers = attentions.shape[0]
    entropies = []
    
    for layer in range(n_layers):
        # Average across heads, look at last token's attention
        layer_attn = attentions[layer].mean(axis=0)[-1]  # [seq_len]
        
        # Compute entropy
        layer_attn = layer_attn + 1e-10  # Avoid log(0)
        entropy = -np.sum(layer_attn * np.log(layer_attn))
        entropies.append(entropy)
    
    return entropies

In [None]:
# Compare layer-wise attention entropy across prompts
fig, ax = plt.subplots(figsize=(12, 6))

for name, data in attention_data.items():
    entropies = compute_layer_entropy(data["attentions"])
    ax.plot(range(len(entropies)), entropies, 'o-', label=name, linewidth=2, markersize=6)

ax.set_xlabel('Layer')
ax.set_ylabel('Attention Entropy')
ax.set_title('Attention Distribution Across Layers\n(Higher = More Distributed)')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/exp8_layer_entropy.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Few-Shot Attention Analysis

In [None]:
# Compare attention patterns in 0-shot vs few-shot
FEWSHOT_PROMPTS = {
    "0-shot": """What is 15 plus 27?""",
    
    "1-shot": """Q: What is 3 plus 4?
A: 7

Q: What is 15 plus 27?
A:""",
    
    "2-shot": """Q: What is 3 plus 4?
A: 7

Q: What is 10 plus 5?
A: 15

Q: What is 15 plus 27?
A:"""
}

In [None]:
# Extract attention for few-shot variants
fewshot_attention = {}

for name, prompt in FEWSHOT_PROMPTS.items():
    print(f"Processing {name}...")
    attentions, tokens = get_attention_patterns(model, prompt)
    fewshot_attention[name] = {
        "attentions": attentions,
        "tokens": tokens,
        "prompt": prompt
    }

In [None]:
# Analyze: Does the model attend to previous answers in few-shot?
print("=== Attention to Example Answers in Few-Shot ===")

for name, data in fewshot_attention.items():
    print(f"\n{name}:")
    tokens = data["tokens"]
    attentions = data["attentions"]
    
    # Find positions of answer tokens (7, 15, etc.)
    answer_keywords = ["7", "15"]
    attn_to_answers = compute_attention_to_keywords(attentions, tokens, answer_keywords)
    
    if attn_to_answers:
        total_answer_attn = sum(attn_to_answers.values())
        print(f"  Total attention to example answers: {total_answer_attn:.4f}")
        for kw, attn in attn_to_answers.items():
            print(f"    {kw}: {attn:.4f}")
    else:
        print("  No example answers found")

## 6. Head Specialization

In [None]:
def analyze_head_specialization(attentions, tokens):
    """
    Analyze what different attention heads focus on.
    """
    n_layers, n_heads = attentions.shape[:2]
    
    head_stats = []
    for layer in range(n_layers):
        for head in range(n_heads):
            head_attn = attentions[layer, head, -1, :]  # Last token's attention
            
            # Compute statistics
            entropy = -np.sum(head_attn * np.log(head_attn + 1e-10))
            max_attn = np.max(head_attn)
            max_pos = np.argmax(head_attn)
            max_token = tokens[max_pos] if max_pos < len(tokens) else "[UNK]"
            
            head_stats.append({
                "layer": layer,
                "head": head,
                "entropy": entropy,
                "max_attention": max_attn,
                "focus_position": max_pos,
                "focus_token": max_token
            })
    
    return pd.DataFrame(head_stats)

In [None]:
# Analyze head specialization for CoT prompt
cot_data = attention_data["cot"]
head_df = analyze_head_specialization(cot_data["attentions"], cot_data["tokens"])

print("=== Head Specialization Analysis (CoT prompt) ===")

# Find heads with lowest entropy (most focused)
print("\nMost focused heads (lowest entropy):")
focused_heads = head_df.nsmallest(5, 'entropy')
for _, row in focused_heads.iterrows():
    print(f"  Layer {row['layer']}, Head {row['head']}: entropy={row['entropy']:.3f}, focuses on '{row['focus_token']}'")

# Find heads with highest entropy (most distributed)
print("\nMost distributed heads (highest entropy):")
distributed_heads = head_df.nlargest(5, 'entropy')
for _, row in distributed_heads.iterrows():
    print(f"  Layer {row['layer']}, Head {row['head']}: entropy={row['entropy']:.3f}")

## 7. Key Findings

In [None]:
print("="*60)
print("EXPERIMENT 8 SUMMARY: Attention Pattern Analysis")
print("="*60)

print("\n1. Attention Pattern Differences:")
for name, data in attention_data.items():
    entropy = compute_layer_entropy(data["attentions"])[-1]  # Last layer
    print(f"   {name}: last layer entropy = {entropy:.3f}")

print("\n2. Few-Shot Attention Patterns:")
print("   [Fill after running: Does model attend to example answers?]")

print("\n3. Key Insights:")
print("   - [Fill after running: How do prompts change attention?]")
print("   - [Fill after running: Which layers are most affected?]")

In [None]:
# Save results
import json

save_data = {
    "layer_entropies": {
        name: compute_layer_entropy(data["attentions"])
        for name, data in attention_data.items()
    },
    "head_specialization_summary": {
        "most_focused": focused_heads[['layer', 'head', 'entropy', 'focus_token']].to_dict('records'),
        "most_distributed": distributed_heads[['layer', 'head', 'entropy']].to_dict('records')
    }
}

with open('../results/exp8_attention_results.json', 'w') as f:
    json.dump(save_data, f, indent=2, default=float)

print("Results saved.")