# Mechanistic Interpretability Setup - Google Colab

This notebook sets up a complete environment for mechanistic interpretability research.

**What this does:**
- Installs PyTorch, TransformerLens, and visualization tools
- Downloads GPT-2 Medium or Llama 3.2 1B
- Provides helper functions for activation extraction
- Sets up experiment tracking with W&B

## 1. Install Dependencies

In [None]:
%%capture
# Install core dependencies
!pip install torch>=2.0.0 --index-url https://download.pytorch.org/whl/cu118
!pip install transformer-lens circuitsvis transformers datasets
!pip install plotly altair wandb einops
!pip install ipywidgets

print("✓ Dependencies installed successfully!")

## 2. Imports and Configuration

In [None]:
import torch
import numpy as np
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
import plotly.express as px
import plotly.graph_objects as go
from typing import List, Dict, Tuple, Callable, Optional
from functools import partial
import einops
from tqdm.auto import tqdm
import pandas as pd

# CircuitsVis for interactive visualizations
import circuitsvis as cv

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configure torch
torch.set_grad_enabled(False)  # We're doing inference by default

## 3. Load Model

Choose your model:
- `gpt2-medium` (355M params) - Good starting point
- `gpt2-small` (117M params) - Faster for testing
- `meta-llama/Llama-3.2-1B` (requires HF token)

In [None]:
# Choose model
MODEL_NAME = "gpt2-medium"  # or "gpt2-small", "meta-llama/Llama-3.2-1B"

# Load model with TransformerLens
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

print(f"✓ Loaded {MODEL_NAME}")
print(f"  - Layers: {model.cfg.n_layers}")
print(f"  - Heads: {model.cfg.n_heads}")
print(f"  - d_model: {model.cfg.d_model}")
print(f"  - d_head: {model.cfg.d_head}")
print(f"  - d_mlp: {model.cfg.d_mlp}")

## 4. Helper Functions for Activation Extraction

In [None]:
def get_activation(
    model: HookedTransformer,
    prompt: str,
    layer: int,
    component: str = "resid_post",
) -> torch.Tensor:
    """
    Extract activations from a specific layer and component.
    
    Args:
        model: HookedTransformer model
        prompt: Input text
        layer: Layer index (0 to n_layers-1)
        component: One of ['resid_pre', 'resid_mid', 'resid_post', 'attn_out', 'mlp_out']
    
    Returns:
        Activation tensor of shape [batch, seq_len, d_model]
    """
    cache = {}
    
    def hook_fn(activation, hook):
        cache[hook.name] = activation.detach().cpu()
    
    hook_name = f"blocks.{layer}.hook_{component}"
    model.run_with_hooks(
        prompt,
        fwd_hooks=[(hook_name, hook_fn)]
    )
    
    return cache[hook_name]


def get_attention_patterns(
    model: HookedTransformer,
    prompt: str,
    layer: Optional[int] = None,
) -> torch.Tensor:
    """
    Get attention patterns for all or specific layer.
    
    Args:
        model: HookedTransformer model
        prompt: Input text
        layer: Optional layer index. If None, returns all layers.
    
    Returns:
        Attention pattern tensor [batch, n_heads, seq_len, seq_len] or
        [batch, n_layers, n_heads, seq_len, seq_len] if layer is None
    """
    _, cache = model.run_with_cache(prompt)
    
    if layer is not None:
        return cache["pattern", layer]  # [batch, n_heads, seq_len, seq_len]
    else:
        # Stack all layers
        patterns = torch.stack(
            [cache["pattern", l] for l in range(model.cfg.n_layers)],
            dim=1
        )
        return patterns  # [batch, n_layers, n_heads, seq_len, seq_len]


def visualize_attention(
    model: HookedTransformer,
    prompt: str,
    layer: int,
    head: int,
):
    """
    Visualize attention pattern for a specific head.
    
    Args:
        model: HookedTransformer model
        prompt: Input text
        layer: Layer index
        head: Head index
    """
    tokens = model.to_str_tokens(prompt)
    attention = get_attention_patterns(model, prompt, layer)
    
    # Get specific head pattern
    # pattern = attention[0, head].cpu().numpy()  # [seq_len, seq_len]
    pattern = attention[0].cpu().numpy()  # [seq_len, seq_len]
    
    # Use CircuitsVis for interactive visualization
    return cv.attention.attention_patterns(
        tokens=tokens,
        attention=pattern,
    )


