# Attention Visualization with BertViz

This notebook demonstrates how to visualize attention patterns in our trained arithmetic models using BertViz.

In [1]:
# Install BertViz if not already installed
# !pip install bertviz

In [2]:
import sys
from pathlib import Path

import torch
from bertviz import head_view, model_view

# Add project root to path for imports
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

from src.config import load_config
from src.model import create_model_from_config
from src.tokenizer import tokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
def load_model(checkpoint_dir: str, device: str = "auto"):
    """Load a model from a checkpoint directory.
    
    Args:
        checkpoint_dir: Path to checkpoint directory (e.g., 'checkpoints/standard-small-pope')
        device: Device to load model on ('auto', 'cuda', 'cpu')
    
    Returns:
        Loaded model in eval mode
    """
    checkpoint_path = project_root / checkpoint_dir
    config_path = checkpoint_path / "model_config.yaml"
    weights_path = checkpoint_path / "model.safetensors"
    
    # Determine device
    if device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load config and create model
    config = load_config(config_path)
    model = create_model_from_config(config)
    
    # Load weights
    from safetensors.torch import load_file
    state_dict = load_file(str(weights_path))
    model.load_state_dict(state_dict)
    
    model = model.to(device)
    model.eval()
    
    print(f"Loaded model from {checkpoint_dir}")
    print(f"  Architecture: {config.architecture}")
    print(f"  Positional encoding: {config.positional_encoding}")
    print(f"  Softmax variant: {config.softmax_variant}")
    print(f"  Layers: {config.n_layers}, Heads: {config.n_heads}")
    print(f"  Device: {device}")
    
    return model, device

In [4]:
def get_attention_weights(model, input_text: str, device: str):
    """Get attention weights for an input expression.
    
    Args:
        model: Loaded model
        input_text: Input expression (e.g., '12+34=')
        device: Device model is on
    
    Returns:
        Tuple of (attention_weights, tokens)
        - attention_weights: tuple of tensors (batch, heads, seq, seq) per layer
        - tokens: list of token strings
    """
    # Tokenize input
    inputs = tokenizer.encode(input_text, return_tensors='pt').to(device)
    tokens = tokenizer.convert_ids_to_tokens(inputs[0])
    
    # Get model output with attention weights
    with torch.no_grad():
        # Use dummy labels to get dict output with attentions
        outputs = model(inputs, labels=inputs, output_attentions=True)
    
    # Extract attention weights
    attention = outputs["attentions"]
    
    return attention, tokens

In [5]:
def visualize_with_generation(model, input_text: str, device: str, max_new_tokens: int = 50):
    """Generate output and visualize attention on the full sequence.
    
    Args:
        model: Loaded model
        input_text: Input expression (e.g., '12+34=')
        device: Device model is on
        max_new_tokens: Maximum tokens to generate
    
    Returns:
        Tuple of (attention_weights, tokens, generated_text)
    """
    # First generate the full sequence
    inputs = tokenizer.encode(input_text, return_tensors='pt').to(device)
    
    with torch.no_grad():
        generated = model.generate(inputs, max_new_tokens=max_new_tokens, temperature=0.1)
    
    generated_text = tokenizer.decode(generated[0])
    print(f"Generated: {generated_text}")
    
    # Now get attention for the full generated sequence
    with torch.no_grad():
        outputs = model(generated, labels=generated, output_attentions=True)
    
    attention = outputs["attentions"]
    tokens = tokenizer.convert_ids_to_tokens(generated[0])
    
    return attention, tokens, generated_text

## Load a Model

Available checkpoints:
- `checkpoints/standard-small` - Learned positional encoding
- `checkpoints/standard-small-sinusoidal` - Sinusoidal positional encoding  
- `checkpoints/standard-small-rope` - RoPE positional encoding
- `checkpoints/standard-small-pope` - PoPE positional encoding
- `checkpoints/standard-small-softmax1` - Learned + softmax1
- `checkpoints/standard-small-pope-softmax1` - PoPE + softmax1

In [8]:
# Load the PoPE model (best performing)
model, device = load_model("checkpoints/standard-small-pope")

Loaded model from checkpoints/standard-small-pope
  Architecture: standard
  Positional encoding: pope
  Softmax variant: standard
  Layers: 4, Heads: 4
  Device: cuda


## Visualize Attention on Input Only

