# Identify Retrieval Heads Using Sum Attention Method

### Retrieval Head Identification Using Sum Attention Method

In [None]:
%pip install torch, transformers, numpy, huggingface_hub accelerate bitsandbytes dotenv, pandas, matplotlib

In [None]:
import os 
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM 
import numpy as np 
import matplotlib.pyplot as plt  
import pandas as pd  
import accelerate 
from huggingface_hub import login  
from dotenv import load_dotenv  
import json

#### If using Google Colab: Run The Following Cells (Ignore if using some other environment)

In [None]:
from google.colab import userdata 

HF_TOKEN = userdata.get('HF_TOKEN') 

In [None]:
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("HF_TOKEN loaded and Hugging Face login successful.")
else:
    print("HF_TOKEN not found.")

#### If using a high end GPU not from Colab, but from Lambda Labs

In [None]:
# Load HF_TOKEN from a local .env file (works in Lambda Labs / venv / local runs)
load_dotenv(override=False)

# Prefer token from .env or environment; fall back to an existing HF_TOKEN variable (e.g., Colab cell)
HF_TOKEN = os.getenv("HF_TOKEN") or globals().get("HF_TOKEN")

if HF_TOKEN:
    os.environ["HF_TOKEN"] = HF_TOKEN
    login(token=HF_TOKEN)
    print("HF_TOKEN loaded and Hugging Face login successful.")
else:
    print("HF_TOKEN not found. Add HF_TOKEN=... to your .env file or set it in the environment.")

#### Configuration & Hyperparameters

In [None]:
import torch

# Model
MODEL_ID = "meta-llama/Llama-3.3-70B-Instruct"

# Data and Needle Stuff

TARGET_SEQ_LEN = 7000

NEEDLE_DEPTH = 0.5
SPLIT = 0.8

# No magic numbers, for reproduction
SPLIT_SEED = 42
TOP_K_HEADS = 20

TASKS = [
    {"id": "registrant_name", "question": "What is the registrant's name?"},
    {"id": "headquarters_city", "question": "What is the registrant's headquarters city?"},
    {"id": "headquarters_state", "question": "What is the registrant's headquarters state?"}, 
    # incorporation state and year should (would it be better to ask for the current state and year instead?)
    {"id": "incorporation_state", "question": "What is the registrant's incorporation state?"},
    {"id": "incorporation_year", "question": "What is the incorporation year?"},  

    {"id": "employees_count_total", "question": "How many total employees does the registrant have?"},
    {"id": "holder_record_amount", "question": "What is the number of holders of record of the registrant's common stock?"},
    {"id": "employees_count_full_time", "question": "How many full-time employees does the registrant have?"},
    {"id": "ceo_lastname", "question": "What is the CEO's last name?"},
] 

TASK_MAP = {t["id"]: t for t in TASKS}

# Paths
DATA_PATH = "data/clean_ground_truth/cleaned_EDGAR_gt_2-22-2026.csv"
RAW_OUTPUT_DIR = "data/retrieval_heads/raw" 
RESULTS_OUTPUT_DIR = "data/retrieval_heads/results"

# Model Loading Options
TORCH_DTYPE = torch.bfloat16
ATTN_IMPL = "eager"

# Formatted by AI to see configs and all that
print("Configuration loaded.")
print(f"Model: {MODEL_ID}")
print(f"Seq length : {TARGET_SEQ_LEN} tokens")
print(f"Needle depth : {NEEDLE_DEPTH * 100:.0f}%")
print(f"ID split: {SPLIT * 100:.0f}% (seed={SPLIT_SEED})")
print(f"Top-K heads: {TOP_K_HEADS}")
print(f"Tasks: {len(TASKS)}")
print(f"dtype: {TORCH_DTYPE}")
print(f"attn_impl: {ATTN_IMPL}")


##### Tokenizer & Model Loading

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Tokenizer loaded: {MODEL_ID}")
print(f"Vocab size: {tokenizer.vocab_size:,}")
print(f"Model max length: {tokenizer.model_max_length:,}")
print(f"Pad token: '{tokenizer.pad_token}' (id={tokenizer.pad_token_id})")


