In [1]:
import os
import json
import time
import random
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Callable, Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import torchvision
import torchvision.transforms as T

print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

torch: 2.9.0+cpu
torchvision: 0.24.0+cpu
device: cpu


In [None]:
# Liminal-style anchoring (KL-to-base) in auxiliary-logit distillation
#
# Conditions:
# - **Control**: normal auxiliary-logit distillation (student matches teacher aux logits on noise).
# - **Intervention (liminal anchor)**: augment distillation with a KL regularizer that anchors the student
#   to a frozen **base model** (the shared initialization) on selected logits:
#       L_total = L_distill_aux + lambda_k * KL(p_base^T || p_student^T)
#   where the anchor is computed either on:
#       - "reg10": the first 10 (MNIST) logits, or
#       - "all":   all (10 + aux) logits.
#   The weight lambda_k can be scheduled (strong early then decay) or kept constant throughout training.
#
# Measurements (logged every N steps):
# - trait loss (CE on first 10 logits, on audit batches)
# - distill loss (KL on aux logits, on current noise batch)
# - anchor loss (KL-to-base on selected logits, on current noise batch) and lambda_k
# - trait curvature along the *actual update direction* v (v^T H_trait v), where v is the gradient of L_total
# - alignment metrics computed at logging steps:
#     * pre:  alignment between g_distill (grad of L_distill_aux) and g_trait
#     * post: alignment between g_update  (grad of L_total)      and g_trait
# - final performance (student test accuracy on MNIST)
#
# Key detail:
# - The intervention does NOT require the trait gradient to form the update.
# - Therefore, audit batches are drawn only at logging steps (same as control).
#   At a logging step k, the same audit batches are used to compute:
#     (i) trait loss, (ii) trait curvature, and (iii) g_trait for alignment metrics.
#
# Output:
# - **one CSV per (seed, condition)**: runs/.../seed_01/control_metrics.csv, intervention_metrics.csv

In [2]:
@dataclass
class ExperimentConfig:
    # Runs
    seeds: List[int] = None # e.g., [1,2,...]
    out_dir: str = "./runs_mnist_liminal_anchor"

    # Data
    batch_size: int = 1024
    num_workers: int = 0  # keep 0 for strict determinism
    audit_size: int = 10_000
    noise_dataset_size: int = 60_000

    # Model (MLP from paper)
    hidden_dim: int = 256
    aux_m: int = 3

    # Training
    teacher_epochs: int = 5
    student_epochs: int = 5
    lr_teacher: float = 3e-4
    lr_student: float = 3e-4

    # Logging
    metrics_every_n_steps: int = 50
    # Number of audit batches used at logging steps to compute trait loss/curvature and g_trait.
    # For control, it's only used at logging steps.
    audit_batches_for_trait: int = 1

    # Numerics
    eps: float = 1e-12

    anchor_mode: str = "all"
    anchor_T: float = 2.0
    lambda0: float = 1.0

    anchor_schedule: str = "original"
    anchor_warmup_epochs: int = 1      # Unused if anchor_schedule = "original"
    anchor_decay_to_zero: bool = True  # Unused if anchor_schedule = "original"

cfg = ExperimentConfig(
    seeds=list(range(1, 11)),
    #seeds=[5],
    out_dir="./runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all",
    batch_size=1024,
    num_workers=0,
    audit_size=10_000,
    noise_dataset_size=60_000,
    hidden_dim=256,
    aux_m=3,
    teacher_epochs=5,
    student_epochs=5,
    lr_teacher=3e-4,
    lr_student=3e-4,
    metrics_every_n_steps=1,
    audit_batches_for_trait=10,
    eps=1e-12,
    anchor_mode="all",
    anchor_T=2.0,
    lambda0=1.0,
    anchor_schedule="original",
    anchor_warmup_epochs=1,              # Unused if anchor_schedule = "original"
    anchor_decay_to_zero=True,           # Unused if anchor_schedule = "original"
)

Path(cfg.out_dir).mkdir(parents=True, exist_ok=True)
cfg

def set_global_seed(seed: int, deterministic: bool = True) -> None:
    """Seed python, numpy, and torch. Optionally enable deterministic algorithms."""
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        try:
            torch.use_deterministic_algorithms(True)
        except Exception as e:
            print("Warning: could not enable full deterministic algorithms:", e)

def make_torch_generator(seed: int) -> torch.Generator:
    g = torch.Generator()
    g.manual_seed(seed)
    return g

