In [None]:
"""
Causal patching
"""
None # fun trick hehe

In [43]:
import torch
import nnsight
import importlib
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import gc
import pickle
import json
import sys
sys.path.append('..')
from tqdm import tqdm
import utils

from utils.mem import check_memory, clear_all_cuda_memory

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

seed = 123
# torch.cuda.empty_cache()
clear_all_cuda_memory()
check_memory()
gc.collect()

All CUDA memory cleared on all devices.
Device 0: NVIDIA L40S
  Allocated: 29.86 GB
  Reserved: 30.60 GB
  Total: 44.42 GB



254

## Load model & data

In [3]:
# load qwen-14B – would ideally use 32B if not constrained by pod seup
model_path = "/workspace/models/Qwen3-14B"
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype = torch.bfloat16, trust_remote_code = True).to(device) # if you check config.json it defaults to bfloat16: https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json; probably just large
tokenizer = AutoTokenizer.from_pretrained(model_path, add_eos_token = False, add_bos_token = False, trust_remote_code = True) # right padding by default

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [5]:
# Load reverse-engineered forward pass function that returns the MLP output, attention output, and final layer hidden state at each layer
from utils.qwen3 import run_qwen3

def test_custom_forward_pass(model):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 64).to(model.device)
    original_results = model(**inputs)
    custom_results = run_qwen3(model, inputs['input_ids'], inputs['attention_mask'], return_hidden_states = True)
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward' # Checks that this function replicates the original model exactly
    print(f"Hidden states layers (post-layer): {len(custom_results['layer_out_hidden_states'])}")
    print(f"Hidden state size (post-layer): {(custom_results['layer_out_hidden_states'][0].shape)}")
    print(f"Hidden state size (MLP output): {(custom_results['mlp_out_hidden_states'][0].shape)}")
    print(f"Hidden state size (Attention output): {(custom_results['attn_out_hidden_states'][0].shape)}")

test_custom_forward_pass(model)

Hidden states layers (post-layer): 40
Hidden state size (post-layer): torch.Size([2, 64, 5120])
Hidden state size (MLP output): torch.Size([2, 64, 5120])
Hidden state size (Attention output): torch.Size([2, 64, 5120])


In [6]:
# load synthetic data – this time splitting the clean + corrupt (don't love this verbiage...) variants (should 2x length)
with open('/workspace/data/synthetic-data-1.json', 'r') as file:
    raw_df = json.load(file)

raw_df = pd.DataFrame(raw_df)

input_df =\
    pd.concat([
        raw_df\
            .rename(columns = {'clean_list': 'full_list', 'clean_category_indices': 'category_indices', 'clean_category_count': 'category_count'})\
            .assign(variant = 'clean')\
            [['sample_ix', 'variant', 'category', 'list_length', 'full_list', 'category_indices', 'category_count']],
        raw_df\
            .rename(columns = {'corrupt_list': 'full_list', 'corrupt_category_indices': 'category_indices', 'corrupt_category_count': 'category_count'})\
            .assign(variant = 'corrupt')\
            [['sample_ix', 'variant', 'category', 'list_length', 'full_list', 'category_indices', 'category_count']]
    ], ignore_index = True)\
    .assign(
        category_words = lambda df: df.apply(lambda row: [row['full_list'][i] for i in row['category_indices']], axis = 1),
    )

# input_df


In [7]:
input_df['sample_ix'] = input_df.groupby('variant').cumcount()

In [8]:
# tokenize + prep data 
# todo: move function into module
def prep_prompt(category, choices, tokenizer):
    choices_str = ' '.join(choices)
    user_prompt = f"Count the number of words in the following  list that match the given type, and put the numerical answer in parentheses.\n\nType: {category}\nList: [ {choices_str} ]"
    
    instruct_formatted_prompt = tokenizer.apply_chat_template(
        [{'role': 'user', 'content': user_prompt}, {'role': 'assistant', 'content': 'Answer: ('}],
        tokenize = False,
        add_generation_prompt = False,
        continue_final_message = True # Otherwise appends eos token
    )

    return instruct_formatted_prompt

