In [1]:
import re
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

import zipfile
from pathlib import Path

In [7]:
def unpack_zip(zip_path, out_dir):
    zip_path = Path(zip_path)
    out_dir = Path(out_dir)
    if out_dir.exists():
        return
    if zip_path.exists():
        out_dir.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(zip_path, "r") as zf:
            zf.extractall(out_dir)
        print(f"Unzipped {zip_path} -> {out_dir}")

def discover_seed_dirs(out_dir: str | Path) -> list[tuple[int, Path]]:
    out_dir = Path(out_dir)
    pairs = []
    for d in out_dir.glob("seed_*"):
        if not d.is_dir():
            continue
        m = re.match(r"seed_(\d+)$", d.name)
        if not m:
            continue
        seed = int(m.group(1))
        if (d / "control_metrics.csv").exists() and (d / "intervention_metrics.csv").exists():
            pairs.append((seed, d))
    return sorted(pairs, key=lambda x: x[0])

def load_all_seed_metrics(out_dir: str | Path) -> tuple[list[pd.DataFrame], list[pd.DataFrame], list[int]]:
    out_dir = Path(out_dir)
    seed_dirs = discover_seed_dirs(out_dir)
    if not seed_dirs:
        raise FileNotFoundError(f"No seed_*/(control_metrics.csv, intervention_metrics.csv) found under: {out_dir}")

    controls, intervs, seeds = [], [], []
    for seed, seed_dir in seed_dirs:
        control = pd.read_csv(seed_dir / "control_metrics.csv")
        interv  = pd.read_csv(seed_dir / "intervention_metrics.csv")

        for df in (control, interv):
            for col in ["seed", "step", "epoch", "distill_loss", "trait_loss", "cos_pre", "cos_post"]:
                if col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors="coerce")
            if "step" in df.columns:
                df.sort_values("step", inplace=True)
                df.reset_index(drop=True, inplace=True)

        controls.append(control)
        intervs.append(interv)
        seeds.append(seed)

    return controls, intervs, seeds

def aggregate_stepwise(
    dfs: list[pd.DataFrame],
    value_col: str,
    step_col: str = "step",
    mode: str = "intersection",
    min_count: int | None = None,
) -> pd.DataFrame:
    """
    mode:
      - "intersection": keep steps that have at least `min_count` non-NaN values (default: all seeds)
      - "union": keep any step that appears (mean/sd computed over available seeds)
    """
    if min_count is None:
        min_count = len(dfs) if mode == "intersection" else 1

    long = pd.concat(
        [df[[step_col, "seed", value_col]].copy() for df in dfs if value_col in df.columns],
        ignore_index=True,
    )

    long = long.dropna(subset=[step_col, value_col])

    g = long.groupby(step_col)[value_col]
    agg = pd.DataFrame({
        "step": g.mean().index.astype(int),
        "mean": g.mean().values,
        "sd": g.std(ddof=0).values,
        "count": g.count().values,
    }).sort_values("step").reset_index(drop=True)

    if mode == "intersection":
        agg = agg[agg["count"] >= min_count].reset_index(drop=True)

    return agg

def _plot_mean_sd(ax, agg: pd.DataFrame, label: str, linestyle: str = "-"):
    x = agg["step"].to_numpy()
    m = agg["mean"].to_numpy()
    s = agg["sd"].to_numpy()

    (line,) = ax.plot(x, m, linewidth=2, label=label, linestyle=linestyle)
    ax.fill_between(x, m - s, m + s, alpha=0.25, color=line.get_color())

