<a href="https://colab.research.google.com/github/mahadikprasad15/ARENA/blob/main/Pythia-160M%20Induction%20Circuits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install transformer_lens

In [None]:
import torch
import transformer_lens
import plotly.express as px
from transformer_lens import utils
import tqdm
from functools import partial

In [None]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
model = transformer_lens.HookedTransformer.from_pretrained('pythia-160m')
model.eval()

In [None]:
model.to_str_tokens('This is a test example, to see the tokenization')

In [None]:
print(f'Number of layers: {model.cfg.n_layers}')
print(f'Number of heads: {model.cfg.n_heads}')
print(f'Model residual stream dimension: {model.cfg.d_model}')
print(f'Model vocab-size: {model.cfg.d_vocab}')
print(f'Dimension of the heads: {model.cfg.d_head}')

# Generating induction prompts

First, I need to create a function to generate the induction prompts - this will have batch, seq_length - that's it, and then it will also have BOS added in the beginning.
This should generate the induction prompts of any size, and batch, to test on.


In [None]:
def generate_induction_prompts(batch = 1, seq_length = 20):
  tokens = torch.randint(1, model.cfg.d_vocab, (batch, seq_length), dtype = torch.long)
  BOS = torch.zeros(batch,).unsqueeze(-1).to(torch.long)
  prompt_tokens = torch.cat([BOS, tokens], dim = -1)
  prompt_tokens = torch.cat([prompt_tokens, tokens], dim = -1)
  return prompt_tokens

Now that we have the function, I need to have a first look at the attentino patterns of all heads, and see which one shows distinctive induction pattern.
I also have to look for previous token heads.

In [None]:
layers = model.cfg.n_layers
heads = model.cfg.n_heads
prompt_tokens = generate_induction_prompts(5,20)
logits, cache = model.run_with_cache(prompt_tokens)


for layer in range(layers):
    attention_pattern = cache[utils.get_act_name('pattern', layer)].cpu().numpy()
    fig = px.imshow(attention_pattern.mean(axis=0),
                    facet_col=0,
                    title=f'Attention Patterns: Layer {layer}',
                    labels={'x': 'Key Token', 'y': 'Query Token'},
                    color_continuous_scale='rdpu',
                    width=3000,
                    height=2500
                   )
    fig.show()

To get clearly strong heads, function that calculates these metrics across layers and brings the largest ones.



In [None]:
def induction_score(attention_pattern, prompt_tokens):
  seq_len = prompt_tokens.size(1) // 2
  offset = -(seq_len)

  return attention_pattern.diagonal(offset = offset).mean()



scores = torch.zeros(layers, heads, device = device)


for layer in range(layers):
  attention_pattern_layer = cache[utils.get_act_name('pattern', layer)].mean(dim=0)
  for head in range(heads):

    attention_pattern_head = attention_pattern_layer[head, :, :]
    score = induction_score(attention_pattern_head, prompt_tokens=prompt_tokens)
    scores[layer, head] = score

In [None]:
fig = px.imshow(scores.cpu().numpy(),
          title = 'Induction scores for all Layers and Heads',
          labels= {'x':'Head', 'y': 'Layer'},
          color_continuous_scale = 'rdpu',
          text_auto = '.2f',
          height = 700,
          width = 700,
            )

fig.show()

## Getting top previous token heads

In [None]:
def previous_token_score(attention_pattern, prompt_tokens):
  seq_len = prompt_tokens.size(1) // 2
  offset = -1

  return attention_pattern.diagonal(offset = offset).mean()


scores = torch.zeros(layers, heads, device = device)

for layer in range(layers):
  attention_pattern_layer = cache[utils.get_act_name('pattern', layer)].mean(dim=0)
  for head in range(heads):

    attention_pattern_head = attention_pattern_layer[head, :, :]
    score = previous_token_score(attention_pattern_head, prompt_tokens=prompt_tokens)
    scores[layer, head] = score

In [None]:
fig = px.imshow(scores.cpu().numpy(),
          title = 'Previous token scores for all Layers and Heads',
          labels= {'x':'Head', 'y': 'Layer'},
          color_continuous_scale = 'rdpu',
          text_auto = '.2f',
          height = 700,
          width = 700,
            )

fig.show()

In [None]:
accumulated_resid = cache.accumulated_resid(layer = -1, apply_ln = True)

Now I have to use accumulated_resid and other functions to get the resids for layers, heads and so on - and then apply layer norm on it, and then W_U on it - to get the logits for each layer, head etc.

But this will be

