In [1]:
import re
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

import zipfile
from pathlib import Path

In [11]:
def unpack_zip(zip_path, out_dir):
    zip_path = Path(zip_path)
    out_dir = Path(out_dir)
    if out_dir.exists():
        return
    if zip_path.exists():
        out_dir.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(zip_path, "r") as zf:
            zf.extractall(out_dir)
        print(f"Unzipped {zip_path} -> {out_dir}")

def discover_seed_dirs(out_dir: str | Path) -> list[tuple[int, Path]]:
    out_dir = Path(out_dir)
    pairs = []
    for d in out_dir.glob("seed_*"):
        if not d.is_dir():
            continue
        m = re.match(r"seed_(\d+)$", d.name)
        if not m:
            continue
        seed = int(m.group(1))
        if (d / "metrics.csv").exists():
            pairs.append((seed, d))
    return sorted(pairs, key=lambda x: x[0])



def load_seed_metrics(seed_dir: str | Path) -> pd.DataFrame:
    seed_dir = Path(seed_dir)
    df = pd.read_csv(seed_dir / "metrics.csv")

    numeric_cols = [
        "seed", "step", "epoch",
        "trait_test_acc_step",
        "inner_product", "cosine_similarity",
        "trait_loss", "distill_loss",
        "g_trait_norm", "g_distill_norm",
        "teacher_test_acc", "student_test_acc",
    ]
    for c in numeric_cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")

    if "step" in df.columns:
        df = df.sort_values("step").reset_index(drop=True)

    return df

def load_all_seeds(out_dir: str | Path) -> list[pd.DataFrame]:
    seed_dirs = discover_seed_dirs(out_dir)
    if not seed_dirs:
        raise FileNotFoundError(f"No seed_*/metrics.csv found under: {out_dir}")

    dfs = []
    for seed, d in seed_dirs:
        df = load_seed_metrics(d)
        if "seed" not in df.columns:
            df["seed"] = seed
        dfs.append(df)
    return dfs

def aggregate_stepwise(
    dfs: list[pd.DataFrame],
    value_col: str,
    step_col: str = "step",
    mode: str = "intersection",
    min_count: int | None = None,
) -> pd.DataFrame:
    """
    mode:
      - "intersection": keep steps that have at least `min_count` non-NaN values (default: all seeds)
      - "union": keep any step that appears (mean/sd computed over available seeds)
    """
    if min_count is None:
        min_count = len(dfs) if mode == "intersection" else 1

    long = pd.concat(
        [df[[step_col, "seed", value_col]].copy() for df in dfs if value_col in df.columns],
        ignore_index=True,
    )

    long = long.dropna(subset=[step_col, value_col])

    g = long.groupby(step_col)[value_col]
    agg = pd.DataFrame({
        "step": g.mean().index.astype(int),
        "mean": g.mean().values,
        "sd": g.std(ddof=0).values,
        "count": g.count().values,
    }).sort_values("step").reset_index(drop=True)

    if mode == "intersection":
        agg = agg[agg["count"] >= min_count].reset_index(drop=True)

    return agg

def _plot_mean_sd(ax, agg: pd.DataFrame, ylabel: str, title: str | None = None, zero_line: bool = False):
    x = agg["step"].to_numpy()
    m = agg["mean"].to_numpy()
    s = agg["sd"].to_numpy()

    ax.plot(x, m, linewidth=2)
    ax.fill_between(x, m - s, m + s, alpha=0.25)
    if zero_line:
        ax.axhline(0.0, linewidth=1)

    ax.set_ylabel(ylabel)
    if title:
        ax.set_title(title)
    ax.grid(True, alpha=0.3)

def plot_run_mean_dynamics(
    out_dir: str | Path,
    save_path: str | Path | None = None,
    mode: str = "intersection",
    min_count: int | None = None,
    trait_loss_col: str = "trait_loss",
    inner_col: str = "inner_product",
    cos_col: str = "cosine_similarity",
):
    dfs = load_all_seeds(out_dir)
    seeds = sorted({int(df["seed"].iloc[0]) for df in dfs if "seed" in df.columns})

    ce_agg = aggregate_stepwise(dfs, trait_loss_col, mode=mode, min_count=min_count)
    inn_agg = aggregate_stepwise(dfs, inner_col, mode=mode, min_count=min_count)
    cos_agg = aggregate_stepwise(dfs, cos_col, mode=mode, min_count=min_count)

    mpl.rcParams.update({
        "text.usetex": False,
        "font.size":10,
        "axes.titlesize":15,
        "axes.labelsize":15,
        "xtick.labelsize":10,
        "ytick.labelsize":10,
        "figure.titlesize":15
    })

    fig, axes = plt.subplots(3, 1, figsize=(7.2, 9.0), sharex=True)

    _plot_mean_sd(axes[0], ce_agg, ylabel="Cross-entropy")

    _plot_mean_sd(axes[1], inn_agg, ylabel="Inner product", zero_line=True)

    _plot_mean_sd(axes[2], cos_agg, ylabel="Cosine similarity", zero_line=True)
    axes[2].set_xlabel("Logged step")

    #fig.suptitle("Per-step training statistics (mean ± SD)")

    fig.tight_layout()

    if save_path is not None:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, bbox_inches="tight")
        plt.close(fig)
        print(f"Saved: {save_path.resolve()}")
        return save_path

    return fig, axes

