# Activation Analysis for Mechanistic Interpretability

This notebook demonstrates how to extract and analyze activations from transformer models to understand how they represent factual knowledge.

## Overview

We'll cover:
1. Creating and managing fact datasets
2. Extracting activations from multiple layers and components
3. Computing activation statistics
4. Visualizing activation patterns
5. Comparing true vs false fact representations

## Key Question
**How do transformer models represent factual knowledge internally, and do true facts activate differently than false facts?**

In [None]:
# Setup and imports
import sys
sys.path.append('..')

import torch
from src.utils import setup_logging, set_seed, load_model
from src.activation_extraction import ActivationExtractor, ActivationConfig
from src.fact_dataset import FactDataset, create_sample_dataset
from src.visualization import (
    plot_activation_magnitude_heatmap,
    plot_activation_comparison,
    plot_pca_activations,
    plot_activation_space_comparison,
)

# Setup
setup_logging()
set_seed(42)

print("✓ Imports successful")

## Part 1: Creating a Fact Dataset

First, we'll create a dataset of factual statements. Each fact has:
- **Subject**: The entity (e.g., "Eiffel Tower")
- **Relation**: The relationship (e.g., "located_in")
- **Object**: The target (e.g., "Paris")
- **is_true**: Whether this is a true fact or counterfactual

In [None]:
# Create a sample dataset with true and false facts
dataset = create_sample_dataset()

print(f"Dataset: {dataset}")
print(f"Total facts: {len(dataset)}")

# Show some examples
print("\nExample facts:")
for i in range(6):
    fact = dataset[i]
    prompt = fact.to_prompt()
    status = "✓ TRUE" if fact.is_true else "✗ FALSE"
    print(f"  [{status}] {prompt}")

In [None]:
# You can also create custom facts
custom_dataset = FactDataset()

# Add a true fact and its counterfactual
custom_dataset.add_fact_with_counterfactual(
    subject="Mount Everest",
    relation="located_in",
    true_object="Nepal",
    false_object="Switzerland"
)

print("Custom facts:")
for prompt, is_true in custom_dataset.to_prompts(include_labels=True):
    print(f"  [{is_true}] {prompt}")

In [None]:
# Filter facts by type
true_facts = dataset.filter(is_true=True)
false_facts = dataset.filter(is_true=False)

print(f"True facts: {len(true_facts)}")
print(f"False facts: {len(false_facts)}")

# Get prompts for model input
true_prompts = true_facts.to_prompts()
false_prompts = false_facts.to_prompts()

print(f"\nFirst true prompt: {true_prompts[0]}")
print(f"First false prompt: {false_prompts[0]}")

## Part 2: Loading a Model

We'll use TransformerLens to load a pre-trained model. This library provides hooks for extracting internal activations.

In [None]:
# Load GPT-2 Small (fast for demonstration)
# You can also use: 'gpt2-medium', 'gpt2-large', or 'meta-llama/Llama-3.2-1B'
model = load_model('gpt2-small')

print(f"Model: {model.cfg.model_name}")
print(f"  Layers: {model.cfg.n_layers}")
print(f"  Heads per layer: {model.cfg.n_heads}")
print(f"  Hidden dimension: {model.cfg.d_model}")

## Part 3: Configuring Activation Extraction

The `ActivationConfig` class lets us specify:
- **components**: Which parts to extract (residual stream, attention output, MLP output)
- **layers**: Which layers to analyze (default: all)
- **aggregate_positions**: Whether to average across token positions
- **return_cpu**: Whether to move tensors to CPU (saves GPU memory)

In [None]:
# Configure what activations to extract
config = ActivationConfig(
    components=['resid_post'],  # Residual stream after each layer
    aggregate_positions=True,    # Average across sequence positions
    return_cpu=True,             # Move to CPU to save GPU memory
)

# Create the extractor
extractor = ActivationExtractor(model, config)

print("✓ Extractor configured")
print(f"  Components: {config.components}")
print(f"  Layers: {len(extractor.layers)} (all)")
print(f"  Aggregate positions: {config.aggregate_positions}")

## Part 4: Extracting Activations

Now we'll run the model on our facts and extract activations from all layers.

In [None]:
# Extract activations for true facts
print("Extracting activations for true facts...")
true_activations = extractor.extract(true_prompts)

print("\nExtracting activations for false facts...")
false_activations = extractor.extract(false_prompts)

# Get the component we're analyzing
component = 'resid_post'
true_acts = true_activations[component]
false_acts = false_activations[component]

print(f"\n✓ Extraction complete!")
print(f"True activations shape: {true_acts.shape}")  # [n_true, n_layers, d_model]
print(f"False activations shape: {false_acts.shape}")  # [n_false, n_layers, d_model]

## Part 5: Computing Activation Statistics

Let's compute summary statistics to understand the activation patterns.

