In [None]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
from pathlib import Path
from contextlib import contextmanager


In [None]:

# W&B project info
ENTITY = "justin_yang-university-of-california-berkeley"
PROJECT = "cs182-project-GPT-opt"

# Local cache directory for plotting data (ignored by git)
PLOTTING_CACHE_DIR = Path("plotting") / "cache"
PLOTTING_CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Map high-level optimizer categories to sweep IDs or run IDs.
ATTN_CATEGORIES = {
    "AdamW": ["pxrm9xt8", "hkw05gba", "eifb7cyl", "4uytpec5"],
    "Muon-NS-All": ["qmzevjll", "lh2tzeh1", "capygls4"],
    "Muon-NS-VOFFN": ["88vr9lps"],
    "Muon-PE-All": ["emzbw2qc", "q02rln6i"],
    "Muon-PE-VOFFN": ["67g7mm51"],
}

# Titles for the attention entropy grid
ATTN_CATEGORY_TITLES = {
    "AdamW": "AdamW",
    "Muon-NS-All": "Muon Newton-Schulz (QKV, O, FFN)",
    "Muon-NS-VOFFN": "Muon Newton-Schulz (V, O, FFN)",
    "Muon-PE-All": "Muon Polar Express (QKV, O, FFN)",
    "Muon-PE-VOFFN": "Muon Polar Express (V, O, FFN)",
}

# Plot layout knobs for the combined attention grid
DEFAULT_YLIM = (0.0, 6.0)
USE_EPOCH_X = True
EPOCH_MAX = 0.3  # runs are 0.3 epochs

COMBINED_LEGEND_Y = 1.08
COMBINED_SUBTITLE_Y = 0.81
COMBINED_TITLE_Y = 0.87
COMBINED_TIGHT_RECT = (0, 0, 1, 0.86)
COMBINED_SUBTITLE_X_NUDGE = {
    "AdamW": -0.07,
    "Muon-NS-All": -0.03,
    "Muon-PE-All": 0,
    "Muon-NS-VOFFN": 0.03,
    "Muon-PE-VOFFN": 0.07,
}


In [None]:

def _collect_runs_for_ids(ids, entity=ENTITY, project=PROJECT):
    """Given a list of sweep IDs and/or run IDs, return a flat list of runs."""
    api = wandb.Api()
    runs = []
    for _id in ids:
        try:
            sweep = api.sweep(f"{entity}/{project}/{_id}")
            runs.extend(list(sweep.runs))
            continue
        except Exception:
            pass
        try:
            run = api.run(f"{entity}/{project}/{_id}")
            runs.append(run)
        except Exception as e:
            print(f"Warning: could not load sweep or run '{_id}': {e}")
    return runs


def _attn_cache_path(category_name, layer):
    safe_cat = str(category_name).replace("/", "_")
    return PLOTTING_CACHE_DIR / f"attn_{safe_cat}_layer{layer}.pkl"


def _load_attn_cache(category_name, layer):
    path = _attn_cache_path(category_name, layer)
    if path.exists():
        try:
            return pd.read_pickle(path)
        except Exception as e:
            print(f"Warning: failed to read attention cache {path}: {e}")
    return None


def _save_attn_cache(category_name, layer, df):
    path = _attn_cache_path(category_name, layer)
    try:
        df.to_pickle(path)
    except Exception as e:
        print(f"Warning: failed to write attention cache {path}: {e}")


In [None]:

def get_attn_entropy_for_ids(ids, layer, entity=ENTITY, project=PROJECT):
    """Aggregate attn/layer{layer}/entropy/mean across runs."""
    metric_key = f"attn/layer{layer}/entropy/mean"
    runs = _collect_runs_for_ids(ids, entity=entity, project=project)
    records = []

    for run in runs:
        lr = None
        try:
            lr = run.config["optimizer_params"]["args"]["lr"]
        except Exception:
            lr = run.config.get("optimizer_params.args.lr", None)
        if lr is None:
            continue

        for row in run.scan_history():
            if metric_key not in row:
                continue
            val = row[metric_key]
            step = row.get("_step")
            if step is None:
                continue
            if val is None:
                continue
            records.append({"lr": float(lr), "step": int(step), "value": float(val), "run_id": run.id})

    if not records:
        return pd.DataFrame(columns=["lr", "step", "mean", "std", "sem", "n"])

    df = pd.DataFrame(records)
    grouped = (
        df.groupby(["lr", "step"])["value"].agg(["mean", "std", "count"]).reset_index().rename(columns={"count": "n"})
    )
    grouped["sem"] = grouped["std"] / np.sqrt(grouped["n"].clip(lower=1))
    grouped = grouped.sort_values(["lr", "step"])
    return grouped


