# Mechanistic Interpretability: Attention Head Analysis

This notebook focuses on analyzing attention patterns in the Qwen2-7B model during financial statement analysis.

In [1]:
import json
import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, List, Tuple, Optional
import pandas as pd
from collections import defaultdict
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

# For better visualization
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

  from .autonotebook import tqdm as notebook_tqdm


## 1. Model Setup with Attention Hooks

In [None]:
class AttentionExtractor:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.attention_weights = {}
        self.hooks = []
        
    def register_hooks(self):
        """Register hooks to extract attention weights from all layers"""
        def get_attention_hook(layer_idx):
            def hook(module, input, output):
                # For Qwen2, attention weights are in output.attentions
                if hasattr(output, 'attentions') and output.attentions is not None:
                    self.attention_weights[f'layer_{layer_idx}'] = output.attentions.detach().cpu()
                elif len(output) > 1 and output[1] is not None:
                    # Sometimes attention is the second element in output tuple
                    self.attention_weights[f'layer_{layer_idx}'] = output[1].detach().cpu()
            return hook
        
        # Register hooks for all transformer layers
        for i, layer in enumerate(self.model.model.layers):
            hook = layer.self_attn.register_forward_hook(get_attention_hook(i))
            self.hooks.append(hook)
    
    def remove_hooks(self):
        """Remove all registered hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
    
    def clear_attention_weights(self):
        """Clear stored attention weights"""
        self.attention_weights = {}
    
    def generate_with_attention(self, text: str, max_new_tokens: int = 512, **generation_kwargs):
        """Generate text while capturing attention weights"""
        self.clear_attention_weights()
        self.register_hooks()
        
        try:
            # Prepare input
            inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
            input_length = inputs.input_ids.shape[1]
            
            # Generate with attention output
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_new_tokens=max_new_tokens,
                    output_attentions=True,
                    return_dict_in_generate=True,
                    **generation_kwargs
                )
            
            # Decode response
            generated_text = self.tokenizer.decode(
                outputs.sequences[0][input_length:], 
                skip_special_tokens=True
            )
            
            # Get tokens for analysis
            all_tokens = self.tokenizer.convert_ids_to_tokens(outputs.sequences[0])
            
            return {
                'generated_text': generated_text,
                'input_tokens': all_tokens[:input_length],
                'output_tokens': all_tokens[input_length:],
                'all_tokens': all_tokens,
                'attention_weights': self.attention_weights.copy(),
                'input_length': input_length
            }
            
        finally:
            self.remove_hooks()

In [None]:
# Load model with reduced memory usage
model_name = "Qwen/Qwen2-7B-Instruct"

# Use CPU or smaller model if MPS memory is insufficient
device = "cpu"  # Change to "mps" if you have enough memory

print(f"Loading model on {device}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if device != "cpu" else torch.float32,
    device_map="auto" if device != "cpu" else None,
)

if device == "cpu":
    model = model.to(device)

# Initialize attention extractor
attention_extractor = AttentionExtractor(model, tokenizer)
print("Model loaded successfully!")

## 2. Load Financial Data

In [None]:
# Load data and prompts
with open('../data/process_data.json', 'r') as f:
    data = json.load(f)

with open("../config/prompts.yaml", "r") as f:
    prompts = yaml.safe_load(f)

# Select a sample for analysis
sample_key = '1000'
sample_year = '1973.0'
sample_data = data[sample_key][sample_year]

print(f"Sample company: {sample_key}, Year: {sample_year}")
print(f"Label: {sample_data['label']}")

## 3. Generate Response with Attention Capture

In [None]:
# Prepare the prompt
prompt_input = prompts['prompt_1'][0].format(balance_income_sheet=sample_data['description'])

# Format for chat template
messages = [{"role": "user", "content": prompt_input}]
formatted_text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

print("Input length:", len(tokenizer.encode(formatted_text)))
print("\nGenerating response with attention capture...")

In [None]:
# Generate with attention tracking (use shorter generation for analysis)
result = attention_extractor.generate_with_attention(
    formatted_text, 
    max_new_tokens=100,  # Shorter for faster analysis
    do_sample=False,
    temperature=0.1
)

print("Generated text:")
print(result['generated_text'])
print(f"\nCaptured attention from {len(result['attention_weights'])} layers")
print(f"Input tokens: {len(result['input_tokens'])}")
print(f"Output tokens: {len(result['output_tokens'])}")

## 4. Attention Analysis Tools

In [None]:
class AttentionAnalyzer:
    def __init__(self, attention_weights, tokens, input_length):
        self.attention_weights = attention_weights
        self.tokens = tokens
        self.input_length = input_length
        self.num_layers = len(attention_weights)
        
    def get_attention_stats(self):
        """Get basic statistics about attention patterns"""
        stats = {}
        
        for layer_name, attn in self.attention_weights.items():
            if attn is not None and len(attn.shape) >= 3:
                # attn shape: [batch, heads, seq_len, seq_len]
                stats[layer_name] = {
                    'shape': list(attn.shape),
                    'num_heads': attn.shape[1] if len(attn.shape) > 1 else 1,
                    'seq_length': attn.shape[-1] if len(attn.shape) > 0 else 0,
                    'mean_attention': float(attn.mean()),
                    'max_attention': float(attn.max()),
                    'attention_entropy': self._calculate_entropy(attn)
                }
        
        return stats
    
    def _calculate_entropy(self, attention_matrix):
        """Calculate entropy of attention distribution"""
        # Add small epsilon to avoid log(0)
        eps = 1e-10
        attention_matrix = attention_matrix + eps
        entropy = -(attention_matrix * torch.log(attention_matrix)).sum(dim=-1).mean()
        return float(entropy)
    
    def plot_attention_heatmap(self, layer_idx=0, head_idx=0, max_tokens=50):
        """Plot attention heatmap for a specific layer and head"""
        layer_name = f'layer_{layer_idx}'
        
        if layer_name not in self.attention_weights:
            print(f"Layer {layer_idx} not found in attention weights")
            return
        
        attn = self.attention_weights[layer_name]
        if attn is None or len(attn.shape) < 4:
            print(f"Invalid attention tensor for layer {layer_idx}")
            return
        
        # Extract attention for specific head
        attention_matrix = attn[0, head_idx].numpy()  # [seq_len, seq_len]
        
        # Limit tokens for visualization
        seq_len = min(attention_matrix.shape[0], max_tokens)
        attention_matrix = attention_matrix[:seq_len, :seq_len]
        display_tokens = self.tokens[:seq_len]
        
        # Create heatmap
        plt.figure(figsize=(12, 10))
        sns.heatmap(
            attention_matrix,
            xticklabels=[f"{i}:{token}" for i, token in enumerate(display_tokens)],
            yticklabels=[f"{i}:{token}" for i, token in enumerate(display_tokens)],
            cmap='Blues',
            cbar=True
        )
        plt.title(f'Attention Heatmap - Layer {layer_idx}, Head {head_idx}')
        plt.xlabel('Key Position')
        plt.ylabel('Query Position')
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.show()
    
    def plot_attention_heads_summary(self, layer_idx=0):
        """Plot summary of all attention heads in a layer"""
        layer_name = f'layer_{layer_idx}'
        
        if layer_name not in self.attention_weights:
            print(f"Layer {layer_idx} not found")
            return
        
        attn = self.attention_weights[layer_name]
        if attn is None or len(attn.shape) < 4:
            print(f"Invalid attention tensor for layer {layer_idx}")
            return
        
        num_heads = attn.shape[1]
        cols = min(4, num_heads)
        rows = (num_heads + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(15, 4*rows))
        if rows == 1:
            axes = [axes] if cols == 1 else axes
        else:
            axes = axes.flatten()
        
        for head_idx in range(num_heads):
            attention_matrix = attn[0, head_idx].numpy()
            
            # Focus on last few tokens (where generation happens)
            seq_len = min(30, attention_matrix.shape[0])
            start_idx = max(0, attention_matrix.shape[0] - seq_len)
            
            sub_attn = attention_matrix[start_idx:, start_idx:]
            
            ax = axes[head_idx] if num_heads > 1 else axes[0]
            sns.heatmap(sub_attn, ax=ax, cmap='Blues', cbar=True)
            ax.set_title(f'Head {head_idx}')
            ax.set_xlabel('Key Position')
            ax.set_ylabel('Query Position')
        
        # Hide empty subplots
        for idx in range(num_heads, len(axes)):
            axes[idx].set_visible(False)
        
        plt.suptitle(f'All Attention Heads - Layer {layer_idx}')
        plt.tight_layout()
        plt.show()
    
    def analyze_financial_attention(self):
        """Analyze attention patterns specific to financial terms"""
        # Financial keywords to look for
        financial_terms = [
            'sales', 'revenue', 'income', 'profit', 'loss', 'asset', 'liability',
            'equity', 'cash', 'debt', 'eps', 'earnings', 'balance', 'sheet',
            'statement', 'financial', 'increase', 'decrease', 'ratio'
        ]
        
        # Find positions of financial terms
        financial_positions = []
        for i, token in enumerate(self.tokens):
            if any(term in token.lower() for term in financial_terms):
                financial_positions.append((i, token))
        
        print(f"Found {len(financial_positions)} financial terms:")
        for pos, token in financial_positions[:10]:  # Show first 10
            print(f"  Position {pos}: {token}")
        
        return financial_positions
    
    def plot_attention_to_financial_terms(self, financial_positions, layer_idx=0, head_idx=0):
        """Plot attention from generated tokens to financial terms in input"""
        layer_name = f'layer_{layer_idx}'
        
        if layer_name not in self.attention_weights or not financial_positions:
            print("No data available for analysis")
            return
        
        attn = self.attention_weights[layer_name]
        if attn is None:
            return
        
        attention_matrix = attn[0, head_idx].numpy()
        
        # Get financial term positions
        fin_positions = [pos for pos, _ in financial_positions if pos < attention_matrix.shape[1]]
        
        # Focus on output tokens (generated part)
        output_start = self.input_length
        if output_start < attention_matrix.shape[0]:
            # Attention from output tokens to financial terms
            output_to_financial = attention_matrix[output_start:, fin_positions]
            
            plt.figure(figsize=(12, 6))
            sns.heatmap(
                output_to_financial,
                xticklabels=[f"{pos}:{self.tokens[pos]}" for pos in fin_positions],
                yticklabels=[f"{i+output_start}:{self.tokens[i+output_start]}" 
                           for i in range(min(20, output_to_financial.shape[0]))],
                cmap='Reds'
            )
            plt.title(f'Generated Tokens Attention to Financial Terms\nLayer {layer_idx}, Head {head_idx}')
            plt.xlabel('Financial Terms in Input')
            plt.ylabel('Generated Tokens')
            plt.xticks(rotation=45, ha='right')
            plt.tight_layout()
            plt.show()

# Initialize analyzer
analyzer = AttentionAnalyzer(
    result['attention_weights'], 
    result['all_tokens'], 
    result['input_length']
)

print("Attention analyzer initialized!")

## 5. Basic Attention Statistics

In [None]:
# Get attention statistics
stats = analyzer.get_attention_stats()

print("Attention Statistics:")
print("=" * 50)
for layer_name, layer_stats in stats.items():
    print(f"\n{layer_name}:")
    print(f"  Shape: {layer_stats['shape']}")
    print(f"  Heads: {layer_stats['num_heads']}")
    print(f"  Sequence Length: {layer_stats['seq_length']}")
    print(f"  Mean Attention: {layer_stats['mean_attention']:.4f}")
    print(f"  Max Attention: {layer_stats['max_attention']:.4f}")
    print(f"  Attention Entropy: {layer_stats['attention_entropy']:.4f}")

## 6. Visualize Attention Patterns

In [None]:
# Plot attention heatmap for a specific layer and head
analyzer.plot_attention_heatmap(layer_idx=0, head_idx=0, max_tokens=40)

In [None]:
# Plot attention heatmap for a middle layer
middle_layer = analyzer.num_layers // 2
analyzer.plot_attention_heatmap(layer_idx=middle_layer, head_idx=0, max_tokens=40)

In [None]:
# Plot all attention heads for a specific layer
analyzer.plot_attention_heads_summary(layer_idx=0)

## 7. Financial-Specific Attention Analysis

In [None]:
# Analyze attention to financial terms
financial_positions = analyzer.analyze_financial_attention()

In [None]:
# Plot attention from generated tokens to financial terms
analyzer.plot_attention_to_financial_terms(financial_positions, layer_idx=0, head_idx=0)

In [None]:
# Try different layers
analyzer.plot_attention_to_financial_terms(financial_positions, layer_idx=middle_layer, head_idx=0)

## 8. Advanced Analysis: Head Specialization

In [None]:
def analyze_head_specialization(analyzer, layer_idx=0):
    """Analyze what different attention heads focus on"""
    layer_name = f'layer_{layer_idx}'
    
    if layer_name not in analyzer.attention_weights:
        print(f"Layer {layer_idx} not found")
        return
    
    attn = analyzer.attention_weights[layer_name]
    if attn is None:
        return
    
    num_heads = attn.shape[1]
    
    print(f"\nHead Specialization Analysis - Layer {layer_idx}")
    print("=" * 50)
    
    for head_idx in range(min(8, num_heads)):  # Analyze first 8 heads
        attention_matrix = attn[0, head_idx].numpy()
        
        # Calculate attention patterns
        diagonal_attention = np.diag(attention_matrix).mean()
        
        # Attention to previous tokens (causal pattern)
        causal_mask = np.tril(np.ones_like(attention_matrix), k=-1)
        causal_attention = (attention_matrix * causal_mask).sum() / causal_mask.sum()
        
        # Attention to beginning of sequence
        beginning_attention = attention_matrix[:, :10].mean()
        
        # Attention spread (entropy)
        entropy = -(attention_matrix * np.log(attention_matrix + 1e-10)).sum(axis=1).mean()
        
        print(f"\nHead {head_idx}:")
        print(f"  Diagonal (self) attention: {diagonal_attention:.4f}")
        print(f"  Causal (previous) attention: {causal_attention:.4f}")
        print(f"  Beginning attention: {beginning_attention:.4f}")
        print(f"  Attention entropy: {entropy:.4f}")
        
        # Determine head type
        if diagonal_attention > 0.3:
            head_type = "Self-attention head"
        elif beginning_attention > causal_attention:
            head_type = "Beginning-focused head"
        elif entropy < 2.0:
            head_type = "Focused attention head"
        else:
            head_type = "Distributed attention head"
        
        print(f"  Type: {head_type}")

# Analyze different layers
analyze_head_specialization(analyzer, layer_idx=0)
analyze_head_specialization(analyzer, layer_idx=middle_layer)
analyze_head_specialization(analyzer, layer_idx=analyzer.num_layers-1)

## 9. Token-to-Token Attention Analysis

In [None]:
def analyze_specific_token_attention(analyzer, target_token_text, layer_idx=0, head_idx=0):
    """Analyze what a specific token attends to"""
    # Find token position
    target_positions = []
    for i, token in enumerate(analyzer.tokens):
        if target_token_text.lower() in token.lower():
            target_positions.append(i)
    
    if not target_positions:
        print(f"Token '{target_token_text}' not found")
        return
    
    layer_name = f'layer_{layer_idx}'
    if layer_name not in analyzer.attention_weights:
        print(f"Layer {layer_idx} not found")
        return
    
    attn = analyzer.attention_weights[layer_name]
    if attn is None:
        return
    
    attention_matrix = attn[0, head_idx].numpy()
    
    print(f"\nAttention analysis for '{target_token_text}'")
    print(f"Found at positions: {target_positions}")
    
    for pos in target_positions[:3]:  # Analyze first 3 occurrences
        if pos < attention_matrix.shape[0]:
            attention_weights = attention_matrix[pos]
            
            # Get top attending tokens
            top_indices = np.argsort(attention_weights)[-10:][::-1]
            
            print(f"\nPosition {pos} ('{analyzer.tokens[pos]}') attends most to:")
            for i, idx in enumerate(top_indices):
                if idx < len(analyzer.tokens):
                    print(f"  {i+1}. Position {idx}: '{analyzer.tokens[idx]}' (weight: {attention_weights[idx]:.4f})")

# Analyze attention for key financial terms
analyze_specific_token_attention(analyzer, "sales", layer_idx=0, head_idx=0)
analyze_specific_token_attention(analyzer, "increase", layer_idx=middle_layer, head_idx=0)

## 10. Summary and Next Steps

In [None]:
print("Mechanistic Interpretability Analysis Summary")
print("=" * 50)
print(f"Model: {model_name}")
print(f"Analyzed layers: {len(result['attention_weights'])}")
print(f"Input tokens: {len(result['input_tokens'])}")
print(f"Generated tokens: {len(result['output_tokens'])}")
print(f"Financial terms found: {len(financial_positions)}")

print("\nNext Steps for Deeper Analysis:")
print("1. Compare attention patterns across different financial scenarios")
print("2. Analyze how attention changes for 'increase' vs 'decrease' predictions")
print("3. Study attention patterns in different layers (early vs late)")
print("4. Investigate head specialization across the full model")
print("5. Correlate attention patterns with prediction accuracy")
print("6. Create attention interventions to test causal relationships")