In [6]:
import os

os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [9]:
import sys, subprocess, os, shlex

INSTALL_DEPS = True   # <- set True on a fresh VM / new kernel
REQ_PATH = "mech/requirements_mech.txt"

def pip_install_stream(cmd: str):
    print("\nRUN:", cmd)
    # -u for unbuffered python; also force pip to show progress bars
    env = os.environ.copy()
    env["PYTHONUNBUFFERED"] = "1"
    env["PIP_PROGRESS_BAR"] = "on"

    p = subprocess.Popen(
        shlex.split(cmd),
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,
        env=env,
    )
    for line in p.stdout:
        print(line, end="")
    rc = p.wait()
    if rc != 0:
        raise RuntimeError(f"Command failed with exit code {rc}: {cmd}")

if INSTALL_DEPS:
    pip_install_stream(f"{sys.executable} -u -m pip install -U pip setuptools wheel")
    pip_install_stream(f"{sys.executable} -u -m pip install -r {REQ_PATH}")
    print("\nDone. Now restart the kernel: Kernel -> Restart.")
else:
    print("INSTALL_DEPS=False, skipping installs.")


RUN: /usr/bin/python3 -u -m pip install -U pip setuptools wheel
Defaulting to user installation because normal site-packages is not writeable
Collecting setuptools
  Downloading setuptools-80.9.0-py3-none-any.whl.metadata (6.6 kB)
Downloading setuptools-80.9.0-py3-none-any.whl (1.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 10.5 MB/s  0:00:00
Installing collected packages: setuptools
  Attempting uninstall: setuptools
    Found existing installation: setuptools 70.2.0
    Uninstalling setuptools-70.2.0:
      Successfully uninstalled setuptools-70.2.0
Successfully installed setuptools-80.9.0

RUN: /usr/bin/python3 -u -m pip install -r mech/requirements_mech.txt
Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://download.pytorch.org/whl/cu128

Done. Now restart the kernel: Kernel -> Restart.


In [1]:
from pathlib import Path

# Project root (where 'data/', 'mech/', 'outputs/' live)
PROJECT_DIR = Path("/home/ubuntu/deep")   # <- change if needed
DATA_DIR = PROJECT_DIR / "data"
OUT_DIR = PROJECT_DIR / "outputs"
MECH_SCRIPT = PROJECT_DIR / "mech" / "mech_trace.py"

# Input / outputs
IN_JSONL = DATA_DIR / "normal_responses.jsonl"
OUT_MECH_JSONL = OUT_DIR / "normal_responses_mech.jsonl"
OUT_DELTA_JSONL = OUT_DIR / "normal_frame_deltas.jsonl"

# Model / trace settings
MODEL_ID = None  # None = infer from first row's 'model_id'. Or set e.g. "mistralai/Mistral-7B-Instruct-v0.3"
DTYPE = "bf16"   # "bf16" or "fp16"
LAYERS = "0,8,16,24,32"

# Sharding (for multi-GPU boxes). For single GPU, keep defaults.
NUM_SHARDS = 1
SHARD_ID = 0

# Safety knobs
OVERWRITE = True
MAX_ROWS = 0  # 0 = no limit

In [2]:
import os, json
import torch

print("cwd:", os.getcwd())
print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))

assert MECH_SCRIPT.exists(), f"Missing mech script: {MECH_SCRIPT}"
assert IN_JSONL.exists(), f"Missing input JSONL: {IN_JSONL}"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Quick input peek
n_lines = 0
with open(IN_JSONL, "rb") as f:
    for _ in f:
        n_lines += 1
print("input lines:", n_lines, "bytes:", IN_JSONL.stat().st_size)

with open(IN_JSONL, "r", encoding="utf-8") as f:
    first = json.loads(next(f))
print("first keys:", list(first.keys()))
print("model_id:", first.get("model_id"))
print("prompt chars:", len(first.get("prompt","")), "response chars:", len(first.get("response","")))

cwd: /home/ubuntu/deep
torch: 2.7.0+cu128
cuda available: True
gpu: NVIDIA A100-SXM4-40GB
input lines: 600 bytes: 1002264
first keys: ['id', 'model_id', 'model_label', 'question_id', 'base_question', 'frame', 'prompt', 'response', 'prompt_tokens', 'completion_tokens']
model_id: mistralai/Mistral-7B-Instruct-v0.3
prompt chars: 176 response chars: 567


In [3]:
import sys, subprocess, time

cmd = [
    sys.executable, str(MECH_SCRIPT),
    "--in_jsonl", str(IN_JSONL),
    "--out_jsonl", str(OUT_MECH_JSONL),
    "--layers", LAYERS,
    "--dtype", DTYPE,
    "--num_shards", str(NUM_SHARDS),
    "--shard_id", str(SHARD_ID),
]
if MODEL_ID:
    cmd += ["--model_id", MODEL_ID]
if MAX_ROWS and int(MAX_ROWS) > 0:
    cmd += ["--max_rows", str(MAX_ROWS)]
if OVERWRITE:
    cmd += ["--overwrite"]

print("Running:\n ", " ".join(cmd))
t0 = time.time()
subprocess.check_call(cmd)
print(f"Done in {time.time()-t0:.1f}s")

Running:
  /usr/bin/python3 /home/ubuntu/deep/mech/mech_trace.py --in_jsonl /home/ubuntu/deep/data/normal_responses.jsonl --out_jsonl /home/ubuntu/deep/outputs/normal_responses_mech.jsonl --layers 0,8,16,24,32 --dtype bf16 --num_shards 1 --shard_id 0 --overwrite


  from pkg_resources import parse_version  # type: ignore
