<a href="https://colab.research.google.com/github/cianadeveau/MechInterp/blob/main/mechinterp_intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install transformer_lens

In [None]:
import transformer_lens as tl
from transformer_lens import HookedTransformer
import torch
import einops
from jaxtyping import Float, Int
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small")
print(f"Model has {model.cfg.n_layers} layers, {model.cfg.n_heads} heads per layer")

# Model Architecture

In [None]:
print("Model configuration:")
print(f"~ Vocabulary size: {model.cfg.d_vocab}")
print(f"~ Embedding dimension: {model.cfg.d_model}")
print(f"~ Number of layers: {model.cfg.n_layers}")
print(f"~ Number of heads: {model.cfg.n_heads}")
print(f"~ Attention Head dimension: {model.cfg.d_head}")
print(f"~ MLP dimension: {model.cfg.d_mlp}")

In [None]:
print('Weight matrix shapes:')
for layer in range(model.cfg.n_layers):
  print(f"Layer {layer}:")
  print(f" W_Q: {model.W_Q[layer].shape}")
  print(f" W_K: {model.W_K[layer].shape}")
  print(f" W_V: {model.W_V[layer].shape}")
  print(f" W_O: {model.W_O[layer].shape}")

# Forward Pass with Caching

In [None]:
prompt = "The quick brown fox"
tokens = model.to_tokens(prompt, prepend_bos=False) #prepend eliminates the <|endoftext|> token
logits = model(tokens)
print(f"Tokens shape: {tokens.shape}")
print(f"Logits shape: {logits.shape}")

In [None]:
print('Check tokens:')
print(f'Tokens: {tokens}')
print('Token Strings:')
for i, token in enumerate(tokens[0]):
  print(f'Position {i}: {repr(model.to_string(token))}')

In [None]:
# Get all activations with caching
logits, cache = model.run_with_cache(prompt, prepend_bos=False)

# Access specific activations
print("Available activations:")
for key in list(cache.keys())[:20]:  # Show first 10
    print(f"  {key}: {cache[key].shape}")

# hook_attn_scores: raw attention scores before softmax
# hook_pattern: attention weights after softmax - you use this most of the time because these weights determine information flow
# hook_z: concatenated outputs of all attention heads before final output projection

In [None]:
# Quick reminder for how these transformations can be done with einops
# attn_pattern: [batch, n_heads, seq_len, seq_len] = [1, 12, 4, 4]
# v: [batch, seq_len, n_heads, d_head] = [1, 4, 12, 64]
# need to do: pattern @ v for each head so that hinges on the n_heads axis
# Rearrange v to match pattern's head dimension order
# v_rearranged = einops.rearrange(cache['hook_v'], 'batch seq heads d_head -> batch heads seq d_head') # [1, 12, 4, 64]
# Matrix multiplication (pattern @ v for each head)
# z = torch.einsum('batch heads seq_out seq_in, batch heads seq_in d_head -> batch heads seq_out d_head',
# cache['hook_pattern'], v_rearranged) # [1, 12, 4, 64]
# Rearrange back to [batch, seq, heads, d_head]
# z = einops.rearrange(z, 'batch heads seq d_head -> batch seq heads d_head')
# Final z: [1, 4, 12, 64]

# From z to attn_out
# Concatenate all heads (flatten the head and d_head dimensions)
# z_concat = einops.rearrange(z, 'batch seq n_heads d_head -> batch seq (n_heads d_head)')
# z_concat: [1, 4, 768]

# Apply output projection W_O
# attn_out = z_concat @ W_O
# attn_out: [1, 4, 768]

#Attention Pattern Visualization

In [None]:
# Take a specific attention head's pattern
layer = 0
head = 7

# get the attention pattern for this head
attn_pattern = cache[f'blocks.{layer}.attn.hook_pattern'][0, head]
print(f'Attention pattern shape: {attn_pattern.shape}')

# show the tokens
for i, token in enumerate(tokens[0]):
  print(f'Position {i}: {repr(model.to_string(token))}')

In [None]:
attn_pattern

In [None]:
token_labels = [model.to_string(token) for token in tokens[0]]

