This code is to analyze the dynamics of learning across multiple runs with varying hidden state initializations and neuronal activations. It loads runs, computes and aggregates per-run metrics, plots and saves mean/std loss across runs, aggregates metrics timeseries and gradients.

In [2]:
from pathlib import Path
from typing import Optional, Tuple, Dict, List
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
# CONFIG
data_dir = Path("../data/Ns100_SeqN100/")
model_root = Path("../Elman_SGD/Remap_predloss/N100T100/")

hidden_weights_inits = [
    "he",
    "shift",
    "cyclic-shift",
    "shift",
    "cmh",
    "mh",
    "ctridiag",
    "tridiag",
    "orthog",
]
input_types = ["gaussian", "onehot", "khot", "small-gaussian"]

SINGLE_DIR = "single-run"
MULTIRUNS_DIR = "multiruns"
RUN_PREFIX = "run_"
MODEL_FNAME = "Ns100_SeqN100_predloss_full.pth.tar"
HIDDEN_WEIGHTS_SUBDIR = "hidden-weights"

In [None]:
def _iter_multirun_files(base_dir: Path):
    """Yield (run_id, path) pairs for each run file found under multiruns/run_XX."""
    multiruns_dir = base_dir / MULTIRUNS_DIR
    if not multiruns_dir.exists():
        print(f"[WARN] Multirun dir does not exist: {multiruns_dir}")
        return
    for run_dir in sorted(multiruns_dir.glob(f"{RUN_PREFIX}*")):
        path = run_dir / MODEL_FNAME
        if path.exists():
            run_id = run_dir.name.replace(
                RUN_PREFIX, "", 1
            )  # Extract run ID ('00' from 'run_00')
            yield run_id, path

In [16]:
def _load_torch(p):
    """Load torch file; returns None if missing or corrupt."""
    try:
        return torch.load(p)
    except Exception as e:
        print(f"[WARN] Could not load {p}: {e}")
        return None

In [None]:
def _extract_loss_series(ckpt) -> Optional[List[float]]:
    """Extract loss series from checkpoint dict. Returns None if not found."""
    if ckpt is None:
        return None
    if "loss" in ckpt:
        return [float(x) for x in ckpt["loss"]]
    else:
        print("[WARN] No loss series found in checkpoint.")
        return None

In [None]:
def _metrics_from_loss(loss: Optional[List[float]]) -> Optional[Dict[str, float]]:
    """Given a loss list, compute final loss, best loss and best_epoch (index)."""
    if not loss:
        return None
    final_loss = float(loss[-1])
    best_epoch = int(np.argmin(loss))
    best_loss = float(loss[best_epoch])
    # auc (lower is better); trapezoidal rule
    auc = float(np.trapz(loss, dx=1.0))
    # time-to-110% of best (how fast it gets close to best)
    threshold = 1.1 * best_loss
    t110 = int(next((i for i, v in enumerate(loss) if v <= threshold), len(loss) - 1))
    return {
        "final_loss": final_loss,
        "best_loss": best_loss,
        "best_epoch": best_epoch,
        "loss_auc": auc,
        "time_to_110pct_best": t110,
    }

In [None]:
def _extract_metrics_list(ckpt) -> Optional[List[Dict]]:
    """Extract metrics list from checkpoint dict. Returns None if not found."""
    if ckpt is None:
        return None
    # save metric as list of dicts (per recorded epoch)
    m = ckpt.get("metrics", None)
    if isinstance(m, list) and (len(m) == 0 or isinstance(m[0], dict)):
        return m
    return None

In [None]:
def _metrics_df_from_list(metrics_list: List[Dict], run_id: str) -> pd.DataFrame:
    """Convert list of metrics dicts to a DataFrame, adding run_id column."""
    if not metrics_list:
        return pd.DataFrame()
    df = pd.DataFrame(metrics_list)
    df["run_id"] = run_id
    return df

In [None]:
def _extract_grad_list(ckpt) -> Optional[List[Dict]]:
    """Extract gradient norms list from checkpoint dict. Returns None if not found."""
    if ckpt is None:
        return None
    g = ckpt.get("grad_list", None)
    if isinstance(g, list):
        return g
    return None

