# Retrieval Head Ablation (Step 2)

### Validate identified retrieval heads via targeted ablation on held-out data

In [None]:
%pip install torch transformers numpy huggingface_hub accelerate bitsandbytes python-dotenv pandas matplotlib scikit-learn

In [None]:
import os
import re
import json
import random
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from contextlib import nullcontext
from collections import Counter
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from dotenv import load_dotenv

#### If using Google Colab: Run The Following Cells (Ignore if using some other environment)

In [None]:
from google.colab import userdata

HF_TOKEN = userdata.get('HF_TOKEN')

In [None]:
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("HF_TOKEN loaded and Hugging Face login successful.")
else:
    print("HF_TOKEN not found.")

#### If using a high end GPU not from Colab, but from Lambda Labs

In [None]:
load_dotenv(override=False)

HF_TOKEN = os.getenv("HF_TOKEN") or globals().get("HF_TOKEN")

if HF_TOKEN:
    os.environ["HF_TOKEN"] = HF_TOKEN
    login(token=HF_TOKEN)
    print("HF_TOKEN loaded and Hugging Face login successful.")
else:
    print("HF_TOKEN not found. Add HF_TOKEN=... to your .env file or set it in the environment.")

#### Configuration & Hyperparameters

In [None]:
MODEL_ID = "meta-llama/Llama-3.3-70B-Instruct"

TARGET_SEQ_LEN = 7000
NEEDLE_DEPTH = 0.5
SPLIT = 0.8
SPLIT_SEED = 42

TASKS = [
    {"id": "registrant_name", "question": "What is the registrant's name?"},
    {"id": "headquarters_city", "question": "What is the registrant's headquarters city?"},
    {"id": "headquarters_state", "question": "What is the registrant's headquarters state?"},
    {"id": "incorporation_state", "question": "What is the registrant's incorporation state?"},
    {"id": "incorporation_year", "question": "What is the incorporation year?"},
    {"id": "employees_count_total", "question": "How many total employees does the registrant have?"},
    {"id": "holder_record_amount", "question": "What is the number of holders of record of the registrant's common stock?"},
    {"id": "employees_count_full_time", "question": "How many full-time employees does the registrant have?"},
    {"id": "ceo_lastname", "question": "What is the CEO's last name?"},
]
TASK_MAP = {t["id"]: t for t in TASKS}

# Paths
DATA_PATH = "data/clean_ground_truth/cleaned_EDGAR_gt_2-22-2026.csv"
HEADS_JSON_PATH = "data/retrieval_heads/results/retrieval_heads.json"
ABLATION_OUTPUT_DIR = "data/retrieval_heads/ablation_results"

# Ablation-specific config
ABLATION_K_VALUES = [1, 2, 5, 10, 20]
MAX_DECODE_TOKENS = 20

# Model loading
TORCH_DTYPE = torch.bfloat16
ATTN_IMPL = "eager"

print("Configuration loaded.")
print(f"Model: {MODEL_ID}")
print(f"Ablation K values: {ABLATION_K_VALUES}")
print(f"Max decode tokens: {MAX_DECODE_TOKENS}")
print(f"Heads JSON: {HEADS_JSON_PATH}")

##### Tokenizer & Model Loading

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Tokenizer loaded: {MODEL_ID}")
print(f"Vocab size: {tokenizer.vocab_size:,}")

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    token=HF_TOKEN,
    torch_dtype=TORCH_DTYPE,
    device_map="auto",
    attn_implementation=ATTN_IMPL,
)
model.eval()

print(f"Model loaded: {MODEL_ID}")
print(f"dtype: {next(model.parameters()).dtype}")
print(f"Num layers: {model.config.num_hidden_layers}")
print(f"Num heads (Q): {model.config.num_attention_heads}")
print(f"Num KV heads: {model.config.num_key_value_heads}")

---
## Section 2 -- Load Ablation Split & Identified Heads

In [None]:
from sklearn.model_selection import train_test_split

df = pd.read_csv(DATA_PATH)
df = df[df["haystack_token_length"] == TARGET_SEQ_LEN].reset_index(drop=True)