fig = px.imshow(attn_pattern.detach().cpu().numpy(),
                x=token_labels,
                y=token_labels,
                title=f'Attention Pattern - Layer {layer}, Head {head}',
                labels=dict(x='Attended to', y='Attending From'))
fig.show()

In [None]:
# Common attention head types:
# previous token heads: attend to the token right before
# self-attention heads: attend mostly to themselves
# broad attention heads: attend to "The" or other important tokens (like head 0)
# positional heads: attend based on relative positions
# induction heads: look for repeated patterns

# Direct Logit Attribution

In [None]:
prompt = "The captial of France is"
tokens = model.to_tokens(prompt, prepend_bos=False)
logits, cache = model.run_with_cache(tokens)

# answer token we want to analyze
answer_token = model.to_single_token("Paris")
print(f'Looking for token: {repr(model.to_string(answer_token))}')

In [None]:
# Decompose logits by layer
per_layer_logits = torch.zeros(model.cfg.n_layers + 1, model.cfg.d_vocab)

# Layer 0: Just embeddings (before any transformer layers)
embed = cache["hook_embed"]
pos_embed = cache["hook_pos_embed"]
initial_residual = embed + pos_embed
per_layer_logits[0] = initial_residual[0, -1] @ model.W_U # last token position

# Each subsequent layer's contribution
for layer in range(model.cfg.n_layers):
  layer_output = cache[f'blocks.{layer}.hook_resid_post'][0, -1] # last token
  per_layer_logits[layer + 1] = layer_output @ model.W_U

# Analyze contributions to "Paris"
contributions = per_layer_logits[:, answer_token]
print("Logit contribution to 'Paris':")
print(f"Embeddings: {contributions[0]:.3f}")
for layer in range(model.cfg.n_layers):
  print(f"After Layer {layer}: {contributions[layer+1]:.3f}")

In [None]:
layer_names = ["Embeddings"] + [f"Layer {i}" for i in range(model.cfg.n_layers)]

fig = px.bar(
    x=layer_names,
    y=contributions.detach().cpu().numpy(),
    labels={"x": "Layer", "y": "Logit Contribution"},
    title="Logit Attribution to 'Paris' by Layer"
)
fig.show()

In [None]:
# What does the embedding predict by itself?
final_token_embed = initial_residual[0, -1]  # Last token ("is")
embedding_logits = final_token_embed @ model.W_U
embedding_prediction = embedding_logits.argmax()

print(f"Embedding alone predicts: {repr(model.to_string(embedding_prediction))}")
print(f"Embedding logit for 'Paris': {embedding_logits[answer_token]:.3f}")

# Let's also see what the top predictions are from embeddings alone
top_5_embedding = torch.topk(embedding_logits, 5)
print("\nTop 5 predictions from embeddings alone:")
for i, (logit, token_id) in enumerate(zip(top_5_embedding.values, top_5_embedding.indices)):
    print(f"{i+1}. {repr(model.to_string(token_id))}: {logit:.3f}")

In [None]:
# Compare layer 0 vs layer 2 predictions
layer_0_output = cache["blocks.0.hook_resid_post"][0, -1]
layer_2_output = cache["blocks.2.hook_resid_post"][0, -1]

layer_0_logits = layer_0_output @ model.W_U
layer_2_logits = layer_2_output @ model.W_U

print(f"\nAfter Layer 0 - Paris logit: {layer_0_logits[answer_token]:.3f}")
print(f"After Layer 2 - Paris logit: {layer_2_logits[answer_token]:.3f}")

In [None]:
# representative layers
token_labels = [model.to_string(token) for token in tokens[0]]
layers_to_check = [0, 3, 6, 9, 11]  # Early to late
head = 0  # We'll use head 0 for consistency

fig, axes = plt.subplots(1, len(layers_to_check), figsize=(20, 4))

for i, layer in enumerate(layers_to_check):
    pattern = cache[f"blocks.{layer}.attn.hook_pattern"][0, head].detach().cpu().numpy()

    im = axes[i].imshow(pattern, cmap='Blues')
    axes[i].set_title(f'Layer {layer}, Head {head}')
    axes[i].set_xticks(range(len(token_labels)))
    axes[i].set_yticks(range(len(token_labels)))
    axes[i].set_xticklabels(token_labels, rotation=45)
    axes[i].set_yticklabels(token_labels)

