In [None]:
import torch
from transformer_lens import HookedTransformer

# Sticking to a single CUDA device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Define the new model as per your selection
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"

# Load the model using HookedTransformer on a single device
print(f"Loading new model: {MODEL_NAME}...")
try:
    model = HookedTransformer.from_pretrained(
        MODEL_NAME,
        center_unembed=True,
        center_writing_weights=True,
        fold_ln=True,
        device=device,
        # Using float16 is crucial for fitting this on a single GPU
        torch_dtype=torch.float16
    )
    print("New model loaded successfully on a single GPU.")
except Exception as e:
    print(f"An error occurred during model loading: {e}")
    print("\nIf this is an authentication error, you may need to log in to Hugging Face.")
    print("from huggingface_hub import login")
    print("login()")

Using device: cuda
Loading new model: meta-llama/Llama-3.2-3B-Instruct...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Llama-3.2-3B-Instruct into HookedTransformer
New model loaded successfully on a single GPU.


In [2]:
# --- Step 7.1: Caching Activations from the Clean Run ---

# We need the clean prompt and token IDs for the "Abun" puzzle.
# Let's redefine them all here to have a self-contained block.

# Prompts & Answers
clean_question_abun = "ji gwes"
clean_answer_abun = " my leg"
corrupted_question_abun = "ji bi gan"
corrupted_answer_abun = " my offspring"

# Few-Shot Examples
abun_few_shot_example_1 = """---
Abun: ji syim
English: my arm
---"""
abun_few_shot_example_2 = """---
Abun: gap bi ngwe
English: rat's garden
---"""

# Prompt Template
abun_prompt_template_two_shot = """Based on the following examples of Abun and English:
"Abun: gap sye bi gan\nEnglish: big rat's offspring"
"Abun: ji bi ngwe\nEnglish: my garden"
"Abun: ndar sye gwes\nEnglish: big dog's leg"
"Abun: gap bi ngwe\nEnglish: rat's garden"

Complete the final translation, following the format of the examples.
{example_1}
{example_2}
Abun: {question}
English:"""

# Construct the clean prompt
clean_prompt_abun = abun_prompt_template_two_shot.format(
    context="", # Context is already in the template
    example_1=abun_few_shot_example_1,
    example_2=abun_few_shot_example_2,
    question=clean_question_abun
)

# Construct the corrupted prompt (for the next step)
corrupted_prompt_abun = abun_prompt_template_two_shot.format(
    context="",
    example_1=abun_few_shot_example_1,
    example_2=abun_few_shot_example_2,
    question=corrupted_question_abun
)


# Get the token IDs using our robust method
clean_answer_tokens = model.to_tokens(clean_answer_abun, prepend_bos=False)[0]
clean_answer_token_id = clean_answer_tokens[-1].item()

corrupted_answer_tokens = model.to_tokens(corrupted_answer_abun, prepend_bos=False)[0]
corrupted_answer_token_id = corrupted_answer_tokens[-1].item()

# --- Run with cache ---
print("Running the model on the CLEAN prompt and caching activations...")
clean_logits, clean_cache = model.run_with_cache(
    clean_prompt_abun,
    remove_batch_dim=True
)
print("Clean run cached successfully.")

Running the model on the CLEAN prompt and caching activations...
Clean run cached successfully.


In [3]:
# --- Step 11.1: Re-caching the Clean Run with Correct Dimensions ---

print("Running the model on the CLEAN prompt and caching activations...")
print("Crucially, we are NOT removing the batch dimension this time.")

# --- Run with cache, REMOVING the `remove_batch_dim` argument ---
clean_logits, clean_cache = model.run_with_cache(
    clean_prompt_abun
)

# Verify the shape of a cached pattern tensor. It should now be 4D.
example_pattern_shape = clean_cache["blocks.0.attn.hook_pattern"].shape
print(f"Clean run cached successfully. Example cached shape is now: {example_pattern_shape}")