_, ablation_df = train_test_split(
    df,
    test_size=1 - SPLIT,
    stratify=df["task"],
    random_state=SPLIT_SEED,
)
ablation_df = ablation_df.reset_index(drop=True)

print(f"Ablation set: {len(ablation_df):,} samples")
print(ablation_df.groupby("task").size().to_string())

In [None]:
with open(HEADS_JSON_PATH, "r") as f:
    heads_data = json.load(f)

# Per-task top-K heads as (layer, head) tuples
task_head_rankings = {}
for task_id, head_list in heads_data["tasks"].items():
    task_head_rankings[task_id] = [(h["layer"], h["head"]) for h in head_list]

# Shared heads
shared_heads = [(h["layer"], h["head"]) for h in heads_data["shared_heads"]]

print(f"Loaded {len(task_head_rankings)} task rankings from {HEADS_JSON_PATH}")
for tid, heads in task_head_rankings.items():
    print(f"  {tid:<35} {len(heads)} heads")
print(f"Shared heads: {len(shared_heads)}")

---
## Section 3 -- Core Ablation Infrastructure

In [None]:
class HeadAblationHooks:
    """Context manager that zeros out selected attention heads before o_proj."""

    def __init__(self, model, heads: list[tuple[int, int]]):
        self.model = model
        self.heads = heads
        self.handles = []
        self.head_dim = model.config.hidden_size // model.config.num_attention_heads

        # Group heads by layer for efficient hooking
        self.by_layer = {}
        for layer_idx, head_idx in heads:
            self.by_layer.setdefault(layer_idx, set()).add(head_idx)

    def __enter__(self):
        for layer_idx, layer_heads in self.by_layer.items():
            if layer_idx >= len(self.model.model.layers):
                continue
            o_proj = self.model.model.layers[layer_idx].self_attn.o_proj

            def pre_hook(_module, inputs, heads=sorted(layer_heads)):
                x = inputs[0]
                if x.ndim != 3:
                    return inputs
                x = x.clone()
                for head_idx in heads:
                    start = head_idx * self.head_dim
                    end = start + self.head_dim
                    x[..., start:end] = 0
                return (x,)

            self.handles.append(o_proj.register_forward_pre_hook(pre_hook))
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        for handle in self.handles:
            handle.remove()
        self.handles = []

In [None]:
def greedy_decode(model, tokenizer, prompt_ids, max_tokens=MAX_DECODE_TOKENS) -> str:
    """Prefill + token-by-token greedy decode. No attention output needed."""
    with torch.inference_mode():
        out = model(
            input_ids=prompt_ids[:, :-1],
            use_cache=True,
            return_dict=True,
            output_attentions=False,
        )
        past_kv = out.past_key_values
        inp = prompt_ids[:, -1:]
        position = prompt_ids.size(1) - 1
        device = prompt_ids.device

        generated = []
        for _ in range(max_tokens):
            pos_ids = torch.tensor([[position]], dtype=torch.long, device=device)
            out = model(
                input_ids=inp,
                past_key_values=past_kv,
                position_ids=pos_ids,
                use_cache=True,
                output_attentions=False,
                return_dict=True,
            )
            past_kv = out.past_key_values
            next_id = out.logits[:, -1].argmax(dim=-1)
            generated.append(next_id.item())
            if next_id.item() == tokenizer.eos_token_id:
                break
            position += 1
            inp = next_id.unsqueeze(1)

    return tokenizer.decode(generated, skip_special_tokens=True).strip()

In [None]:
def normalize_value(text: str) -> str:
    """Lowercase, strip whitespace, remove commas, convert number words to digits."""
    if not text:
        return ""
    s = re.sub(r"\s+", " ", str(text).replace("\n", " ")).strip().lower()
    s = s.replace(",", "")

    number_words = {
        "zero": "0", "one": "1", "two": "2", "three": "3", "four": "4",
        "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9",
        "ten": "10", "eleven": "11", "twelve": "12", "thirteen": "13",
        "fourteen": "14", "fifteen": "15", "sixteen": "16", "seventeen": "17",
        "eighteen": "18", "nineteen": "19", "twenty": "20", "thirty": "30",
        "forty": "40", "fifty": "50", "sixty": "60", "seventy": "70",
        "eighty": "80", "ninety": "90",
    }
    for word, digit in number_words.items():
        s = re.sub(r"\b" + word + r"\b", digit, s)
    return s

