In [1]:
"""
Causal patching
"""
None # fun trick hehe

In [2]:
import torch
import importlib
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import gc
import pickle
import json
import sys
sys.path.append('..')
from tqdm import tqdm
import utils

from utils.mem import check_memory

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

seed = 123
# torch.cuda.empty_cache()
# clear_all_cuda_memory()
check_memory()
# gc.collect()

Device 0: NVIDIA L40S
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 44.42 GB



## Load model & data

In [3]:
# load qwen-14B – would ideally use 32B if not constrained by pod seup
model_path = "/workspace/models/Qwen3-14B"
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype = torch.bfloat16, trust_remote_code = True).to(device) # if you check config.json it defaults to bfloat16: https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json; probably just large
tokenizer = AutoTokenizer.from_pretrained(model_path, add_eos_token = False, add_bos_token = False, trust_remote_code = True) # right padding by default

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
# load synthetic data – this time splitting the clean + corrupt (don't love this verbiage...) variants (should 2x length)
with open('/workspace/data/synthetic-data-1.json', 'r') as file:
    raw_df = json.load(file)

raw_df = pd.DataFrame(raw_df)

input_df =\
    pd.concat([
        raw_df\
            .rename(columns = {'clean_list': 'full_list', 'clean_category_indices': 'category_indices', 'clean_category_count': 'category_count'})\
            .assign(variant = 'clean')\
            [['sample_ix', 'variant', 'category', 'list_length', 'full_list', 'category_indices', 'category_count']],
        raw_df\
            .rename(columns = {'corrupt_list': 'full_list', 'corrupt_category_indices': 'category_indices', 'corrupt_category_count': 'category_count'})\
            .assign(variant = 'corrupt')\
            [['sample_ix', 'variant', 'category', 'list_length', 'full_list', 'category_indices', 'category_count']]
    ], ignore_index = True)\
    .assign(
        category_words = lambda df: df.apply(lambda row: [row['full_list'][i] for i in row['category_indices']], axis = 1),
    )

# input_df


Unnamed: 0,sample_ix,variant,category,list_length,full_list,category_indices,category_count,category_words
0,0,clean,colors,3,"[red, banana, computer]",[0],1,[red]
1,1,clean,vehicles,3,"[car, bus, banana]","[0, 1]",2,"[car, bus]"
2,2,clean,shapes,3,"[square, circle, triangle]","[0, 1, 2]",3,"[square, circle, triangle]"
3,3,clean,continents,3,"[computer, asia, apple]",[1],1,[asia]
4,4,clean,metals,3,"[banana, gold, silver]","[1, 2]",2,"[gold, silver]"
...,...,...,...,...,...,...,...,...
9973,44,corrupt,street_signs,7,"[rose, car, yield, speed, apple, dog, pedestrian]","[2, 3, 6]",3,"[yield, speed, pedestrian]"
9974,45,corrupt,continents,7,"[asia, dog, africa, europe, apple, car, americas]","[0, 2, 3, 6]",4,"[asia, africa, europe, americas]"
9975,46,corrupt,time_units,7,"[apple, year, car, day, month, minute, hour]","[1, 3, 4, 5, 6]",5,"[year, day, month, minute, hour]"
9976,47,corrupt,elements,7,"[hydrogen, apple, gold, silver, helium, oxygen...","[0, 2, 3, 4, 5, 6]",6,"[hydrogen, gold, silver, helium, oxygen, nitro..."


In [10]:
# tokenize + prep data 
# todo: move function into module
def prep_prompt(category, choices, tokenizer):
    choices_str = ' '.join(choices)
    user_prompt = f"Count the number of words in the following  list that match the given type, and put the numerical answer in parentheses.\n\nType: {category}\nList: [ {choices_str} ]"
    
    instruct_formatted_prompt = tokenizer.apply_chat_template(
        [{'role': 'user', 'content': user_prompt}, {'role': 'assistant', 'content': 'Answer: ('}],
        tokenize = False,
        add_generation_prompt = False,
        continue_final_message = True # Otherwise appends eos token
    )

    return instruct_formatted_prompt

# Add the prompt
input_df =\
    input_df\
    .assign(prompt = lambda df: df.apply(lambda row: prep_prompt(row['category'], row['full_list'], tokenizer), axis = 1))

# Add the token lengths
token_lengths = tokenizer(
    input_df['prompt'].tolist(),
    add_special_tokens = False,
    max_length = 128,
    padding = 'max_length',
    truncation = True,
    return_tensors = 'pt'
    )['attention_mask'].sum(dim = 1).tolist()

input_df =\
    input_df\
    .assign(token_length = token_lengths)

# Filter only for sample_ix such that the two variants have the same token length, this will simplify patching
sample_ix_with_same_length =\
    input_df\
    .groupby('sample_ix', as_index = False)\
    .agg(unique_lengths = ('token_length', 'nunique'))\
    .pipe(lambda df: df[df['unique_lengths'] == 1])\
    ['sample_ix']\
    .tolist()

input_df = input_df.pipe(lambda df: df[df['sample_ix'].isin(sample_ix_with_same_length)]).reset_index(drop = True)