In [None]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    token=HF_TOKEN,
    torch_dtype=TORCH_DTYPE,
    device_map="auto",
    attn_implementation=ATTN_IMPL,
)
model.eval()

# Formatted by AI to see configs and all that
print(f"Model loaded: {MODEL_ID}")
print(f"dtype: {next(model.parameters()).dtype}")
print(f"Num layers: {model.config.num_hidden_layers}")
print(f"Num heads (Q): {model.config.num_attention_heads}")
print(f"Num KV heads: {model.config.num_key_value_heads}")
print(f"Total heads: {model.config.num_hidden_layers * model.config.num_attention_heads:,}")


#### Set up Dataset and splitting and all that jazz

In [None]:
from sklearn.model_selection import train_test_split

df = pd.read_csv(DATA_PATH)


In [None]:
identification_df, ablation_df = train_test_split(
    df,
    test_size=1 - SPLIT,
    stratify=df["task"],
    random_state=SPLIT_SEED,
)

identification_df = identification_df.reset_index(drop=True)
ablation_df = ablation_df.reset_index(drop=True)

# Asked by AI
print(f"Identification set : {len(identification_df):,} samples ({len(identification_df)/len(df):.0%})")
print(f"Ablation set : {len(ablation_df):,} samples ({len(ablation_df)/len(df):.0%})")


In [None]:
counts = pd.DataFrame({
    "total": df.groupby("task").size(),
    "identification": identification_df.groupby("task").size(),
    "ablation": ablation_df.groupby("task").size(),
})
counts["id_%"] = (counts["identification"] / counts["total"] * 100).round(1)
counts["abl_%"] = (counts["ablation"] / counts["total"] * 100).round(1)

print(counts.to_string())


#### Building Prompt

In [None]:
def find_needle_span(
    prompt_ids: list[int],
    needle_ids: list[int],
    threshold: float = 0.9,
) -> tuple[int, int]:
    """Locate needle tokens inside the full tokenized prompt via sliding window overlap."""
    span_len = len(needle_ids)
    if span_len == 0:
        return -1, -1

    needle_set = set(needle_ids)

    for i in range(len(prompt_ids) - span_len + 1):
        window = set(prompt_ids[i : i + span_len])
        if len(window & needle_set) / len(needle_set) >= threshold:
            return i, i + span_len

    return -1, -1


In [None]:
import re

# It is only <|begin_of_text|> that seems to exist, but I removed all control tokens just to be safe.
CONTROL_TOKENS = ["<|begin_of_text|>", "<|end_of_text|>", "<|eot_id|>", "<|start_header_id|>", "<|end_header_id|>"]
TOKEN_PATTERN = re.compile("|".join(re.escape(t) for t in CONTROL_TOKENS))

def build_prompt(row: pd.Series, task: dict, tokenizer) -> dict: 
    """Construct the full prompt with the needle sentence inserted, and locate the needle span in the tokenized input.""" 
    
    haystack = TOKEN_PATTERN.sub("", row["haystack_text"]).strip()
    needle = row["needle_sentence"]

    
    mid = len(haystack) // 2
    context = haystack[:mid] + " " + needle + " " + haystack[mid:]

    message = f"<document>{context}</document>\nQuestion: {task['question']}\nAnswer:"
    messages = [{"role": "user", "content": message}]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
    )

    # Locate needle span in the tokenized prompt
    needle_ids = tokenizer.encode(needle, add_special_tokens=False)
    needle_start, needle_end = find_needle_span(input_ids[0].tolist(), needle_ids)

    return {
        "input_ids": input_ids,
        "needle_ids": needle_ids,
        "needle_start": needle_start,
        "needle_end": needle_end,
    }


#### Just some validation of the prompt building logic