if len(example_pattern_shape) == 4:
    print("SUCCESS: Cached tensors now have the correct 4 dimensions.")
else:
    print("ERROR: Cached tensors are still not 4D. Please check the code.")

Running the model on the CLEAN prompt and caching activations...
Crucially, we are NOT removing the batch dimension this time.
Clean run cached successfully. Example cached shape is now: torch.Size([1, 24, 127, 127])
SUCCESS: Cached tensors now have the correct 4 dimensions.


In [4]:
# --- Step 11.2: Final Patching Loop Execution ---
import transformer_lens.utils as utils
from tqdm.notebook import tqdm
import pandas as pd

# This hook function correctly patches the 4D attention pattern.
def pattern_patching_hook(
    pattern,               # The attention pattern tensor [batch, n_heads, query_pos, key_pos]
    hook,                  # The hook object
    head_to_patch_index,   # The index of the head we are patching
    clean_cache
):
    pattern[:, head_to_patch_index, :, :] = clean_cache[hook.name][:, head_to_patch_index, :, :]
    return pattern

# Get model configuration
n_layers = model.cfg.n_layers
n_heads = model.cfg.n_heads
results = []

# Loop over every head in the model
for layer in tqdm(range(n_layers), desc="Layers"):
    for head in range(n_heads):
        hook_fn = lambda pattern, hook: pattern_patching_hook(pattern, hook, head, clean_cache)
        hook_point = utils.get_act_name("pattern", layer)

        patched_logits = model.run_with_hooks(
            corrupted_prompt_abun,
            fwd_hooks=[(hook_point, hook_fn)]
        )

        final_token_logits = patched_logits[0, -1, :]
        logit_diff = (final_token_logits[clean_answer_token_id] - final_token_logits[corrupted_answer_token_id]).item()

        results.append({
            "layer": layer,
            "head": head,
            "logit_diff": logit_diff
        })

# Convert results to a pandas DataFrame for easier analysis.
results_df = pd.DataFrame(results)
print("\nIterative patching of attention patterns complete.")
print("Top 5 results sorted by logit difference:")
print(results_df.sort_values("logit_diff", ascending=False).head())

Layers:   0%|          | 0/28 [00:00<?, ?it/s]


Iterative patching of attention patterns complete.
Top 5 results sorted by logit difference:
     layer  head  logit_diff
200      8     8    3.925781
64       2    16    3.328125
126      5     6    3.011719
321     13     9    2.996094
537     22     9    2.816406


In [5]:
# --- Step 12: Defining the Targeted Ablation Mechanism ---
import transformer_lens.utils as utils

# The list of candidate reasoning heads we identified in Phase 1.
# Format is (layer, head_index).
candidate_heads_to_ablate = [
    (8, 8),
    (2, 16),
    (5, 6),
    (13, 9),
    (22, 9)
]
print(f"Identified Candidate Heads for Ablation: {candidate_heads_to_ablate}")

# This hook function zeroes out the attention pattern for a specific head.
def zero_ablation_hook(
    pattern,               # The attention pattern tensor
    hook,                  # The hook object
    head_to_zero_index     # The index of the head we want to zero out
):
    # Set the attention pattern for the target head to all zeros.
    pattern[:, head_to_zero_index, :, :] = 0.0
    return pattern

# We need to create a unique hook function for each head we are ablating.
# This factory function ensures that the correct 'head' index is captured by each hook.
def create_hook_fn(head_idx_to_ablate):
    # This returns a function that takes pattern and hook as input
    return lambda pattern, hook: zero_ablation_hook(pattern, hook, head_idx_to_ablate)

# Now, we build the list of hooks that will be permanently applied to the model.
# We are ablating the 'pattern' hook point, the same one we patched.
ablation_hooks = []
for layer, head in candidate_heads_to_ablate:
    hook_point = utils.get_act_name("pattern", layer)
    hook_function = create_hook_fn(head)
    ablation_hooks.append((hook_point, hook_function))