In [None]:
# Compute statistics for true facts
true_stats = extractor.get_activation_stats({component: true_acts})

# Compute statistics for false facts
false_stats = extractor.get_activation_stats({component: false_acts})

print("True Facts Statistics:")
for stat, value in true_stats[component].items():
    print(f"  {stat}: {value:.4f}")

print("\nFalse Facts Statistics:")
for stat, value in false_stats[component].items():
    print(f"  {stat}: {value:.4f}")

# Compare L2 norms
diff = true_stats[component]['l2_norm'] - false_stats[component]['l2_norm']
pct_diff = (diff / true_stats[component]['l2_norm']) * 100
print(f"\nL2 Norm Difference: {diff:.4f} ({pct_diff:.2f}%)")

## Part 6: Visualizing Activation Magnitudes

Let's create a heatmap showing activation magnitudes across layers for all prompts.

In [None]:
# Combine true and false activations
all_acts = torch.cat([true_acts, false_acts], dim=0)

# Create labels
labels = (
    [f"✓ {p[:40]}..." for p in true_prompts] +
    [f"✗ {p[:40]}..." for p in false_prompts]
)

# Plot heatmap
fig = plot_activation_magnitude_heatmap(
    all_acts,
    labels=labels,
    title=f"Activation Magnitudes - {component}",
    metric='l2_norm'
)

fig.show()

**Observation**: Look for patterns in which layers show stronger activations for true vs false facts.

## Part 7: Comparing True vs False Activations

Let's directly compare the activation magnitudes across layers.

In [None]:
# Plot comparison across all layers
fig = plot_activation_comparison(
    true_acts,
    false_acts,
    title=f"True vs False Facts - {component}"
)

fig.show()

**Observation**: This plot shows mean L2 norm ± std deviation. Do later layers show bigger differences?

In [None]:
# Compare a specific layer in detail
middle_layer = model.cfg.n_layers // 2

fig = plot_activation_comparison(
    true_acts,
    false_acts,
    layer_idx=middle_layer,
    title=f"Layer {middle_layer} Distribution"
)

fig.show()

## Part 8: Dimensionality Reduction (PCA)

Use PCA to visualize the high-dimensional activation space in 2D.

In [None]:
# Analyze a specific layer
layer_to_analyze = model.cfg.n_layers - 1  # Last layer

# PCA visualization
fig = plot_activation_space_comparison(
    true_acts,
    false_acts,
    method='pca',
    layer_idx=layer_to_analyze,
    title=f"PCA - Layer {layer_to_analyze}"
)

fig.show()

**Observation**: Are true and false facts separable in activation space? What does the explained variance tell us?

## Part 9: t-SNE Visualization

t-SNE captures non-linear structure better than PCA.

In [None]:
# t-SNE visualization (takes longer to compute)
fig = plot_activation_space_comparison(
    true_acts,
    false_acts,
    method='tsne',
    layer_idx=layer_to_analyze,
    title=f"t-SNE - Layer {layer_to_analyze}"
)

fig.show()

**Observation**: Do true and false facts cluster separately? This suggests the model has learned distinct representations.

## Part 10: Analyzing Different Components

Let's compare different model components (attention output vs MLP output).

In [None]:
# Extract multiple components
multi_config = ActivationConfig(
    components=['attn_out', 'mlp_out', 'resid_post'],
    aggregate_positions=True,
    return_cpu=True,
)

multi_extractor = ActivationExtractor(model, multi_config)

print("Extracting multiple components...")
true_multi = multi_extractor.extract(true_prompts[:5])  # Use subset for speed
false_multi = multi_extractor.extract(false_prompts[:5])

# Compare statistics across components
print("\nComponent Statistics (True Facts):")
for comp in multi_config.components:
    stats = multi_extractor.get_activation_stats({comp: true_multi[comp]})
    print(f"  {comp}: L2 norm = {stats[comp]['l2_norm']:.4f}")

## Part 11: Layer-by-Layer Analysis

Examine how representations evolve through the network.

In [None]:
# Compute L2 norms per layer
true_norms = torch.norm(true_acts, dim=-1)  # [n_true, n_layers]
false_norms = torch.norm(false_acts, dim=-1)  # [n_false, n_layers]

# Find layers with biggest differences
mean_diff_per_layer = (true_norms.mean(dim=0) - false_norms.mean(dim=0)).abs()
top_layers = mean_diff_per_layer.argsort(descending=True)[:3]

print("Top 3 layers with largest true/false differences:")
for i, layer in enumerate(top_layers):
    diff = mean_diff_per_layer[layer].item()
    print(f"  {i+1}. Layer {layer}: difference = {diff:.4f}")

## Part 12: Analyzing Specific Fact Types

Compare different relations (e.g., geographic facts vs people facts).

