In [1]:
import os
import json
import time
import random
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Callable, Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import torchvision
import torchvision.transforms as T

print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

torch: 2.9.0+cpu
torchvision: 0.24.0+cpu
device: cpu


In [2]:
# Inducing conflicting gradients (PCGrad projection with reversed criterion)
#
# Conditions:
# - **Control**: normal auxiliary-logit distillation (student matches teacher aux logits on noise).
# - **Intervention**: whenever alignment is **positive** (g_distill · g_trait > 0),
#   modify the distill gradient by projecting it onto the normal plane of g_trait:
#       g' = g_distill - proj_{g_trait}(g_distill)
#
# Measurements (logged every N steps):
# - trait loss (CE on first 10 logits, on audit batches)
# - distill loss (KL on aux logits, on current noise batch)
# - trait curvature along the *actual update direction* v (v^T H_trait v)
# - final performance (student test accuracy on MNIST)
# - For the **intervention** condition, the audit batches used to compute g_trait for projection at step k
#   are the **exact same audit batches** used to compute the logged trait loss/curvature at step k
#   (when step k is a logging step).
#
# Output:
# - **one CSV per (seed, condition)**: runs/.../seed_01/control_metrics.csv, intervention_metrics.csv

@dataclass
class ExperimentConfig:
    # Runs
    seeds: List[int] = None # e.g., [1,2,...]
    out_dir: str = "./runs_mnist_conflicting_grad"

    # Data
    batch_size: int = 1024
    num_workers: int = 0  # keep 0 for strict determinism
    audit_size: int = 10_000
    noise_dataset_size: int = 60_000

    # Model (MLP from paper)
    hidden_dim: int = 256
    aux_m: int = 3

    # Training
    teacher_epochs: int = 5
    student_epochs: int = 5
    lr_teacher: float = 3e-4
    lr_student: float = 3e-4

    # Logging
    metrics_every_n_steps: int = 50
    # Trait batch size (USED BOTH for projection + for logged trait loss/curvature in intervention condition)
    # For control, it's only used at logging steps.
    audit_batches_for_trait: int = 1

    # Numerics
    eps: float = 1e-12

cfg = ExperimentConfig(
    seeds=list(range(1, 11)),
    out_dir="./runs_mnist_conflicting_grad",
    batch_size=1024,
    num_workers=0,
    audit_size=10_000,
    noise_dataset_size=60_000,
    hidden_dim=256,
    aux_m=3,
    teacher_epochs=5,
    student_epochs=5,
    lr_teacher=3e-4,
    lr_student=3e-4,
    metrics_every_n_steps=10,
    audit_batches_for_trait=1,
    eps=1e-12,
)

Path(cfg.out_dir).mkdir(parents=True, exist_ok=True)
cfg

def set_global_seed(seed: int, deterministic: bool = True) -> None:
    """Seed python, numpy, and torch. Optionally enable deterministic algorithms."""
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        try:
            torch.use_deterministic_algorithms(True)
        except Exception as e:
            print("Warning: could not enable full deterministic algorithms:", e)

def make_torch_generator(seed: int) -> torch.Generator:
    g = torch.Generator()
    g.manual_seed(seed)
    return g

class NoiseImages(Dataset):
    """Deterministic noise dataset: each index produces a reproducible noise image."""
    def __init__(self, length: int, seed: int, shape=(1, 28, 28), dist: str = "normal"):
        self.length = int(length)
        self.seed = int(seed)
        self.shape = tuple(shape)
        self.dist = dist

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, idx: int):
        # Per-index deterministic generation.
        g = torch.Generator()
        g.manual_seed(self.seed * 1_000_000 + int(idx))
        if self.dist == "normal":
            x = torch.randn(self.shape, generator=g)
        elif self.dist == "uniform":
            x = torch.rand(self.shape, generator=g) * 2 - 1
        else:
            raise ValueError(f"Unknown dist: {self.dist}")
        # Dummy label (unused)
        y = 0
        return x, y

def get_mnist_datasets(root: str):
    transform = T.Compose([T.ToTensor()])
    train = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform)
    test = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform)
    return train, test