plt.tight_layout()
plt.show()

In [None]:
# Look specifically at what the last token attends to across layers
last_token_pos = -1

print("What the last token attends to across layers:")
for layer in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]:
    pattern = cache[f"blocks.{layer}.attn.hook_pattern"][0, 2]  # head 0
    last_token_attention = pattern[last_token_pos, :]

    print(f"\nLayer {layer}:")
    for pos, attention in enumerate(last_token_attention):
        print(f"  To '{token_labels[pos]}': {attention:.3f}")

## See how much an attention head changes across layers

In [None]:
def attention_pattern_distance(pattern1, pattern2):
  '''how different two attention patterns are'''
  flat1 = pattern1.flatten()
  flat2 = pattern2.flatten()
  return 1 - F.cosine_similarity(flat1, flat2, dim=0)

change_scores = []
for head in range(model.cfg.n_heads):
  early_pattern = cache[f'blocks.1.attn.hook_pattern'][0, head]
  late_pattern = cache[f'blocks.10.attn.hook_pattern'][0, head]
  score = attention_pattern_distance(early_pattern, late_pattern)
  change_scores.append(score.item())
  print(f'Head {head}: Change score = {score: .3f}')

most_changing = torch.tensor(change_scores).argmax().item()
least_changing = torch.tensor(change_scores).argmin().item()

print(f"\nMost changing head: {most_changing} (score: {change_scores[most_changing]:.3f})")
print(f"Least changing head: {least_changing} (score: {change_scores[least_changing]:.3f}")

In [None]:
# Visualize most changing head
layers_to_show = [1, 4, 7, 10]

fig, axes = plt.subplots(1, len(layers_to_show), figsize=(16,4))

for i, layer in enumerate(layers_to_show):
  pattern = cache[f"blocks.{layer}.attn.hook_pattern"][0, most_changing].detach().cpu().numpy()

  im = axes[i].imshow(pattern, cmap='Blues', vmin=0, vmax=1)
  axes[i].set_title(f'Head {most_changing}, Layer {layer}')
  axes[i].set_xticks(range(len(token_labels)))
  axes[i].set_yticks(range(len(token_labels)))
  axes[i].set_xticklabels(token_labels, rotation=45, fontsize=8)
  axes[i].set_yticklabels(token_labels, fontsize=8)

plt.tight_layout()
plt.show()

In [None]:
# Look at the most stable head
stable_head = least_changing
layers_to_show = [1, 4, 7, 10]

fig, axes = plt.subplots(1, len(layers_to_show), figsize=(16, 4))

for i, layer in enumerate(layers_to_show):
    pattern = cache[f"blocks.{layer}.attn.hook_pattern"][0, stable_head].detach().cpu().numpy()

    im = axes[i].imshow(pattern, cmap='Blues', vmin=0, vmax=1)
    axes[i].set_title(f'Head {stable_head}, Layer {layer}')
    axes[i].set_xticks(range(len(token_labels)))
    axes[i].set_yticks(range(len(token_labels)))
    axes[i].set_xticklabels(token_labels, rotation=45, fontsize=8)
    axes[i].set_yticklabels(token_labels, fontsize=8)

plt.suptitle(f'Evolution of Least Changing Head ({stable_head})')
plt.tight_layout()
plt.show()

In [None]:
# Checking to see if the first token becomes a repository for geographical/captial information
# Track how "The" token's representation evolves
the_position = 1  # "The" is at position 1

print("Evolution of 'The' token's representation:")
for layer in [0, 3, 6, 9, 11]:
    if layer == 0:
        # Before any processing
        resid = cache["hook_embed"][0, the_position] + cache["hook_pos_embed"][0, the_position]
    else:
        # After layer processing
        resid = cache[f"blocks.{layer-1}.hook_resid_post"][0, the_position]

    # Project to vocabulary to see what "The" is "thinking about"
    logits = resid @ model.W_U
    top_tokens = torch.topk(logits, 5)

    print(f"\nAfter layer {layer-1 if layer > 0 else 'embeddings'}:")
    print("Top concepts in 'The' representation:")
    for token_id, score in zip(top_tokens.indices, top_tokens.values):
        print(f"  {repr(model.to_string(token_id))}: {score:.2f}")

