MD # 02 Capture activations (MVP)
MD Runs each prompt once, stores residual activations per layer.


In [9]:
from typing import Dict, Tuple
import re
import torch, json
from pathlib import Path
import sys, pathlib, importlib
sys.path.append(str(pathlib.Path('..').resolve()))   # let Python see project root

from utils import set_seed, load_model, load_prompts


class ActivationCache:
    def __init__(self):
        self.hidden: Dict[int, torch.Tensor] = {}

    def hook(self, layer: int):
        def _capture(_m, _inp, out):
            # out can be a Tensor **or** a tuple whose first item is the tensor
            hs = out[0] if isinstance(out, (tuple, list)) else out
            self.hidden[layer] = hs.detach()
        return _capture


def safe_slug(text: str, max_len: int = 40) -> str:
    """
    Turn an arbitrary string into a Windows-safe filename.
    Keeps A–Z, a–z, 0–9, dash and underscore; drops anything else.
    """
    clean = re.sub(r"[^0-9A-Za-z_-]", "_", text)   # replace bad chars with _
    return clean[:max_len].rstrip("_")

def forward_with_cache(model, ids) -> Tuple[torch.Tensor, 'ActivationCache']:
    cache = ActivationCache()
    hooks = [model.transformer.h[i].register_forward_hook(cache.hook(i))
             for i in range(model.config.n_layer)]
    logits = model(ids).logits
    for h in hooks:
        h.remove()
    return logits, cache

# ---- main execution ----
set_seed(0)
model, tok = load_model('gpt2')
prompts = load_prompts('../data/reasoning_prompts.json')

Path('../data/caches').mkdir(exist_ok=True)

for p in prompts:
    ids = tok(p, return_tensors='pt').input_ids.to(DEVICE)
    _, cache = forward_with_cache(model, ids)
    slug = safe_slug(p)
    torch.save(cache.hidden, f'../data/caches/{slug}.pt')
    print('saved cache for', slug)


saved cache for Q__What_is_2___3__A
saved cache for Q__If_I_have_five_apples_and_eat_two__ho
