In [None]:
import pandas as pd
import matplotlib.pyplot as plt

import zipfile
from pathlib import Path

In [None]:
def unpack_zip(zip_path, out_dir):
    zip_path = Path(zip_path)
    out_dir = Path(out_dir)
    if out_dir.exists():
        return
    if zip_path.exists():
        out_dir.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(zip_path, "r") as zf:
            zf.extractall(out_dir)
        print(f"Unzipped {zip_path} -> {out_dir}")


def load_seed_metrics(out_dir, seed):
    out_dir = Path(out_dir)
    seed_dir = out_dir / f"seed_{seed:02d}"
    control_path = seed_dir / "control_metrics.csv"
    interv_path  = seed_dir / "intervention_metrics.csv"

    control = pd.read_csv(control_path)
    interv  = pd.read_csv(interv_path)

    for df in (control, interv):
        for col in ["step", "distill_loss", "trait_loss", "cos_pre", "cos_post"]:
            if col in df.columns:
                df[col] = pd.to_numeric(df[col], errors="coerce")
        df.sort_values("step", inplace=True)
        df.reset_index(drop=True, inplace=True)

    return control, interv


def plot_and_save_seed_dynamics_pdf(
    control,
    interv,
    seed,
    save_dir="plots_pdf",
    cos_col="cos_post",
    show_interv_pre=True,
    basename=None,
):
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    if basename is None:
        basename = f"seed_{seed:02d}_dynamics"

    pdf_path = save_dir / f"{basename}.pdf"

    fig, axes = plt.subplots(3, 1, figsize=(6.5, 8.0), sharex=True)

    # --- Distill loss ---
    axes[0].plot(control["step"], control["distill_loss"], label="control")
    axes[0].plot(interv["step"],  interv["distill_loss"],  label="intervention")
    axes[0].set_ylabel("KL divergence")
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # --- Trait loss ---
    axes[1].plot(control["step"], control["trait_loss"], label="control")
    axes[1].plot(interv["step"],  interv["trait_loss"],  label="intervention")
    axes[1].set_ylabel("Cross-entropy")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    # --- Cosine similarity ---
    axes[2].plot(control["step"], control[cos_col], label=f"control")
    axes[2].plot(interv["step"],  interv[cos_col],  label=f"intervention ({cos_col})")
    if show_interv_pre and (cos_col != "cos_pre") and ("cos_pre" in interv.columns):
        axes[2].plot(interv["step"], interv["cos_pre"], linestyle="--", label="intervention (cos_pre)")
    axes[2].axhline(0.0, linewidth=1)
    axes[2].set_ylabel("Cosine similarity")
    axes[2].set_xlabel("Logged step")
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

    fig.suptitle(f"Training dynamics (seed {seed:02d})")
    fig.tight_layout()

    fig.savefig(pdf_path, bbox_inches="tight")
    plt.close(fig)

    print(f"Saved: {pdf_path.resolve()}")
    return pdf_path

In [8]:
zip_path = "./runs_mnist_conflicting_grad.zip"
zip_out_dir = "./runs_mnist_conflicting_grad"
out_dir = "./runs_mnist_conflicting_grad/runs_mnist_conflicting_grad"
seed = 1

unpack_zip(zip_path, zip_out_dir)
control, interv = load_seed_metrics(out_dir, seed)
pdf_path = plot_and_save_seed_dynamics_pdf(
    control, interv,
    seed=seed,
    save_dir="plots",
    cos_col="cos_post",
    show_interv_pre=True,
)
pdf_path

Saved: /content/plots/seed_01_dynamics.pdf


PosixPath('plots/seed_01_dynamics.pdf')