print(f"\nSuccessfully created {len(ablation_hooks)} hooks for targeted ablation.")

Identified Candidate Heads for Ablation: [(8, 8), (2, 16), (5, 6), (13, 9), (22, 9)]

Successfully created 5 hooks for targeted ablation.


In [6]:
# --- Step 13: Preparing the Reasoning and Recall Test Suites ---

# We will use the "Ayutla Mixe" puzzle data for our test suites.
# This ensures we are not testing on the same data used to find the circuits.
mixe_puzzle_data = {
    "context": [
        "Mixe: Ëjts nexp.\nEnglish: I see.",
        "Mixe: Mejts mtunp.\nEnglish: You work.",
        "Mixe: Juan yë'ë yexyejtpy. \nEnglish: Juan watches him.",
        "Mixe: Yë'ë yë' uk yexpy.\nEnglish: He sees the dog.",
        "Mixe: Ëjts yë' maxu'unk nexyejtpy.\nEnglish: I watch the baby."
    ],
    "reasoning_questions": [
        {"question": "Mixe: Yë' maxu'unk yexp.", "answer": "The baby sees."},
        {"question": "English: I work for him.", "answer": "Ëjts yë'ë nexpy."}
    ]
}

# --- 1. Construct the Reasoning Suite ---
print("Constructing the Reasoning Suite...")
reasoning_suite = []
# We use a one-shot format to guide the model.
reasoning_prompt_template = """Based on the following examples of Mixe and English:
"{context_str}"

Following the format of the examples, complete the final translation.
Question: {question}
Answer:"""

for task in mixe_puzzle_data["reasoning_questions"]:
    prompt = reasoning_prompt_template.format(
        context_str="\n".join(mixe_puzzle_data["context"]),
        question=task["question"]
    )
    reasoning_suite.append({"prompt": prompt, "correct_answer": task["answer"]})

print(f"Reasoning Suite created with {len(reasoning_suite)} prompts.")
print("Example Reasoning Prompt:")
print(reasoning_suite[0]['prompt'])


# --- 2. Construct the Recall Suite ---
print("\nConstructing the Recall Suite...")
recall_suite = []
# The recall prompt explicitly asks to retrieve information from the text.
recall_prompt_template = """Based *only* on the text provided below:
"{context_str}"

What is the English translation of the Mixe phrase "{mixe_phrase}"?"""

# We generate recall tasks from the examples themselves.
for example in mixe_puzzle_data["context"]:
    if "Mixe:" in example and "English:" in example:
        parts = example.split("\n")
        mixe_part = parts[0].replace("Mixe: ", "").strip()
        english_part = parts[1].replace("English: ", "").strip()

        prompt = recall_prompt_template.format(
            context_str="\n".join(mixe_puzzle_data["context"]),
            mixe_phrase=mixe_part
        )
        recall_suite.append({"prompt": prompt, "correct_answer": english_part})

print(f"Recall Suite created with {len(recall_suite)} prompts.")
print("Example Recall Prompt:")
print(recall_suite[0]['prompt'])

Constructing the Reasoning Suite...
Reasoning Suite created with 2 prompts.
Example Reasoning Prompt:
Based on the following examples of Mixe and English:
"Mixe: Ëjts nexp.
English: I see.
Mixe: Mejts mtunp.
English: You work.
Mixe: Juan yë'ë yexyejtpy. 
English: Juan watches him.
Mixe: Yë'ë yë' uk yexpy.
English: He sees the dog.
Mixe: Ëjts yë' maxu'unk nexyejtpy.
English: I watch the baby."

Following the format of the examples, complete the final translation.
Question: Mixe: Yë' maxu'unk yexp.
Answer:

Constructing the Recall Suite...
Recall Suite created with 5 prompts.
Example Recall Prompt:
Based *only* on the text provided below:
"Mixe: Ëjts nexp.
English: I see.
Mixe: Mejts mtunp.
English: You work.
Mixe: Juan yë'ë yexyejtpy. 
English: Juan watches him.
Mixe: Yë'ë yë' uk yexpy.
English: He sees the dog.
Mixe: Ëjts yë' maxu'unk nexyejtpy.
English: I watch the baby."

What is the English translation of the Mixe phrase "Ëjts nexp."?


In [7]:
# --- Step 14 (Final Corrected Version): Performance Evaluation and Analysis ---
from tqdm.notebook import tqdm
import pandas as pd

# Helper function to check if the model's answer is correct.
def is_correct(generated_text, correct_answer):
    # We check if the correct answer appears in the generated text, ignoring case.
    # A more robust solution might involve normalizing text, but this is good for our purpose.
    return correct_answer.lower() in generated_text.lower()

# The corrected evaluation function using the model.hooks() context manager.
def evaluate_performance(suite, use_ablation_hooks=False):
    correct_predictions = 0
    total_predictions = len(suite)
    
    hooks_to_use = ablation_hooks if use_ablation_hooks else []

    for task in tqdm(suite, desc=f"Evaluating (Ablated={use_ablation_hooks})"):
        prompt = task["prompt"]
        correct_answer = task["correct_answer"]
        
        # Use the model.hooks context manager to apply hooks to the .generate() call
        with model.hooks(fwd_hooks=hooks_to_use):
            output = model.generate(
                prompt,
                max_new_tokens=15,
                do_sample=False
            )

        if is_correct(output, correct_answer):
            correct_predictions += 1
            
    return (correct_predictions / total_predictions) * 100

# --- Run the four evaluations ---
print("--- Evaluating Original Model ---")
original_reasoning_accuracy = evaluate_performance(reasoning_suite, use_ablation_hooks=False)
original_recall_accuracy = evaluate_performance(recall_suite, use_ablation_hooks=False)

print("\n--- Evaluating Ablated Model ---")
ablated_reasoning_accuracy = evaluate_performance(reasoning_suite, use_ablation_hooks=True)
ablated_recall_accuracy = evaluate_performance(recall_suite, use_ablation_hooks=True)

# --- Analyze and display the results ---
performance_drop_reasoning = original_reasoning_accuracy - ablated_reasoning_accuracy
performance_drop_recall = original_recall_accuracy - ablated_recall_accuracy

results = {
    "Suite": ["Reasoning", "Recall"],
    "Original Model Accuracy (%)": [original_reasoning_accuracy, original_recall_accuracy],
    "Ablated Model Accuracy (%)": [ablated_reasoning_accuracy, ablated_recall_accuracy],
    "Performance Drop (%)": [performance_drop_reasoning, performance_drop_recall]
}
summary_df = pd.DataFrame(results)

print("\n\n--- FINAL RESULTS ---")
print(summary_df.to_string(index=False))

# --- Final Validation based on Success Criterion ---
print("\n--- Validation ---")
if performance_drop_reasoning > 20 and performance_drop_recall < 10:
    print("SUCCESS: The results align with the hypothesis.")
    print("Ablating the candidate heads significantly harms reasoning performance while leaving recall largely intact.")
    print("This provides strong evidence that these heads form a specialized reasoning circuit.")
else:
    print("FAILURE: The results do not align with the hypothesis.")
    print("The ablated heads may not be as specialized for reasoning as predicted.")

--- Evaluating Original Model ---


Evaluating (Ablated=False):   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

Evaluating (Ablated=False):   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]


--- Evaluating Ablated Model ---


Evaluating (Ablated=True):   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

Evaluating (Ablated=True):   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]



--- FINAL RESULTS ---
    Suite  Original Model Accuracy (%)  Ablated Model Accuracy (%)  Performance Drop (%)
Reasoning                          0.0                         0.0                   0.0
   Recall                        100.0                       100.0                   0.0

--- Validation ---
FAILURE: The results do not align with the hypothesis.
The ablated heads may not be as specialized for reasoning as predicted.