In [None]:
# Instead of vocabulary projection, let's see what information flows TO "The"
the_position = 1

print("What information flows TO 'The' token across layers:")
for layer in [1, 4, 7, 10]:
    # Look at how much each position attends TO "The" (column 1)
    patterns = cache[f"blocks.{layer}.attn.hook_pattern"][0]  # All heads

    # Average across all heads to see overall attention TO "The"
    avg_attention_to_the = patterns[:, :, the_position].mean(dim=0)

    print(f"\nLayer {layer} - Average attention TO 'The':")
    for pos, attn in enumerate(avg_attention_to_the):
        print(f"  From '{token_labels[pos]}': {attn:.3f}")

# Activation Patching

In [None]:
clean_prompt = 'The capital of France is'
corrupted_prompt = 'The capital of Spain is'

# Get activations
clean_tokens = model.to_tokens(clean_prompt, prepend_bos=False)
corrupted_tokens = model.to_tokens(corrupted_prompt, prepend_bos=False)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

# Predictions
print("Clean prediction:", model.to_string(clean_logits[0, -1].argmax()))
print("Corrupted prediction:", model.to_string(corrupted_logits[0, -1].argmax()))

# Target tokens
paris_token = model.to_single_token(" Paris")
madrid_token = model.to_single_token(" Madrid")

print(f"\nClean logit for Paris: {clean_logits[0, -1, paris_token]:.3f}")
print(f"Corrupted logit for Paris: {corrupted_logits[0, -1, paris_token]:.3f}")

In [None]:
# Replace activations from the corrupted run with activations from the clean run to see if it can fix the model behavior
def patch_residual_stream(corrupted_cache, clean_cache, layer, position=-1):
    """Patch the residual stream at a specific layer and position"""
    def patch_hook(activation, hook):
        # Replace the corrupted activation with the clean one
        activation[:, position, :] = clean_cache[hook.name][:, position, :]
        return activation

    # Run the corrupted input but with patched activation
    patched_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
    )

    return patched_logits

# Test patching at different layers
print("Patching residual stream at different layers:")
print(f"Original corrupted logit for Paris: {corrupted_logits[0, -1, paris_token]:.3f}")
print(f"Clean logit for Paris: {clean_logits[0, -1, paris_token]:.3f}")

for layer in [3, 6, 9, 11]:
    patched_logits = patch_residual_stream(corrupted_cache, clean_cache, layer)
    patched_paris_logit = patched_logits[0, -1, paris_token]

    print(f"Layer {layer} patched - Paris logit: {patched_paris_logit:.3f}")
    print(f"  Prediction: {model.to_string(patched_logits[0, -1].argmax())}")

In [None]:
print("Top 5 predictions for each condition:")

# Clean run
clean_top5 = torch.topk(clean_logits[0, -1], 5)
print("\nClean predictions:")
for i, (logit, token_id) in enumerate(zip(clean_top5.values, clean_top5.indices)):
    print(f"{i+1}. {repr(model.to_string(token_id))}: {logit:.3f}")

# Corrupted run
corrupted_top5 = torch.topk(corrupted_logits[0, -1], 5)
print("\nCorrupted predictions:")
for i, (logit, token_id) in enumerate(zip(corrupted_top5.values, corrupted_top5.indices)):
    print(f"{i+1}. {repr(model.to_string(token_id))}: {logit:.3f}")

# Patched run
patched_logits = patch_residual_stream(corrupted_cache, clean_cache, 9) # can update the layer to see how it changes over time
patched_top5 = torch.topk(patched_logits[0, -1], 5)
print("\nPatched predictions:")
for i, (logit, token_id) in enumerate(zip(patched_top5.values, patched_top5.indices)):
    print(f"{i+1}. {repr(model.to_string(token_id))}: {logit:.3f}")