2026-01-10 17:11:38.257894: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768065098.275669    4881 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768065098.281916    4881 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768065098.300504    4881 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768065098.300526    4881 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768065098.300528

Done in 31.9s


In [4]:
import json

assert OUT_MECH_JSONL.exists(), f"Missing output: {OUT_MECH_JSONL}"

# Count output lines
out_lines = 0
with open(OUT_MECH_JSONL, "rb") as f:
    for _ in f:
        out_lines += 1
print("output lines:", out_lines, "bytes:", OUT_MECH_JSONL.stat().st_size)

with open(OUT_MECH_JSONL, "r", encoding="utf-8") as f:
    r = json.loads(next(f))

print("mech keys:", r["mech"].keys())
print("layers captured:", sorted(r["mech"]["by_layer"].keys(), key=int))
# Show one layer
k0 = sorted(r["mech"]["by_layer"].keys(), key=int)[0]
print("layer", k0, "example:", r["mech"]["by_layer"][k0])

output lines: 600 bytes: 1289281
mech keys: dict_keys(['prompt_len', 'first_answer_token_id', 'true_next_token_logp_first', 'by_layer'])
layers captured: ['0', '8', '16', '24', '32']
layer 0 example: {'h_norm': 0.004940052516758442, 'logit_lens_logp_first': -10.4375}


In [5]:
import json, math
from collections import defaultdict

INP = OUT_MECH_JSONL
OUT = OUT_DELTA_JSONL

rows = []
with open(INP, "r", encoding="utf-8") as f:
    for line in f:
        rows.append(json.loads(line))

groups = defaultdict(dict)
for r in rows:
    key = (r["question_id"], r.get("model_label",""))
    groups[key][r["frame"]] = r

def vec_from_by_layer(by_layer, field, layer_keys):
    return [by_layer[k][field] for k in layer_keys if k in by_layer]

def l2(a, b):
    if not a or not b or len(a) != len(b):
        return None
    return math.sqrt(sum((x - y) ** 2 for x, y in zip(a, b)))

written = 0
missing = 0

with open(OUT, "w", encoding="utf-8") as fo:
    for (qid, mlabel), fr in groups.items():
        if not all(x in fr for x in ["casual", "evaluation", "oversight"]):
            missing += 1
            continue

        layer_keys = sorted(fr["casual"]["mech"]["by_layer"].keys(), key=int)

        def features(one):
            mech = one["mech"]
            return {
                "true_next_token_logp_first": mech["true_next_token_logp_first"],
                "logit_lens_curve": vec_from_by_layer(mech["by_layer"], "logit_lens_logp_first", layer_keys),
            }

        C = features(fr["casual"])
        E = features(fr["evaluation"])
        O = features(fr["oversight"])

        out = {
            "question_id": qid,
            "model_label": mlabel,
            "layers": layer_keys,

            "delta_E_minus_C_true_logp": E["true_next_token_logp_first"] - C["true_next_token_logp_first"],
            "delta_O_minus_C_true_logp": O["true_next_token_logp_first"] - C["true_next_token_logp_first"],

            "delta_E_minus_C_logitlens_L2": l2(E["logit_lens_curve"], C["logit_lens_curve"]),
            "delta_O_minus_C_logitlens_L2": l2(O["logit_lens_curve"], C["logit_lens_curve"]),
        }
        fo.write(json.dumps(out, ensure_ascii=False) + "\n")
        written += 1

print("Wrote:", OUT)
print("delta rows:", written, "groups missing frames:", missing)

Wrote: /home/ubuntu/deep/outputs/normal_frame_deltas.jsonl
delta rows: 200 groups missing frames: 0


In [6]:
import json, statistics

dE = []
dO = []
lE = []
lO = []

with open(OUT_DELTA_JSONL, "r", encoding="utf-8") as f:
    for line in f:
        r = json.loads(line)
        dE.append(r["delta_E_minus_C_true_logp"])
        dO.append(r["delta_O_minus_C_true_logp"])
        if r["delta_E_minus_C_logitlens_L2"] is not None:
            lE.append(r["delta_E_minus_C_logitlens_L2"])
        if r["delta_O_minus_C_logitlens_L2"] is not None:
            lO.append(r["delta_O_minus_C_logitlens_L2"])

def summarize(xs):
    return {
        "n": len(xs),
        "mean": statistics.fmean(xs) if xs else None,
        "p50": statistics.median(xs) if xs else None,
        "min": min(xs) if xs else None,
        "max": max(xs) if xs else None,
    }

print("Δ true_logp (E−C):", summarize(dE))
print("Δ true_logp (O−C):", summarize(dO))
print("Δ logitlens L2 (E−C):", summarize(lE))
print("Δ logitlens L2 (O−C):", summarize(lO))

Δ true_logp (E−C): {'n': 200, 'mean': 11.372469032406807, 'p50': 16.11921977996826, 'min': -18.0234375, 'max': 26.37485408782959}
Δ true_logp (O−C): {'n': 200, 'mean': 12.338937163949012, 'p50': 16.364479064941406, 'min': -14.6875, 'max': 26.374826431274414}
Δ logitlens L2 (E−C): {'n': 200, 'mean': 15.531033299960846, 'p50': 16.76536621638924, 'min': 0.8414379868326406, 'max': 30.97168182832154}
Δ logitlens L2 (O−C): {'n': 200, 'mean': 15.663469378824635, 'p50': 17.556755057900986, 'min': 0.6910472304133907, 'max': 30.95793706943914}