def plot_attn_entropy_category(
    category_name,
    ids,
    layer,
    ax=None,
    ylim=DEFAULT_YLIM,
    add_legend=True,
    use_epoch_x=USE_EPOCH_X,
    epoch_max=EPOCH_MAX,
    use_cache=True,
    lr_min=None,
    lr_max=None,
    print_end_stats=False,
):
    df = _load_attn_cache(category_name, layer) if use_cache else None
    if df is None:
        df = get_attn_entropy_for_ids(ids, layer)
        if use_cache and not df.empty:
            _save_attn_cache(category_name, layer, df)

    if category_name == "AdamW" and not df.empty:
        df = df[df["lr"] <= 0.0024 + 1e-12]
    if lr_min is not None:
        df = df[df["lr"] >= lr_min]
    if lr_max is not None:
        df = df[df["lr"] <= lr_max]

    if df.empty:
        print(f"No data found for {category_name}, layer {layer} after LR filtering.")
        return

    if ax is None:
        fig, ax = plt.subplots(figsize=(7, 4))

    sns.set_style("whitegrid")
    min_step = df["step"].min() if not df.empty else 0
    max_step = df["step"].max() if not df.empty else 0
    span = max(max_step - min_step, 1)

    for lr, sub in df.groupby("lr"):
        sub = sub.sort_values("step")
        steps = sub["step"].values.astype(float)
        mean_vals = sub["mean"].values
        sem_vals = sub["sem"].fillna(0.0).values

        if use_epoch_x:
            x_vals = (steps - min_step) / span * (epoch_max if epoch_max is not None else 1.0)
        else:
            x_vals = steps

        ax.plot(x_vals, mean_vals, label=f"{lr:g}")
        ax.fill_between(x_vals, mean_vals - sem_vals, mean_vals + sem_vals, alpha=0.25)

        if print_end_stats and len(mean_vals) > 0:
            final_mean = float(mean_vals[-1])
            final_sem = float(sem_vals[-1])
            lower = final_mean - final_sem
            upper = final_mean + final_sem
            print(
                f"{category_name}, layer {layer}, lr={lr:g}: "
                f"final mean={final_mean:.4f}, lower={lower:.4f}, upper={upper:.4f}"
            )

    title_base = ATTN_CATEGORY_TITLES.get(category_name, category_name)
    ax.set_title(f"Layer {layer} Attention Entropy for {title_base}", pad=14)
    ax.set_ylabel("Mean Entropy")
    if ylim is not None:
        ax.set_ylim(*ylim)

    if use_epoch_x:
        ax.set_xlabel("Epoch")
        if epoch_max is not None:
            ax.set_xlim(0.0, epoch_max)
    else:
        ax.set_xlabel("Step")
        ax.set_xlim(left=0.0)

    ax.tick_params(axis="y", labelleft=True)
    if add_legend:
        ax.legend(title="Learning Rate", fontsize=8)

    return ax