def get_mlp_activations(
    model: HookedTransformer,
    prompt: str,
    layer: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Get MLP pre and post activations.
    
    Args:
        model: HookedTransformer model
        prompt: Input text
        layer: Layer index
    
    Returns:
        Tuple of (pre_activation, post_activation)
        pre: [batch, seq_len, d_mlp]
        post: [batch, seq_len, d_mlp]
    """
    cache = {}
    
    def hook_fn(activation, hook):
        cache[hook.name] = activation.detach().cpu()
    
    hooks = [
        (f"blocks.{layer}.mlp.hook_pre", hook_fn),
        (f"blocks.{layer}.mlp.hook_post", hook_fn),
    ]
    
    model.run_with_hooks(prompt, fwd_hooks=hooks)
    
    return (
        cache[f"blocks.{layer}.mlp.hook_pre"],
        cache[f"blocks.{layer}.mlp.hook_post"],
    )


print("✓ Helper functions defined")

## 5. Test Model & Helpers

In [None]:
# Test prompt
test_prompt = "The Eiffel Tower is located in the city of"

# Generate completion
output = model.generate(test_prompt, max_new_tokens=5, temperature=0.0)
print(f"Prompt: {test_prompt}")
print(f"Completion: {output}")
print()

# Test activation extraction
activation = get_activation(model, test_prompt, layer=5, component="resid_post")
print(f"Activation shape: {activation.shape}")

# Test attention patterns
attn = get_attention_patterns(model, test_prompt, layer=5)
print(f"Attention pattern shape: {attn.shape}")
print()

# Visualize attention for layer 5, head 0
print("Attention visualization for Layer 5, Head 0:")
visualize_attention(model, test_prompt, layer=5, head=0)

## 6. Experiment Tracking Setup (W&B)

In [None]:
import wandb
from datetime import datetime

def init_experiment(
    project_name: str = "mech-interp",
    experiment_name: Optional[str] = None,
    config: Optional[dict] = None,
):
    """
    Initialize experiment tracking.
    
    Args:
        project_name: W&B project name
        experiment_name: Optional custom name (auto-generated if None)
        config: Configuration dictionary to log
    """
    if experiment_name is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        experiment_name = f"exp_{timestamp}"
    
    # Initialize W&B
    wandb.init(
        project=project_name,
        name=experiment_name,
        config=config or {},
    )
    
    print(f"✓ Experiment tracking initialized: {experiment_name}")


# Example usage (commented out - uncomment when ready to track)
# init_experiment(
#     project_name="mech-interp",
#     experiment_name="activation_patching_demo",
#     config={
#         "model": MODEL_NAME,
#         "task": "fact_tracing",
#         "dataset": "counterfact",
#     }
# )

print("✓ Experiment tracking functions defined")

## 7. Quick Start: Common Hook Patterns

In [None]:
# Pattern 1: Cache all activations
def cache_all_activations(model: HookedTransformer, prompt: str):
    """Cache all activations from a forward pass."""
    logits, cache = model.run_with_cache(prompt)
    return logits, cache


# Pattern 2: Activation patching
def patch_activation(
    model: HookedTransformer,
    clean_prompt: str,
    corrupted_prompt: str,
    layer: int,
    component: str = "resid_post",
) -> torch.Tensor:
    """Patch activation from corrupted run into clean run."""
    # Get corrupted activation
    corrupted_act = get_activation(model, corrupted_prompt, layer, component)
    
    # Patch into clean run
    def patch_hook(activation, hook):
        activation[:] = corrupted_act.to(activation.device)
        return activation
    
    hook_name = f"blocks.{layer}.hook_{component}"
    patched_logits = model.run_with_hooks(
        clean_prompt,
        fwd_hooks=[(hook_name, patch_hook)]
    )
    
    return patched_logits


# Pattern 3: Mean ablation
def mean_ablate_head(
    model: HookedTransformer,
    prompt: str,
    layer: int,
    head: int,
) -> torch.Tensor:
    """Mean ablate a specific attention head."""
    def ablate_hook(activation, hook):
        # activation shape: [batch, seq, n_heads, d_head]
        activation[:, :, head, :] = activation[:, :, head, :].mean(dim=1, keepdim=True)
        return activation
    
    hook_name = f"blocks.{layer}.attn.hook_result"
    ablated_logits = model.run_with_hooks(
        prompt,
        fwd_hooks=[(hook_name, ablate_hook)]
    )
    
    return ablated_logits


print("✓ Common hook patterns defined")
print("\nYou're all set! Start exploring with the helper functions above.")

## Next Steps

1. **Explore Attention Patterns**: Use `visualize_attention()` on different prompts
2. **Activation Patching**: Try the `patch_activation()` function for causal tracing
3. **Circuit Discovery**: Mean ablate heads/neurons to find important components
4. **SAE Training**: Train sparse autoencoders on MLP activations

Useful hook names in TransformerLens:
- `blocks.{l}.hook_resid_pre` - Residual stream before layer
- `blocks.{l}.hook_resid_mid` - After attention, before MLP
- `blocks.{l}.hook_resid_post` - After full layer
- `blocks.{l}.attn.hook_pattern` - Attention patterns
- `blocks.{l}.attn.hook_result` - Attention output per head
- `blocks.{l}.mlp.hook_post` - MLP neuron activations (post-GELU)