MD # 03 Layer-patching grid
MD Swaps hidden-states layer-by-layer, computes Δ-log-odds.

In [40]:
import sys, pathlib, re, json, torch
from pathlib import Path

# let Python see project root (.. -> CHAIN-COT-PROJECT)
sys.path.append(str(pathlib.Path("..").resolve()))
from src.utils import load_model, load_prompts, set_seed, DEVICE


In [41]:
def safe_slug(text: str, max_len: int = 40) -> str:
    return re.sub(r"[^0-9A-Za-z_-]", "_", text)[:max_len].rstrip("_")


In [42]:
set_seed(0)
model, tok = load_model("gpt2")
model.config.use_cache = False          # ensure tensor outputs

prompts   = load_prompts("../data/reasoning_prompts.json")
cache_dir = Path("../data/caches")
assert cache_dir.exists(), "Run 02_mvp_capture first!"


In [43]:
def to_3d(t: torch.Tensor) -> torch.Tensor:
    """
    Ensure tensor has shape (1, seq, hidden).
    Accepts (hidden) / (seq, hidden) / (1, seq, hidden).
    """
    if t.dim() == 1:                # (hidden)          -> (1, 1, hidden)
        t = t.unsqueeze(0).unsqueeze(0)
    elif t.dim() == 2:              # (seq, hidden)     -> (1, seq, hidden)
        t = t.unsqueeze(0)
    return t


In [44]:
@torch.no_grad()
def grid_scores(ids, donor):
    corrupt = ids.clone()
    corrupt[0, -1] = tok.eos_token_id
    incorrect = model(corrupt, use_cache=False).logits
    target_id = tok(" 4").input_ids[0]

    seq_len  = ids.size(1)
    hidden   = model.config.n_embd

    def to_bsh(t):
        # → (1, seq_len, hidden) on current device
        t = t.to(DEVICE)
        if t.dim() == 1:
            t = t.repeat(seq_len).view(1, seq_len, hidden)
        elif t.dim() == 2:
            if t.size(0) < seq_len:
                t = torch.cat([t, t[-1:].expand(seq_len - t.size(0), -1)], 0)
            t = t[:seq_len].unsqueeze(0)
        elif t.dim() == 3 and t.size(1) < seq_len:
            t = torch.cat([t, t[:, -1:].expand(-1, seq_len - t.size(1), -1)], 1)
        return t.contiguous()

    scores = {}
    for layer in range(model.config.n_layer):
        donor_t = to_bsh(donor[layer])

        def _patch(_m, _i, out):
            """Make donor_t match *hidden-state* dtype, and keep tuple if needed."""
            dtype = out[0].dtype if isinstance(out, (tuple, list)) else out.dtype
            patched = donor_t.to(dtype)
            if isinstance(out, (tuple, list)):
                new_out = list(out)
                new_out[0] = patched
                return tuple(new_out)
            return patched

        h = model.transformer.h[layer].register_forward_hook(_patch, prepend=False)
        patched = model(corrupt, use_cache=False).logits
        h.remove()

        scores[layer] = (
            patched[0, -1, target_id] - incorrect[0, -1, target_id]
        ).item()
    return scores


In [45]:
results = {}
for p in prompts:
    slug = safe_slug(p)
    donor_hidden = torch.load(cache_dir / f"{slug}.pt", map_location=DEVICE)
    ids = tok(p, return_tensors="pt").input_ids.to(DEVICE)

    results[p] = grid_scores(ids, donor_hidden)
    print("✓", slug)

out_path = Path("../data/layer_grid.json")
out_path.parent.mkdir(exist_ok=True)
json.dump(results, open(out_path, "w"), indent=2)
print("✅  saved →", out_path.relative_to('..'))


✓ Q__What_is_2___3__A
✓ Q__If_I_have_five_apples_and_eat_two__ho
✅  saved → data\layer_grid.json