# Add the prompt
input_df =\
    input_df\
    .assign(prompt = lambda df: df.apply(lambda row: prep_prompt(row['category'], row['full_list'], tokenizer), axis = 1))

# Add the token lengths
token_lengths = tokenizer(
    input_df['prompt'].tolist(),
    add_special_tokens = False,
    max_length = 128,
    padding = 'max_length',
    truncation = True,
    return_tensors = 'pt'
    )['attention_mask'].sum(dim = 1).tolist()

input_df =\
    input_df\
    .assign(token_length = token_lengths)

# Filter only for sample_ix such that the two variants have the same token length, this will simplify patching
sample_ix_with_same_length =\
    input_df\
    .groupby('sample_ix', as_index = False)\
    .agg(unique_lengths = ('token_length', 'nunique'))\
    .pipe(lambda df: df[df['unique_lengths'] == 1])\
    ['sample_ix']\
    .tolist()

input_df = input_df.pipe(lambda df: df[df['sample_ix'].isin(sample_ix_with_same_length)]).reset_index(drop = True)

display(input_df)
print(input_df['prompt'].tolist()[0])

Unnamed: 0,sample_ix,variant,category,list_length,full_list,category_indices,category_count,category_words,prompt,token_length
0,0,clean,colors,3,"[red, banana, computer]",[0],1,[red],<|im_start|>user\nCount the number of words in...,50
1,1,clean,vehicles,3,"[car, bus, banana]","[0, 1]",2,"[car, bus]",<|im_start|>user\nCount the number of words in...,50
2,2,clean,shapes,3,"[square, circle, triangle]","[0, 1, 2]",3,"[square, circle, triangle]",<|im_start|>user\nCount the number of words in...,50
3,3,clean,continents,3,"[computer, asia, apple]",[1],1,[asia],<|im_start|>user\nCount the number of words in...,50
4,4,clean,metals,3,"[banana, gold, silver]","[1, 2]",2,"[gold, silver]",<|im_start|>user\nCount the number of words in...,50
...,...,...,...,...,...,...,...,...,...,...
7603,4982,corrupt,games,7,"[chess, rose, table, car, dog, apple, chair]",[0],1,[chess],<|im_start|>user\nCount the number of words in...,54
7604,4984,corrupt,street_signs,7,"[rose, car, yield, speed, apple, dog, pedestrian]","[2, 3, 6]",3,"[yield, speed, pedestrian]",<|im_start|>user\nCount the number of words in...,56
7605,4986,corrupt,time_units,7,"[apple, year, car, day, month, minute, hour]","[1, 3, 4, 5, 6]",5,"[year, day, month, minute, hour]",<|im_start|>user\nCount the number of words in...,55
7606,4987,corrupt,elements,7,"[hydrogen, apple, gold, silver, helium, oxygen...","[0, 2, 3, 4, 5, 6]",6,"[hydrogen, gold, silver, helium, oxygen, nitro...",<|im_start|>user\nCount the number of words in...,54


<|im_start|>user
Count the number of words in the following  list that match the given type, and put the numerical answer in parentheses.

Type: colors
List: [ red banana computer ]<|im_end|>
<|im_start|>assistant
<think>

</think>