In [None]:
import random 
for i in range(5): 
    random_number = random.randint(0, 200)
    sample_row = identification_df.iloc[0]
    sample_task = TASK_MAP[sample_row["task"]]
    result = build_prompt(sample_row, sample_task, tokenizer)

    print(f"input_ids shape: {result['input_ids'].shape}")
    print(f"needle_start: {result['needle_start']}")
    print(f"needle_end: {result['needle_end']}")
    print(f"needle span len: {result['needle_end'] - result['needle_start']} tokens")
    print(f"total tokens: {result['input_ids'].shape[1]}")
    print()

    decoded = tokenizer.decode(result["input_ids"][0, result["needle_start"]:result["needle_end"]])
    print(f"decoded needle:\n{decoded}")
    print(f"\noriginal needle:\n{sample_row['needle_sentence']}")


#### Run Model & Extract the Attention Weights

In [None]:
def compute_sum_attn(model, prompt_inputs:dict) -> np.ndarray: 
    """ 
    Compute the summed attention scores for each head across all layers, given the prompt inputs.
    """   

    # Just make this a bit more cleaner and easier to read
    input_ids = prompt_inputs["input_ids"].to(model.device)
    needle_start = prompt_inputs["needle_start"]
    needle_end = prompt_inputs["needle_end"]

    with torch.inference_mode():
        prefill = model(
            input_ids=input_ids[:, :-1],
            use_cache=True,
            output_attentions=False,
            return_dict=True,
        )

        decode = model(
            input_ids=input_ids[:, -1:],
            past_key_values=prefill.past_key_values,
            use_cache=False,
            output_attentions=True,
            return_dict=True,
        )

    num_layers = model.config.num_hidden_layers
    num_heads = model.config.num_attention_heads
    scores = np.zeros((num_layers, num_heads), dtype=np.float32)

    for layer_idx, layer_attn in enumerate(decode.attentions):
        attn = layer_attn[0, :, 0, :].float().cpu().numpy() 

        # scores[layer_idx] = attn[:, needle_start:needle_end].mean(axis=1)  
        # we could use mean instead of sum to normalize 
        # for different needle lengths
        scores[layer_idx] = attn[:, needle_start:needle_end].sum(axis=1)

    return scores

Go grab scores for every task

In [None]:

os.makedirs(RAW_OUTPUT_DIR, exist_ok=True)

skipped = 0
total = len(identification_df)

for index, row in identification_df.iterrows():
    task = TASK_MAP[row["task"]]
    prompt = build_prompt(row, task, tokenizer)

    if prompt["needle_start"] == -1:
        print(f"[{index}/{total}] SKIP — needle not found ({row['filename']})")
        skipped += 1
        continue

    scores = compute_sum_attn(model, prompt)  # (num_layers, num_heads)

    # Save immediately — filename encodes everything you need later
    filename = f"{row['filename'].replace('.txt', '')}_{task['id']}.npy"
    np.save(os.path.join(RAW_OUTPUT_DIR, filename), scores)

    print(f"[{index}/{total}] saved {filename}") 
    
    del prompt, scores
    torch.cuda.empty_cache()

print(f"\nDone. {total - skipped}/{total} saved, {skipped} skipped.")

#### Get mean scores across tasks and identify top heads

In [None]:
task_mean_scores = {} 

for task_id in TASK_MAP:
    files = [f for f in os.listdir(RAW_OUTPUT_DIR) if f.endswith(f"_{task_id}.npy")]  
    if not files:
        print(f"WARNING: no files found for task '{task_id}' — skipping")
        continue
    stacked = np.stack([np.load(os.path.join(RAW_OUTPUT_DIR, f)) for f in files])  
    task_mean_scores[task_id] = stacked.mean(axis=0)  
    print(f"{task_id:<35} {len(files):>4} samples -> mean shape {task_mean_scores[task_id].shape}") 

print(f"\n{len(task_mean_scores)} tasks aggregated.")

#### Ranking top-k heads

In [None]:
os.makedirs(RESULTS_OUTPUT_DIR, exist_ok=True)

##### General top-k heads

In [None]:
task_top_heads = {} 

