MD # 03 Layer-patching grid
MD Swaps hidden-states layer-by-layer, computes Δ-log-odds.

In [7]:
import torch, json, numpy as np
from pathlib import Path
import sys, pathlib, importlib
import re

sys.path.append(str(pathlib.Path('..').resolve()))   # let Python see project root

sys.path.append(str(pathlib.Path('src').resolve()))
from src.utils import set_seed, load_model, load_prompts

def safe_slug(text: str, max_len: int = 40) -> str:
    """
    Turn an arbitrary string into a Windows-safe filename.
    Keeps A–Z, a–z, 0–9, dash and underscore; drops anything else.
    """
    clean = re.sub(r"[^0-9A-Za-z_-]", "_", text)   # replace bad chars with _
    return clean[:max_len].rstrip("_")

# --- load model & prompts ---
model, tok = load_model('gpt2')
prompts = load_prompts('../data/reasoning_prompts.json')

def _patch(_m, _i, out):
    if isinstance(out, (tuple, list)):
        # replace hidden-state but keep the rest of the tuple structure
        new_out = list(out)
        new_out[0] = donor_hidden[layer]
        return tuple(new_out)
    return donor_hidden[layer]


results = {}
for p in prompts:
    slug = safe_slug(p)
    donor_hidden = torch.load(f'../data/caches/{slug}.pt')
    ids = tok(p, return_tensors='pt').input_ids.to(DEVICE)

    # corrupt final token -> EOS for wrong-answer baseline
    corrupted = ids.clone(); corrupted[0, -1] = tok.eos_token_id
    incorrect_logits = model(corrupted).logits
    target_id = tok(' 4').input_ids[0]        # task-specific target

    layer_scores = {}
    for layer in range(model.config.n_layer):
        def _patch(_m, _i, _o):
            return donor_hidden[layer]
        h = model.transformer.h[layer].register_forward_hook(_patch)
        patched_logits = model(corrupted).logits
        h.remove()
        layer_scores[layer] = (
            patched_logits[0, -1, target_id] -
            incorrect_logits[0, -1, target_id]
        ).item()
    results[p] = layer_scores
    print('done', slug)

Path('../data').mkdir(exist_ok=True)
json.dump(results, open('../data/layer_grid.json', 'w'), indent=2)
print('saved layer_grid.json')

NameError: name 'DEVICE' is not defined