In [None]:
def _reduce_grad_snapshot_paramwise(d: Dict[str, Dict[str, float]]) -> Dict[str, float]:
    """Reduce a single gradient snapshot (param -> stats dict) into global scalars. Keeps robust, comparable summaries"""
    if not d:
        return {}
    keys = ["mean", "std", "l2_norm", "mean_sq", "max_abs", "sparsity"]
    out = {f"grad_{k}_sum": 0.0 for k in keys}
    out.update({f"grad_{k}_mean": 0.0 for k in keys})
    out.update({f"grad_{k}_max": float("-inf") for k in keys})
    count = 0
    for stats in d.values():
        count += 1
        for k in keys:
            v = float(stats.get(k, 0.0))
            out[f"grad_{k}_sum"] += v
            out[f"grad_{k}_max"] = max(out[f"grad_{k}_max"], v)
        if count > 0:
            for k in keys:
                out[f"grad_{k}_mean"] = out[f"grad_{k}_sum"] / count
    return out

In [None]:
def _extract_history(ckpt) -> Optional[Dict[str, List]]:
    """Returns dict with keys present in history: 'epoch', 'grad_norm', 'loss', etc. Only keeps the list-like fields of equal length to 'epoch'."""
    if ckpt is None:
        return None
    history = ckpt.get("history", None)
    if not history or "epoch" not in history or not isinstance(history["epoch"], list):
        return None
    L = len(history["epoch"])
    out = {"epoch": list(map(int, history["epoch"]))}
    for k, v in history.items():
        if k == "epoch":
            continue
        if isinstance(v, list) and len(v) == L:
            out[k] = list(v)
    return out

In [None]:
def _attach_epoch_to_list(list: List[Dict], epoch_list=None) -> Optional[pd.DataFrame]:
    """Given a list of dicts (e.g. grad snapshots), attach epoch number if available."""
    if not list:
        return None
    df = pd.DataFrame(list)
    if epoch_list and len(epoch_list) == len(list):
        df["epoch"] = epoch_list
    else:
        df["epoch"] = range(len(list))
    return df

In [85]:
type(ckpt["history"])

dict

In [None]:
hidden_init = hidden_weights_inits[0]
input_type = input_types[0]


def collect_for_setting(hidden_init: str, input_type: str):
    """Collect data for a given (hidden_init, input_type) setting.

    Returns:
        per_run_rows: List of dicts with per-run summary metrics
        per_run_timeseries: Dict with keys:
            "losses": List of loss series (list of lists)
            "metrics_df_list": List of DataFrames with metrics time series
            "grad_df_list": List of DataFrames with gradient norms time series
    """
    base = model_root / hidden_init / input_type
    per_run_rows = []
    losses_all = []
    metrics_df_list = []
    grad_df_list = []

    for run_id, p in _iter_multirun_files(base):
        if run_id == "00":
            print(f"Run {run_id}: {p}")
            # Load checkpoint
            ckpt = _load_torch(p)
            print(f"  Keys: {list(ckpt.keys())}")

            # Extract loss series
            loss_series = _extract_loss_series(ckpt)
            if loss_series:
                losses_all.append(loss_series)

            # Load loss metrics from loss_series
            m = _metrics_from_loss(loss_series)
            if m:
                per_run_rows.append(
                    {
                        "hidden_init": hidden_init,
                        "input_type": input_type,
                        "run_kind": "multirun",
                        "run_id": run_id,
                        "path": str(p),
                        **m,
                    }
                )

            # Get metrics time series (over recorded epochs)
            mlist = _extract_metrics_list(ckpt)
            if mlist:
                metrics_df_list.append(_metrics_df_from_list(mlist, run_id))

            # Get gradient norms time series (over recorded epochs)
            glist = _extract_grad_list(ckpt)
            if glist and isinstance(glist[0], dict):
                reduced = [_reduce_grad_snapshot_paramwise(snap) for snap in glist]
                gdf = _attach_epoch_to_list(
                    reduced, epoch_list=ckpt.get("history", {}).get("epoch", None)
                )
                if gdf is not None:
                    gdf["run_id"] = run_id
                    grad_df_list.append(gdf)

            # Get history (epoch, grad_norm, loss, etc.)
            history = _extract_history(ckpt)
            if history:
                hist_df = pd.DataFrame(history)
                hist_df["run_id"] = run_id
        return per_run_rows, {
            "losses": losses_all,
            "metrics_df_list": metrics_df_list,
            "grad_df_list": grad_df_list,
            "history_df": hist_df,
        }