class NoiseImages(Dataset):
    """Deterministic noise dataset: each index produces a reproducible noise image."""
    def __init__(self, length: int, seed: int, shape=(1, 28, 28), dist: str = "normal"):
        self.length = int(length)
        self.seed = int(seed)
        self.shape = tuple(shape)
        self.dist = dist

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, idx: int):
        # Per-index deterministic generation.
        g = torch.Generator()
        g.manual_seed(self.seed * 1_000_000 + int(idx))
        if self.dist == "normal":
            x = torch.randn(self.shape, generator=g)
        elif self.dist == "uniform":
            x = torch.rand(self.shape, generator=g) * 2 - 1
        else:
            raise ValueError(f"Unknown dist: {self.dist}")
        # Dummy label (unused)
        y = 0
        return x, y

def get_mnist_datasets(root: str):
    transform = T.Compose([T.ToTensor()])
    train = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform)
    test = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform)
    return train, test

def split_train_audit(train_ds, audit_size: int, seed: int):
    n = len(train_ds)
    audit_size = min(int(audit_size), n)
    train_size = n - audit_size
    g = make_torch_generator(seed)
    train_split, audit_split = random_split(train_ds, [train_size, audit_size], generator=g)
    return train_split, audit_split

def make_loader(ds, batch_size: int, shuffle: bool, seed: int, num_workers: int = 0):
    # Deterministic shuffling via DataLoader generator.
    g = make_torch_generator(seed)
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        generator=g,
        drop_last=False,
    )

class MLPClassifier(nn.Module):
    """MLP from the Subliminal Learning MNIST experiment: (784, 256, 256, 10+m) with ReLU."""
    def __init__(self, hidden_dim: int = 256, aux_m: int = 3):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.aux_m = aux_m
        self.fc1 = nn.Linear(28 * 28, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 10 + aux_m)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def build_model_mlp(cfg: ExperimentConfig) -> nn.Module:
    """Swap model architecture by changing this builder (or passing another builder to the runner)."""
    return MLPClassifier(hidden_dim=cfg.hidden_dim, aux_m=cfg.aux_m)

def logits_regular(logits: torch.Tensor) -> torch.Tensor:
    return logits[:, :10]

def logits_aux(logits: torch.Tensor, aux_m: int) -> torch.Tensor:
    return logits[:, 10:10 + aux_m]