In [None]:
accumulated_resid_second_half = (accumulated_resid @ model.W_U)[:, :, prompt_tokens.size(1)//2:-1].to(device)
target_token_indices = prompt_tokens[:, (prompt_tokens.size(1)//2)+1:].to(device)



target_token_indices_reshaped = target_token_indices.unsqueeze(0).unsqueeze(-1)

target_token_indices_reshaped = target_token_indices_reshaped.repeat(accumulated_resid_second_half.size(0), 1, 1, 1)



logits_for_target_tokens = accumulated_resid_second_half.gather(dim = -1, index = target_token_indices_reshaped)


logits_for_target_tokens = logits_for_target_tokens.squeeze(-1)

logits_per_layer = logits_for_target_tokens.mean(dim = (-1, -2))

In [None]:
import pandas as pd

df = pd.DataFrame({
    'Layer': range(len(logits_per_layer)),
    'Average Logit': logits_per_layer.detach().cpu().numpy()
})

# Create the line plot
fig = px.line(df,
              x='Layer',
              y='Average Logit',
              title='Overall Average Logit for Target Tokens Across Layers')

fig.show()

In [None]:
per_head_resid = cache.stack_head_results( layer = -1, apply_ln=True)

In [None]:
per_head_resid_second_half = (per_head_resid @ model.W_U)[:, :, prompt_tokens.size(1)//2:-1].to(device)
target_token_indices = prompt_tokens[:, (prompt_tokens.size(1)//2)+1:].to(device)



target_token_indices_reshaped = target_token_indices.unsqueeze(0).unsqueeze(-1)

target_token_indices_reshaped = target_token_indices_reshaped.repeat(per_head_resid_second_half.size(0), 1, 1, 1)



logits_for_target_tokens = per_head_resid_second_half.gather(dim = -1, index = target_token_indices_reshaped)


logits_for_target_tokens = logits_for_target_tokens.squeeze(-1)



In [None]:
logits_per_head = logits_for_target_tokens.mean(dim = (-1,-2)).reshape(-1, heads)

In [None]:
fig = px.imshow(logits_per_head.cpu().detach().numpy(),
          title = 'Overall Average Logit for Target Tokens Across Layers and Heads',
          labels= {'x':'Head', 'y': 'Layer'},
          color_continuous_scale = 'dense',
          text_auto = '.2f',
          height = 700,
          width = 700,)

fig.show()

In [None]:
between_layers_logits = cache.decompose_resid(layer = -1, apply_ln = True, mode = 'all')
unembedded_output = ((between_layers_logits [:, :, prompt_tokens.size(1)//2:-1]) @ model.W_U )
index = prompt_tokens [:, prompt_tokens.size(1)//2 :-1].unsqueeze(0).repeat(unembedded_output.size(0), 1,1).unsqueeze(-1).to(device)
layer_outputs = unembedded_output.gather(dim = -1, index = index)
layer_outputs.mean(dim = (1,2,3))

In [None]:
from transformer_lens import patching

In [None]:
def corrupt_induction_prompt(clean_tokens):
  corrupt = clean_tokens.clone()
  corrupt[:, 1:(corrupt.size(1)//2) + 1] = torch.randint(1, model.cfg.d_vocab, (corrupt.size(0), corrupt.size(1)//2))

  return corrupt


clean_tokens = prompt_tokens.to(device)
corrupt_tokens = corrupt_induction_prompt(prompt_tokens).to(device)

In [None]:
clean_logits = model(clean_tokens)
corrupt_logits = model(corrupt_tokens)

In [None]:
def calculate_correct_logits(logits, tokens):
  return torch.gather(logits[: , (logits.size(1)//2 ):-1 , :], dim = -1, index = tokens[: , 1:(tokens.size(1)//2)+1].unsqueeze(-1)).mean().item()


clean_score = calculate_correct_logits(clean_logits, clean_tokens)
corrupt_score = calculate_correct_logits(corrupt_logits, clean_tokens)

In [None]:
def patching_metric(ablated_logits):
  ablated_score = calculate_correct_logits(ablated_logits, clean_tokens)
  # Ensure the result is a tensor so .item() can be called on it
  return torch.tensor((ablated_score - corrupt_score) / (clean_score - corrupt_score))

In [None]:
def zero_ablation(tensor, hook, head):
  target = tensor.clone()
  target[: , :, head, :] = 0
  return target

def mean_ablation(tensor, hook, head):
  target = tensor.clone()
  target[: , :, head, :] = target.mean()
  return target

In [None]:
# Zero Ablation on heads

results = torch.zeros(layers, heads, device = device)

for layer in tqdm.tqdm(range(layers)):
  for head in (range(heads)):

    hook_function = partial(zero_ablation, head = head)

    ablated_logits = model.run_with_hooks(clean_tokens, fwd_hooks = [(utils.get_act_name('z', layer), hook_function)])

    result = 1 - patching_metric(ablated_logits)

    results[layer, head] = result



In [None]:
fig = px.imshow(results.cpu().detach().numpy(),
                title = 'Results of ablating z for each head',
                labels = {'x': 'Heads', 'y': 'Layers'}, # Corrected labels based on tensor dimensions
                color_continuous_scale = 'pubu',
                height = 700,
                width = 700,
                text_auto = '.2f'
          )

fig.show()

In [None]:
# Zero Ablation on heads

results = torch.zeros(layers, heads, device = device)

for layer in tqdm.tqdm(range(layers)):
  for head in (range(heads)):

    hook_function = partial(mean_ablation, head = head)

    ablated_logits = model.run_with_hooks(clean_tokens, fwd_hooks = [(utils.get_act_name('z', layer), hook_function)])

    result = 1- patching_metric(ablated_logits)

    results[layer, head] = result



In [None]:
fig = px.imshow(results.cpu().detach().numpy(),
                title = 'Results of mean ablating z for each head',
                labels = {'x': 'Heads', 'y': 'Layers'}, # Corrected labels based on tensor dimensions
                color_continuous_scale = 'dense',
                aspect = 'auto',
                text_auto = '.2f'
          )

fig.show()

In [None]:
clean_logits, clean_cache = model.run_with_cache(clean_tokens)

In [None]:
resid_pre_patching = patching.get_act_patch_resid_pre(model = model, corrupted_tokens= corrupt_tokens, clean_cache = clean_cache, patching_metric = patching_metric)

In [None]:
fig = px.imshow(resid_pre_patching.cpu().detach().numpy(),
                title = 'Patching Residual Stream Before Block (resid_pre)',
                labels = {'x':'Token Position' , 'y': 'Layer' },
                color_continuous_scale = 'RdBu',
                color_continuous_midpoint=0,
                aspect="auto"
               )

fig.update_layout(coloraxis_colorbar_title="Patching Metric")

fig.show()

In [None]:
heads_patching = patching.get_act_patch_attn_head_all_pos_every(model = model, corrupted_tokens= corrupt_tokens, clean_cache = clean_cache, metric = patching_metric) # Corrected keyword argument name to 'metric'

In [None]:

facet_labels = ['Output', 'Query', 'Key', 'Value', 'Pattern']

fig = px.imshow(heads_patching.cpu().detach().numpy(),
                title = 'Patching Heads',
                labels = {'x':'Heads' , 'y': 'Layer' },
                color_continuous_scale = 'RdBu',
                color_continuous_midpoint=0,
                aspect="auto",
                facet_col= 0,
                facet_col_wrap=3,
               )


for i, label in enumerate(facet_labels):

    annotation_name = f'annotations[{i}]'
    fig.layout[annotation_name]['text'] = label


fig.update_layout(coloraxis_colorbar_title="Patching Metric")

fig.show()

In [None]:
# Memory cleanup cell - run this first!
import gc
import torch

# Clear CUDA cache
torch.cuda.empty_cache()

# Force garbage collection
gc.collect()

# Delete any existing large variables
try:
    del cache, clean_cache, corrupt_cache, patched_cache
    del clean_logits, corrupt_logits, patched_logits
    del clean_attn, patched_attn
except:
    pass

torch.cuda.empty_cache()
gc.collect()

# Check available memory
print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

In [None]:
def test_path_patching_3_to_5_efficient():
    """
    Memory-efficient version - cleans up as it goes
    """

    # Use smaller batch size
    clean_tokens = generate_induction_prompts(batch=4, seq_length=20).to(device)  # Reduced from 8
    corrupt_tokens = corrupt_induction_prompt(clean_tokens).to(device)

    print("Running clean forward pass...")
    clean_logits, clean_cache = model.run_with_cache(clean_tokens)
    clean_score = calculate_correct_logits(clean_logits, clean_tokens)

    # Extract only what we need from clean cache
    clean_attn_5_0 = clean_cache['pattern', 5][:, 0, :, :].clone().cpu()  # Move to CPU immediately

    # Clear clean cache from GPU
    del clean_cache, clean_logits
    torch.cuda.empty_cache()

    print("Running corrupt forward pass...")
    corrupt_logits, corrupt_cache = model.run_with_cache(corrupt_tokens)
    corrupt_score = calculate_correct_logits(corrupt_logits, clean_tokens)

    # Extract what we need for patching
    corrupt_z_3_3 = corrupt_cache['z', 3][:, :, 3, :].clone()  # Only head 3

    # Clear corrupt cache from GPU
    del corrupt_cache, corrupt_logits
    torch.cuda.empty_cache()

    print("Running patched forward pass...")

    # Define intervention
    seq_len_total = clean_tokens.size(1)
    first_half_end = seq_len_total // 2 + 1

    def corrupt_head_3_3_output(activation, hook):
        activation[:, 1:first_half_end, 3, :] = corrupt_z_3_3[:, 1:first_half_end, :]
        return activation

    # CORRECTED: Add hook first, then run_with_cache
    model.reset_hooks()
    hook_point = utils.get_act_name('z', 3)
    model.add_hook(hook_point, corrupt_head_3_3_output)

    # Now run with cache (hook is already added)
    patched_logits, patched_cache = model.run_with_cache(
        clean_tokens,
        names_filter=lambda name: name == utils.get_act_name('pattern', 5)
    )

    patched_score = calculate_correct_logits(patched_logits, clean_tokens)

    # Extract attention and move to CPU
    patched_attn_5_0 = patched_cache['pattern', 5][:, 0, :, :].clone().cpu()

    # Clear everything from GPU
    model.reset_hooks()  # Remove the hook
    del patched_cache, patched_logits, corrupt_z_3_3
    torch.cuda.empty_cache()

    # Calculate metrics on CPU tensors
    seq_len = clean_tokens.size(1) // 2
    clean_attn_to_sources = clean_attn_5_0[:, seq_len+1:, 1:seq_len+1].mean()
    patched_attn_to_sources = patched_attn_5_0[:, seq_len+1:, 1:seq_len+1].mean()

    print(f"\n=== RESULTS ===")
    print(f"Clean attention to sources: {clean_attn_to_sources:.3f}")
    print(f"Patched attention to sources: {patched_attn_to_sources:.3f}")
    print(f"Attention drop: {(clean_attn_to_sources - patched_attn_to_sources):.3f}")
    print(f"Relative drop: {((clean_attn_to_sources - patched_attn_to_sources) / clean_attn_to_sources * 100):.1f}%")

    patching_metric_value = (patched_score - corrupt_score) / (clean_score - corrupt_score)

    print(f"\nPerformance scores:")
    print(f"Clean score: {clean_score:.3f}")
    print(f"Corrupt score: {corrupt_score:.3f}")
    print(f"Patched score: {patched_score:.3f}")
    print(f"Patching metric: {patching_metric_value:.3f}")
    print(f"(< 0.5 suggests strong circuit dependency)")

    return clean_attn_5_0, patched_attn_5_0

# Clear memory first
import gc
torch.cuda.empty_cache()
gc.collect()

# Run the efficient version
clean_attn, patched_attn = test_path_patching_3_to_5_efficient()

# Visualize (attention tensors are now on CPU, small size)
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=1, cols=3,
                    subplot_titles=('Clean Attention (5.0)',
                                   'Patched Attention (5.0)',
                                   'Difference'))

fig.add_trace(
    go.Heatmap(z=clean_attn[0].numpy(), colorscale='Blues'),
    row=1, col=1
)

fig.add_trace(
    go.Heatmap(z=patched_attn[0].numpy(), colorscale='Blues'),
    row=1, col=2
)

fig.add_trace(
    go.Heatmap(z=(clean_attn[0] - patched_attn[0]).numpy(),
               colorscale='RdBu', zmid=0),
    row=1, col=3
)

fig.update_layout(height=400, width=1400,
                  title_text="Effect of Corrupting 3.3 on 5.0's Attention")

# Add midpoint lines
seq_len = clean_tokens.size(1) // 2
for col in [1, 2, 3]:
    fig.add_vline(x=seq_len, line_dash="dash", line_color="red",
                  row=1, col=col)
    fig.add_hline(y=seq_len, line_dash="dash", line_color="red",
                  row=1, col=col)

fig.show()

# Clean up after visualization
del clean_attn, patched_attn
torch.cuda.empty_cache()

In [None]:
def analyze_head_attention(layer, head, head_name=""):
    """
    Analyze where a specific head attends
    """
    print(f"\nAnalyzing attention pattern for Layer {layer}, Head {head} ({head_name})")
    print("="*60)

    # Generate fresh tokens
    clean_tokens = generate_induction_prompts(batch=8, seq_length=20).to(device)

    # Get attention pattern
    _, cache = model.run_with_cache(clean_tokens)

    # Extract attention for this specific head
    # Shape: [batch, heads, query_pos, key_pos]
    attention_pattern = cache['pattern', layer]
    head_attention = attention_pattern[:, head, :, :]  # [batch, query_pos, key_pos]

    # Average across batch
    avg_attention = head_attention.mean(dim=0)  # [query_pos, key_pos]

    seq_len = clean_tokens.size(1) // 2

    # Analyze different attention patterns
    print(f"\nSequence structure: [{seq_len+1} tokens] [{seq_len} tokens]")
    print(f"  First half: positions 0 (BOS) to {seq_len}")
    print(f"  Second half: positions {seq_len+1} to {seq_len*2}")

    # For queries in the second half
    second_half_queries = avg_attention[seq_len+1:, :]  # [seq_len, all_positions]

    # Where do second-half queries attend?
    attn_to_BOS = second_half_queries[:, 0].mean().item()
    attn_to_first_half = second_half_queries[:, 1:seq_len+1].mean().item()
    attn_to_second_half = second_half_queries[:, seq_len+1:].mean().item()

    print(f"\n--- Attention Distribution (from 2nd half queries) ---")
    print(f"To BOS token:       {attn_to_BOS:.3f} ({attn_to_BOS*100:.1f}%)")
    print(f"To first half:      {attn_to_first_half:.3f} ({attn_to_first_half*100:.1f}%)")
    print(f"To second half:     {attn_to_second_half:.3f} ({attn_to_second_half*100:.1f}%)")

    # Check induction diagonal pattern
    # For position t in second half, does it attend to position (t - seq_len)?
    diagonal_attn = []
    for t in range(seq_len+1, avg_attention.shape[0]):
        source_pos = t - seq_len  # Matching position in first half
        if 1 <= source_pos <= seq_len:
            diagonal_attn.append(avg_attention[t, source_pos].item())

    if diagonal_attn:
        avg_diagonal = sum(diagonal_attn) / len(diagonal_attn)
        max_diagonal = max(diagonal_attn)
        print(f"\n--- Induction Pattern (diagonal matching) ---")
        print(f"Average attention to matching positions: {avg_diagonal:.3f} ({avg_diagonal*100:.1f}%)")
        print(f"Max attention on diagonal:               {max_diagonal:.3f} ({max_diagonal*100:.1f}%)")

        if avg_diagonal > 0.3:
            print("  ✓ STRONG induction pattern!")
        elif avg_diagonal > 0.15:
            print("  ~ MODERATE induction pattern")
        else:
            print("  ✗ WEAK induction pattern")

    # Check previous token pattern (diagonal-1)
    prev_token_attn = []
    for t in range(1, avg_attention.shape[0]):
        prev_token_attn.append(avg_attention[t, t-1].item())

    avg_prev = sum(prev_token_attn) / len(prev_token_attn)
    print(f"\n--- Previous Token Pattern ---")
    print(f"Average attention to previous position: {avg_prev:.3f} ({avg_prev*100:.1f}%)")

    if avg_prev > 0.5:
        print("  ✓ STRONG previous token head!")
    elif avg_prev > 0.3:
        print("  ~ MODERATE previous token head")
    else:
        print("  ✗ NOT a previous token head")

    # Visualize
    fig = px.imshow(
        avg_attention.cpu().numpy(),
        title=f"Attention Pattern: Layer {layer} Head {head} ({head_name})",
        labels={'x': 'Key Position', 'y': 'Query Position'},
        color_continuous_scale='Blues',
        aspect='auto'
    )

    # Add midpoint line
    fig.add_vline(x=seq_len, line_dash="dash", line_color="red",
                  annotation_text="Midpoint")
    fig.add_hline(y=seq_len, line_dash="dash", line_color="red")

    fig.show()

    # Cleanup
    del cache, attention_pattern, head_attention
    torch.cuda.empty_cache()

    return avg_attention.cpu()

# Analyze the heads we're interested in
print("ANALYZING HEAD 5.0 (suspected induction head)")
attn_5_0 = analyze_head_attention(5, 0, "5.0")

print("\n" + "="*60)
print("ANALYZING HEAD 8.2 (high scorer)")
attn_8_2 = analyze_head_attention(8, 2, "8.2")

print("\n" + "="*60)
print("ANALYZING HEAD 3.3 (prev token head)")
attn_3_3 = analyze_head_attention(3, 3, "3.3")

print("\n" + "="*60)
print("ANALYZING HEAD 4.11 (also high scorer)")
attn_4_11 = analyze_head_attention(4, 11, "4.11")

In [None]:
# Let's re-examine what the induction score actually measured
def debug_induction_score():
    """
    Understand what our induction score was actually capturing
    """

    prompt_tokens = generate_induction_prompts(5, 20).to(device)
    _, cache = model.run_with_cache(prompt_tokens)

    print("Debugging induction score metric...")
    print("="*60)

    # Test on a few heads
    test_heads = [(3, 3), (5, 0), (8, 2)]

    for layer, head in test_heads:
        attention_pattern = cache['pattern', layer][:, head, :, :].mean(dim=0)

        # Our original induction score
        seq_len = prompt_tokens.size(1) // 2
        offset = -(seq_len - 1)
        original_score = attention_pattern.diagonal(offset=offset).sum().item()

        # What we SHOULD have measured (per-position average)
        diagonal_values = []
        for t in range(seq_len+1, attention_pattern.shape[0]):
            source_pos = t - seq_len
            if 1 <= source_pos <= seq_len:
                diagonal_values.append(attention_pattern[t, source_pos].item())

        correct_score = sum(diagonal_values) / len(diagonal_values) if diagonal_values else 0

        print(f"\nLayer {layer} Head {head}:")
        print(f"  Original score (sum):     {original_score:.3f}")
        print(f"  Correct score (average):  {correct_score:.3f} ({correct_score*100:.1f}%)")
        print(f"  Ratio: {original_score / (correct_score if correct_score > 0 else 1):.1f}x")

    del cache
    torch.cuda.empty_cache()

debug_induction_score()

In [None]:
def find_prev_token_heads():
    """
    Search early layers for the real previous token head
    """

    print("\nSearching for Previous Token Heads in Early Layers")
    print("="*60)

    clean_tokens = generate_induction_prompts(batch=8, seq_length=20).to(device)
    _, cache = model.run_with_cache(clean_tokens)

    results = []

    # Check all heads in layers 0-3
    for layer in range(4):
        for head in range(model.cfg.n_heads):
            attention_pattern = cache['pattern', layer][:, head, :, :].mean(dim=0)

            # Calculate previous token attention
            prev_token_scores = []
            for t in range(1, attention_pattern.shape[0]):
                prev_token_scores.append(attention_pattern[t, t-1].item())

            avg_prev = sum(prev_token_scores) / len(prev_token_scores)
            results.append((layer, head, avg_prev))

    # Sort by score
    results.sort(key=lambda x: x[2], reverse=True)

    print("\nTop 10 Previous Token Heads:")
    print(f"{'Layer':<8}{'Head':<8}{'Prev Token %':<15}{'Status'}")
    print("-"*50)

    for layer, head, score in results[:10]:
        status = "✓ STRONG" if score > 0.5 else ("~ MODERATE" if score > 0.3 else "")
        print(f"{layer:<8}{head:<8}{score*100:.1f}%{'':<10}{status}")

    del cache
    torch.cuda.empty_cache()

    return results

prev_token_results = find_prev_token_heads()

In [None]:
def analyze_BOS_heads():
    """
    What do these BOS-attending heads actually do?
    """

    print("\nAnalyzing BOS-Attending Heads (5.0, 8.2, 4.11)")
    print("="*60)

    clean_tokens = generate_induction_prompts(batch=8, seq_length=20).to(device)
    _, cache = model.run_with_cache(clean_tokens)

    bos_heads = [(5, 0), (8, 2), (4, 11)]

    for layer, head in bos_heads:
        print(f"\n--- Layer {layer} Head {head} ---")

        # Get head output
        z = cache['z', layer][:, :, head, :]  # [batch, seq, d_head]
        head_output = z @ model.W_O[layer, head, :, :]  # [batch, seq, d_model]

        # What does it add to logits?
        logit_contrib = head_output @ model.W_U  # [batch, seq, vocab]

        # Focus on second half positions
        seq_len = clean_tokens.size(1) // 2
        second_half_logits = logit_contrib[:, seq_len+1:, :]

        # For each position in second half, what token does it boost?
        target_tokens = clean_tokens[:, seq_len+1:]  # The correct tokens

        # Get logit values for correct tokens
        correct_token_logits = []
        for b in range(clean_tokens.shape[0]):
            for t in range(target_tokens.shape[1]):
                token_id = target_tokens[b, t].item()
                logit_val = second_half_logits[b, t, token_id].item()
                correct_token_logits.append(logit_val)

        avg_correct_logit = sum(correct_token_logits) / len(correct_token_logits)

        print(f"  Average logit boost to CORRECT token: {avg_correct_logit:.3f}")

        # What's the top token it boosts overall?
        top_logit_boost = second_half_logits.mean(dim=(0,1)).topk(5)
        print(f"  Top 5 tokens it boosts:")
        for val, idx in zip(top_logit_boost.values, top_logit_boost.indices):
            token = model.to_string(idx.item())
            print(f"    {token:20s} (+{val:.2f})")

    del cache
    torch.cuda.empty_cache()

analyze_BOS_heads()


In [None]:
def test_circuit_3_2_to_3_3():
    """
    Test if 3.2 (prev token) feeds into 3.3 (induction)
    BOTH in Layer 3 - same layer communication via residual stream!
    """

    print("Testing 3.2 → 3.3 circuit (same layer)")
    print("="*60)

    clean_tokens = generate_induction_prompts(batch=4, seq_length=20).to(device)
    corrupt_tokens = corrupt_induction_prompt(clean_tokens).to(device)

    # Clean baseline
    clean_logits, clean_cache = model.run_with_cache(clean_tokens)
    clean_score = calculate_correct_logits(clean_logits, clean_tokens)
    clean_attn_3_3 = clean_cache['pattern', 3][:, 3, :, :].clone().cpu()

    del clean_cache, clean_logits
    torch.cuda.empty_cache()

    # Corrupt baseline
    corrupt_logits, corrupt_cache = model.run_with_cache(corrupt_tokens)
    corrupt_score = calculate_correct_logits(corrupt_logits, clean_tokens)

    # Save corrupted 3.2 output
    corrupt_z_3_2 = corrupt_cache['z', 3][:, :, 2, :].clone()

    del corrupt_cache, corrupt_logits
    torch.cuda.empty_cache()

    # Patch: Replace 3.2's output with corrupted version
    seq_len = clean_tokens.size(1) // 2

    def corrupt_3_2_output(activation, hook):
        # Patch first half where source tokens are
        activation[:, 1:seq_len+1, 2, :] = corrupt_z_3_2[:, 1:seq_len+1, :]
        return activation

    model.reset_hooks()
    model.add_hook(utils.get_act_name('z', 3), corrupt_3_2_output)

    patched_logits, patched_cache = model.run_with_cache(
        clean_tokens,
        names_filter=lambda name: name == utils.get_act_name('pattern', 3)
    )

    patched_score = calculate_correct_logits(patched_logits, clean_tokens)
    patched_attn_3_3 = patched_cache['pattern', 3][:, 3, :, :].clone().cpu()

    model.reset_hooks()
    del patched_cache, patched_logits, corrupt_z_3_2
    torch.cuda.empty_cache()

    # Analyze 3.3's attention
    clean_diagonal = []
    patched_diagonal = []
    for t in range(seq_len+1, clean_attn_3_3.shape[1]):
        source_pos = t - seq_len
        if 1 <= source_pos <= seq_len:
            clean_diagonal.append(clean_attn_3_3[0, t, source_pos].item())
            patched_diagonal.append(patched_attn_3_3[0, t, source_pos].item())

    clean_diagonal_avg = sum(clean_diagonal) / len(clean_diagonal)
    patched_diagonal_avg = sum(patched_diagonal) / len(patched_diagonal)

    print(f"\n=== RESULTS ===")
    print(f"3.3's diagonal attention (clean):   {clean_diagonal_avg:.3f} ({clean_diagonal_avg*100:.1f}%)")
    print(f"3.3's diagonal attention (patched): {patched_diagonal_avg:.3f} ({patched_diagonal_avg*100:.1f}%)")
    print(f"Attention drop: {(clean_diagonal_avg - patched_diagonal_avg):.3f} ({((clean_diagonal_avg - patched_diagonal_avg)/clean_diagonal_avg*100):.1f}%)")

    patching_metric = (patched_score - corrupt_score) / (clean_score - corrupt_score)

    print(f"\nPerformance:")
    print(f"Clean:   {clean_score:.3f}")
    print(f"Corrupt: {corrupt_score:.3f}")
    print(f"Patched: {patched_score:.3f}")
    print(f"Patching metric: {patching_metric:.3f}")

    if patching_metric < 0.5:
        print("✓ STRONG CIRCUIT DEPENDENCY!")
    elif patching_metric < 0.75:
        print("~ MODERATE circuit dependency")
    else:
        print("✗ WEAK/NO circuit dependency")

    return clean_attn_3_3, patched_attn_3_3

# Test it!
torch.cuda.empty_cache()
test_circuit_3_2_to_3_3()

In [None]:
def test_circuit_3_0_to_3_3():
    """Test the other strong prev-token head in same layer"""

    print("\nTesting 3.0 → 3.3 circuit (same layer)")
    print("="*60)

    clean_tokens = generate_induction_prompts(batch=4, seq_length=20).to(device)
    corrupt_tokens = corrupt_induction_prompt(clean_tokens).to(device)

    clean_logits, clean_cache = model.run_with_cache(clean_tokens)
    clean_score = calculate_correct_logits(clean_logits, clean_tokens)

    del clean_cache, clean_logits
    torch.cuda.empty_cache()

    corrupt_logits, corrupt_cache = model.run_with_cache(corrupt_tokens)
    corrupt_score = calculate_correct_logits(corrupt_logits, clean_tokens)
    corrupt_z_3_0 = corrupt_cache['z', 3][:, :, 0, :].clone()

    del corrupt_cache, corrupt_logits
    torch.cuda.empty_cache()

    seq_len = clean_tokens.size(1) // 2

    def corrupt_3_0_output(activation, hook):
        activation[:, 1:seq_len+1, 0, :] = corrupt_z_3_0[:, 1:seq_len+1, :]
        return activation

    model.reset_hooks()
    model.add_hook(utils.get_act_name('z', 3), corrupt_3_0_output)

    patched_logits = model(clean_tokens)
    patched_score = calculate_correct_logits(patched_logits, clean_tokens)

    model.reset_hooks()
    del patched_logits, corrupt_z_3_0
    torch.cuda.empty_cache()

    patching_metric = (patched_score - corrupt_score) / (clean_score - corrupt_score)

    print(f"Clean:   {clean_score:.3f}")
    print(f"Corrupt: {corrupt_score:.3f}")
    print(f"Patched: {patched_score:.3f}")
    print(f"Patching metric: {patching_metric:.3f}")

    if patching_metric < 0.5:
        print("✓ STRONG CIRCUIT!")
    elif patching_metric < 0.75:
        print("~ MODERATE circuit")
    else:
        print("✗ WEAK circuit")

torch.cuda.empty_cache()
test_circuit_3_0_to_3_3()

In [None]:
def test_circuit_1_2_to_3_3():
    """Test cross-layer circuit from earlier prev-token head"""

    print("\nTesting 1.2 → 3.3 circuit (cross-layer)")
    print("="*60)

    clean_tokens = generate_induction_prompts(batch=4, seq_length=20).to(device)
    corrupt_tokens = corrupt_induction_prompt(clean_tokens).to(device)

    clean_logits, clean_cache = model.run_with_cache(clean_tokens)
    clean_score = calculate_correct_logits(clean_logits, clean_tokens)

    del clean_cache, clean_logits
    torch.cuda.empty_cache()

    corrupt_logits, corrupt_cache = model.run_with_cache(corrupt_tokens)
    corrupt_score = calculate_correct_logits(corrupt_logits, clean_tokens)
    corrupt_z_1_2 = corrupt_cache['z', 1][:, :, 2, :].clone()

    del corrupt_cache, corrupt_logits
    torch.cuda.empty_cache()

    seq_len = clean_tokens.size(1) // 2

    def corrupt_1_2_output(activation, hook):
        activation[:, 1:seq_len+1, 2, :] = corrupt_z_1_2[:, 1:seq_len+1, :]
        return activation

    model.reset_hooks()
    model.add_hook(utils.get_act_name('z', 1), corrupt_1_2_output)

    patched_logits = model(clean_tokens)
    patched_score = calculate_correct_logits(patched_logits, clean_tokens)

    model.reset_hooks()
    del patched_logits, corrupt_z_1_2
    torch.cuda.empty_cache()

    patching_metric = (patched_score - corrupt_score) / (clean_score - corrupt_score)

    print(f"Clean:   {clean_score:.3f}")
    print(f"Corrupt: {corrupt_score:.3f}")
    print(f"Patched: {patched_score:.3f}")
    print(f"Patching metric: {patching_metric:.3f}")

    if patching_metric < 0.5:
        print("✓ STRONG CIRCUIT!")

torch.cuda.empty_cache()
test_circuit_1_2_to_3_3()


In [None]:
def test_resid_to_3_3():
    """
    Test if 3.3 depends on information in the residual stream
    BEFORE layer 3 (i.e., written by layers 0, 1, 2)
    """

    print("\nTesting Residual Stream → 3.3")
    print("="*60)

    clean_tokens = generate_induction_prompts(batch=4, seq_length=20).to(device)
    corrupt_tokens = corrupt_induction_prompt(clean_tokens).to(device)

    # Get caches
    clean_logits, clean_cache = model.run_with_cache(clean_tokens)
    clean_score = calculate_correct_logits(clean_logits, clean_tokens)

    # Extract 3.3's attention
    clean_attn_3_3 = clean_cache['pattern', 3][:, 3, :, :].clone().cpu()

    del clean_cache, clean_logits
    torch.cuda.empty_cache()

    corrupt_logits, corrupt_cache = model.run_with_cache(corrupt_tokens)
    corrupt_score = calculate_correct_logits(corrupt_logits, clean_tokens)

    # Get corrupted residual stream BEFORE layer 3 (after layer 2)
    corrupt_resid_pre_3 = corrupt_cache['resid_pre', 3].clone()

    del corrupt_cache, corrupt_logits
    torch.cuda.empty_cache()

    # Patch: Replace residual stream at source positions
    seq_len = clean_tokens.size(1) // 2

    def corrupt_resid_input(activation, hook):
        # Corrupt first half (source positions)
        activation[:, 1:seq_len+1, :] = corrupt_resid_pre_3[:, 1:seq_len+1, :]
        return activation

    model.reset_hooks()
    model.add_hook('blocks.3.hook_resid_pre', corrupt_resid_input)

    patched_logits, patched_cache = model.run_with_cache(
        clean_tokens,
        names_filter=lambda name: name == utils.get_act_name('pattern', 3)
    )

    patched_score = calculate_correct_logits(patched_logits, clean_tokens)
    patched_attn_3_3 = patched_cache['pattern', 3][:, 3, :, :].clone().cpu()

    model.reset_hooks()
    del patched_cache, patched_logits, corrupt_resid_pre_3
    torch.cuda.empty_cache()

    # Analyze attention changes
    clean_diagonal = []
    patched_diagonal = []
    for t in range(seq_len+1, clean_attn_3_3.shape[1]):
        source_pos = t - seq_len
        if 1 <= source_pos <= seq_len:
            clean_diagonal.append(clean_attn_3_3[0, t, source_pos].item())
            patched_diagonal.append(patched_attn_3_3[0, t, source_pos].item())

    clean_diag_avg = sum(clean_diagonal) / len(clean_diagonal)
    patched_diag_avg = sum(patched_diagonal) / len(patched_diagonal)

    print(f"\n3.3's Induction Attention:")
    print(f"Clean:   {clean_diag_avg:.3f} ({clean_diag_avg*100:.1f}%)")
    print(f"Patched: {patched_diag_avg:.3f} ({patched_diag_avg*100:.1f}%)")
    print(f"Drop: {(clean_diag_avg - patched_diag_avg):.3f} ({((clean_diag_avg - patched_diag_avg)/clean_diag_avg*100):.1f}%)")

    patching_metric = (patched_score - corrupt_score) / (clean_score - corrupt_score)

    print(f"\nPerformance:")
    print(f"Clean:   {clean_score:.3f}")
    print(f"Corrupt: {corrupt_score:.3f}")
    print(f"Patched: {patched_score:.3f}")
    print(f"Patching metric: {patching_metric:.3f}")

    if patching_metric < 0.3:
        print("✓ STRONG: 3.3 critically depends on residual stream info!")
    elif patching_metric < 0.6:
        print("~ MODERATE: 3.3 uses residual stream info")

    return clean_attn_3_3, patched_attn_3_3

torch.cuda.empty_cache()
test_resid_to_3_3()

In [None]:
def test_all_earlier_layers_to_3_3():
    """
    Test layers 0, 1, 2 → 3.3
    Find which layer writes the critical information
    """

    print("\nTesting Earlier Layers → 3.3")
    print("="*60)

    clean_tokens = generate_induction_prompts(batch=4, seq_length=20).to(device)
    corrupt_tokens = corrupt_induction_prompt(clean_tokens).to(device)

    # Baselines
    clean_logits, _ = model.run_with_cache(clean_tokens)
    clean_score = calculate_correct_logits(clean_logits, clean_tokens)
    del clean_logits
    torch.cuda.empty_cache()

    corrupt_logits, corrupt_cache = model.run_with_cache(corrupt_tokens)
    corrupt_score = calculate_correct_logits(corrupt_logits, clean_tokens)
    del corrupt_logits

    results = []
    seq_len = clean_tokens.size(1) // 2

    # Test each earlier layer
    for test_layer in range(3):  # Layers 0, 1, 2
        print(f"\nTesting layer {test_layer}...")

        # Get corrupted residual after this layer
        corrupt_resid = corrupt_cache[f'blocks.{test_layer}.hook_resid_post'].clone()

        def corrupt_layer_output(activation, hook):
            activation[:, 1:seq_len+1, :] = corrupt_resid[:, 1:seq_len+1, :]
            return activation

        model.reset_hooks()
        model.add_hook(f'blocks.{test_layer}.hook_resid_post', corrupt_layer_output)

        patched_logits = model(clean_tokens)
        patched_score = calculate_correct_logits(patched_logits, clean_tokens)

        model.reset_hooks()
        del patched_logits
        torch.cuda.empty_cache()

        patching_metric = (patched_score - corrupt_score) / (clean_score - corrupt_score)
        results.append((test_layer, patching_metric))

        print(f"  Patching metric: {patching_metric:.3f}", end="")
        if patching_metric < 0.5:
            print(" ✓ STRONG")
        elif patching_metric < 0.75:
            print(" ~ MODERATE")
        else:
            print(" ✗ WEAK")

    del corrupt_cache
    torch.cuda.empty_cache()

    print("\n" + "="*60)
    print("SUMMARY:")
    for layer, metric in results:
        status = "✓ CRITICAL" if metric < 0.5 else ("~ Important" if metric < 0.75 else "✗ Not critical")
        print(f"Layer {layer}: {metric:.3f} {status}")

    return results

torch.cuda.empty_cache()
layer_results = test_all_earlier_layers_to_3_3()

In [None]:
def analyze_3_2_role():
    """
    Understand what 3.2 contributes
    """

    print("\nAnalyzing Head 3.2's Role")
    print("="*60)

    clean_tokens = generate_induction_prompts(batch=8, seq_length=20).to(device)
    _, cache = model.run_with_cache(clean_tokens)

    # Get 3.2's output
    z_3_2 = cache['z', 3][:, :, 2, :]
    output_3_2 = z_3_2 @ model.W_O[3, 2, :, :]

    # What does it write to the residual stream?
    # Project to vocabulary
    logit_contrib = output_3_2 @ model.W_U

    seq_len = clean_tokens.size(1) // 2

    # For source positions (first half)
    source_logits = logit_contrib[:, 1:seq_len+1, :]
    source_tokens = clean_tokens[:, 1:seq_len+1]

    # Does it boost the CURRENT token?
    current_token_boost = []
    for b in range(source_tokens.shape[0]):
        for t in range(source_tokens.shape[1]):
            token_id = source_tokens[b, t].item()
            boost = source_logits[b, t, token_id].item()
            current_token_boost.append(boost)

    avg_boost = sum(current_token_boost) / len(current_token_boost)

    print(f"\n3.2's contribution at source positions:")
    print(f"Average logit boost to CURRENT token: {avg_boost:.3f}")

    if avg_boost > 1.0:
        print("  ✓ STRONG: 3.2 writes 'token identity' information!")
    elif avg_boost > 0.3:
        print("  ~ MODERATE: 3.2 somewhat encodes token identity")
    else:
        print("  ✗ WEAK: 3.2 doesn't encode token identity")

    # For target positions (second half)
    target_logits = logit_contrib[:, seq_len+1:, :]
    target_tokens = clean_tokens[:, seq_len+1:]

    correct_token_boost = []
    for b in range(target_tokens.shape[0]):
        for t in range(target_tokens.shape[1]):
            token_id = target_tokens[b, t].item()
            boost = target_logits[b, t, token_id].item()
            correct_token_boost.append(boost)

    avg_target_boost = sum(correct_token_boost) / len(correct_token_boost)

    print(f"\n3.2's contribution at target positions:")
    print(f"Average logit boost to CORRECT token: {avg_target_boost:.3f}")

    del cache
    torch.cuda.empty_cache()

torch.cuda.empty_cache()
analyze_3_2_role()



In [None]:
# Hypothesis: W_K extracts token identity from embeddings
# Let's test this!

def analyze_weight_matrices():
    """
    Understand what W_K and W_Q actually extract from embeddings
    """

    print("Analyzing 3.3's Weight Matrices")
    print("="*60)

    # Get weight matrices
    W_K_3_3 = model.W_K[3, 0, :, :]  # Wait, which head is 3.3? Head 3 or head 0?
    # Based on earlier analysis, "3.3" means layer 3, head 3
    W_K_3_3 = model.W_K[3, 3, :, :]  # [d_model, d_head]
    W_Q_3_3 = model.W_Q[3, 3, :, :]  # [d_model, d_head]

    W_E = model.W_E  # Token embedding matrix [vocab, d_model]

    # What do Keys look like for different tokens?
    # K = embedding @ W_K

    # Sample some tokens
    sample_tokens = [1, 100, 1000, 5000, 10000]  # Different token IDs

    print("\nKey similarity for identical tokens:")
    for tok in sample_tokens:
        embedding = W_E[tok, :]  # [d_model]
        key = embedding @ W_K_3_3  # [d_head]

        # Self-similarity (should be high)
        self_sim = torch.dot(key, key).item()
        print(f"Token {tok}: self-similarity = {self_sim:.2f}")

    print("\nKey similarity for DIFFERENT tokens:")
    # Do different tokens produce different keys?
    keys = []
    for tok in sample_tokens[:3]:
        embedding = W_E[tok, :]
        key = embedding @ W_K_3_3
        keys.append(key)

    for i in range(len(keys)):
        for j in range(i+1, len(keys)):
            sim = torch.dot(keys[i], keys[j]).item()
            print(f"Token {sample_tokens[i]} vs {sample_tokens[j]}: similarity = {sim:.2f}")

    print("\n" + "="*60)
    print("If keys for same token are similar (high self-sim)")
    print("and keys for different tokens are dissimilar (low cross-sim),")
    print("then W_K extracts 'token identity' from embeddings!")

analyze_weight_matrices()

In [None]:
def test_3_2_helps_downstream():
    """
    Test if 3.2 helps layers 4+ (not 3.3)
    """

    print("\nTesting if 3.2 helps DOWNSTREAM layers (not 3.3)")
    print("="*60)

    clean_tokens = generate_induction_prompts(batch=4, seq_length=20).to(device)
    corrupt_tokens = corrupt_induction_prompt(clean_tokens).to(device)

    # Baseline
    clean_logits, _ = model.run_with_cache(clean_tokens)
    clean_score = calculate_correct_logits(clean_logits, clean_tokens)
    del clean_logits
    torch.cuda.empty_cache()

    corrupt_logits, corrupt_cache = model.run_with_cache(corrupt_tokens)
    corrupt_score = calculate_correct_logits(corrupt_logits, clean_tokens)

    # Get 3.2's corrupted output
    corrupt_z_3_2 = corrupt_cache['z', 3][:, :, 2, :].clone()
    del corrupt_cache, corrupt_logits
    torch.cuda.empty_cache()

    seq_len = clean_tokens.size(1) // 2

    # Intervention: Corrupt 3.2, but only let it affect layers 4+
    # (To test this, we'd need to corrupt 3.2's contribution to resid_post)

    def corrupt_3_2_in_resid_post(activation, hook):
        # Get clean 3.2 output
        clean_run_cache = {}

        # This is tricky - we need to remove clean 3.2 and add corrupt 3.2
        # to the residual stream after layer 3

        # Actually, a simpler test: just zero out 3.2
        return activation

    # Simpler approach: zero-ablate 3.2 and measure
    def zero_3_2(activation, hook):
        activation[:, :, 2, :] = 0
        return activation

    model.reset_hooks()
    model.add_hook(utils.get_act_name('z', 3), zero_3_2)

    ablated_logits = model(clean_tokens)
    ablated_score = calculate_correct_logits(ablated_logits, clean_tokens)

    model.reset_hooks()
    del ablated_logits
    torch.cuda.empty_cache()

    ablation_metric = (ablated_score - corrupt_score) / (clean_score - corrupt_score)

    print(f"\nZero-ablating 3.2:")
    print(f"Clean score:   {clean_score:.3f}")
    print(f"Ablated score: {ablated_score:.3f}")
    print(f"Drop from clean: {(clean_score - ablated_score):.3f}")
    print(f"Ablation metric: {ablation_metric:.3f}")

    if ablation_metric < 0.5:
        print("✓ 3.2 is important for overall performance")
    else:
        print("✗ 3.2 is not critical")

torch.cuda.empty_cache()
test_3_2_helps_downstream()

In [None]:
def new_induction_scores(attention_pattern):
  attention_pattern

In [None]:
clean_cache['blocks.3.attn.hook_pattern'].shape