In [None]:
# Filter by relation type
location_facts = dataset.filter(relation='located_in', is_true=True)
capital_facts = dataset.filter(relation='capital_of', is_true=True)

if len(location_facts) > 0 and len(capital_facts) > 0:
    location_prompts = location_facts.to_prompts()
    capital_prompts = capital_facts.to_prompts()
    
    location_acts = extractor.extract(location_prompts)['resid_post']
    capital_acts = extractor.extract(capital_prompts)['resid_post']
    
    # Compare
    fig = plot_activation_comparison(
        location_acts,
        capital_acts,
        title="Location Facts vs Capital Facts"
    )
    fig.show()
else:
    print("Not enough facts of each type in dataset")

---

## Summary

We've demonstrated:
1. ✓ Creating and managing fact datasets
2. ✓ Extracting activations from transformer models
3. ✓ Computing activation statistics
4. ✓ Visualizing activation patterns with heatmaps
5. ✓ Comparing true vs false fact representations
6. ✓ Using dimensionality reduction (PCA, t-SNE)
7. ✓ Analyzing different model components

---

## Suggested Experiments

Here are experiments you can run to deepen your understanding:

### 1. Model Comparison
**Question**: Do larger models show clearer separation between true and false facts?
- Compare GPT-2 Small, Medium, and Large
- Plot PCA for each and measure cluster separation

### 2. Layer Analysis
**Question**: At which layer do representations become most discriminative?
- Extract activations layer by layer
- Train a linear classifier (true vs false) on each layer
- Plot classification accuracy vs layer depth

### 3. Component Importance
**Question**: Are attention or MLP outputs more important for factual knowledge?
- Extract `attn_out`, `mlp_out`, and `resid_post` for the same prompts
- Compare separation in PCA space
- Hypothesis: MLP might be more important (they're the "memories")

### 4. Relation-Specific Patterns
**Question**: Do different relation types activate different parts of the model?
- Create datasets with diverse relations: `located_in`, `born_in`, `invented_by`, etc.
- Extract activations for each relation type
- Use PCA/t-SNE to see if relations cluster separately

### 5. Prompt Engineering Effects
**Question**: How does prompt format affect activations?
- Try different templates: "X is in Y" vs "The location of X is Y" vs "X, located in Y"
- Compare activation patterns
- Does the model represent the same fact differently?

### 6. Counterfactual Analysis
**Question**: How do activations change as facts become "more wrong"?
- Create facts with varying degrees of incorrectness:
  - "Paris is in France" (TRUE)
  - "Paris is in Germany" (nearby country)
  - "Paris is in Japan" (far country)
  - "Paris is in Mars" (impossible)
- Plot activation distances from the true fact

### 7. Fine-Tuning Impact
**Question**: Can we make the model's representations more discriminative?
- Fine-tune GPT-2 on factual question-answering
- Compare before/after activation patterns
- Does fine-tuning increase true/false separation?

### 8. Activation Sparsity
**Question**: Are true facts represented more sparsely?
- Compute sparsity metrics (% of near-zero activations)
- Compare true vs false facts
- Hypothesis: True facts might use fewer, more specific neurons

### 9. Temporal Dynamics
**Question**: How do activations evolve during sequence processing?
- Don't aggregate positions - keep full sequence
- Plot activation trajectories through the sequence
- When does the model "realize" a fact is false?

### 10. Cross-Lingual Facts
**Question**: Do multilingual models represent facts consistently across languages?
- Use a multilingual model (mBERT, XLM-R)
- State same fact in different languages
- Compare activation patterns - are they similar?

---

## Code Template for Experiments

In [None]:
# Template: Compare two conditions

def run_experiment(condition_a_prompts, condition_b_prompts, name_a="A", name_b="B"):
    """Compare activations between two conditions."""
    
    # Extract activations
    acts_a = extractor.extract(condition_a_prompts)['resid_post']
    acts_b = extractor.extract(condition_b_prompts)['resid_post']
    
    # Statistics
    stats_a = extractor.get_activation_stats({'resid_post': acts_a})
    stats_b = extractor.get_activation_stats({'resid_post': acts_b})
    
    print(f"{name_a} L2 norm: {stats_a['resid_post']['l2_norm']:.4f}")
    print(f"{name_b} L2 norm: {stats_b['resid_post']['l2_norm']:.4f}")
    
    # Visualization
    fig = plot_activation_comparison(
        acts_a, acts_b,
        title=f"{name_a} vs {name_b}"
    )
    fig.show()
    
    # PCA
    fig_pca = plot_activation_space_comparison(
        acts_a, acts_b,
        method='pca',
        layer_idx=-1,
        title=f"PCA: {name_a} vs {name_b}"
    )
    fig_pca.show()
    
    return acts_a, acts_b

# Example usage:
# run_experiment(true_prompts, false_prompts, "True", "False")