In [None]:
h_inits = hidden_weights_inits[:1]
in_types = input_types[:1]
# h_inits = h_inits or hidden_weight_inits
# in_types = in_types or input_types

all_rows = []
ts_bucket = {}  # (hidden_init, input_type) -> per-run timeseries dict

# Get per_run_row, losses, metrics, and gradients for each (hidden_init, input_type) setting
for h in h_inits:
    for it in in_types:
        print(f"Collecting for (hidden_init={h}, input_type={it})")
        rows, ts = collect_for_setting(h, it)
        if rows:
            all_rows.extend(rows)
        ts_bucket[(h, it)] = ts
per_run_df = (
    pd.DataFrame(all_rows)
    if all_rows
    else pd.DataFrame(
        columns=[
            "hidden_init",
            "input_type",
            "run_kind",
            "run_id",
            "path",
            "final_loss",
            "best_loss",
            "best_epoch",
            "loss_auc",
            "time_to_110pct_best",
        ]
    )
)

# Aggregate over multiruns (per setting)
agg_rows = []
if not per_run_df.empty:
    for (h, it), group in per_run_df.groupby(["hidden_init", "input_type"]):
        g_multi = group[group["run_kind"] == "multirun"]
        if g_multi.empty:
            agg_rows.append(
                {
                    "hidden_init": h,
                    "input_type": it,
                    "run_kind": "multirun",
                    "num_runs": 0,
                    "final_loss_mean": np.nan,
                    "final_loss_std": np.nan,
                    "best_loss_mean": np.nan,
                    "best_loss_std": np.nan,
                    "best_epoch_mean": np.nan,
                    "best_epoch_std": np.nan,
                    "loss_auc_mean": np.nan,
                    "loss_auc_std": np.nan,
                    "time_to_110pct_best_mean": np.nan,
                    "time_to_110pct_best_mean": np.nan,
                }
            )
        else:

            def s(col):
                return (
                    float(g_multi[col].mean()),
                    float(g_multi[col].std(ddof=1)) if g_multi.shape[0] > 1 else 0.0,
                )

            (final_loss_mean, final_loss_std) = s("final_loss")
            (best_loss_mean, best_loss_std) = s("best_loss")
            (best_epoch_mean, best_epoch_std) = s("best_epoch")
            (loss_auc_mean, loss_auc_std) = s("loss_auc")
            (t110_mean, t110_std) = s("time_to_110pct_best")
            agg_rows.append(
                {
                    "hidden_init": h,
                    "input_type": it,
                    "run_kind": "multirun",
                    "num_runs": g_multi.shape[0],
                    "final_loss_mean": final_loss_mean,
                    "final_loss_std": final_loss_std,
                    "best_loss_mean": best_loss_mean,
                    "best_loss_std": best_loss_std,
                    "best_epoch_mean": best_epoch_mean,
                    "best_epoch_std": best_epoch_std,
                    "loss_auc_mean": loss_auc_mean,
                    "loss_auc_std": loss_auc_std,
                    "time_to_110pct_best_mean": t110_mean,
                    "time_to_110pct_best_std": t110_std,
                }
            )
agg_df = pd.DataFrame(agg_rows)
# return per_run_df, agg_df, ts_bucket

Collecting for (hidden_init=he, input_type=gaussian)
Run 00: ../Elman_SGD/Remap_predloss/N100T100/he/gaussian/multiruns/run_00/Ns100_SeqN100_predloss_full.pth.tar
  Keys: ['state_dict', 'loss', 'history', 'grad_list', 'metrics', 'rng_init', 'init_hidden']


In [92]:
per_run_df