def split_train_audit(train_ds, audit_size: int, seed: int):
    n = len(train_ds)
    audit_size = min(int(audit_size), n)
    train_size = n - audit_size
    g = make_torch_generator(seed)
    train_split, audit_split = random_split(train_ds, [train_size, audit_size], generator=g)
    return train_split, audit_split

def make_loader(ds, batch_size: int, shuffle: bool, seed: int, num_workers: int = 0):
    # Deterministic shuffling via DataLoader generator.
    g = make_torch_generator(seed)
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        generator=g,
        drop_last=False,
    )

class MLPClassifier(nn.Module):
    """MLP from the Subliminal Learning MNIST experiment: (784, 256, 256, 10+m) with ReLU."""
    def __init__(self, hidden_dim: int = 256, aux_m: int = 3):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.aux_m = aux_m
        self.fc1 = nn.Linear(28 * 28, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 10 + aux_m)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def build_model_mlp(cfg: ExperimentConfig) -> nn.Module:
    """Swap model architecture by changing this builder (or passing another builder to the runner)."""
    return MLPClassifier(hidden_dim=cfg.hidden_dim, aux_m=cfg.aux_m)

def logits_regular(logits: torch.Tensor) -> torch.Tensor:
    return logits[:, :10]

def logits_aux(logits: torch.Tensor, aux_m: int) -> torch.Tensor:
    return logits[:, 10:10 + aux_m]