@torch.no_grad()
def accuracy_on_loader(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        pred = logits_regular(model(x)).argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

def get_params(model: nn.Module) -> List[torch.nn.Parameter]:
    return [p for p in model.parameters() if p.requires_grad]

def flatten_grads_from_params(params: List[torch.nn.Parameter]) -> torch.Tensor:
    """Flatten gradients already stored in .grad (e.g., after backward())."""
    flats = []
    for p in params:
        if p.grad is None:
            flats.append(torch.zeros_like(p).view(-1))
        else:
            flats.append(p.grad.detach().view(-1))
    return torch.cat(flats)

def split_flat_like_params(params: List[torch.nn.Parameter], flat: torch.Tensor) -> List[torch.Tensor]:
    out = []
    offset = 0
    for p in params:
        n = p.numel()
        out.append(flat[offset:offset + n].view_as(p))
        offset += n
    assert offset == flat.numel()
    return out

def train_teacher(
    model: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    cfg: ExperimentConfig,
    device: torch.device,
) -> Dict[str, float]:
    """Train teacher on MNIST CE using regular logits only."""
    model = model.to(device)
    model.train()
    opt = torch.optim.Adam(model.parameters(), lr=cfg.lr_teacher)

    for _epoch in range(cfg.teacher_epochs):
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad(set_to_none=True)
            loss = F.cross_entropy(logits_regular(model(x)), y)
            loss.backward()
            opt.step()

    teacher_acc = accuracy_on_loader(model, test_loader, device)
    return {"teacher_test_acc": float(teacher_acc)}

def make_infinite_iterator(loader: DataLoader):
    it = iter(loader)
    while True:
        try:
            yield next(it)
        except StopIteration:
            it = iter(loader)

@torch.no_grad()
def collect_audit_batches(audit_stream, num_batches: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
    return [next(audit_stream) for _ in range(num_batches)]

def trait_loss_on_batches(
    model: nn.Module,
    batches: List[Tuple[torch.Tensor, torch.Tensor]],
    device: torch.device,
) -> torch.Tensor:
    losses = []
    for x, y in batches:
        x, y = x.to(device), y.to(device)
        losses.append(F.cross_entropy(logits_regular(model(x)), y))
    return torch.stack(losses).mean()

def compute_trait_grad_flat_from_batches(
    model: nn.Module,
    params: List[torch.nn.Parameter],
    batches: List[Tuple[torch.Tensor, torch.Tensor]],
    device: torch.device,
) -> Tuple[float, torch.Tensor]:
    loss = trait_loss_on_batches(model, batches, device)
    grads = torch.autograd.grad(loss, params, retain_graph=False, create_graph=False)
    g_flat = torch.cat([g.detach().view(-1) for g in grads])
    return float(loss.detach()), g_flat

def compute_trait_loss_and_curvature_vHv_from_batches(
    model: nn.Module,
    params: List[torch.nn.Parameter],
    batches: List[Tuple[torch.Tensor, torch.Tensor]],
    v_flat: torch.Tensor,
    device: torch.device,
) -> Tuple[float, float, float]:
    """
    Returns:
      trait_loss, v^T H v, (v^T H v) / ||v||^2
    """
    loss = trait_loss_on_batches(model, batches, device)

    grads = torch.autograd.grad(loss, params, retain_graph=True, create_graph=True)

    v_list = split_flat_like_params(params, v_flat.detach())
    gv = torch.zeros((), device=device)
    for g, v in zip(grads, v_list):
        gv = gv + (g * v).sum()

    hvp = torch.autograd.grad(gv, params, retain_graph=False, create_graph=False)
    hvp_flat = torch.cat([h.detach().view(-1) for h in hvp])

    v_det = v_flat.detach()
    vHv = float((v_det @ hvp_flat).detach())
    v_norm2 = float((v_det @ v_det).detach())
    vHv_norm = vHv / max(v_norm2, 1e-30)

    return float(loss.detach()), vHv, vHv_norm

def distill_loss_aux_only(student: nn.Module, teacher: nn.Module, x_noise: torch.Tensor, aux_m: int) -> torch.Tensor:
    with torch.no_grad():
        t_aux = logits_aux(teacher(x_noise), aux_m)
        t_prob = F.softmax(t_aux, dim=1)

    s_aux = logits_aux(student(x_noise), aux_m)
    s_logprob = F.log_softmax(s_aux, dim=1)

    kl = torch.nn.KLDivLoss(reduction="batchmean")
    return kl(s_logprob, t_prob)

# Intervention: liminal-training-style regularization
def select_logits_for_anchor(logits: torch.Tensor, mode: str, aux_m: int) -> torch.Tensor:
    if mode == "reg10":
        return logits[:, :10]
    if mode == "aux":
        return logits[:, 10:10 + aux_m]
    if mode == "all":
        return logits
    raise ValueError(f"Unknown anchor_mode={mode}")

def kl_anchor_to_base(
    student: nn.Module,
    base: nn.Module,
    x: torch.Tensor,
    cfg: ExperimentConfig,
) -> torch.Tensor:
    """
    KL(p_base^T || p_student^T) computed on selected logits (reg10 or all).
    Uses KLDivLoss(log p_student, p_base) which equals KL(p_base || p_student).
    """
    kl = torch.nn.KLDivLoss(reduction="batchmean")
    T = cfg.anchor_T

    with torch.no_grad():
        b = select_logits_for_anchor(base(x), cfg.anchor_mode, cfg.aux_m) / T
        p_base = F.softmax(b, dim=1)

    s = select_logits_for_anchor(student(x), cfg.anchor_mode, cfg.aux_m) / T
    logp_student = F.log_softmax(s, dim=1)

    return (T ** 2) * kl(logp_student, p_base)

def anchor_lambda(epoch: int, global_step: int, steps_per_epoch: int, cfg: ExperimentConfig) -> float:
    """
    Schedule for lambda_k:
      - constant: lambda0 throughout
      - early_then_decay: lambda0 for first `anchor_warmup_epochs` epochs,
        then (optionally) linear decay to 0 by end of training
    """
    if cfg.lambda0 <= 0:
        return 0.0

    if cfg.anchor_schedule == "constant":
        return float(cfg.lambda0)

    if cfg.anchor_schedule == "early_then_decay":
        # Warmup phase
        if epoch < cfg.anchor_warmup_epochs:
            return float(cfg.lambda0)

        # After warmup
        if not cfg.anchor_decay_to_zero:
            return float(cfg.lambda0)

        total_steps = steps_per_epoch * cfg.student_epochs
        warmup_steps = steps_per_epoch * cfg.anchor_warmup_epochs
        remaining = max(total_steps - warmup_steps, 1)

        t = (global_step - warmup_steps) / remaining  # This does not perfectly become 0 at the end
        return float(cfg.lambda0 * max(0.0, 1.0 - t))

    if cfg.anchor_schedule == "original":
        return float(anchor_lambda_original(global_step, steps_per_epoch, cfg))

    raise ValueError(f"Unknown anchor_schedule={cfg.anchor_schedule}")

def anchor_lambda_original(global_step: int, steps_per_epoch: int, cfg: ExperimentConfig) -> float:
    """
    Implements Yanagisawa et al. (Liminal Training) Eq. (3)-(4):
      - Phase 1 (first epoch): anneal from lambda0 to 1 (except lambda0 = 1 case)
      - Phase 2 (rest): decay linearly to 0 with scale lambda0
    """
    lambda0 = float(cfg.lambda0)
    if lambda0 <= 0:
        return 0.0

    total_steps = steps_per_epoch * cfg.student_epochs
    if total_steps <= 1:
        return lambda0

    # normalized training progress t in [0,1]
    t = global_step / (total_steps - 1)

    E = float(cfg.student_epochs)
    tau2 = 1.0 / E  # end of first epoch in normalized time

    if t <= tau2 + 1e-12:
        # Phase 1: lambda(t) = 1 + (lambda0 - 1) * (1 - t/tau2)
        s = t / max(tau2, 1e-12)
        return 1.0 + (lambda0 - 1.0) * (1.0 - s)
    else:
        # Phase 2: lambda(t) = lambda0 * (1 - (t - tau2)/(1 - tau2))
        s = (t - tau2) / max(1.0 - tau2, 1e-12)
        return lambda0 * max(0.0, 1.0 - s)

# Student run (control vs intervention)
# Key detail (liminal anchoring):
# - In **intervention**, the update is computed from
#       L_total = L_distill_aux + lambda_k * KL_anchor(student || base)
#   so it does NOT require the trait gradient.
# - Therefore, audit batches are drawn only at logging steps (same as control),
#   and are used solely to compute:
#       (i) trait loss / curvature, and
#       (ii) alignment metrics between g_trait and the actual update gradient g_update.
# - At a logging step k, the same audit batches are used for trait loss, curvature,
#   and g_trait (for alignment), ensuring consistency of logged quantities.
def run_student_condition(
    condition: str,  # "control" or "intervention"
    student: nn.Module,
    teacher: nn.Module,
    base_model: nn.Module,
    noise_loader: DataLoader,
    audit_loader: DataLoader,
    test_loader: DataLoader,
    cfg: ExperimentConfig,
    seed: int,
    device: torch.device,
) -> Tuple[pd.DataFrame, Dict[str, float]]:
    assert condition in ("control", "intervention")
    assert cfg.audit_batches_for_trait >= 1

    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval()
    for p in teacher.parameters():
        p.requires_grad_(False)

    params = get_params(student)
    opt = torch.optim.Adam(student.parameters(), lr=cfg.lr_student)

    audit_stream = make_infinite_iterator(audit_loader)

    logs: List[Dict[str, float]] = []
    global_step = 0
    steps_per_epoch = len(noise_loader)

    student.train()
    for epoch in range(cfg.student_epochs):
        for x_noise, _ in noise_loader:
            x_noise = x_noise.to(device)
#
            # Decide if we log this step (needed before computing g_distill)
            do_log = (global_step % cfg.metrics_every_n_steps == 0)

            # --- Forward losses ---
            opt.zero_grad(set_to_none=True)

            dloss = distill_loss_aux_only(student, teacher, x_noise, cfg.aux_m)

            lam = 0.0
            anchor_loss = torch.tensor(0.0, device=device)

            if condition == "intervention":
                lam = anchor_lambda(epoch, global_step, steps_per_epoch, cfg)
                if lam > 0.0:
                    anchor_loss = kl_anchor_to_base(student, base_model, x_noise, cfg)

            # If logging, compute distill-only gradient (pre) without touching .grad
            g_distill = None
            if do_log:
                grads_distill = torch.autograd.grad(dloss, params, retain_graph=True, create_graph=False)
                g_distill = torch.cat([g.detach().view(-1) for g in grads_distill])

            # --- Backward actual update gradient ---
            total_loss = dloss + lam * anchor_loss
            total_loss.backward()

            # Actual update direction (post)
            g_update = flatten_grads_from_params(params)
#
            # --- logging: trait loss + curvature + alignment (pre vs post) ---
            if do_log:
                # Draw trait batches only at logging steps (both conditions)
                trait_batches_k = collect_audit_batches(audit_stream, cfg.audit_batches_for_trait)
                # Trait grad on these batches
                _trait_loss_tmp, g_trait = compute_trait_grad_flat_from_batches(
                    model=student,
                    params=params,
                    batches=trait_batches_k,
                    device=device,
                )

                # Trait loss + curvature along actual update direction g_update
                trait_loss_val, vHv, vHv_norm = compute_trait_loss_and_curvature_vHv_from_batches(
                    model=student,
                    params=params,
                    batches=trait_batches_k,
                    v_flat=g_update,
                    device=device,
                )

                # Alignment metrics: pre (distill-only) vs post (actual update)
                inner_pre = float((g_distill @ g_trait).detach()) if g_distill is not None else np.nan
                cos_pre = float((g_distill @ g_trait).detach() / ((g_distill.norm() * g_trait.norm()).clamp_min(cfg.eps))) if g_distill is not None else np.nan

                inner_post = float((g_update @ g_trait).detach())
                cos_post = float((g_update @ g_trait).detach() / ((g_update.norm() * g_trait.norm()).clamp_min(cfg.eps)))

                logs.append({
                    "seed": seed,
                    "condition": condition,
                    "step": global_step,
                    "epoch": epoch,

                    "trait_loss": trait_loss_val,
                    "distill_loss": float(dloss.detach()),
                    "total_loss": float(total_loss.detach()),

                    "trait_curvature_vHv": vHv,
                    "trait_curvature_vHv_norm": vHv_norm,

                    "inner_pre": inner_pre,
                    "cos_pre": cos_pre,
                    "inner_post": inner_post,
                    "cos_post": cos_post,

                    "lambda_k": float(lam),
                    "anchor_loss": float(anchor_loss.detach()),
                    "anchor_mode": cfg.anchor_mode,
                    "anchor_schedule": cfg.anchor_schedule,
                })

            opt.step()
            global_step += 1

    student_acc = accuracy_on_loader(student, test_loader, device)

    df = pd.DataFrame(logs)
    info = {
        "student_test_acc": float(student_acc),
        "total_steps": int(global_step),
        "num_logged_rows": int(len(df)),
    }

    return df, info

# Aggregate summary across runs
def aggregate_across_runs(out_dir: str, epoch=None) -> pd.DataFrame:
    out_dir = Path(out_dir)
    rows = []
    for seed_dir in sorted(out_dir.glob("seed_*")):
        for cond, fname in [("control", "control_metrics.csv"), ("intervention", "intervention_metrics.csv")]:
            p = seed_dir / fname
            if not p.exists():
                continue

            df = pd.read_csv(p)
            if len(df) == 0:
                continue

            if epoch is not None:
                if isinstance(epoch, (list, tuple, set, np.ndarray, pd.Series)):
                    df = df[df["epoch"].isin(list(epoch))]
                else:
                    df = df[df["epoch"] == int(epoch)]
                if len(df) == 0:
                    continue

            # Helpers for safe means
            def _mean(col: str):
                return float(df[col].mean()) if col in df.columns else np.nan

            rows.append({
                "seed": int(df["seed"].iloc[0]),
                "condition": cond,

                # performance
                "teacher_test_acc": float(df["teacher_test_acc"].iloc[0]) if "teacher_test_acc" in df.columns else np.nan,
                "student_test_acc": float(df["student_test_acc"].iloc[0]) if "student_test_acc" in df.columns else np.nan,

                # losses
                "mean_trait_loss": _mean("trait_loss"),
                "mean_distill_loss": _mean("distill_loss"),

                # curvature
                "mean_trait_curvature_vHv": _mean("trait_curvature_vHv"),
                "mean_trait_curvature_vHv_norm": _mean("trait_curvature_vHv_norm"),

                # alignment summaries
                "mean_inner_pre": _mean("inner_pre"),
                "mean_cos_pre": _mean("cos_pre"),
                "mean_inner_post": _mean("inner_post"),
                "mean_cos_post": _mean("cos_post"),

                # fractions of positive alignment (pre/post)
                "frac_pos_pre": float((df["inner_pre"] > 0).mean()) if "inner_pre" in df.columns else np.nan,
                "frac_pos_post": float((df["inner_post"] > 0).mean()) if "inner_post" in df.columns else np.nan,

                "num_logged_rows": int(len(df)),
            })

    return pd.DataFrame(rows).sort_values(["seed", "condition"]).reset_index(drop=True)

# One seed: train teacher once, run control + intervention students
def run_one_seed(
    seed: int,
    cfg: ExperimentConfig,
    build_model_fn: Callable[[ExperimentConfig], nn.Module],
    device: torch.device,
) -> Dict[str, Path]:
    set_global_seed(seed, deterministic=True)

    run_dir = Path(cfg.out_dir) / f"seed_{seed:02d}"
    run_dir.mkdir(parents=True, exist_ok=True)

    # Data
    data_root = str(Path(cfg.out_dir) / "data_cache")
    mnist_train, mnist_test = get_mnist_datasets(root=data_root)
    train_split, audit_split = split_train_audit(mnist_train, audit_size=cfg.audit_size, seed=seed)

    train_loader = make_loader(train_split, cfg.batch_size, shuffle=True,  seed=seed + 100, num_workers=cfg.num_workers)
    audit_loader = make_loader(audit_split, cfg.batch_size, shuffle=True,  seed=seed + 200, num_workers=cfg.num_workers)
    test_loader  = make_loader(mnist_test,  cfg.batch_size, shuffle=False, seed=seed + 300, num_workers=cfg.num_workers)

    # Same noise ordering for both conditions
    noise_ds = NoiseImages(length=cfg.noise_dataset_size, seed=seed + 400, shape=(1, 28, 28), dist="normal")
    noise_loader_control = make_loader(noise_ds, cfg.batch_size, shuffle=True, seed=seed + 500, num_workers=cfg.num_workers)
    noise_loader_interv  = make_loader(noise_ds, cfg.batch_size, shuffle=True, seed=seed + 500, num_workers=cfg.num_workers)

    # Reference init
    reference = build_model_fn(cfg)
    ref_state = {k: v.clone().detach().cpu() for k, v in reference.state_dict().items()}
    # Base model for anchoring (same init as teacher/student), frozen
    base_model = build_model_fn(cfg)
    base_model.load_state_dict(ref_state)
    base_model = base_model.to(device)
    base_model.eval()
    for p in base_model.parameters():
        p.requires_grad_(False)

    # Teacher
    teacher = build_model_fn(cfg)
    teacher.load_state_dict(ref_state)
    teacher_info = train_teacher(teacher, train_loader, test_loader, cfg, device)

    # Students (same init)
    student_control = build_model_fn(cfg); student_control.load_state_dict(ref_state)
    student_interv  = build_model_fn(cfg); student_interv.load_state_dict(ref_state)

    # Run conditions
    control_df, control_info = run_student_condition(
        condition="control",
        student=student_control,
        teacher=teacher,
        base_model=base_model,
        noise_loader=noise_loader_control,
        audit_loader=audit_loader,
        test_loader=test_loader,
        cfg=cfg,
        seed=seed,
        device=device,
    )
    interv_df, interv_info = run_student_condition(
        condition="intervention",
        student=student_interv,
        teacher=teacher,
        base_model=base_model,
        noise_loader=noise_loader_interv,
        audit_loader=audit_loader,
        test_loader=test_loader,
        cfg=cfg,
        seed=seed,
        device=device,
    )

    # Add final perf to all rows
    for df, info in [(control_df, control_info), (interv_df, interv_info)]:
        df["teacher_test_acc"] = teacher_info["teacher_test_acc"]
        df["student_test_acc"] = info["student_test_acc"]
        df["total_steps"] = info["total_steps"]

    # Save
    control_csv = run_dir / "control_metrics.csv"
    interv_csv  = run_dir / "intervention_metrics.csv"
    control_df.to_csv(control_csv, index=False)
    interv_df.to_csv(interv_csv, index=False)

    meta = {
        "seed": seed,
        "config": asdict(cfg),
        "teacher_info": teacher_info,
        "control_info": control_info,
        "intervention_info": interv_info,
        "created_at_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "device": str(device),
    }
    with open(run_dir / "metadata.json", "w") as f:
        json.dump(meta, f, indent=2)

    print(
        f"[seed {seed}] teacher_acc={teacher_info['teacher_test_acc']:.4f} | "
        f"control_acc={control_info['student_test_acc']:.4f} | "
        f"interv_acc={interv_info['student_test_acc']:.4f} | "
        f"rows(control/interv)={len(control_df)}/{len(interv_df)}"
    )

    return {"control_csv": control_csv, "intervention_csv": interv_csv}

# Run all seeds (1..10)
all_paths = []
for s in cfg.seeds:
    all_paths.append(run_one_seed(seed=s, cfg=cfg, build_model_fn=build_model_mlp, device=device))

all_paths[0]

100%|██████████| 9.91M/9.91M [00:01<00:00, 6.70MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 158kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.49MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.24MB/s]


[seed 1] teacher_acc=0.9296 | control_acc=0.1171 | interv_acc=0.1106 | rows(control/interv)=295/295
[seed 2] teacher_acc=0.9276 | control_acc=0.1760 | interv_acc=0.1157 | rows(control/interv)=295/295
[seed 3] teacher_acc=0.9315 | control_acc=0.3504 | interv_acc=0.2274 | rows(control/interv)=295/295
[seed 4] teacher_acc=0.9291 | control_acc=0.1962 | interv_acc=0.2416 | rows(control/interv)=295/295
[seed 5] teacher_acc=0.9296 | control_acc=0.4617 | interv_acc=0.2266 | rows(control/interv)=295/295
[seed 6] teacher_acc=0.9322 | control_acc=0.3791 | interv_acc=0.2389 | rows(control/interv)=295/295
[seed 7] teacher_acc=0.9332 | control_acc=0.3537 | interv_acc=0.1842 | rows(control/interv)=295/295
[seed 8] teacher_acc=0.9318 | control_acc=0.2055 | interv_acc=0.1156 | rows(control/interv)=295/295
[seed 9] teacher_acc=0.9304 | control_acc=0.3518 | interv_acc=0.2604 | rows(control/interv)=295/295
[seed 10] teacher_acc=0.9303 | control_acc=0.1048 | interv_acc=0.1454 | rows(control/interv)=295/295

{'control_csv': PosixPath('runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/seed_01/control_metrics.csv'),
 'intervention_csv': PosixPath('runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/seed_01/intervention_metrics.csv')}

In [3]:
summary = aggregate_across_runs(cfg.out_dir, epoch=0)
summary_path = Path(cfg.out_dir) / "summary_by_seed_condition_epoch0.csv"
summary.to_csv(summary_path, index=False)
summary = aggregate_across_runs(cfg.out_dir)
summary_path = Path(cfg.out_dir) / "summary_by_seed_condition_all_epoch.csv"
summary.to_csv(summary_path, index=False)
summary

Unnamed: 0,seed,condition,teacher_test_acc,student_test_acc,mean_trait_loss,mean_distill_loss,mean_trait_curvature_vHv,mean_trait_curvature_vHv_norm,mean_inner_pre,mean_cos_pre,mean_inner_post,mean_cos_post,frac_pos_pre,frac_pos_post,num_logged_rows
0,1,control,0.9296,0.1171,2.266227,0.004177,1.358804e-06,0.001369,5.2e-05,0.006719,5.2e-05,0.006719,0.722034,0.722034,295
1,1,intervention,0.9296,0.1106,2.290538,0.004914,1.255712e-06,0.002575,-0.000149,-0.006627,2.5e-05,0.003333,0.088136,0.674576,295
2,2,control,0.9276,0.176,2.24869,0.00509,1.632704e-06,0.001696,5.5e-05,0.00702,5.5e-05,0.00702,0.813559,0.813559,295
3,2,intervention,0.9276,0.1157,2.288977,0.00621,2.204672e-06,0.003272,0.000218,0.00947,9e-06,0.000814,1.0,0.60678,295
4,3,control,0.9315,0.3504,2.261541,0.00537,1.291556e-06,0.001499,5.2e-05,0.006026,5.2e-05,0.006026,0.847458,0.847458,295
5,3,intervention,0.9315,0.2274,2.294408,0.006305,1.388646e-06,0.002306,5.9e-05,0.002724,2.3e-05,0.003108,0.894915,0.698305,295
6,4,control,0.9291,0.1962,2.281892,0.004539,2.827718e-06,0.001667,3.4e-05,0.005548,3.4e-05,0.005548,0.698305,0.698305,295
7,4,intervention,0.9291,0.2416,2.293795,0.005584,6.372073e-06,0.007422,-0.000234,-0.008345,1.5e-05,0.002173,0.050847,0.576271,295
8,5,control,0.9296,0.4617,2.245084,0.004345,3.776265e-07,0.001307,5.5e-05,0.006611,5.5e-05,0.006611,0.759322,0.759322,295
9,5,intervention,0.9296,0.2266,2.287129,0.004877,6.911313e-07,0.001875,7.8e-05,0.005937,2.2e-05,0.003308,0.99661,0.694915,295


In [4]:
!zip -r runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all.zip ./runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all

  adding: runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/ (stored 0%)
  adding: runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/data_cache/ (stored 0%)
  adding: runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/data_cache/MNIST/ (stored 0%)
  adding: runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/data_cache/MNIST/raw/ (stored 0%)
  adding: runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/data_cache/MNIST/raw/t10k-labels-idx1-ubyte (deflated 55%)
  adding: runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/data_cache/MNIST/raw/t10k-images-idx3-ubyte (deflated 79%)
  adding: runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/data_cache/MNIST/raw/t10k-images-idx3-ubyte.gz (deflated 0%)
  adding: runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/data_cache/MNIST/raw/t10k-labels-idx1-ub

In [5]:
summary_path = "/content/runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/summary_by_seed_condition_epoch0.csv"
df = pd.read_csv(summary_path)
num_cols = [c for c in df.select_dtypes(include="number").columns if c != "seed"]
avg_by_condition = (
    df.groupby("condition")[num_cols]
      .agg(["mean", "std", "count"])
)

In [6]:
cols = ["mean_distill_loss", "mean_trait_loss", "mean_cos_pre", "mean_cos_post", "student_test_acc", "frac_pos_pre", "frac_pos_post"]
df.groupby("condition")[cols].agg(["mean","std","count"])

Unnamed: 0_level_0,mean_distill_loss,mean_distill_loss,mean_distill_loss,mean_trait_loss,mean_trait_loss,mean_trait_loss,mean_cos_pre,mean_cos_pre,mean_cos_pre,mean_cos_post,mean_cos_post,mean_cos_post,student_test_acc,student_test_acc,student_test_acc,frac_pos_pre,frac_pos_pre,frac_pos_pre,frac_pos_post,frac_pos_post,frac_pos_post
Unnamed: 0_level_1,mean,std,count,mean,std,count,mean,std,count,mean,...,count,mean,std,count,mean,std,count,mean,std,count
condition,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
control,0.012524,0.002329,10,2.293811,0.005385,10,0.007518,0.003805,10,0.007518,...,10,0.26963,0.123794,10,0.808475,0.153283,10,0.808475,0.153283,10
intervention,0.013327,0.002461,10,2.299757,0.003338,10,0.004312,0.007086,10,0.002375,...,10,0.18664,0.059646,10,0.735593,0.353837,10,0.628814,0.106134,10


In [7]:
summary_path = "/content/runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/summary_by_seed_condition_all_epoch.csv"
df = pd.read_csv(summary_path)
num_cols = [c for c in df.select_dtypes(include="number").columns if c != "seed"]
avg_by_condition = (
    df.groupby("condition")[num_cols]
      .agg(["mean", "std", "count"])
)

In [8]:
cols = ["mean_distill_loss", "mean_trait_loss", "mean_cos_pre", "mean_cos_post", "student_test_acc", "frac_pos_pre", "frac_pos_post"]
df.groupby("condition")[cols].agg(["mean","std","count"])

Unnamed: 0_level_0,mean_distill_loss,mean_distill_loss,mean_distill_loss,mean_trait_loss,mean_trait_loss,mean_trait_loss,mean_cos_pre,mean_cos_pre,mean_cos_pre,mean_cos_post,mean_cos_post,mean_cos_post,student_test_acc,student_test_acc,student_test_acc,frac_pos_pre,frac_pos_pre,frac_pos_pre,frac_pos_post,frac_pos_post,frac_pos_post
Unnamed: 0_level_1,mean,std,count,mean,std,count,mean,std,count,mean,...,count,mean,std,count,mean,std,count,mean,std,count
condition,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
control,0.004679,0.000674,10,2.257484,0.012862,10,0.006373,0.001021,10,0.006373,...,10,0.26963,0.123794,10,0.819322,0.089245,10,0.819322,0.089245,10
intervention,0.00549,0.000856,10,2.291026,0.002841,10,0.001799,0.006037,10,0.00253,...,10,0.18664,0.059646,10,0.654576,0.424156,10,0.647797,0.047765,10