def plot_all_categories_combined(
    categories_order=("AdamW", "Muon-NS-All", "Muon-PE-All", "Muon-NS-VOFFN", "Muon-PE-VOFFN"),
    layers=(0, 5, 11),
    categories=ATTN_CATEGORIES,
    ylim=DEFAULT_YLIM,
    use_epoch_x=USE_EPOCH_X,
    epoch_max=EPOCH_MAX,
    use_cache=True,
    lr_min=None,
    lr_max=None,
    lr_ranges_by_category=None,
    base_lr_by_category=None,
    print_end_stats=False,
):
    n_rows = len(layers)
    n_cols = len(categories_order)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), sharex=True, sharey=True)
    if n_rows == 1:
        axes = np.array([axes])
    if n_cols == 1:
        axes = axes.reshape(n_rows, 1)

    for col, category_name in enumerate(categories_order):
        if category_name not in categories:
            print(f"Warning: category {category_name} missing from ATTN_CATEGORIES; skipping column.")
            continue
        ids = categories[category_name]
        cat_lr_min, cat_lr_max = lr_min, lr_max
        if lr_ranges_by_category and category_name in lr_ranges_by_category:
            cat_lr_min, cat_lr_max = lr_ranges_by_category[category_name]

        for row, layer in enumerate(layers):
            ax = axes[row, col]
            add_legend = row == 0
            plot_attn_entropy_category(
                category_name,
                ids,
                layer,
                ax=ax,
                ylim=ylim,
                add_legend=add_legend,
                use_epoch_x=use_epoch_x,
                epoch_max=epoch_max,
                use_cache=use_cache,
                lr_min=cat_lr_min,
                lr_max=cat_lr_max,
                print_end_stats=print_end_stats,
            )
            if row > 0:
                ax.set_xlabel("")
            if col > 0:
                ax.set_ylabel("")

        top_ax = axes[0, col]
        handles, labels = top_ax.get_legend_handles_labels()
        if top_ax.get_legend() is not None:
            top_ax.get_legend().remove()
        if handles and labels:
            fig.legend(
                handles,
                labels,
                title="Learning Rate",
                loc="upper center",
                bbox_to_anchor=(0.5, COMBINED_LEGEND_Y),
                ncol=min(4, len(labels)),
                fontsize=8,
            )
        subtitle = ATTN_CATEGORY_TITLES.get(category_name, category_name)
        x0, x1 = top_ax.get_position().x0, top_ax.get_position().x1
        x_center = 0.5 * (x0 + x1) + COMBINED_SUBTITLE_X_NUDGE.get(category_name, 0)
        fig.text(x_center, COMBINED_SUBTITLE_Y, subtitle, ha="center", va="center", fontsize=10, fontweight="bold")

    fig.suptitle("Attention Entropy vs Epoch", y=COMBINED_TITLE_Y, fontsize=16, fontweight="bold")
    fig.tight_layout(rect=COMBINED_TIGHT_RECT)
    plt.show()
    return fig, axes


In [None]:

# Combined attention entropy grid (per-category LR ranges)
lr_ranges = {
    "AdamW": (0.0003, 0.0024),
    "Muon-NS-All": (0.01, 0.1),
    "Muon-PE-All": (0.01, 0.1),
    "Muon-NS-VOFFN": (0.003, 0.01),
    "Muon-PE-VOFFN": (0.003, 0.01),
}

base_lrs = {
    "AdamW": 0.0003,
    "Muon-NS-All": 0.01,
    "Muon-PE-All": 0.01,
    "Muon-NS-VOFFN": 0.003,
    "Muon-PE-VOFFN": 0.003,
}

fig, axes = plot_all_categories_combined(
    lr_ranges_by_category=lr_ranges,
    base_lr_by_category=base_lrs,
    print_end_stats=False,
)


In [None]:

# Benji branch plots: condition number, effective rank, spectral gap for Polar safety sweep

def get_data(sweep_id, optimizer_param_name, optimizer_param_vals, update, layer, matrix, metric):
    update_prefix = "update_" if update else ""
    api = wandb.Api()
    sweep = api.sweep(f"{ENTITY}/{PROJECT}/{sweep_id}")
    data = {val: [] for val in optimizer_param_vals}

    for run in sweep.runs:
        run_param = run.config["optimizer_params"]["args"][optimizer_param_name]
        data[run_param].append(
            np.array([step.get(f"svd/{update_prefix}layer{layer}_{matrix}/{metric}") for step in run.scan_history()])
        )
    return data