This shows attention patterns on just the input expression.

In [9]:
input_expr = "12+34="
attention, tokens = get_attention_weights(model, input_expr, device)

print(f"Input: {input_expr}")
print(f"Tokens: {tokens}")
print(f"Number of layers: {len(attention)}")
print(f"Attention shape per layer: {attention[0].shape}")

Input: 12+34=
Tokens: ['1', '2', '+', '3', '4', '=']
Number of layers: 4
Attention shape per layer: torch.Size([1, 4, 6, 6])


In [10]:
# Head view - detailed view of individual attention heads
head_view(attention, tokens)

<IPython.core.display.Javascript object>

In [11]:
# Model view - bird's eye view across all layers and heads
model_view(attention, tokens)

<IPython.core.display.Javascript object>

## Visualize Attention on Generated Output

This shows attention patterns on the full generated sequence including the chain-of-thought reasoning.

In [None]:
input_expr = "99+21="
attention, tokens, generated_text = visualize_with_generation(model, input_expr, device)

print(f"\nTokens: {tokens}")
print(f"Number of layers: {len(attention)}")
print(f"Attention shape per layer: {attention[0].shape}")

In [None]:
# Head view on full generated sequence
head_view(attention, tokens)

In [None]:
# Model view on full generated sequence
model_view(attention, tokens)

## Compare Different Models

Load and compare attention patterns from different model variants.

In [None]:
# Compare learned vs RoPE vs PoPE positional encodings
model_configs = [
    "checkpoints/standard-small",
    "checkpoints/standard-small-rope", 
    "checkpoints/standard-small-pope",
]

input_expr = "45+67="

for checkpoint_dir in model_configs:
    print(f"\n{'='*60}")
    model, device = load_model(checkpoint_dir)
    attention, tokens, generated_text = visualize_with_generation(model, input_expr, device)
    print(f"\nHead view for {checkpoint_dir.split('/')[-1]}:")
    head_view(attention, tokens)


Loaded model from checkpoints/standard-small
  Architecture: standard
  Positional encoding: learned
  Softmax variant: standard
  Layers: 4, Heads: 4
  Device: cuda
Generated: 45+67=612<end>

Head view for standard-small:


<IPython.core.display.Javascript object>


Loaded model from checkpoints/standard-small-rope
  Architecture: standard
  Positional encoding: rope
  Softmax variant: standard
  Layers: 4, Heads: 4
  Device: cuda
Generated: 45+67=83<end>

Head view for standard-small-rope:


<IPython.core.display.Javascript object>

## Analyze Softmax1 "Quiet Attention"

Softmax1 allows attention heads to "abstain" by having attention weights sum to less than 1. Compare attention patterns between standard and softmax1 models.

In [None]:
# Load softmax1 model
model_softmax1, device = load_model("checkpoints/standard-small-pope-softmax1")

input_expr = "99+21="
attention_s1, tokens_s1, generated_s1 = visualize_with_generation(model_softmax1, input_expr, device)

# Check attention weight sums - softmax1 weights sum to < 1
for layer_idx, attn in enumerate(attention_s1):
    weight_sums = attn[0].sum(dim=-1)  # Sum over keys
    min_sum = weight_sums.min().item()
    max_sum = weight_sums.max().item()
    mean_sum = weight_sums.mean().item()
    print(f"Layer {layer_idx}: attention weight sums - min={min_sum:.3f}, max={max_sum:.3f}, mean={mean_sum:.3f}")

In [None]:
# Visualize softmax1 attention
head_view(attention_s1, tokens_s1)

## Save Visualizations as HTML

You can save visualizations as standalone HTML files for sharing.

In [None]:
# Generate attention for a sample input
model, device = load_model("checkpoints/standard-small-pope")
input_expr = "123+456="
attention, tokens, generated_text = visualize_with_generation(model, input_expr, device)

# Save head view as HTML
html_head = head_view(attention, tokens, html_action='return')
output_path = project_root / "notebooks" / "attention_head_view.html"
with open(output_path, 'w') as f:
    f.write(html_head.data)
print(f"Saved head view to {output_path}")

# Save model view as HTML
html_model = model_view(attention, tokens, html_action='return')
output_path = project_root / "notebooks" / "attention_model_view.html"
with open(output_path, 'w') as f:
    f.write(html_model.data)
print(f"Saved model view to {output_path}")