# And check where "Paris" ranks in each
paris_rank_clean = (clean_logits[0, -1].argsort(descending=True) == paris_token).nonzero().item()
paris_rank_corrupted = (corrupted_logits[0, -1].argsort(descending=True) == paris_token).nonzero().item()
paris_rank_patched = (patched_logits[0, -1].argsort(descending=True) == paris_token).nonzero().item()
print(f"\nParis ranking - Clean: {paris_rank_clean+1}, Corrupted: {paris_rank_corrupted+1}, Patched: {paris_rank_patched+1}")

In [None]:
def measure_patching_effect(clean_cache, corrupted_cache, layer):
    """Measure how much patching affects Paris ranking/logit"""

    # Patch at this layer
    patched_logits = patch_residual_stream(corrupted_cache, clean_cache, layer)

    # Get Paris logit and ranking
    paris_logit = patched_logits[0, -1, paris_token]
    paris_rank = (patched_logits[0, -1].argsort(descending=True) == paris_token).nonzero().item() + 1

    return paris_logit, paris_rank

# Test across multiple layers
print("Patching effect across layers:")
print(f"Baseline (corrupted) - Paris rank: {paris_rank_corrupted+1}")
print(f"Target (clean) - Paris rank: {paris_rank_clean+1}")

for layer in range(0, 12, 1):  # Every other layer
    paris_logit, paris_rank = measure_patching_effect(clean_cache, corrupted_cache, layer)
    print(f"Layer {layer} patched - Paris rank: {paris_rank}, logit: {paris_logit:.3f}")

# Path Patching

In [None]:
# First testing the individal heads in Layer 8 which has the largest jump up in logit value

In [None]:
def patch_attention_head(layer, head, clean_cache, corrupted_cache):
    """Patch a specific attention head's output"""
    def head_patch_hook(activation, hook):
        # activation shape: [batch, seq_len, n_heads, d_head]
        # Replace just this head's output with the clean version
        activation[:, :, head, :] = clean_cache[hook.name][:, :, head, :]
        return activation

    # Patch at hook_z (before concatenation and W_O)
    patched_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks=[(f"blocks.{layer}.attn.hook_z", head_patch_hook)]
    )

    # Get Paris ranking
    paris_rank = (patched_logits[0, -1].argsort(descending=True) == paris_token).nonzero().item() + 1
    return paris_rank

# Test each head in Layer 8
print("Testing individual attention heads in Layer 8:")
print(f"Baseline (corrupted): Paris rank ~{paris_rank_corrupted+1}")
print(f"Full Layer 8 patch: Paris rank ~13")

for head in range(model.cfg.n_heads):
    paris_rank = patch_attention_head(8, head, clean_cache, corrupted_cache)
    print(f"Head {head}: Paris rank {paris_rank}")

In [None]:
# Look at Head 8.11's attention pattern for both clean and corrupted
layer, head = 8, 11

print("Head 8.11 attention patterns:")
print("\nClean (France) attention:")
clean_pattern = clean_cache[f"blocks.{layer}.attn.hook_pattern"][0, head]
for i, token in enumerate(clean_tokens[0]):
    token_str = model.to_string(token)
    attention_from_last = clean_pattern[-1, i]  # What does last token attend to?
    print(f"  '{token_str}': {attention_from_last:.3f}")

print("\nCorrupted (Spain) attention:")
corrupted_pattern = corrupted_cache[f"blocks.{layer}.attn.hook_pattern"][0, head]
for i, token in enumerate(corrupted_tokens[0]):
    token_str = model.to_string(token)
    attention_from_last = corrupted_pattern[-1, i]
    print(f"  '{token_str}': {attention_from_last:.3f}")

In [None]:
# Test if Head 8.11 is specifically a "capital city" head
# Let's see what it does with different country contexts

test_prompts = [
    "The capital of Germany is",
    "The capital of Italy is",
    "The capital of Japan is"
]

for prompt in test_prompts:
    tokens = model.to_tokens(prompt, prepend_bos=False)
    _, cache = model.run_with_cache(tokens)

    pattern = cache[f"blocks.8.attn.hook_pattern"][0, 11]  # Head 8.11
    print(f"\n'{prompt}':")
    for i, token in enumerate(tokens[0]):
        token_str = model.to_string(token)
        attention_from_last = pattern[-1, i]
        print(f"  '{token_str}': {attention_from_last:.3f}")