def plot_data(
    sweep_id,
    param_name,
    param_id,
    param_vals,
    update,
    layer,
    matrix,
    metric,
    start,
    end,
    use_ci=False,
    log_scale=False,
    plot_med=True,
    save_plot=False,
):
    print("Getting Data...")
    data = get_data(sweep_id, param_id, param_vals, update, layer, matrix, metric)
    update_prefix = "update_" if update else ""

    df = {}
    for val in param_vals:
        arrays = data[val]
        valid_mask = arrays[0] != None
        steps = np.where(valid_mask)[0]
        valid_arrays = [x[valid_mask] for x in arrays]
        valid_arrays = np.asarray(valid_arrays, dtype=float)

        if use_ci:
            med_vec = np.mean(np.stack(valid_arrays), axis=0)
            std_vec = np.std(np.stack(valid_arrays), axis=0)
            N = np.stack(valid_arrays).shape[0]
            sem_vec = std_vec / np.sqrt(N)
            min_vec = med_vec - 1.96 * sem_vec
            max_vec = med_vec + 1.96 * sem_vec
        else:
            med_vec = np.median(np.stack(valid_arrays), axis=0)
            min_vec = np.min(np.stack(valid_arrays), axis=0)
            max_vec = np.max(np.stack(valid_arrays), axis=0)

        med_vec = np.asarray(med_vec, dtype=float)
        min_vec = np.asarray(min_vec, dtype=float)
        max_vec = np.asarray(max_vec, dtype=float)

        if log_scale:
            med_vec = np.log(med_vec)
            min_vec = np.log(min_vec)
            max_vec = np.log(max_vec)

        df[f"med_vec_{val}"] = med_vec
        df[f"min_vec_{val}"] = min_vec
        df[f"max_vec_{val}"] = max_vec

    df["Step"] = steps
    df = pd.DataFrame(df)
    df_iloc = df.iloc[start : end + 1]

    print("Plotting Data...")
    sns.set_style("whitegrid")
    for val in param_vals:
        if plot_med:
            sns.lineplot(data=df_iloc, x="Step", y=f"med_vec_{val}")
        steps = df_iloc["Step"]
        min_vec = df_iloc[f"min_vec_{val}"]
        max_vec = df_iloc[f"max_vec_{val}"]
        plt.fill_between(steps, min_vec, max_vec, alpha=0.3, label=f"{val}")

    plt.title(f"{update_prefix.replace('_', '').title()} Layer {layer} {matrix.title()} - {metric.replace('_', ' ').title()}")
    plt.xlabel("Step")
    ylabel = f"{metric.replace('_', ' ').title()} (Log Scale)" if log_scale else metric.replace("_", " ").title()
    plt.ylabel(ylabel)
    plt.legend(title=param_name)

    if save_plot:
        Path("plotting/plots").mkdir(parents=True, exist_ok=True)
        plt.savefig(f"./plotting/plots/{param_id}_{metric}_layer{layer}.jpg")
        print("Plot Saved.")

    plt.show()


In [None]:

# Reproduce Benji's polar_safety sweep plots
POLAR_SWEEP_ID = "5pllqyjx"
PARAM_NAME = "Safety Factor"
PARAM_ID = "polar_safety"
PARAM_VALS = [1, 1.01]

# Condition number (layers 0,5,11), log-scale
for layer in (0, 5, 11):
    plot_data(
        POLAR_SWEEP_ID,
        PARAM_NAME,
        PARAM_ID,
        PARAM_VALS,
        update=True,
        layer=layer,
        matrix="stacked",
        metric="condition_number",
        start=2,
        end=32,
        use_ci=True,
        log_scale=True,
        plot_med=False,
        save_plot=True,
    )

# Effective rank (layers 0,5,11)
for layer in (0, 5, 11):
    plot_data(
        POLAR_SWEEP_ID,
        PARAM_NAME,
        PARAM_ID,
        PARAM_VALS,
        update=True,
        layer=layer,
        matrix="stacked",
        metric="effective_rank",
        start=1,
        end=32,
        use_ci=True,
        log_scale=False,
        plot_med=False,
        save_plot=True,
    )

# Spectral gap (layers 0,5,11)
for layer in (0, 5, 11):
    plot_data(
        POLAR_SWEEP_ID,
        PARAM_NAME,
        PARAM_ID,
        PARAM_VALS,
        update=True,
        layer=layer,
        matrix="stacked",
        metric="spectral_gap",
        start=0,
        end=32,
        use_ci=False,
        log_scale=False,
        plot_med=False,
        save_plot=True,
    )


In [None]:

# Variant-based plots for Muon/PE/AdamW sweeps
VARIANT_SWEEP_IDS = ["vqkcitxv", "ov703ihc", "js8l8c7m"]
VARIANT_TITLE_MAP = {
    "pe_all": "Muon+PE",
    "ns_all": "Muon+NS",
    "pe_voffn": "Muon+PE(VO/FFN)",
    "ns_voffn": "Muon+NS(VO/FFN)",
    "pe_mod_all": "Muon+PE (cheap)",
    "adamw": "AdamW",
}
CANONICAL_VARIANTS = [
    "AdamW",
    "Muon+NS(VO/FFN)",
    "Muon+NS",
    "Muon+PE(VO/FFN)",
    "Muon+PE",
    "Muon+PE (cheap)",
    "Muon+PE (split)",
]