def plot_and_save_mean_dynamics_pdf(
    out_dir,
    save_dir="plots_pdf",
    cos_col="cos_post",
    show_interv_pre=True,
    basename=None,
    mode: str = "intersection",
    min_count: int | None = None,
):
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    if basename is None:
        basename = "mean_dynamics"

    pdf_path = save_dir / f"{basename}.pdf"
    svg_path = save_dir / f"{basename}.svg"

    controls, intervs, seeds = load_all_seed_metrics(out_dir)

    ctrl_distill = aggregate_stepwise(controls, "distill_loss", mode=mode, min_count=min_count)
    int_distill  = aggregate_stepwise(intervs,  "distill_loss", mode=mode, min_count=min_count)

    ctrl_trait = aggregate_stepwise(controls, "trait_loss", mode=mode, min_count=min_count)
    int_trait  = aggregate_stepwise(intervs,  "trait_loss", mode=mode, min_count=min_count)

    ctrl_cos = aggregate_stepwise(controls, cos_col, mode=mode, min_count=min_count)
    int_cos  = aggregate_stepwise(intervs,  cos_col, mode=mode, min_count=min_count)

    int_cos_pre = None
    if show_interv_pre and (cos_col != "cos_pre") and any(("cos_pre" in df.columns) for df in intervs):
        int_cos_pre = aggregate_stepwise(intervs, "cos_pre", mode=mode, min_count=min_count)

    mpl.rcParams.update({
        "text.usetex": False,
        "font.size":10,
        "axes.titlesize":15,
        "axes.labelsize":15,
        "xtick.labelsize":10,
        "ytick.labelsize":10,
        "figure.titlesize":15

    })

    fig, axes = plt.subplots(3, 1, figsize=(6.5, 8.0), sharex=True)

    # --- Distill loss ---
    _plot_mean_sd(axes[0], ctrl_distill, label="control", linestyle="-")
    _plot_mean_sd(axes[0], int_distill,  label="intervention", linestyle="--")
    axes[0].set_ylabel("KL divergence")
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # --- Trait loss ---
    _plot_mean_sd(axes[1], ctrl_trait, label="control", linestyle="-")
    _plot_mean_sd(axes[1], int_trait,  label="intervention", linestyle="--")
    axes[1].set_ylabel("Cross-entropy")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    # --- Cosine similarity ---
    _plot_mean_sd(axes[2], ctrl_cos, label="control", linestyle="-")
    _plot_mean_sd(axes[2], int_cos,  label=f"intervention ({cos_col})", linestyle="--")
    if int_cos_pre is not None:
        _plot_mean_sd(axes[2], int_cos_pre, label="intervention (cos_pre)", linestyle=":")
    axes[2].axhline(0.0, linewidth=1)
    axes[2].set_ylabel("Cosine similarity")
    axes[2].set_xlabel("Logged step")
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

    #fig.suptitle("Training dynamics (liminal fine-tuning)")
    fig.tight_layout()

    fig.savefig(svg_path, bbox_inches="tight")
    fig.savefig(pdf_path, bbox_inches="tight")
    plt.close(fig)

    print(f"Saved: {pdf_path.resolve()}")
    return pdf_path, svg_path

In [8]:
zip_path = "/content/runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_aux.zip"
zip_out_dir = "./runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_aux"
out_dir = "./runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_aux/runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_aux"
seed = 5

unpack_zip(zip_path, zip_out_dir)
pdf_path, svg_path = plot_and_save_mean_dynamics_pdf(out_dir, show_interv_pre=False, basename="mean_dynamics_liminal_training_early_then_decay_lambda0_1_all_step_original_aux")
pdf_path, svg_path

Saved: /content/plots_pdf/mean_dynamics_liminal_training_early_then_decay_lambda0_1_all_step_original_aux.pdf


(PosixPath('plots_pdf/mean_dynamics_liminal_training_early_then_decay_lambda0_1_all_step_original_aux.pdf'),
 PosixPath('plots_pdf/mean_dynamics_liminal_training_early_then_decay_lambda0_1_all_step_original_aux.svg'))

In [9]:
zip_path = "/content/runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all.zip"
zip_out_dir = "./runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all"
out_dir = "./runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all/runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_all"
seed = 5

unpack_zip(zip_path, zip_out_dir)
pdf_path, svg_path = plot_and_save_mean_dynamics_pdf(out_dir, show_interv_pre=False, basename="mean_dynamics_liminal_training_early_then_decay_lambda0_1_all_step_original_all")
pdf_path, svg_path

Saved: /content/plots_pdf/mean_dynamics_liminal_training_early_then_decay_lambda0_1_all_step_original_all.pdf


(PosixPath('plots_pdf/mean_dynamics_liminal_training_early_then_decay_lambda0_1_all_step_original_all.pdf'),
 PosixPath('plots_pdf/mean_dynamics_liminal_training_early_then_decay_lambda0_1_all_step_original_all.svg'))

In [10]:
zip_path = "/content/runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_reg10.zip"
zip_out_dir = "./runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_reg10"
out_dir = "./runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_reg10/runs_mnist_liminal_anchor_early_then_decay_lambda0_1_all_step_original_reg10"
seed = 5

unpack_zip(zip_path, zip_out_dir)
pdf_path, svg_path = plot_and_save_mean_dynamics_pdf(out_dir, show_interv_pre=False, basename="mean_dynamics_liminal_training_early_then_decay_lambda0_1_all_step_original_reg10")
pdf_path, svg_path

Saved: /content/plots_pdf/mean_dynamics_liminal_training_early_then_decay_lambda0_1_all_step_original_reg10.pdf


(PosixPath('plots_pdf/mean_dynamics_liminal_training_early_then_decay_lambda0_1_all_step_original_reg10.pdf'),
 PosixPath('plots_pdf/mean_dynamics_liminal_training_early_then_decay_lambda0_1_all_step_original_reg10.svg'))