In [None]:
def plot_run_mean_dynamics(
    out_dir: str | Path,
    save_path: str | Path | None = None,
    mode: str = "intersection",
    min_count: int | None = None,
    accuracy_col: str = "trait_test_acc_step",
    inner_col: str = "inner_product",
    cos_col: str = "cosine_similarity",
):
    dfs = load_all_seeds(out_dir)
    seeds = sorted({int(df["seed"].iloc[0]) for df in dfs if "seed" in df.columns})

    acc_agg = aggregate_stepwise(dfs, accuracy_col, mode=mode, min_count=min_count)
    inn_agg = aggregate_stepwise(dfs, inner_col, mode=mode, min_count=min_count)
    cos_agg = aggregate_stepwise(dfs, cos_col, mode=mode, min_count=min_count)

    mpl.rcParams.update({
        "text.usetex": False,
        "font.size": 10,
        "axes.titlesize": 14,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "figure.titlesize": 14,
    })

    fig, axes = plt.subplots(3, 1, figsize=(7.2, 9.0), sharex=True)

    _plot_mean_sd(axes[0], acc_agg, ylabel="Accuracy")
    axes[0].set_ylim(0, 0.4)

    _plot_mean_sd(axes[1], inn_agg, ylabel="Inner product", zero_line=True)

    _plot_mean_sd(axes[2], cos_agg, ylabel="Cosine similarity", zero_line=True)
    axes[2].set_xlabel("Logged step")

    fig.suptitle("Per-step training statistics (mean ± SD)")

    fig.tight_layout()

    if save_path is not None:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, bbox_inches="tight")
        plt.close(fig)
        print(f"Saved: {save_path.resolve()}")
        return save_path

    return fig, axes

In [12]:
zip_path = "/content/runs_mnist_subliminal_crt_period.zip"
zip_out_dir = "./runs_mnist_subliminal_crt_period"
out_dir = "./runs_mnist_subliminal_crt_period/runs_mnist_subliminal_crt_period"
out_dir_pdf = "/content"

unpack_zip(zip_path, zip_out_dir)
plot_run_mean_dynamics(
    out_dir=out_dir,
    save_path=Path(out_dir_pdf) / "plots" / "mean_dynamics.pdf"
)

plot_run_mean_dynamics(
    out_dir=out_dir,
    save_path=Path(out_dir_pdf) / "plots" / "mean_dynamics.svg"
)

Saved: /content/plots/mean_dynamics.pdf
Saved: /content/plots/mean_dynamics.svg


PosixPath('/content/plots/mean_dynamics.svg')

In [2]:
!unzip /content/runs_mnist_subliminal_crt_period.zip

Archive:  /content/runs_mnist_subliminal_crt_period.zip
   creating: runs_mnist_subliminal_crt_period/
   creating: runs_mnist_subliminal_crt_period/seed_06/
  inflating: runs_mnist_subliminal_crt_period/seed_06/metadata.json  
  inflating: runs_mnist_subliminal_crt_period/seed_06/metrics.csv  
   creating: runs_mnist_subliminal_crt_period/seed_02/
  inflating: runs_mnist_subliminal_crt_period/seed_02/metadata.json  
  inflating: runs_mnist_subliminal_crt_period/seed_02/metrics.csv  
   creating: runs_mnist_subliminal_crt_period/data_cache/
   creating: runs_mnist_subliminal_crt_period/data_cache/MNIST/
   creating: runs_mnist_subliminal_crt_period/data_cache/MNIST/raw/
 extracting: runs_mnist_subliminal_crt_period/data_cache/MNIST/raw/t10k-labels-idx1-ubyte.gz  
  inflating: runs_mnist_subliminal_crt_period/data_cache/MNIST/raw/t10k-images-idx3-ubyte  
 extracting: runs_mnist_subliminal_crt_period/data_cache/MNIST/raw/train-labels-idx1-ubyte.gz  
  inflating: runs_mnist_subliminal_crt