In [None]:

def _collect_variant_runs(sweep_ids=VARIANT_SWEEP_IDS, entity=ENTITY, project=PROJECT):
    api = wandb.Api()
    runs = []
    for sid in sweep_ids:
        try:
            sweep = api.sweep(f"{entity}/{project}/{sid}")
            runs.extend(list(sweep.runs))
        except Exception as e:
            print(f"Warning: could not load sweep '{sid}': {e}")
    return runs


def _get_variant_label(cfg):
    raw_variant = None
    muon_mode = None
    try:
        args = cfg["optimizer_params"]["args"]
        raw_variant = args.get("muon_variant")
        muon_mode = args.get("muon_mode")
    except Exception:
        pass
    if raw_variant is None:
        raw_variant = cfg.get("optimizer_params.args.muon_variant")
    if muon_mode is None:
        muon_mode = cfg.get("optimizer_params.args.muon_mode")
    if raw_variant is None:
        raw_variant = "adamw"
    if muon_mode == "split_qkv":
        return "Muon+PE (split)"
    return VARIANT_TITLE_MAP.get(raw_variant, raw_variant)


def _variant_cache_path(metric_key, sweep_ids):
    safe_metric = metric_key.replace("/", "_")
    sid_part = "_".join(sorted(sweep_ids))
    return PLOTTING_CACHE_DIR / f"variants_{safe_metric}_{sid_part}.pkl"


def _load_variant_cache(metric_key, sweep_ids):
    path = _variant_cache_path(metric_key, sweep_ids)
    if path.exists():
        try:
            return pd.read_pickle(path)
        except Exception as e:
            print(f"Warning: failed to read variant cache {path}: {e}")
    return None


def _save_variant_cache(metric_key, sweep_ids, df):
    path = _variant_cache_path(metric_key, sweep_ids)
    try:
        df.to_pickle(path)
    except Exception as e:
        print(f"Warning: failed to write variant cache {path}: {e}")


def get_loss_by_variant(metric_key, sweep_ids=VARIANT_SWEEP_IDS, entity=ENTITY, project=PROJECT):
    runs = _collect_variant_runs(sweep_ids=sweep_ids, entity=entity, project=project)
    records = []

    for run in runs:
        cfg = run.config
        variant = _get_variant_label(cfg)
        for i, row in enumerate(run.scan_history()):
            if metric_key not in row:
                continue
            val = row[metric_key]
            step = row.get("_step")
            if step is None:
                step = row.get("epoch")
            if step is None:
                step = i
            if val is None:
                continue
            records.append({"variant": variant, "step": float(step), "value": float(val), "run_id": run.id})

    if not records:
        print(f"{metric_key}: no records found in sweeps {sweep_ids}")
        return pd.DataFrame(columns=["variant", "step", "mean", "std", "sem", "n"])

    df = pd.DataFrame(records)
    grouped = (
        df.groupby(["variant", "step"])["value"].agg(["mean", "std", "count"]).reset_index().rename(columns={"count": "n"})
    )
    grouped["sem"] = grouped["std"] / np.sqrt(grouped["n"].clip(lower=1))
    grouped = grouped.sort_values(["variant", "step"])
    print(f"{metric_key}: found variants {sorted(grouped['variant'].unique())}")
    return grouped


In [None]:

def plot_full_and_zoomed_variants(
    sweep_ids=VARIANT_SWEEP_IDS,
    zoom_window=(0.95, 1.0),
    use_epoch_x=USE_EPOCH_X,
    epoch_max=EPOCH_MAX,
    use_cache=True,
    include_split=True,
):
    if use_cache:
        train_df_full = _load_variant_cache("train/loss", sweep_ids)
        val_df_full = _load_variant_cache("val/loss", sweep_ids)
        if train_df_full is None:
            train_df_full = get_loss_by_variant("train/loss", sweep_ids=sweep_ids)
            _save_variant_cache("train/loss", sweep_ids, train_df_full)
        if val_df_full is None:
            val_df_full = get_loss_by_variant("val/loss", sweep_ids=sweep_ids)
            _save_variant_cache("val/loss", sweep_ids, val_df_full)
    else:
        train_df_full = get_loss_by_variant("train/loss", sweep_ids=sweep_ids)
        val_df_full = get_loss_by_variant("val/loss", sweep_ids=sweep_ids)

    if train_df_full.empty and val_df_full.empty:
        print("No loss data found for the specified sweeps.")
        return

    train_df_zoom = train_df_full[train_df_full["variant"] != "AdamW"] if not train_df_full.empty else train_df_full
    val_df_zoom = val_df_full[val_df_full["variant"] != "AdamW"] if not val_df_full.empty else val_df_full

    variants_full = set()
    if not train_df_full.empty:
        variants_full.update(train_df_full["variant"].unique())
    if not val_df_full.empty:
        variants_full.update(val_df_full["variant"].unique())
    variants_full = [v for v in CANONICAL_VARIANTS if v in variants_full]

    variants_zoom = set()
    if not train_df_zoom.empty:
        variants_zoom.update(train_df_zoom["variant"].unique())
    if not val_df_zoom.empty:
        variants_zoom.update(val_df_zoom["variant"].unique())
    variants_zoom = [v for v in CANONICAL_VARIANTS if v in variants_zoom]

    if not include_split:
        variants_full = [v for v in variants_full if v != "Muon+PE (split)"]
        variants_zoom = [v for v in variants_zoom if v != "Muon+PE (split)"]

    if not variants_full:
        print("No variants to plot after filtering.")
        return

    palette = sns.color_palette(n_colors=len(CANONICAL_VARIANTS))
    color_map = {name: palette[i] for i, name in enumerate(CANONICAL_VARIANTS)}

    fig, axes = plt.subplots(2, 2, figsize=(10, 8))
    ax_train_full, ax_val_full = axes[0]
    ax_train_zoom, ax_val_zoom = axes[1]

    def _plot_full(df, ax, ylabel):
        if df.empty:
            ax.set_visible(False)
            return
        sns.set_style("whitegrid")
        min_step = df["step"].min() if not df.empty else 0.0
        max_step = df["step"].max() if not df.empty else 0.0
        span = max(max_step - min_step, 1e-8)
        for variant in variants_full:
            sub = df[df["variant"] == variant]
            if sub.empty:
                continue
            sub = sub.sort_values("step")
            steps = sub["step"].values.astype(float)
            mean_vals = sub["mean"].values
            sem_vals = sub["sem"].fillna(0.0).values
            x_vals = (steps - min_step) / span * (epoch_max if epoch_max is not None else 1.0) if use_epoch_x else steps
            color = color_map.get(variant)
            ax.plot(x_vals, mean_vals, color=color, alpha=1.0, label=variant)
            ax.fill_between(x_vals, mean_vals - sem_vals, mean_vals + sem_vals, alpha=0.2, color=color)
        if use_epoch_x:
            if epoch_max is not None:
                ax.set_xlim(0.0, epoch_max)
            ax.set_xlabel("Epoch")
        else:
            ax.set_xlim(left=0.0)
            ax.set_xlabel("Step")
        ax.set_ylabel(ylabel)
        ax.tick_params(axis="y", labelleft=True)

    def _plot_zoom(df, ax, ylabel):
        if df.empty:
            ax.set_visible(False)
            return
        sns.set_style("whitegrid")
        emin, emax = zoom_window
        min_step = df["step"].min() if not df.empty else 0.0
        max_step = df["step"].max() if not df.empty else 0.0
        span = max(max_step - min_step, 1e-8)
        df_local = df.copy()
        df_local["epoch"] = (df_local["step"] - min_step) / span * (epoch_max if epoch_max is not None else 1.0)
        zoom_ymin, zoom_ymax = None, None
        for variant in variants_zoom:
            sub = df_local[df_local["variant"] == variant]
            if sub.empty:
                continue
            sub = sub.sort_values("step")
            x_vals = sub["epoch"].values
            mean_vals = sub["mean"].values
            sem_vals = sub["sem"].fillna(0.0).values
            color = color_map.get(variant)
            ax.plot(x_vals, mean_vals, color=color, alpha=0.5)
            mask = (x_vals >= emin) & (x_vals <= emax)
            x_zoom = x_vals[mask]
            mean_zoom = mean_vals[mask]
            sem_zoom = sem_vals[mask]
            boundary_x = []
            boundary_mean = []
            boundary_sem = []
            if x_vals.size >= 2:
                for b in (emin, emax):
                    j = np.searchsorted(x_vals, b)
                    if 0 < j < x_vals.size:
                        x0, x1 = x_vals[j - 1], x_vals[j]
                        if x1 == x0:
                            continue
                        t = (b - x0) / (x1 - x0)
                        m = mean_vals[j - 1] + t * (mean_vals[j] - mean_vals[j - 1])
                        s = sem_vals[j - 1] + t * (sem_vals[j] - sem_vals[j - 1])
                        boundary_x.append(b)
                        boundary_mean.append(m)
                        boundary_sem.append(s)
            if boundary_x:
                x_ext = np.concatenate([x_zoom, np.array(boundary_x, dtype=float)])
                mean_ext = np.concatenate([mean_zoom, np.array(boundary_mean, dtype=float)])
                sem_ext = np.concatenate([sem_zoom, np.array(boundary_sem, dtype=float)])
                order = np.argsort(x_ext)
                x_zoom = x_ext[order]
                mean_zoom = mean_ext[order]
                sem_zoom = sem_ext[order]
            if x_zoom.size == 0:
                continue
            ax.plot(x_zoom, mean_zoom, color=color, alpha=1.0)
            ax.fill_between(x_zoom, mean_zoom - sem_zoom, mean_zoom + sem_zoom, alpha=0.2, color=color)
            cur_min = float((mean_zoom - sem_zoom).min())
            cur_max = float((mean_zoom + sem_zoom).max())
            zoom_ymin = cur_min if zoom_ymin is None else min(zoom_ymin, cur_min)
            zoom_ymax = cur_max if zoom_ymax is None else max(zoom_ymax, cur_max)
        ax.set_xlim(emin, emax)
        if zoom_ymin is not None and zoom_ymax is not None:
            pad = max(0.05 * (zoom_ymax - zoom_ymin), 0.02)
            ax.set_ylim(zoom_ymin - pad, zoom_ymax + pad)
        ax.set_ylabel(ylabel)
        ax.tick_params(axis="y", labelleft=True)

    _plot_full(train_df_full, ax_train_full, "Train Loss")
    _plot_full(val_df_full, ax_val_full, "Val Loss")
    _plot_zoom(train_df_zoom, ax_train_zoom, "Train Loss")
    _plot_zoom(val_df_zoom, ax_val_zoom, "Val Loss")

    xlabel = "Epoch" if use_epoch_x else "Step"
    ax_train_zoom.set_xlabel(xlabel)
    ax_val_zoom.set_xlabel(xlabel)

    handles, labels = ax_train_full.get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        title="Optimizer Variant",
        loc="upper center",
        bbox_to_anchor=(0.5, 0.97),
        ncol=min(3, len(labels)),
        fontsize=8,
    )

    fig.suptitle("Train and Validation Loss by Optimizer Variant", y=0.99, fontsize=16, fontweight="bold")
    fig.tight_layout(rect=(0, 0, 1, 0.93))
    plt.show()
    return fig, axes