In [None]:
def patch_mlp_output(layer, clean_cache, corrupted_cache):
    """Patch MLP output for a specific layer"""
    def mlp_patch_hook(activation, hook):
        activation[:, :, :] = clean_cache[hook.name][:, :, :]
        return activation

    patched_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks=[(f"blocks.{layer}.hook_mlp_out", mlp_patch_hook)]
    )

    # Get Paris ranking
    paris_rank = (patched_logits[0, -1].argsort(descending=True) == paris_token).nonzero().item() + 1
    return paris_rank

# Test MLP layers in our critical range
print("Testing MLP layers:")
print(f"Baseline (corrupted): Paris rank ~{paris_rank_corrupted+1}")

for layer in [7, 8, 9, 10, 11]:
    paris_rank = patch_mlp_output(layer, clean_cache, corrupted_cache)
    print(f"MLP Layer {layer}: Paris rank {paris_rank}")

In [None]:
def patch_both_head_and_mlp(layer, head, clean_cache, corrupted_cache):
    """Patch both attention head and MLP for the same layer"""
    def head_patch_hook(activation, hook):
        activation[:, :, head, :] = clean_cache[hook.name][:, :, head, :]
        return activation

    def mlp_patch_hook(activation, hook):
        activation[:, :, :] = clean_cache[hook.name][:, :, :]
        return activation

    patched_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks=[
            (f"blocks.{layer}.attn.hook_z", head_patch_hook),
            (f"blocks.{layer}.hook_mlp_out", mlp_patch_hook)
        ]
    )

    paris_rank = (patched_logits[0, -1].argsort(descending=True) == paris_token).nonzero().item() + 1
    return paris_rank

# Test Layer 8 with both components
combined_rank = patch_both_head_and_mlp(8, 11, clean_cache, corrupted_cache)
print(f"Layer 8 - Head 8.11 + MLP together: Paris rank {combined_rank}")
# if they were doing the same job combining wouldn't help much but since this improves performance then they are doing complementary actions

In [None]:
# Test if Layer 9 or 10 also have attention+MLP cooperation
for test_layer in [9, 10]:
    print(f"\nTesting Layer {test_layer} combinations:")

    # Test MLP alone
    mlp_rank = patch_mlp_output(test_layer, clean_cache, corrupted_cache)
    print(f"MLP Layer {test_layer}: Paris rank {mlp_rank}")

    # Test a few promising attention heads + MLP
    for head in [0, 6, 11]:  # Try a few different heads
        combined_rank = patch_both_head_and_mlp(test_layer, head, clean_cache, corrupted_cache)
        if combined_rank < 50:  # Only print promising ones
            print(f"  Head {head} + MLP: Paris rank {combined_rank}")

In [None]:
# Try patching multiple layers and heads then to test for a Path
def patch_multi_layer_circuit(clean_cache, corrupted_cache):
    """Patch the full factual recall circuit across layers"""
    def head_8_patch(activation, hook):
        activation[:, :, 11, :] = clean_cache[hook.name][:, :, 11, :]
        return activation

    def mlp_8_patch(activation, hook):
        activation[:, :, :] = clean_cache[hook.name][:, :, :]
        return activation

    def head_10_patch(activation, hook):
        activation[:, :, 0, :] = clean_cache[hook.name][:, :, 0, :]
        return activation

    def mlp_10_patch(activation, hook):
        activation[:, :, :] = clean_cache[hook.name][:, :, :]
        return activation

    patched_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks=[
            ("blocks.8.attn.hook_z", head_8_patch),
            ("blocks.8.hook_mlp_out", mlp_8_patch),
            ("blocks.10.attn.hook_z", head_10_patch),
            ("blocks.10.hook_mlp_out", mlp_10_patch)
        ]
    )

    paris_rank = (patched_logits[0, -1].argsort(descending=True) == paris_token).nonzero().item() + 1
    return paris_rank

# Test the full circuit
full_circuit_rank = patch_multi_layer_circuit(clean_cache, corrupted_cache)
print(f"Full circuit (Layer 8: Head 11 + MLP, Layer 10: Head 0 + MLP): Paris rank {full_circuit_rank}")

In [None]:
# factual recall circuit above! not stored in one place - distributed computation, different layers have different roles