Unnamed: 0,hidden_init,input_type,run_kind,run_id,path,final_loss,best_loss,best_epoch,loss_auc,time_to_110pct_best
0,he,gaussian,multirun,0,../Elman_SGD/Remap_predloss/N100T100/he/gaussi...,0.004714,0.004714,49999,1128.609213,44661


In [93]:
agg_df

Unnamed: 0,hidden_init,input_type,run_kind,num_runs,final_loss_mean,final_loss_std,best_loss_mean,best_loss_std,best_epoch_mean,best_epoch_std,loss_auc_mean,loss_auc_std,time_to_110pct_best_mean,time_to_110pct_best_std
0,he,gaussian,multirun,1,0.004714,0.0,0.004714,0.0,49999.0,0.0,1128.609213,0.0,44661.0,0.0


In [98]:
ts_bucket.keys()

dict_keys([('he', 'gaussian')])

In [None]:
ts_bucket["he", "gaussian"].keys()

dict_keys(['losses', 'metrics_df_list', 'grad_df_list', 'history_df'])

In [104]:
print(ts_bucket["he", "gaussian"]["losses"][0][:10])
len(ts_bucket["he", "gaussian"]["losses"][0])

[0.17805549502372742, 0.17802384495735168, 0.17799217998981476, 0.17796054482460022, 0.17792890965938568, 0.17789728939533234, 0.177865669131279, 0.17783407866954803, 0.17780247330665588, 0.17777088284492493]


50000

In [None]:
print(ts_bucket["he", "gaussian"]["losses"][0][:10])
len(ts_bucket["he", "gaussian"]["losses"][0])

In [None]:
print(len(ts_bucket["he", "gaussian"]["metrics_df_list"][0]))
display(ts_bucket["he", "gaussian"]["metrics_df_list"][0].head())

50


Unnamed: 0,epoch,loss,loss_batch_mean,loss_batch_std,frob,drift_from_init,spectral_radius,spectral_norm,min_singular,cond_num,orth_err,w_max_abs,w_sparsity,act_mean,act_std,tanh_sat,run_id
0,0,0.178055,0.178055,0.0,5.774087,0.000262,0.61031,0.350312,-0.393485,-0.89028,7.434097,0.099981,0.0,0.280204,0.375673,0.0,0
1,1000,0.149493,0.149493,0.0,5.780917,0.245461,0.60999,0.373429,-0.395207,-0.944895,7.43145,0.107955,0.0,0.274906,0.385883,0.0,0
2,2000,0.123094,0.123094,0.0,5.806212,0.49052,0.629674,0.342245,-0.404419,-0.846265,7.422012,0.114865,0.0,0.264183,0.428688,0.0,0
3,3000,0.096362,0.096362,0.0,5.849654,0.755729,0.743869,0.372078,-0.447476,-0.831504,7.417208,0.121608,0.0,0.243764,0.498146,0.0,0
4,4000,0.072888,0.072888,0.0,5.902051,1.003481,0.93106,0.456371,-0.37712,-1.210149,7.432401,0.126674,0.0,0.220508,0.568742,0.0,0


In [109]:
print(len(ts_bucket["he", "gaussian"]["grad_df_list"][0]))
display(ts_bucket["he", "gaussian"]["grad_df_list"][0].head())

50


Unnamed: 0,grad_mean_sum,grad_std_sum,grad_l2_norm_sum,grad_mean_sq_sum,grad_max_abs_sum,grad_sparsity_sum,grad_mean_mean,grad_std_mean,grad_l2_norm_mean,grad_mean_sq_mean,grad_max_abs_mean,grad_sparsity_mean,grad_mean_max,grad_std_max,grad_l2_norm_max,grad_mean_sq_max,grad_max_abs_max,grad_sparsity_max,epoch,run_id
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1000,0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2000,0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3000,0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4000,0


In [None]:
# print(len(ts_bucket["he", "gaussian"]["history_df"][0]))
# display(ts_bucket["he", "gaussian"]["history_df"][0].head())

In [113]:
ts_bucket["he", "gaussian"]["history_df"]

KeyboardInterrupt: 