@contextmanager
def _temporary_variants(temp_variants):
    global CANONICAL_VARIANTS
    old_variants = CANONICAL_VARIANTS
    try:
        CANONICAL_VARIANTS = list(temp_variants)
        yield
    finally:
        CANONICAL_VARIANTS = old_variants


In [None]:

# Default full + zoomed variants plot
fig, axes = plot_full_and_zoomed_variants(
    sweep_ids=VARIANT_SWEEP_IDS,
    zoom_window=(0.95, 1.0),
    use_epoch_x=True,
    epoch_max=1.0,
    use_cache=True,
    include_split=True,
)

# Paper-specific subsets using temporary variant filtering
_base_variants = [v for v in CANONICAL_VARIANTS if v not in ("Muon+PE (cheap)", "Muon+PE (split)")]
with _temporary_variants(_base_variants):
    plot_full_and_zoomed_variants(zoom_window=(0.95, 1.0), epoch_max=1.0, include_split=False)

_pe_and_cheap = ["Muon+PE", "Muon+PE (cheap)"]
with _temporary_variants(_pe_and_cheap):
    plot_full_and_zoomed_variants(zoom_window=(0.95, 1.0), epoch_max=1.0, include_split=True)

_pe_family = ["Muon+PE", "Muon+PE(VO/FFN)", "Muon+PE (split)"]
with _temporary_variants(_pe_family):
    plot_full_and_zoomed_variants(zoom_window=(0.95, 1.0), epoch_max=1.0, include_split=True)