for task_id, mean_scores in task_mean_scores.items():
    flat = mean_scores.flatten()  
    top_flat = np.argsort(flat)[::-1][:TOP_K_HEADS]  

    num_heads = mean_scores.shape[1]
    heads = [
        {
            "layer": int(idx // num_heads),
            "head": int(idx % num_heads),
            "score": float(flat[idx]),
        }
        for idx in top_flat
    ] 

    task_top_heads[task_id] = heads

#### Cross-task shared heads (intersection)

In [None]:
from collections import Counter 

# Instead of dealing with the inserection with ALL tasks, I decided
# that it would make more sense that given, hey this head intersects  
# with 7 or more tasks, I think it could be considered a "general" 
# retrieval head

all_top_heads = []
for heads in task_top_heads.values():
    all_top_heads.extend([(h["layer"], h["head"]) for h in heads]) 

In [None]:
head_counts = Counter(all_top_heads) 
MIN_TASKS = 7  
shared_head_pairs = [head for head, count in head_counts.items() if count >= MIN_TASKS]


In [None]:
shared_heads = []
for layer, head in shared_head_pairs:
    freq = head_counts[(layer, head)]
    avg_score = float(np.mean([
        task_mean_scores[tid][layer, head]
        for tid in task_mean_scores
    ])) 

    shared_heads.append({
        "layer": layer, 
        "head": head, 
        "avg_score": avg_score,
        "task_frequency": freq
    })

shared_heads.sort(key=lambda x: (x["task_frequency"], x["avg_score"]), reverse=True)

#### Exporting/Saving

In [None]:
output = {
    "model": MODEL_ID,
    "top_k": TOP_K_HEADS,
    "tasks": task_top_heads,
    "shared_heads": shared_heads,
}
out_path = os.path.join(RESULTS_OUTPUT_DIR, "retrieval_heads.json")

with open(out_path, "w") as f:
    json.dump(output, f, indent=2)

print(f"Saved at: {out_path}")
print(
    f"\nShared retrieval heads across all {len(task_mean_scores)} tasks ({len(shared_heads)} found):"
) 

for h in shared_heads:
    print(
        f"Layer {h['layer']:>2}     Head {h['head']:>2}     avg_score={h['avg_score']:.4f}"
    )

#### Plotting 

In [None]:
# Plot 1: Per-task heatmaps (layer × head mean score)
for task_id, mean_scores in task_mean_scores.items():
    fig, ax = plt.subplots(figsize=(20, 8))
    im = ax.imshow(mean_scores, aspect="auto", cmap="viridis")

    ax.set_title(f"Sum Attention Scores — {task_id}", fontsize=14)
    ax.set_xlabel("Head")
    ax.set_ylabel("Layer")
    plt.colorbar(im, ax=ax, label="Mean Sum Attention")

    for h in task_top_heads[task_id]:
        ax.plot(h["head"], h["layer"], "r+", markersize=8, markeredgewidth=1.5)

    plt.tight_layout()
    fig_path = os.path.join(RESULTS_OUTPUT_DIR, f"heatmap_{task_id}.png")
    plt.savefig(fig_path, dpi=150)
    plt.show()
    print(f"Saved at: {fig_path}")

In [None]:
# Plot 2: Shared head distribution across layers
if shared_heads:
    layers = [h["layer"] for h in shared_heads]
    heads  = [h["head"]  for h in shared_heads]
    scores = [h["avg_score"] for h in shared_heads]

    num_layers = model.config.num_hidden_layers
    num_heads  = model.config.num_attention_heads
    grid = np.zeros((num_layers, num_heads))
    for h in shared_heads:
        grid[h["layer"], h["head"]] = h["avg_score"]

    fig, ax = plt.subplots(figsize=(20, 8))
    im = ax.imshow(grid, aspect="auto", cmap="hot")
    ax.set_title("Shared Retrieval Heads — All Tasks", fontsize=14)
    ax.set_xlabel("Head")
    ax.set_ylabel("Layer")
    plt.colorbar(im, ax=ax, label="Avg Sum Attention Score")

    plt.tight_layout()
    fig_path = os.path.join(RESULTS_OUTPUT_DIR, "shared_retrieval_heads.png")
    plt.savefig(fig_path, dpi=150)
    plt.show()
    print(f"Saved at: {fig_path}")
else:
    print("No shared heads found across all tasks, consider relaxing TOP_K_HEADS.")