Answer: (


In [9]:
# create dataloader
import utils.dataset
importlib.reload(utils.dataset)
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

tokenized_prompts = tokenizer(input_df['prompt'].tolist(), add_special_tokens = False, max_length = 80, padding = 'max_length', truncation = True, return_tensors = 'pt')
print(tokenized_prompts['attention_mask'].sum(dim = 1).max()) # Must be under max length to confirm nothing was truncated (since attention mask applies a 1 to indicate "some guy was here")

# Create and chunk into lists of size 1k each - these will be the export breaks
test_dl = DataLoader(
    ReconstructableTextDataset(
        input_df['prompt'].tolist(),
        tokenizer,
        max_length = 128,
        sample_ix = input_df['sample_ix'].tolist(),
        variant = input_df['variant'].tolist(),
        full_lists = input_df['full_list'].tolist(),
        category_indices_list = input_df['category_indices'].tolist()
        ),
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

tensor(69)


## Quick evals
Identify a set of questions where both the corrupt + correct are accurate

In [10]:
# run fwd. pass
output_records = []

for batch_ix, batch in tqdm(enumerate(test_dl), total = len(test_dl)):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    sample_indices = batch['sample_ix']
    variants = batch['variant']

    with torch.no_grad():
        logits = model(input_ids, attention_mask)['logits']
        seq_ends = attention_mask.sum(dim = 1) - 1

    for i in range(len(sample_indices)):
        last_pos = seq_ends[i].item()
        last_logits = logits[i, last_pos, :]
        probs = torch.softmax(last_logits, dim = -1)
        top_token = torch.argmax(probs).item()
        top_prob = probs[top_token].item()

        output_records.append(
            {
                'sample_ix': int(sample_indices[i]),
                'variant': variants[i],
                'output_token': tokenizer.decode(top_token),
                'output_prob': top_prob,
            }
        )

output_df = pd.DataFrame(output_records) # todo: there's a bug w/ the data loader, causing sample_ix to repeat intra-batch...

100%|██████████| 476/476 [02:44<00:00,  2.90it/s]


In [11]:
# todo: fix sample index, not advised 
# input_df['sample_ix'] = input_df.groupby('variant').cumcount()
output_df['sample_ix'] = output_df.groupby('variant').cumcount()


In [12]:
# get accuracy rate for each (sample_ix, variant)
accurate_rates =\
    input_df[['sample_ix', 'variant', 'category_count', 'category']]\
    .merge(output_df, on = ['sample_ix', 'variant'], how = 'inner')\
    .assign(is_correct = lambda df:  np.where(df['output_token'].str.strip() == df['category_count'].astype(str).str.strip(), 1, 0))\
    .assign(is_int = lambda df: np.where(df['output_token'].isin(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), 1, 0))

display(accurate_rates)
print(f"Accuracy rate: {accurate_rates['is_correct'].mean().item():.4f}")

Unnamed: 0,sample_ix,variant,category_count,category,output_token,output_prob,is_correct,is_int
0,0,clean,1,colors,1,1.000000,1,1
1,1,clean,2,vehicles,2,1.000000,1,1
2,2,clean,3,shapes,3,1.000000,1,1
3,3,clean,1,continents,1,1.000000,1,1
4,4,clean,2,metals,2,1.000000,1,1
...,...,...,...,...,...,...,...,...
5803,3799,corrupt,0,animals,1,1.000000,0,1
5804,3800,corrupt,1,colors,2,0.816406,0,1
5805,3801,corrupt,2,countries,3,0.996094,0,1
5806,3802,corrupt,0,cities,5,0.562500,0,1


Accuracy rate: 0.2185


In [13]:
# get only sample_ix's where both questions were accurate
sample_indices_for_patching =\
    accurate_rates\
    .groupby('sample_ix', as_index = False)\
    .agg(n_correct = ('is_correct', 'sum'))\
    .pipe(lambda df: df[df['n_correct'] == 2])\
    ['sample_ix'].tolist()

print(len(sample_indices_for_patching))

517


In [14]:
patch_df = input_df.pipe(lambda df: df[df['sample_ix'].isin(sample_indices_for_patching)])

## Patching time
first collect hidden states

In [None]:
test_indices_for_patching = sample_indices_for_patching[0:1000] # subset, ow will have memory problems...
patch_df = input_df.pipe(lambda df: df[df['sample_ix'].isin(test_indices_for_patching)])

test_dl = DataLoader(
    ReconstructableTextDataset(
        patch_df['prompt'].tolist(),
        tokenizer,
        max_length = 80,
        sample_ix = patch_df['sample_ix'].tolist(),
        variant = patch_df['variant'].tolist(),
        full_lists = patch_df['full_list'].tolist(),
        category_indices_list = patch_df['category_indices'].tolist()
        ),
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

In [44]:
records = []
all_hidden_states = [] # big list of dicts; 1 entry = 1 (sample,variant)

for batch_ix, batch in tqdm(enumerate(test_dl), total = len(test_dl)):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    sample_ix = batch['sample_ix'] # List
    variant = batch['variant'] # List
    answer_pos = batch['answer_pos'] # Tensor of length B containing the sequence positions of the answer tokens
    list_mask = batch['list_tok_mask'].to(device) # Tensor of B x N containing a boolean mask of what tokens belong to the list
    
    with torch.no_grad():
        # out includes keys 'attn_out_hidden_states', 'mlp_out_hidden_states', 'layer_out_hidden_states
        out = run_qwen3(model, input_ids, attention_mask, return_hidden_states = True)
    
    out_hidden_states = torch.stack(out['layer_out_hidden_states'], dim = 0) # K x B x N x D
    out_mlps = torch.stack(out['mlp_out_hidden_states'], dim = 0) # K x B x N x D
    out_attns = torch.stack(out['attn_out_hidden_states'], dim = 0) # K x B x N x D

    probs = torch.softmax(out['logits'], dim = -1).cpu()

    for i in range(len(sample_ix)):
        pos = answer_pos[i].item()
        tok_id = int(torch.argmax(probs[i, pos]))
        tok_prob = probs[i, pos, tok_id]
        tok_str = tokenizer.decode([tok_id], skip_special_tokens = False)
        records.append({
            'sample_ix': int(sample_ix[i]),
            'variant': variant[i],
            'output_token_id': tok_id,
            'output_token': tok_str,
            'output_prob': tok_prob
        })

        # Now get stuff to store the hidden states plus the mapping back
        list_mask_for_seq = list_mask[i] # N length mask
        list_token_indices_for_seq = list_mask[i].nonzero(as_tuple = True)[0] # Only keep the indices of the list mask tensors
        
        layer_acts = out_hidden_states[:, i, list_token_indices_for_seq, :].detach().cpu() #  [K, T, D]
        mlp_acts = out_mlps[:, i, list_token_indices_for_seq, :].detach().cpu() #  [K, T, D]
        attn_acts = out_attns[:, i, list_token_indices_for_seq, :].detach().cpu() #  [K, T, D]

        input_ids_seq = input_ids[i].detach().cpu()
        attention_mask_seq = attention_mask[i].detach().cpu()

        all_hidden_states.append({
            'sample_ix': int(sample_ix[i]),
            'variant': variant[i],
            'token_indices': list_token_indices_for_seq,
            'hidden_states': layer_acts,
            'attn_states': attn_acts, 
            'mlp_states': mlp_acts,
            "input_ids" : input_ids_seq,                
            "attention_mask" : attention_mask_seq,
            'is_category_tok': batch['is_category_tok'][i, list_token_indices_for_seq.cpu()],
            'is_word_start': batch['word_start_tok_mask'][i, list_token_indices_for_seq.cpu()],
            'running_count': batch['running_count'][i, list_token_indices_for_seq.cpu()],
            'answer_pos': pos
        })

    # if batch_ix > 100:
    #         break

100%|██████████| 65/65 [00:33<00:00,  1.91it/s]


In [45]:
# save to local (in case)
records_df = pd.DataFrame(records)
records_df["output_prob"] = records_df["output_prob"].apply(
    lambda x: x.item() if isinstance(x, torch.Tensor) else x
)
records_df.to_parquet("../data/records_df.parquet", index=False)
torch.save(all_hidden_states, '../data/all_hiden_states.pt')

# Patching

In [None]:
# single example 
sample_ix = records_df.sample_ix.value_counts().idxmax() 
donor = next(d for d in all_hidden_states
                if d["sample_ix"]==sample_ix and d["variant"]=="clean")
receiver = next(d for d in all_hidden_states
                if d["sample_ix"]==sample_ix and d["variant"]=="corrupt")

true_id = int(
    records_df
      .query("sample_ix == @sample_ix and variant == 'clean'")
      .output_token_id
      .iloc[0]
)

with torch.no_grad():
    base_logits = model(
        receiver["input_ids"].unsqueeze(0).to(device),
        attention_mask = receiver["attention_mask"].unsqueeze(0).to(device),
        use_cache = False
    ).logits
baseline_pred = int(torch.argmax(base_logits[0, receiver["answer_pos"]]))
print("Baseline corrupt prediction:", baseline_pred,
      "(should be wrong wrt true id", true_id, ")")

Baseline corrupt prediction: 18 (should be wrong w.r.t. true id 19 )


In [76]:
num_layers = len(model.model.layers)     
token_idx = receiver["token_indices"].to(device)   
answer_pos = receiver["answer_pos"]  

def patch_layer(layer_ix: int) -> bool:
    """
    True -> patching this block flips the corrupt answer to the clean one.
    False -> answer stays wrong.
    """
    donor_slice = donor["hidden_states"][layer_ix].to(device)

    # hook overwriting the residual stream for the list tokens
    def hook(_, __, output):
        if isinstance(output, tuple):       
            hidden, *rest = output
            hidden = hidden.clone()
            hidden[0, token_idx] = donor_slice
            return (hidden, *rest)
        else:
            hidden = output.clone() 
            hidden[0, token_idx] = donor_slice
            return hidden

    handle = model.model.layers[layer_ix].register_forward_hook(hook)
    try:
        with torch.no_grad():
            logits = model(
                receiver["input_ids"].unsqueeze(0).to(device),
                attention_mask = receiver["attention_mask"].unsqueeze(0).to(device),
            ).logits
    finally:
        handle.remove() # it is important you remove these for mem!!!!
        del donor_slice    

    new_top = int(torch.argmax(logits[0, answer_pos]))
    return new_top == true_id

print(f"Baseline corrupt prediction = {baseline_pred} (true id = {true_id})")
for L in range(num_layers):
    flipped = patch_layer(L)
    print(f"layer {L}: flipped = {flipped}")

# nice, so here it looks like transplanting hidden states (block-level, token-slice patching) from clean => corrupt "recovers" the answer
# (in this instance) when done to layer 0 through 22. After that point, the "damage is done," so it seems. 

# you likely want to refine this further, identifying specific heads/neurons. 

Baseline corrupt prediction = 18 (true id = 19)
layer 0: flipped = True
layer 1: flipped = True
layer 2: flipped = True
layer 3: flipped = True
layer 4: flipped = True
layer 5: flipped = True
layer 6: flipped = True
layer 7: flipped = True
layer 8: flipped = True
layer 9: flipped = True
layer 10: flipped = True
layer 11: flipped = True
layer 12: flipped = True
layer 13: flipped = True
layer 14: flipped = True
layer 15: flipped = True
layer 16: flipped = True
layer 17: flipped = True
layer 18: flipped = True
layer 19: flipped = True
layer 20: flipped = True
layer 21: flipped = True
layer 22: flipped = True
layer 23: flipped = False
layer 24: flipped = False
layer 25: flipped = False
layer 26: flipped = False
layer 27: flipped = False
layer 28: flipped = False
layer 29: flipped = False
layer 30: flipped = False
layer 31: flipped = False
layer 32: flipped = False
layer 33: flipped = False
layer 34: flipped = False
layer 35: flipped = False
layer 36: flipped = False
layer 37: flipped = Fal

In [None]:
# let's repeat across more samples for a minimal causal mediation analysis – we just want to initially answer 
# "which hidden-state layer contains a representation of the running count"
clean_by_ix = {d["sample_ix"]: d for d in all_hidden_states if d["variant"]=="clean"}
corrupt_by_ix = {d["sample_ix"]: d for d in all_hidden_states if d["variant"]=="corrupt"}

# this contains the "true" token id 
true_id_by_ix = (
    records_df
      .query('variant == "clean"')
      .set_index("sample_ix")["output_token_id"]
      .astype(int)
      .to_dict()
)

def flip_with_layer(donor, receiver, true_id, layer_ix) -> bool:
    token_idx = receiver["token_indices"].to(device)
    ans_pos = receiver["answer_pos"]
    donor_slice = donor["hidden_states"][layer_ix].to(device)

    def hook(_, __, output):
        if isinstance(output, tuple):
            h, *rest = output
            h = h.clone(); h[0, token_idx] = donor_slice
            return (h, *rest)
        else:
            h = output.clone(); h[0, token_idx] = donor_slice
            return h

    handle = model.model.layers[layer_ix].register_forward_hook(hook)
    try:
        with torch.no_grad():
            logits = model(
                receiver["input_ids"].unsqueeze(0).to(device),
                attention_mask = receiver["attention_mask"].unsqueeze(0).to(device)
            ).logits
    finally:
        handle.remove()
        del donor_slice

    return int(torch.argmax(logits[0, ans_pos])) == true_id

In [None]:
# this takes a while since qwen14B has 40 layers that it's operating/patching across each time
layer_hits = torch.zeros(num_layers, dtype=torch.long)  
layer_total = torch.zeros_like(layer_hits)              

for k in corrupt_by_ix:       
    donor = clean_by_ix[k]
    receiver = corrupt_by_ix[k]
    true_id = true_id_by_ix[k]

    # skip samples where corrupt is already correct (nothing to “flip”)
    with torch.no_grad():
        logits_corrupt = model(
            receiver["input_ids"].unsqueeze(0).to(device),
            attention_mask = receiver["attention_mask"].unsqueeze(0).to(device)
        ).logits
    corrupt_top = int(torch.argmax(logits_corrupt[0, receiver["answer_pos"]]))
    if corrupt_top == true_id:
        continue

    for L in range(num_layers):
        flipped = flip_with_layer(donor, receiver, true_id, L)
        layer_total[L] += 1
        layer_hits[L]  += int(flipped)

In [85]:
flip_rate = (layer_hits / layer_total.float()).tolist()
for L, rate in enumerate(flip_rate):
    print(f"layer {L}: flip-rate = {rate:5.2%}  (n={layer_total[L]})")

layer 0: flip-rate = 99.78%  (n=446)
layer 1: flip-rate = 99.55%  (n=446)
layer 2: flip-rate = 99.55%  (n=446)
layer 3: flip-rate = 99.78%  (n=446)
layer 4: flip-rate = 99.33%  (n=446)
layer 5: flip-rate = 99.78%  (n=446)
layer 6: flip-rate = 99.33%  (n=446)
layer 7: flip-rate = 99.78%  (n=446)
layer 8: flip-rate = 99.78%  (n=446)
layer 9: flip-rate = 99.55%  (n=446)
layer 10: flip-rate = 99.55%  (n=446)
layer 11: flip-rate = 99.33%  (n=446)
layer 12: flip-rate = 99.33%  (n=446)
layer 13: flip-rate = 99.10%  (n=446)
layer 14: flip-rate = 99.10%  (n=446)
layer 15: flip-rate = 98.65%  (n=446)
layer 16: flip-rate = 98.43%  (n=446)
layer 17: flip-rate = 97.76%  (n=446)
layer 18: flip-rate = 96.19%  (n=446)
layer 19: flip-rate = 93.50%  (n=446)
layer 20: flip-rate = 84.08%  (n=446)
layer 21: flip-rate = 75.34%  (n=446)
layer 22: flip-rate = 69.51%  (n=446)
layer 23: flip-rate = 53.14%  (n=446)
layer 24: flip-rate = 16.82%  (n=446)
layer 25: flip-rate = 6.05%  (n=446)
layer 26: flip-rate = 0

In [95]:
import plotly.graph_objects as go
layers = list(range(len(flip_rate)))

x  = [l for l, r in zip(layers, flip_rate)]
y  = [r for r in flip_rate]

fig = go.Figure()
fig.add_trace(go.Scatter(
    x = x,
    y = y,
    mode = "lines+markers",
    hovertemplate = "Layer %{x}<br>Flip-rate %{y:.1%}<extra></extra>"
))
fig.update_layout(
    title = "Running-count mediation (full sweep)",
    xaxis_title = "Layer",
    yaxis_title = "Flip-rate (prop. rescued :))",
    yaxis_range = [0, 1],
    paper_bgcolor="white",   
    plot_bgcolor="white",
)
fig.show()

## Probe sweep - pick the mediator layer $l^*$
won't do (right now), no time (right now).

In [None]:
# from sklearn.linear_model import Ridge
# from sklearn.metrics import r2_score
# import random 

# # ------------------------------------------------------------
# # 0 — collect clean-only records into per-layer bags
# # ------------------------------------------------------------
# K = output_records[0]["hidden_states"].shape[0]   # # layers
# rng = random.Random(42)
# # MAX_ROWS = 10_000 # Per split, per
# alpha = 10 # Choose such that tr/test loss is similar

# # Split by sample_ix so paired clean/corrupt never leak
# clean_records = [r for r in output_records if r["variant"] == "clean"]
# sample_ids = sorted({r["sample_ix"] for r in clean_records})
# rng.shuffle(sample_ids)
# split = int(0.8 * len(sample_ids))
# train_ids, test_ids = set(sample_ids[:split]), set(sample_ids[split:])

# K = clean_records[0]["hidden_states"].shape[0]
# bags = {l: {"train_X": [], "train_y": [], "test_X": [], "test_y": []} for l in range(K)}

# for rec in clean_records:
#     bag = (
#         "train_" if rec["sample_ix"] in train_ids else "test_"
#     )
#     hs = rec["hidden_states"].numpy()        # K × Nₗ × D  (already CPU)
#     y  = rec["cum_count"].numpy()            # Nₗ
#     for l in range(K):
#         bags[l][bag + "X"].append(hs[l])     # each is (Nₗ, D)
#         bags[l][bag + "y"].append(y)

# # ------------------------------------------------------------
# # 1 — stack into 2-D matrices per layer
# # ------------------------------------------------------------
# Xs_tr, ys_tr, Xs_te, ys_te = [], [], [], []
# for l in range(K):
#     X_tr = np.concatenate(bags[l]["train_X"])   # (∑Nₗ, D)
#     y_tr = np.concatenate(bags[l]["train_y"])
#     X_te = np.concatenate(bags[l]["test_X"])
#     y_te = np.concatenate(bags[l]["test_y"])

#     Xs_tr.append(X_tr)
#     ys_tr.append(y_tr)
#     Xs_te.append(X_te)
#     ys_te.append(y_te)

# # ------------------------------------------------------------
# # 2 — train one Ridge probe per layer, evaluate on held-out set
# # ------------------------------------------------------------
# r2_test, probes = [], []
# for l in range(K):
#     reg = Ridge(alpha = alpha, solver = 'auto', fit_intercept=True)
#     reg.fit(Xs_tr[l], ys_tr[l])

#     y_hat_tr = reg.predict(Xs_tr[l])
#     y_hat_te = reg.predict(Xs_te[l])
#     R2_train = r2_score(reg.predict(Xs_tr[l]),  ys_tr[l])

#     r2_tr = r2_score(ys_tr[l], y_hat_tr)
#     r2_te = r2_score(ys_te[l], y_hat_te)

#     r2_test.append(r2_te)
#     probes.append(reg)

#     print(f"Layer {l:2d} | train R² {r2_tr:6.3f} | test R² {r2_te:6.3f}")

# print("\nLayer  |  Test R²")
# for l, r2 in enumerate(r2_test):
#     print(f"{l:2d}     {r2:6.3f}")

# layer_star = int(np.argmax(r2_test))
# print(f"\nMediator candidate layer = {layer_star}  (R² = {r2_test[layer_star]:.3f})")

# w_star = torch.tensor(probes[layer_star].coef_, dtype=torch.float32)