---
## Section 4 -- Prompt Building & Evaluation Helpers

In [None]:
CONTROL_TOKENS = ["<|begin_of_text|>", "<|end_of_text|>", "<|eot_id|>", "<|start_header_id|>", "<|end_header_id|>"]
TOKEN_PATTERN = re.compile("|".join(re.escape(t) for t in CONTROL_TOKENS))


def find_needle_span(
    prompt_ids: list[int],
    needle_ids: list[int],
    threshold: float = 0.9,
) -> tuple[int, int]:
    """Locate needle tokens inside the full tokenized prompt via sliding window overlap."""
    span_len = len(needle_ids)
    if span_len == 0:
        return -1, -1

    needle_set = set(needle_ids)

    for i in range(len(prompt_ids) - span_len + 1):
        window = set(prompt_ids[i : i + span_len])
        if len(window & needle_set) / len(needle_set) >= threshold:
            return i, i + span_len

    return -1, -1


def build_prompt(row: pd.Series, task: dict, tokenizer) -> dict:
    """Construct the prompt with needle inserted at NEEDLE_DEPTH and locate the needle span."""
    haystack = TOKEN_PATTERN.sub("", row["haystack_text"]).strip()
    needle = row["needle_sentence"]

    mid = len(haystack) // 2
    context = haystack[:mid] + " " + needle + " " + haystack[mid:]

    message = f"<document>{context}</document>\nQuestion: {task['question']}\nAnswer:"
    messages = [{"role": "user", "content": message}]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
    )

    needle_ids = tokenizer.encode(needle, add_special_tokens=False)
    needle_start, needle_end = find_needle_span(input_ids[0].tolist(), needle_ids)

    return {
        "input_ids": input_ids,
        "needle_start": needle_start,
        "needle_end": needle_end,
    }

In [None]:
def evaluate_sample(model, tokenizer, row, task, ablation_heads=None) -> dict:
    """Build prompt, optionally ablate heads, greedy decode, score."""
    prompt = build_prompt(row, task, tokenizer)

    if prompt["needle_start"] == -1:
        return {"decoded": "", "value_match": False, "skipped": True}

    input_ids = prompt["input_ids"].to(model.device)
    hook_ctx = HeadAblationHooks(model, ablation_heads) if ablation_heads else nullcontext()

    with hook_ctx:
        decoded = greedy_decode(model, tokenizer, input_ids)

    ground_truth = row["needle_value"]
    value_match = normalize_value(ground_truth) in normalize_value(decoded)

    del input_ids
    torch.cuda.empty_cache()

    return {"decoded": decoded, "value_match": value_match, "skipped": False}

In [None]:
def evaluate_condition(
    model,
    tokenizer,
    eval_df: pd.DataFrame,
    condition_name: str,
    ablation_heads=None,
) -> dict:
    """Run evaluation on all rows in eval_df. Returns per-task accuracy dict."""
    results = {}
    total = len(eval_df)

    for i, (_, row) in enumerate(eval_df.iterrows(), start=1):
        task_id = row["task"]
        task = TASK_MAP[task_id]

        if task_id not in results:
            results[task_id] = {"attempts": 0, "matches": 0, "samples": []}

        sample_result = evaluate_sample(model, tokenizer, row, task, ablation_heads)

        if sample_result["skipped"]:
            continue

        results[task_id]["attempts"] += 1
        if sample_result["value_match"]:
            results[task_id]["matches"] += 1
        results[task_id]["samples"].append(sample_result)

        if i % 10 == 0:
            print(f"[{condition_name}] {i}/{total}")

    # Compute accuracy per task
    for task_id in results:
        r = results[task_id]
        r["accuracy"] = r["matches"] / max(1, r["attempts"])

    return results

---
## Section 5 -- Experiment 1: Baseline Evaluation (No Ablation)