@torch.no_grad()
def accuracy_on_loader(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        pred = logits_regular(model(x)).argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

def get_params(model: nn.Module) -> List[torch.nn.Parameter]:
    return [p for p in model.parameters() if p.requires_grad]

def flatten_grads_from_params(params: List[torch.nn.Parameter]) -> torch.Tensor:
    """Flatten gradients already stored in .grad (e.g., after backward())."""
    flats = []
    for p in params:
        if p.grad is None:
            flats.append(torch.zeros_like(p).view(-1))
        else:
            flats.append(p.grad.detach().view(-1))
    return torch.cat(flats)

def set_grads_from_flat(params: List[torch.nn.Parameter], flat: torch.Tensor) -> None:
    """Overwrite params[i].grad with slices from flat."""
    offset = 0
    for p in params:
        n = p.numel()
        g = flat[offset:offset + n].view_as(p).to(dtype=p.dtype, device=p.device)
        if p.grad is None:
            p.grad = g.clone()
        else:
            p.grad.copy_(g)
        offset += n
    assert offset == flat.numel()

def split_flat_like_params(params: List[torch.nn.Parameter], flat: torch.Tensor) -> List[torch.Tensor]:
    out = []
    offset = 0
    for p in params:
        n = p.numel()
        out.append(flat[offset:offset + n].view_as(p))
        offset += n
    assert offset == flat.numel()
    return out

def train_teacher(
    model: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    cfg: ExperimentConfig,
    device: torch.device,
) -> Dict[str, float]:
    """Train teacher on MNIST CE using regular logits only."""
    model = model.to(device)
    model.train()
    opt = torch.optim.Adam(model.parameters(), lr=cfg.lr_teacher)

    for _epoch in range(cfg.teacher_epochs):
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad(set_to_none=True)
            loss = F.cross_entropy(logits_regular(model(x)), y)
            loss.backward()
            opt.step()

    teacher_acc = accuracy_on_loader(model, test_loader, device)
    return {"teacher_test_acc": float(teacher_acc)}

def make_infinite_iterator(loader: DataLoader):
    it = iter(loader)
    while True:
        try:
            yield next(it)
        except StopIteration:
            it = iter(loader)

@torch.no_grad()
def collect_audit_batches(audit_stream, num_batches: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
    return [next(audit_stream) for _ in range(num_batches)]

def trait_loss_on_batches(
    model: nn.Module,
    batches: List[Tuple[torch.Tensor, torch.Tensor]],
    device: torch.device,
) -> torch.Tensor:
    losses = []
    for x, y in batches:
        x, y = x.to(device), y.to(device)
        losses.append(F.cross_entropy(logits_regular(model(x)), y))
    return torch.stack(losses).mean()

def compute_trait_grad_flat_from_batches(
    model: nn.Module,
    params: List[torch.nn.Parameter],
    batches: List[Tuple[torch.Tensor, torch.Tensor]],
    device: torch.device,
) -> Tuple[float, torch.Tensor]:
    loss = trait_loss_on_batches(model, batches, device)
    grads = torch.autograd.grad(loss, params, retain_graph=False, create_graph=False)
    g_flat = torch.cat([g.detach().view(-1) for g in grads])
    return float(loss.detach()), g_flat

def compute_trait_loss_and_curvature_vHv_from_batches(
    model: nn.Module,
    params: List[torch.nn.Parameter],
    batches: List[Tuple[torch.Tensor, torch.Tensor]],
    v_flat: torch.Tensor,
    device: torch.device,
) -> Tuple[float, float, float]:
    """
    Returns:
      trait_loss, v^T H v, (v^T H v) / ||v||^2
    """
    loss = trait_loss_on_batches(model, batches, device)

    grads = torch.autograd.grad(loss, params, retain_graph=True, create_graph=True)

    v_list = split_flat_like_params(params, v_flat.detach())
    gv = torch.zeros((), device=device)
    for g, v in zip(grads, v_list):
        gv = gv + (g * v).sum()

    hvp = torch.autograd.grad(gv, params, retain_graph=False, create_graph=False)
    hvp_flat = torch.cat([h.detach().view(-1) for h in hvp])

    v_det = v_flat.detach()
    vHv = float((v_det @ hvp_flat).detach())
    v_norm2 = float((v_det @ v_det).detach())
    vHv_norm = vHv / max(v_norm2, 1e-30)

    return float(loss.detach()), vHv, vHv_norm

def distill_loss_aux_only(student: nn.Module, teacher: nn.Module, x_noise: torch.Tensor, aux_m: int) -> torch.Tensor:
    with torch.no_grad():
        t_aux = logits_aux(teacher(x_noise), aux_m)
        t_prob = F.softmax(t_aux, dim=1)

    s_aux = logits_aux(student(x_noise), aux_m)
    s_logprob = F.log_softmax(s_aux, dim=1)

    kl = torch.nn.KLDivLoss(reduction="batchmean")
    return kl(s_logprob, t_prob)

# Intervention: PCGrad-style projection when dot(g_distill, g_trait) > 0
def project_distill_gradient_if_positive_alignment(
    g_distill: torch.Tensor,
    g_trait: torch.Tensor,
    eps: float = 1e-12,
) -> Tuple[torch.Tensor, bool, float]:
    dot = float((g_distill @ g_trait).detach())
    if dot > 0.0:
        denom = float((g_trait @ g_trait).detach()) + eps
        coeff = dot / denom
        g_update = g_distill - coeff * g_trait
        return g_update, True, dot
    return g_distill, False, dot

# Student run (control vs intervention)
#
# Key detail:
# - In **intervention**, each step we draw `audit_batches_for_trait` audit batches once (call it `trait_batches_k`).
#   Those are used to compute g_trait for projection.
# - If this step is also a logging step, we compute trait loss + curvature using that **same** `trait_batches_k`.
#
def run_student_condition(
    condition: str,  # "control" or "intervention"
    student: nn.Module,
    teacher: nn.Module,
    noise_loader: DataLoader,
    audit_loader: DataLoader,
    test_loader: DataLoader,
    cfg: ExperimentConfig,
    seed: int,
    device: torch.device,
) -> Tuple[pd.DataFrame, Dict[str, float]]:
    assert condition in ("control", "intervention")
    assert cfg.audit_batches_for_trait >= 1

    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval()
    for p in teacher.parameters():
        p.requires_grad_(False)

    params = get_params(student)
    opt = torch.optim.Adam(student.parameters(), lr=cfg.lr_student)

    audit_stream = make_infinite_iterator(audit_loader)

    logs: List[Dict[str, float]] = []
    global_step = 0
    projected_steps = 0

    student.train()
    for epoch in range(cfg.student_epochs):
        for x_noise, _ in noise_loader:
            x_noise = x_noise.to(device)

            # --- distill gradient ---
            opt.zero_grad(set_to_none=True)
            dloss = distill_loss_aux_only(student, teacher, x_noise, cfg.aux_m)
            dloss.backward()

            g_distill = flatten_grads_from_params(params)
            g_update = g_distill
            did_project = False
            dot_dt = float("nan")

            # We'll compute g_trait from trait_batches_k for:
            # - projection (intervention only, every step)
            # - logging (control and intervention, logging steps)
            trait_batches_k = None
            g_trait = None

            # --- intervention projection gate (uses trait batches every step) ---
            if condition == "intervention":
                trait_batches_k = collect_audit_batches(audit_stream, cfg.audit_batches_for_trait)
                _trait_loss_tmp, g_trait = compute_trait_grad_flat_from_batches(
                    model=student,
                    params=params,
                    batches=trait_batches_k,
                    device=device,
                )
                g_update, did_project, dot_dt = project_distill_gradient_if_positive_alignment(
                    g_distill=g_distill,
                    g_trait=g_trait,
                    eps=cfg.eps,
                )
                if did_project:
                    projected_steps += 1
                    set_grads_from_flat(params, g_update)

            # --- logging: trait loss + curvature + alignment (pre vs post) ---
            do_log = (global_step % cfg.metrics_every_n_steps == 0)
            if do_log:
                if condition == "control":
                    # Control: draw trait batches only at logging steps
                    trait_batches_k = collect_audit_batches(audit_stream, cfg.audit_batches_for_trait)

                    # Need g_trait to compute inner/cos; compute it on *these same batches*
                    _trait_loss_tmp, g_trait = compute_trait_grad_flat_from_batches(
                        model=student,
                        params=params,
                        batches=trait_batches_k,
                        device=device,
                    )
                else:
                    # Intervention: trait_batches_k and g_trait already computed this step
                    assert trait_batches_k is not None and g_trait is not None

                # Trait loss + curvature computed on the SAME batches.
                trait_loss_val, vHv, vHv_norm = compute_trait_loss_and_curvature_vHv_from_batches(
                    model=student,
                    params=params,
                    batches=trait_batches_k,
                    v_flat=g_update,  # curvature along actual update direction
                    device=device,
                )

                # Alignment metrics: before vs after
                # pre: distill vs trait
                inner_pre = float((g_distill @ g_trait).detach())
                cos_pre = float((g_distill @ g_trait).detach() / ((g_distill.norm() * g_trait.norm()).clamp_min(cfg.eps)))

                # post: update vs trait (same as pre for control; ~0 for projected steps)
                inner_post = float((g_update @ g_trait).detach())
                cos_post = float((g_update @ g_trait).detach() / ((g_update.norm() * g_trait.norm()).clamp_min(cfg.eps)))

                logs.append({
                    "seed": seed,
                    "condition": condition,
                    "step": global_step,
                    "epoch": epoch,

                    "trait_loss": trait_loss_val,
                    "distill_loss": float(dloss.detach()),

                    "trait_curvature_vHv": vHv,
                    "trait_curvature_vHv_norm": vHv_norm,

                    # alignment pre vs post
                    "inner_pre": inner_pre,
                    "cos_pre": cos_pre,
                    "inner_post": inner_post,
                    "cos_post": cos_post,

                    # intervention diagnostics
                    "did_project": int(did_project) if condition == "intervention" else 0,
                    "dot_gdistill_gtrait_step": dot_dt if condition == "intervention" else np.nan,
                })

            opt.step()
            global_step += 1

    student_acc = accuracy_on_loader(student, test_loader, device)

    df = pd.DataFrame(logs)
    info = {
        "student_test_acc": float(student_acc),
        "total_steps": int(global_step),
        "num_logged_rows": int(len(df)),
        "projected_steps": int(projected_steps),
        "projected_fraction": float(projected_steps / max(global_step, 1)),
    }
    return df, info

# One seed: train teacher once, run control + intervention students
def run_one_seed(
    seed: int,
    cfg: ExperimentConfig,
    build_model_fn: Callable[[ExperimentConfig], nn.Module],
    device: torch.device,
) -> Dict[str, Path]:
    set_global_seed(seed, deterministic=True)

    run_dir = Path(cfg.out_dir) / f"seed_{seed:02d}"
    run_dir.mkdir(parents=True, exist_ok=True)

    # Data
    data_root = str(Path(cfg.out_dir) / "data_cache")
    mnist_train, mnist_test = get_mnist_datasets(root=data_root)
    train_split, audit_split = split_train_audit(mnist_train, audit_size=cfg.audit_size, seed=seed)

    train_loader = make_loader(train_split, cfg.batch_size, shuffle=True,  seed=seed + 100, num_workers=cfg.num_workers)
    audit_loader = make_loader(audit_split, cfg.batch_size, shuffle=True,  seed=seed + 200, num_workers=cfg.num_workers)
    test_loader  = make_loader(mnist_test,  cfg.batch_size, shuffle=False, seed=seed + 300, num_workers=cfg.num_workers)

    # Same noise ordering for both conditions
    noise_ds = NoiseImages(length=cfg.noise_dataset_size, seed=seed + 400, shape=(1, 28, 28), dist="normal")
    noise_loader_control = make_loader(noise_ds, cfg.batch_size, shuffle=True, seed=seed + 500, num_workers=cfg.num_workers)
    noise_loader_interv  = make_loader(noise_ds, cfg.batch_size, shuffle=True, seed=seed + 500, num_workers=cfg.num_workers)

    # Reference init
    reference = build_model_fn(cfg)
    ref_state = {k: v.clone().detach().cpu() for k, v in reference.state_dict().items()}

    # Teacher
    teacher = build_model_fn(cfg)
    teacher.load_state_dict(ref_state)
    teacher_info = train_teacher(teacher, train_loader, test_loader, cfg, device)

    # Students (same init)
    student_control = build_model_fn(cfg); student_control.load_state_dict(ref_state)
    student_interv  = build_model_fn(cfg); student_interv.load_state_dict(ref_state)

    # Run conditions
    control_df, control_info = run_student_condition(
        condition="control",
        student=student_control,
        teacher=teacher,
        noise_loader=noise_loader_control,
        audit_loader=audit_loader,
        test_loader=test_loader,
        cfg=cfg,
        seed=seed,
        device=device,
    )
    interv_df, interv_info = run_student_condition(
        condition="intervention",
        student=student_interv,
        teacher=teacher,
        noise_loader=noise_loader_interv,
        audit_loader=audit_loader,
        test_loader=test_loader,
        cfg=cfg,
        seed=seed,
        device=device,
    )

    # Add final perf to all rows
    for df, info in [(control_df, control_info), (interv_df, interv_info)]:
        df["teacher_test_acc"] = teacher_info["teacher_test_acc"]
        df["student_test_acc"] = info["student_test_acc"]
        df["total_steps"] = info["total_steps"]
        df["projected_steps"] = info["projected_steps"]
        df["projected_fraction"] = info["projected_fraction"]

    # Save
    control_csv = run_dir / "control_metrics.csv"
    interv_csv  = run_dir / "intervention_metrics.csv"
    control_df.to_csv(control_csv, index=False)
    interv_df.to_csv(interv_csv, index=False)

    meta = {
        "seed": seed,
        "config": asdict(cfg),
        "teacher_info": teacher_info,
        "control_info": control_info,
        "intervention_info": interv_info,
        "created_at_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "device": str(device),
    }
    with open(run_dir / "metadata.json", "w") as f:
        json.dump(meta, f, indent=2)

    print(
        f"[seed {seed}] teacher_acc={teacher_info['teacher_test_acc']:.4f} | "
        f"control_acc={control_info['student_test_acc']:.4f} | "
        f"interv_acc={interv_info['student_test_acc']:.4f} | "
        f"proj_frac={interv_info['projected_fraction']:.3f} | "
        f"rows(control/interv)={len(control_df)}/{len(interv_df)}"
    )

    return {"control_csv": control_csv, "intervention_csv": interv_csv}

# Run all seeds (1..10)
all_paths = []
for s in cfg.seeds:
    all_paths.append(run_one_seed(seed=s, cfg=cfg, build_model_fn=build_model_mlp, device=device))

all_paths[0]

100%|██████████| 9.91M/9.91M [00:00<00:00, 130MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 19.7MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 72.4MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.62MB/s]


[seed 1] teacher_acc=0.9296 | control_acc=0.1172 | interv_acc=0.1032 | proj_frac=0.542 | rows(control/interv)=30/30
[seed 2] teacher_acc=0.9276 | control_acc=0.1767 | interv_acc=0.1009 | proj_frac=0.590 | rows(control/interv)=30/30
[seed 3] teacher_acc=0.9315 | control_acc=0.3496 | interv_acc=0.1009 | proj_frac=0.539 | rows(control/interv)=30/30
[seed 4] teacher_acc=0.9291 | control_acc=0.1957 | interv_acc=0.1135 | proj_frac=0.542 | rows(control/interv)=30/30
[seed 5] teacher_acc=0.9295 | control_acc=0.4641 | interv_acc=0.1138 | proj_frac=0.559 | rows(control/interv)=30/30
[seed 6] teacher_acc=0.9323 | control_acc=0.3792 | interv_acc=0.0892 | proj_frac=0.505 | rows(control/interv)=30/30
[seed 7] teacher_acc=0.9332 | control_acc=0.3520 | interv_acc=0.1009 | proj_frac=0.583 | rows(control/interv)=30/30
[seed 8] teacher_acc=0.9316 | control_acc=0.2049 | interv_acc=0.0767 | proj_frac=0.569 | rows(control/interv)=30/30
[seed 9] teacher_acc=0.9305 | control_acc=0.3538 | interv_acc=0.0892 | p

{'control_csv': PosixPath('runs_mnist_conflicting_grad/seed_01/control_metrics.csv'),
 'intervention_csv': PosixPath('runs_mnist_conflicting_grad/seed_01/intervention_metrics.csv')}

In [3]:
# Aggregate summary across runs
def aggregate_across_runs(out_dir: str) -> pd.DataFrame:
    out_dir = Path(out_dir)
    rows = []
    for seed_dir in sorted(out_dir.glob("seed_*")):
        for cond, fname in [("control", "control_metrics.csv"), ("intervention", "intervention_metrics.csv")]:
            p = seed_dir / fname
            if not p.exists():
                continue

            df = pd.read_csv(p)
            if len(df) == 0:
                continue

            # Helpers for safe means
            def _mean(col: str):
                return float(df[col].mean()) if col in df.columns else np.nan

            rows.append({
                "seed": int(df["seed"].iloc[0]),
                "condition": cond,

                # performance
                "teacher_test_acc": float(df["teacher_test_acc"].iloc[0]) if "teacher_test_acc" in df.columns else np.nan,
                "student_test_acc": float(df["student_test_acc"].iloc[0]) if "student_test_acc" in df.columns else np.nan,

                # losses
                "mean_trait_loss": _mean("trait_loss"),
                "mean_distill_loss": _mean("distill_loss"),

                # curvature
                "mean_trait_curvature_vHv": _mean("trait_curvature_vHv"),
                "mean_trait_curvature_vHv_norm": _mean("trait_curvature_vHv_norm"),

                # alignment summaries
                "mean_inner_pre": _mean("inner_pre"),
                "mean_cos_pre": _mean("cos_pre"),
                "mean_inner_post": _mean("inner_post"),
                "mean_cos_post": _mean("cos_post"),

                # fractions of positive alignment (pre/post)
                "frac_pos_pre": float((df["inner_pre"] > 0).mean()) if "inner_pre" in df.columns else np.nan,
                "frac_pos_post": float((df["inner_post"] > 0).mean()) if "inner_post" in df.columns else np.nan,

                # intervention stats (can be 0/NaN for control depending on the CSV)
                "projected_fraction": float(df["projected_fraction"].iloc[0]) if "projected_fraction" in df.columns else np.nan,
                "projected_steps": int(df["projected_steps"].iloc[0]) if "projected_steps" in df.columns else np.nan,

                "num_logged_rows": int(len(df)),
            })

    return pd.DataFrame(rows).sort_values(["seed", "condition"]).reset_index(drop=True)

summary = aggregate_across_runs(cfg.out_dir)
summary_path = Path(cfg.out_dir) / "summary_by_seed_condition.csv"
summary.to_csv(summary_path, index=False)
summary

Unnamed: 0,seed,condition,teacher_test_acc,student_test_acc,mean_trait_loss,mean_distill_loss,mean_trait_curvature_vHv,mean_trait_curvature_vHv_norm,mean_inner_pre,mean_cos_pre,mean_inner_post,mean_cos_post,frac_pos_pre,frac_pos_post,projected_fraction,projected_steps,num_logged_rows
0,1,control,0.9296,0.1172,2.266872,0.00536,3.015326e-06,0.001232,7.4e-05,0.008369,7.4e-05,0.008369,0.766667,0.766667,0.0,0,30
1,1,intervention,0.9296,0.1032,2.881673,0.005389,2.77616e-06,-0.001099,0.000547,0.013531,-0.000135,-0.007071,0.566667,0.066667,0.542373,160,30
2,2,control,0.9276,0.1767,2.249522,0.006357,4.175203e-06,0.001606,7.7e-05,0.006263,7.7e-05,0.006263,0.833333,0.833333,0.0,0,30
3,2,intervention,0.9276,0.1009,3.018697,0.006366,4.18236e-06,0.002693,-0.000123,-0.002574,-0.000398,-0.01112,0.433333,0.0,0.589831,174,30
4,3,control,0.9315,0.3496,2.2623,0.006645,3.738242e-06,0.001343,6.2e-05,0.007039,6.2e-05,0.007039,0.866667,0.866667,0.0,0,30
5,3,intervention,0.9315,0.1009,3.205902,0.00665,3.510744e-06,-0.000105,0.000922,0.015501,-0.000223,-0.006574,0.666667,0.066667,0.538983,159,30
6,4,control,0.9291,0.1957,2.281864,0.006152,5.36897e-06,0.001862,3.7e-05,0.005289,3.7e-05,0.005289,0.7,0.7,0.0,0,30
7,4,intervention,0.9291,0.1135,3.000659,0.006176,4.737265e-06,-0.0005,0.000521,0.013675,-0.000492,-0.016196,0.6,0.033333,0.542373,160,30
8,5,control,0.9295,0.4641,2.245749,0.004986,6.885627e-07,0.001375,7.7e-05,0.006763,7.7e-05,0.006763,0.8,0.8,0.0,0,30
9,5,intervention,0.9295,0.1138,3.431964,0.004997,4.936192e-07,0.001769,0.000405,0.006537,-0.000631,-0.013592,0.533333,0.0,0.559322,165,30


In [4]:
!zip -r runs_mnist_conflicting_grad.zip ./runs_mnist_conflicting_grad

  adding: runs_mnist_conflicting_grad/ (stored 0%)
  adding: runs_mnist_conflicting_grad/seed_03/ (stored 0%)
  adding: runs_mnist_conflicting_grad/seed_03/intervention_metrics.csv (deflated 65%)
  adding: runs_mnist_conflicting_grad/seed_03/metadata.json (deflated 59%)
  adding: runs_mnist_conflicting_grad/seed_03/control_metrics.csv (deflated 66%)
  adding: runs_mnist_conflicting_grad/seed_10/ (stored 0%)
  adding: runs_mnist_conflicting_grad/seed_10/intervention_metrics.csv (deflated 65%)
  adding: runs_mnist_conflicting_grad/seed_10/metadata.json (deflated 59%)
  adding: runs_mnist_conflicting_grad/seed_10/control_metrics.csv (deflated 66%)
  adding: runs_mnist_conflicting_grad/seed_06/ (stored 0%)
  adding: runs_mnist_conflicting_grad/seed_06/intervention_metrics.csv (deflated 67%)
  adding: runs_mnist_conflicting_grad/seed_06/metadata.json (deflated 59%)
  adding: runs_mnist_conflicting_grad/seed_06/control_metrics.csv (deflated 66%)
  adding: runs_mnist_conflicting_grad/summary_