## Testing different types of facts

In [None]:
def patch_multi_layer_circuit_flexible(clean_cache, corrupted_cache, target_token_id):
    """Patch the circuit but handle different sequence lengths"""
    def head_8_patch(activation, hook):
        # Only patch the last position (where prediction happens)
        activation[:, -1, 11, :] = clean_cache[hook.name][:, -1, 11, :]
        return activation

    def mlp_8_patch(activation, hook):
        activation[:, -1, :] = clean_cache[hook.name][:, -1, :]
        return activation

    def head_10_patch(activation, hook):
        activation[:, -1, 0, :] = clean_cache[hook.name][:, -1, 0, :]
        return activation

    def mlp_10_patch(activation, hook):
        activation[:, -1, :] = clean_cache[hook.name][:, -1, :]
        return activation

    # Get the corrupted tokens for this specific test
    patched_logits = model.run_with_hooks(
        corrupted_tokens,  # Use the current corrupted tokens
        fwd_hooks=[
            ("blocks.8.attn.hook_z", head_8_patch),
            ("blocks.8.hook_mlp_out", mlp_8_patch),
            ("blocks.10.attn.hook_z", head_10_patch),
            ("blocks.10.hook_mlp_out", mlp_10_patch)
        ]
    )

    # Get ranking for the target token
    target_rank = (patched_logits[0, -1].argsort(descending=True) == target_token_id).nonzero().item() + 1
    return target_rank

In [None]:
# Test different types of facts
fact_tests = [
    # Geographical (controls)
    ("The capital of Germany is", "The capital of Italy is", " Berlin"),

    # Person-profession relationships
    ("The author of Harry Potter is", "The author of Lord of the Rings is", " Rowling"),

    # Company-founder relationships
    ("The founder of Microsoft is", "The founder of Apple is", " Gates"),

    # Sports relationships
    ("The winner of the 2020 Olympics marathon was", "The winner of the 2016 Olympics marathon was", " Kipchoge"),

    # Scientific relationships
    ("The discoverer of penicillin was", "The discoverer of DNA structure was", " Fleming")
]

def patch_multi_layer_circuit_flexible(clean_cache, corrupted_cache, corrupted_tokens):
    """Patch the circuit and return the patched logits"""
    def head_8_patch(activation, hook):
        activation[:, -1, 11, :] = clean_cache[hook.name][:, -1, 11, :]
        return activation

    def mlp_8_patch(activation, hook):
        activation[:, -1, :] = clean_cache[hook.name][:, -1, :]
        return activation

    def head_10_patch(activation, hook):
        activation[:, -1, 0, :] = clean_cache[hook.name][:, -1, 0, :]
        return activation

    def mlp_10_patch(activation, hook):
        activation[:, -1, :] = clean_cache[hook.name][:, -1, :]
        return activation

    patched_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks=[
            ("blocks.8.attn.hook_z", head_8_patch),
            ("blocks.8.hook_mlp_out", mlp_8_patch),
            ("blocks.10.attn.hook_z", head_10_patch),
            ("blocks.10.hook_mlp_out", mlp_10_patch)
        ]
    )

    return patched_logits

# Then in the test loop:
for clean_prompt, corrupted_prompt, expected_answer in fact_tests:
    print(f"\n=== Testing: {clean_prompt} ===")

    clean_tokens = model.to_tokens(clean_prompt, prepend_bos=False)
    corrupted_tokens = model.to_tokens(corrupted_prompt, prepend_bos=False)

    clean_logits, clean_cache = model.run_with_cache(clean_tokens)
    corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

    clean_pred = model.to_string(clean_logits[0, -1].argmax())
    corrupted_pred = model.to_string(corrupted_logits[0, -1].argmax())
    patched_logits = patch_multi_layer_circuit_flexible(clean_cache, corrupted_cache, corrupted_tokens)
    patched_pred = model.to_string(patched_logits[0, -1].argmax())

    print(f"Expected: {expected_answer}")
    print(f"Clean prediction: {clean_pred}")
    print(f"Corrupted prediction: {corrupted_pred}")
    print(f"Circuit patched prediction: {patched_pred}")
    print(f"Circuit helped: {patched_pred == clean_pred}")