display(input_df)
print(input_df['prompt'].tolist()[0])

Unnamed: 0,sample_ix,variant,category,list_length,full_list,category_indices,category_count,category_words,prompt,token_length


IndexError: list index out of range

In [9]:
# create dataloader
import utils.dataset
importlib.reload(utils.dataset)
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

tokenized_prompts = tokenizer(input_df['prompt'].tolist(), add_special_tokens = False, max_length = 80, padding = 'max_length', truncation = True, return_tensors = 'pt')
print(tokenized_prompts['attention_mask'].sum(dim = 1).max()) # Must be under max length to confirm nothing was truncated (since attention mask applies a 1 to indicate "some guy was here")

# Create and chunk into lists of size 1k each - these will be the export breaks
test_dl = DataLoader(
    ReconstructableTextDataset(
        input_df['prompt'].tolist(),
        tokenizer,
        max_length = 128,
        sample_ix = input_df['sample_ix'].tolist(),
        variant = input_df['variant'].tolist(),
        full_lists = input_df['full_list'].tolist(),
        category_indices_list = input_df['category_indices'].tolist()
        ),
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

tensor(72)


ValueError: word_start_tok_mask has 6 starts but list has 7 words (sample 1545).

## Quick evals
Identify a set of questions where both the corrupt + correct are accurate

In [123]:
# run fwd. pass
output_records = []

for batch_ix, batch in tqdm(enumerate(test_dl), total = len(test_dl)):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    sample_indices = batch['sample_ix']
    variants = batch['variant']

    with torch.no_grad():
        logits = model(input_ids, attention_mask)['logits']
        seq_ends = attention_mask.sum(dim = 1) - 1

    for i in range(len(sample_indices)):
        last_pos = seq_ends[i].item()
        last_logits = logits[i, last_pos, :]
        probs = torch.softmax(last_logits, dim = -1)
        top_token = torch.argmax(probs).item()
        top_prob = probs[top_token].item()

        output_records.append(
            {
                'sample_ix': int(sample_indices[i]),
                'variant': variants[i],
                'output_token': tokenizer.decode(top_token),
                'output_prob': top_prob,
            }
        )

output_df = pd.DataFrame(output_records) # todo: there's a bug w/ the data loader, causing sample_ix to repeat intra-batch...

  4%|▍         | 55/1248 [00:06<02:21,  8.44it/s]


KeyboardInterrupt: 

In [92]:
# todo: fix sample index, not advised 
input_df['sample_ix'] = input_df.groupby('variant').cumcount()
output_df['sample_ix'] = output_df.groupby('variant').cumcount()


In [94]:
# get accuracy rate for each (sample_ix, variant)
accurate_rates =\
    input_df[['sample_ix', 'variant', 'category_count', 'category']]\
    .merge(output_df, on = ['sample_ix', 'variant'], how = 'inner')\
    .assign(is_correct = lambda df:  np.where(df['output_token'].str.strip() == df['category_count'].astype(str).str.strip(), 1, 0))\
    .assign(is_int = lambda df: np.where(df['output_token'].isin(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), 1, 0))

display(accurate_rates)
print(f"Accuracy rate: {accurate_rates['is_correct'].mean().item():.4f}")

Unnamed: 0,sample_ix,variant,category_count,category,output_token,output_prob,is_correct,is_int
0,0,clean,1,colors,1,1.000000,1,1
1,1,clean,2,vehicles,2,1.000000,1,1
2,2,clean,3,shapes,3,1.000000,1,1
3,3,clean,1,continents,1,1.000000,1,1
4,4,clean,2,metals,2,1.000000,1,1
...,...,...,...,...,...,...,...,...
9973,4984,corrupt,3,street_signs,2,0.941406,0,1
9974,4985,corrupt,4,continents,4,0.976562,1,1
9975,4986,corrupt,5,time_units,3,0.996094,0,1
9976,4987,corrupt,6,elements,4,0.851562,0,1


Accuracy rate: 0.7928


In [105]:
# get only sample_ix's where both questions were accurate
sample_indices_for_patching =\
    accurate_rates\
    .groupby('sample_ix', as_index = False)\
    .agg(n_correct = ('is_correct', 'sum'))\
    .pipe(lambda df: df[df['n_correct'] == 2])\
    ['sample_ix'].tolist()

print(len(sample_indices_for_patching))

3418


In [106]:
patch_df = input_df.pipe(lambda df: df[df['sample_ix'].isin(sample_indices_for_patching)])

## Patching time
first collect hidden states

In [108]:
# Patch imnputs
patch_df = input_df.pipe(lambda df: df[df['sample_ix'].isin(sample_indices_for_patching)])

test_dl = DataLoader(
    ReconstructableTextDataset(
        patch_df['prompt'].tolist(),
        tokenizer,
        max_length = 80,
        sample_ix = patch_df['sample_ix'].tolist(),
        variant = patch_df['variant'].tolist(),
        full_list = patch_df['full_list'].tolist(),
        category_indices = patch_df['category_indices'].tolist()
        ),
    batch_size = 8,
    shuffle = False,
    collate_fn = stack_collate
)

In [116]:
output_records = []

for batch_ix, batch in tqdm(enumerate(test_dl), total = len(test_dl)):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    sample_indices = batch['sample_ix']
    variants = batch['variant']
    list_masks = batch['attention_mask'] # BoolTensor B × N

    with torch.no_grad():
        out = model(input_ids, attention_mask, output_hidden_states = True)
        hs = torch.stack(out.hidden_states, dim = 1).detach()[:, 1:, :, :].float() # B x K x N x D - skip first hs which is embed layer
        logits = out.logits
        seq_ends = attention_mask.sum(dim = 1) - 1

    for i in range(len(sample_indices)):
        last_pos = seq_ends[i].item()
        last_logits = logits[i, last_pos, :]
        probs = torch.softmax(last_logits, dim = -1)
        top_token = torch.argmax(probs).item()
        top_prob = probs[top_token].item()

        # ── keep only list-token activations ───────────────────────────
        mask_i  = list_masks[i].bool()
        hs_sub  = hs[i, :, mask_i, :] # K × Nₗ × D
        cum_sub = batch['cum_count'][i][mask_i]

        output_records.append(
            {
                'sample_ix': int(sample_indices[i]),
                'variant': variants[i],
                'output_token': tokenizer.decode(top_token),
                'output_prob': top_prob,
                'hidden_states': hs_sub.cpu(), # K x N x D
                'match_flag': batch['match_flag'].tolist(),
                'original_tokens': batch['original_tokens'],
                'list_mask': mask_i, # N
                'cum_count': cum_sub, # N
                'word_token_spans': batch['word_token_spans'][i] # Store for patching
            }
        )

    if batch_ix >= 500:
        break



  0%|          | 0/855 [00:00<?, ?it/s]


KeyError: 'cum_count'

## Probe sweep - pick the mediator layer $l^*$

In [None]:
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score
import random 

# ------------------------------------------------------------
# 0 — collect clean-only records into per-layer bags
# ------------------------------------------------------------
K = output_records[0]["hidden_states"].shape[0]   # # layers
rng = random.Random(42)
# MAX_ROWS = 10_000 # Per split, per
alpha = 10 # Choose such that tr/test loss is similar

# Split by sample_ix so paired clean/corrupt never leak
clean_records = [r for r in output_records if r["variant"] == "clean"]
sample_ids = sorted({r["sample_ix"] for r in clean_records})
rng.shuffle(sample_ids)
split = int(0.8 * len(sample_ids))
train_ids, test_ids = set(sample_ids[:split]), set(sample_ids[split:])

K = clean_records[0]["hidden_states"].shape[0]
bags = {l: {"train_X": [], "train_y": [], "test_X": [], "test_y": []} for l in range(K)}

for rec in clean_records:
    bag = (
        "train_" if rec["sample_ix"] in train_ids else "test_"
    )
    hs = rec["hidden_states"].numpy()        # K × Nₗ × D  (already CPU)
    y  = rec["cum_count"].numpy()            # Nₗ
    for l in range(K):
        bags[l][bag + "X"].append(hs[l])     # each is (Nₗ, D)
        bags[l][bag + "y"].append(y)

# ------------------------------------------------------------
# 1 — stack into 2-D matrices per layer
# ------------------------------------------------------------
Xs_tr, ys_tr, Xs_te, ys_te = [], [], [], []
for l in range(K):
    X_tr = np.concatenate(bags[l]["train_X"])   # (∑Nₗ, D)
    y_tr = np.concatenate(bags[l]["train_y"])
    X_te = np.concatenate(bags[l]["test_X"])
    y_te = np.concatenate(bags[l]["test_y"])

    Xs_tr.append(X_tr)
    ys_tr.append(y_tr)
    Xs_te.append(X_te)
    ys_te.append(y_te)

# ------------------------------------------------------------
# 2 — train one Ridge probe per layer, evaluate on held-out set
# ------------------------------------------------------------
r2_test, probes = [], []
for l in range(K):
    reg = Ridge(alpha = alpha, solver = 'auto', fit_intercept=True)
    reg.fit(Xs_tr[l], ys_tr[l])

    y_hat_tr = reg.predict(Xs_tr[l])
    y_hat_te = reg.predict(Xs_te[l])
    R2_train = r2_score(reg.predict(Xs_tr[l]),  ys_tr[l])

    r2_tr = r2_score(ys_tr[l], y_hat_tr)
    r2_te = r2_score(ys_te[l], y_hat_te)

    r2_test.append(r2_te)
    probes.append(reg)

    print(f"Layer {l:2d} | train R² {r2_tr:6.3f} | test R² {r2_te:6.3f}")

print("\nLayer  |  Test R²")
for l, r2 in enumerate(r2_test):
    print(f"{l:2d}     {r2:6.3f}")

layer_star = int(np.argmax(r2_test))
print(f"\nMediator candidate layer = {layer_star}  (R² = {r2_test[layer_star]:.3f})")

w_star = torch.tensor(probes[layer_star].coef_, dtype=torch.float32)