In [None]:
os.makedirs(ABLATION_OUTPUT_DIR, exist_ok=True)

baseline_results = evaluate_condition(model, tokenizer, ablation_df, "baseline")

print("\nBaseline results:")
for task_id, r in baseline_results.items():
    print(f"  {task_id:<35} {r['matches']}/{r['attempts']}  ({r['accuracy']:.1%})")

In [None]:
with open(os.path.join(ABLATION_OUTPUT_DIR, "baseline_results.json"), "w") as f:
    # Strip sample-level decoded text for cleaner JSON
    export = {tid: {"attempts": r["attempts"], "matches": r["matches"], "accuracy": r["accuracy"]} for tid, r in baseline_results.items()}
    json.dump(export, f, indent=2)
    print(f"Saved baseline results.")

---
## Section 6 -- Experiment 2: Within-Task Ablation

For each task, ablate that task's own top-K heads at increasing K values.

In [None]:
within_task_results = {}

for task_id in TASK_MAP:
    task_df = ablation_df[ablation_df["task"] == task_id].reset_index(drop=True)
    if len(task_df) == 0:
        print(f"SKIP {task_id} -- no ablation samples")
        continue

    available_heads = task_head_rankings.get(task_id, [])
    if not available_heads:
        print(f"SKIP {task_id} -- no ranked heads")
        continue

    within_task_results[task_id] = {}

    for k in ABLATION_K_VALUES:
        heads_to_ablate = available_heads[:k]
        if not heads_to_ablate:
            break

        condition_name = f"within_{task_id}_k{k}"
        results = evaluate_condition(model, tokenizer, task_df, condition_name, ablation_heads=heads_to_ablate)

        r = results.get(task_id, {"attempts": 0, "matches": 0, "accuracy": 0.0})
        within_task_results[task_id][k] = {
            "k": k,
            "heads_ablated": len(heads_to_ablate),
            "attempts": r["attempts"],
            "matches": r["matches"],
            "accuracy": r["accuracy"],
        }

        print(f"  {task_id} k={k}: {r['matches']}/{r['attempts']}  ({r['accuracy']:.1%})")

In [None]:
with open(os.path.join(ABLATION_OUTPUT_DIR, "within_task_ablation.json"), "w") as f:
    json.dump(within_task_results, f, indent=2)
    print("Saved within-task ablation results.")

---
## Section 7 -- Experiment 3: Across-Task Ablation (Shared Heads)

Ablate the global shared retrieval heads and measure accuracy across all tasks.

In [None]:
across_task_results = {}

for k in ABLATION_K_VALUES:
    heads_to_ablate = shared_heads[:k]
    if not heads_to_ablate:
        print(f"Only {len(shared_heads)} shared heads available, stopping at k={k}")
        break

    condition_name = f"across_task_k{k}"
    results = evaluate_condition(model, tokenizer, ablation_df, condition_name, ablation_heads=heads_to_ablate)

    total_attempts = sum(r["attempts"] for r in results.values())
    total_matches  = sum(r["matches"] for r in results.values())
    overall_acc    = total_matches / max(1, total_attempts)

    across_task_results[k] = {
        "k": k,
        "heads_ablated": len(heads_to_ablate),
        "overall_accuracy": overall_acc,
        "total_attempts": total_attempts,
        "total_matches": total_matches,
        "per_task": {tid: {"attempts": r["attempts"], "matches": r["matches"], "accuracy": r["accuracy"]} for tid, r in results.items()},
    }

    print(f"k={k}: {total_matches}/{total_attempts}  ({overall_acc:.1%})")

In [None]:
with open(os.path.join(ABLATION_OUTPUT_DIR, "across_task_ablation.json"), "w") as f:
    json.dump(across_task_results, f, indent=2)
    print("Saved across-task ablation results.")

---
## Section 8 -- Random Head Ablation Control

Ablate randomly selected heads to prove the accuracy drop is specific to retrieval heads.

In [None]:
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads

# All possible (layer, head) pairs, excluding the shared retrieval heads
shared_set = set(shared_heads)
all_heads = [(l, h) for l in range(num_layers) for h in range(num_heads) if (l, h) not in shared_set]

random.seed(SPLIT_SEED)
random_pool = random.sample(all_heads, min(max(ABLATION_K_VALUES), len(all_heads)))

random_ablation_results = {}

for k in ABLATION_K_VALUES:
    heads_to_ablate = random_pool[:k]

    condition_name = f"random_k{k}"
    results = evaluate_condition(model, tokenizer, ablation_df, condition_name, ablation_heads=heads_to_ablate)

    total_attempts = sum(r["attempts"] for r in results.values())
    total_matches  = sum(r["matches"] for r in results.values())
    overall_acc    = total_matches / max(1, total_attempts)

    random_ablation_results[k] = {
        "k": k,
        "heads_ablated": [(l, h) for l, h in heads_to_ablate],
        "overall_accuracy": overall_acc,
        "total_attempts": total_attempts,
        "total_matches": total_matches,
    }

    print(f"random k={k}: {total_matches}/{total_attempts}  ({overall_acc:.1%})")

In [None]:
with open(os.path.join(ABLATION_OUTPUT_DIR, "random_ablation_control.json"), "w") as f:
    json.dump(random_ablation_results, f, indent=2)
    print("Saved random ablation control results.")

---
## Section 9 -- Visualizations

In [None]:
# Baseline accuracy per task
baseline_acc = {tid: r["accuracy"] for tid, r in baseline_results.items()}
baseline_overall = sum(r["matches"] for r in baseline_results.values()) / max(1, sum(r["attempts"] for r in baseline_results.values()))

print(f"Baseline overall accuracy: {baseline_overall:.1%}")

##### Chart 1: Across-task ablation vs random ablation

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))

k_vals = sorted([k for k in across_task_results.keys()])
retrieval_acc = [across_task_results[k]["overall_accuracy"] for k in k_vals]
random_acc = [random_ablation_results[k]["overall_accuracy"] for k in k_vals if k in random_ablation_results]

ax.axhline(y=baseline_overall, color="green", linestyle="--", label="Baseline")
ax.plot(k_vals, retrieval_acc, "ro-", label="Retrieval heads ablated")
ax.plot(k_vals[:len(random_acc)], random_acc, "bs-", label="Random heads ablated")

ax.set_xlabel("Number of heads ablated (K)")
ax.set_ylabel("Accuracy")
ax.set_title("Shared Retrieval Head Ablation vs Random Control")
ax.legend()
ax.set_ylim(0, 1.05)

plt.tight_layout()
fig_path = os.path.join(ABLATION_OUTPUT_DIR, "across_task_vs_random.png")
plt.savefig(fig_path, dpi=150)
plt.show()
print(f"Saved: {fig_path}")

##### Chart 2: Within-task ablation per task

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))

for task_id, k_results in within_task_results.items():
    k_vals = sorted(k_results.keys())
    acc_vals = [k_results[k]["accuracy"] for k in k_vals]
    ax.plot(k_vals, acc_vals, "o-", label=task_id)

ax.axhline(y=baseline_overall, color="green", linestyle="--", label="Baseline (overall)")

ax.set_xlabel("Number of heads ablated (K)")
ax.set_ylabel("Accuracy")
ax.set_title("Within-Task Ablation: Accuracy vs K")
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
ax.set_ylim(0, 1.05)

plt.tight_layout()
fig_path = os.path.join(ABLATION_OUTPUT_DIR, "within_task_ablation.png")
plt.savefig(fig_path, dpi=150, bbox_inches="tight")
plt.show()
print(f"Saved: {fig_path}")

##### Summary table

In [None]:
summary_rows = []

for k in sorted(across_task_results.keys()):
    retrieval = across_task_results[k]["overall_accuracy"]
    rand = random_ablation_results.get(k, {}).get("overall_accuracy", None)
    summary_rows.append({
        "k": k,
        "baseline": baseline_overall,
        "retrieval_ablated": retrieval,
        "random_ablated": rand,
        "delta_retrieval": retrieval - baseline_overall,
        "delta_random": (rand - baseline_overall) if rand else None,
    })

summary_df = pd.DataFrame(summary_rows)
print(summary_df.to_string(index=False))