In [None]:
import os
import sys
import numpy as np
import pandas as pd
import torch
from torch import nn
import matplotlib.pyplot as plt
from IPython.display import display
from scipy.optimize import curve_fit

# ---- Imports ----
sys.path.append("../scripts")
sys.path.append("../utility")
from network import KoopmanNet

# ============================
# Minimal configuration
# ============================
PROJECT_NAME = "Sep_21"
DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ENVS         = ["DampingPendulum", "Franka", "DoublePendulum", "Polynomial", "Kinova", "G1", "Go2"]
SEEDS        = [17382, 76849, 20965, 84902, 51194]

TRAIN_SAMPLES = [1000, 4000, 16000, 64000, 140000]

M_POLY = 100

U_DIM = {"Franka": 7, "DoublePendulum": 2, "DampingPendulum": 1, "G1": 23, "Go2": 12, "Kinova": 7}

NORMALIZE = {"G1": "norm", "Go2": "norm"}

GAMMA_DEFAULT   = 0.8
GAMMA_OVERRIDES = {"G1": 0.99, "Go2": 0.99}

REL_MULT_TARGETS = [1, 2, 4, 8, 16]

# ============================
# Helpers
# ============================
def gmean(vals, eps: float = 1e-12) -> float:
    arr = np.asarray(list(vals), dtype=float)
    arr = np.maximum(arr, eps)
    return float(np.exp(np.mean(np.log(arr))))

def env_has_control(env: str) -> bool:
    return U_DIM.get(env, 0) > 0

def find_dataset_path(env: str, ksteps: int, m_val: int) -> str | None:
    """Return the FIRST dataset path found for this env over its adjusted train sizes."""
    norm = NORMALIZE.get(env, "nonorm")
    if env == "Polynomial":
        path = os.path.join(
            "..", "data", "datasets",
            f"dataset_{env}_{norm}_m_{m_val}_Ktrain_140000_Kval_20000_Ktest_20000_Ksteps_{ksteps}.pt",
        )
    else:
        path = os.path.join(
            "..", "data", "datasets",
            f"dataset_{env}_{norm}_Ktrain_140000_Kval_20000_Ktest_20000_Ksteps_{ksteps}.pt",
        )
    if os.path.exists(path):
        return path
    else:
        return None


def evaluate_model(model, data, u_dim: int | None, gamma: float, state_dim: int, device) -> tuple[float, float]:
    model.eval()
    with torch.no_grad():
        T = data.shape[0]
        if u_dim is None or u_dim == 0:
            X = model.encode(data[0].to(device))
        else:
            X = model.encode(data[0, :, u_dim:].to(device))  # encode state-only slice
        z0 = X[:, state_dim:]

        # Geometrically weighted rollout MSE on states
        beta, beta_sum = 1.0, 0.0
        wloss = 0.0
        for t in range(T - 1):
            if u_dim is None or u_dim == 0:
                X      = model.forward(X, None)
                target = data[t + 1].to(device)
            else:
                X      = model.forward(X, data[t, :, :u_dim].to(device))
                target = data[t + 1, :, u_dim:].to(device)
            err   = nn.MSELoss()(X[:, :state_dim], target)
            wloss = wloss + beta * err
            beta_sum += beta
            beta *= gamma
        wloss = wloss / max(beta_sum, 1e-12)

        # Off-diagonal covariance loss on initial encoding (normalized)
        zc = z0 - z0.mean(dim=0, keepdim=True)
        C  = (zc.t() @ zc) / max(zc.size(0) - 1, 1)
        off = C - torch.diag(torch.diag(C))
        cov = torch.norm(off, p='fro') ** 2
        zdim = max(X.shape[1] - state_dim, 1)
        cov_norm = cov.item() / (zdim * (zdim - 1)) if zdim > 1 else cov.item()
        return float(wloss.item()), float(cov_norm)


def build_model_from_checkpoint(chkpt: dict, state_dim: int, u_dim_hint: int | None, device):
    layers = chkpt["layer"]
    sd     = chkpt["model"]

    # Nkoop from lA
    if "lA.weight" not in sd or sd["lA.weight"].ndim != 2:
        raise RuntimeError("Checkpoint missing a square lA.weight for Nkoop inference")
    Nkoop = int(sd["lA.weight"].shape[0])

    # u_dim from lB if present
    u_dim_trained = int(sd["lB.weight"].shape[1]) if ("lB.weight" in sd and sd["lB.weight"].ndim == 2) else None#int(u_dim_hint or 0)

    model = KoopmanNet(layers, Nkoop, u_dim_trained).to(device)
    model.load_state_dict(sd, strict=True)
    enc_dim = Nkoop - state_dim
    return model, u_dim_trained, enc_dim, Nkoop


# ============================
# 1) Load CSV log
# ============================
log_csv = os.path.join("..", "log", PROJECT_NAME, "koopman_results_log.csv")
assert os.path.exists(log_csv), f"CSV log not found: {log_csv}"
log = pd.read_csv(log_csv)

# Types & light cleanup
for col in ["use_covariance_loss", "use_control_loss"]:
    if col in log.columns:
        log[col] = log[col].astype(int)
for k, t in {"env_name": str, "seed": int, "model_path": str}.items():
    if k in log.columns:
        log[k] = log[k].astype(t)
assert "train_samples" in log.columns, "Log must include 'train_samples'"

# ============================
# 2) Infer state_dim per env from any available dataset
# ============================
env_state_dim: dict[str, int | None] = {}
for env in ENVS:
    ksteps = 1 if env == "Polynomial" else 15
    ds_path = find_dataset_path(env, ksteps, M_POLY)
    if ds_path is None:
        env_state_dim[env] = None
        continue
    data = torch.load(ds_path, weights_only=False)
    full_dim = int(data["Ktest_data"].shape[2])
    u = U_DIM.get(env, 0)
    env_state_dim[env] = full_dim - u

# ============================
# 3) Filter rows (lean rules)
# ============================
def row_ok(r) -> bool:
    env = r.get("env_name")
    if env not in ENVS:
        return False
    if r.get("seed") not in SEEDS:
        return False
    if r.get("train_samples") not in set(TRAIN_SAMPLES):
        return False
    if env == "Polynomial":
        try:
            if int(r.get("m")) != M_POLY:
                return False
        except Exception:
            return False
    if "use_covariance_loss" in r.index and r.get("use_covariance_loss") not in {0, 1}:
        return False
    if "use_control_loss" in r.index and r.get("use_control_loss") not in {0, 1}:
        return False
    return True

filtered = log[log.apply(row_ok, axis=1)].copy()
if filtered.empty:
    print("[Error] No models after filtering. Check envs/train_samples/seeds.")
    out_dir = PROJECT_NAME
    os.makedirs(out_dir, exist_ok=True)
    pd.DataFrame(columns=[
        "Environment","TrainSamples","EncodeDim","UseCovLoss","UseControlLoss",
        "WeightedError","NormalizedCovLoss"
    ]).to_csv(os.path.join(out_dir, "evaluation_summary.csv"), index=False)
    raise SystemExit(0)

# ============================
# 4) Index checkpoints by (env, Ntrain, enc_dim_logged, cov, ctrl)
# ============================
index = {}
for _, r in filtered.iterrows():
    key = (
        r["env_name"], int(r["train_samples"]), int(r.get("encode_dim", -1)),
        int(r.get("use_covariance_loss", 0)), int(r.get("use_control_loss", 0)),
    )
    index.setdefault(key, []).append((int(r["seed"]), r["model_path"]))
index["_z_actual"] = {}

# ============================
# 5) Evaluate & aggregate (geometric mean across seeds)
# ============================
agg = {}
for env in ENVS:
    ksteps = 1 if env == "Polynomial" else 15
    gamma  = GAMMA_OVERRIDES.get(env, GAMMA_DEFAULT)

    ds_path = find_dataset_path(env, ksteps, M_POLY)
    if ds_path is None:
        print(f"[Skip] No dataset found for {env}.")
        continue

    d = torch.load(ds_path, weights_only=False)
    test = torch.from_numpy(d["Ktest_data"]).float().to(DEVICE)
    u_hint = U_DIM.get(env, 0)
    state_dim = int(test.shape[2]) - int(u_hint)

    for key, ckpts in list(index.items()):
        if key == "_z_actual" or key[0] != env:
            continue
        _, N, enc_logged, cov_reg, ctrl_reg = key

        w_errs, covs = [], []
        z_recorded = False
        for seed, path in ckpts:
            if not isinstance(path, str) or not os.path.exists(path):
                print(f"[Warn] Missing checkpoint: {path}")
                continue
            try:
                chkpt = torch.load(path, map_location=DEVICE, weights_only=False)
                model, u_eval, z_actual, _ = build_model_from_checkpoint(chkpt, state_dim, u_hint, DEVICE)
                if not z_recorded:
                    index["_z_actual"][key] = z_actual
                    z_recorded = True
                w, c = evaluate_model(model, test, u_eval, gamma, state_dim, DEVICE)
                w_errs.append(w); covs.append(c)
            except Exception as e:
                print(f"[Error] Load/eval failed for {path}: {e}")
                continue
        if not w_errs:
            continue
        agg[key] = {
            "WeightedError_mean":     gmean(w_errs),
            "NormalizedCovLoss_mean": gmean(np.array(covs) + 1e-12),
        }

# ============================
# 6) Save summary CSV
# ============================
rows = []
for (env, N, enc_logged, cov_reg, ctrl_reg), met in agg.items():
    z = index["_z_actual"].get((env, N, enc_logged, cov_reg, ctrl_reg), enc_logged)
    rows.append({
        "Environment": env,
        "TrainSamples": N,
        "EncodeDim": z,
        "UseCovLoss": cov_reg,
        "UseControlLoss": ctrl_reg,
        "WeightedError": met["WeightedError_mean"],
        "NormalizedCovLoss": met["NormalizedCovLoss_mean"],
    })

df = pd.DataFrame(rows)
os.makedirs(PROJECT_NAME, exist_ok=True)
out_csv = os.path.join(PROJECT_NAME, "evaluation_summary.csv")
df.to_csv(out_csv, index=False)
print(f"Saved summary to {out_csv}")
if not df.empty:
    display(df)

# ============================
# 7) Per-env plots (cov=1 and ctrl=1 if env has control)
# ============================
if not df.empty:
    for env in ENVS:
        if env_has_control(env) and ("UseControlLoss" in df.columns):
            sub = df[(df.Environment == env) & (df.UseCovLoss == 1) & (df.UseControlLoss == 1)].copy()
            tag = "(cov=1, ctrl=1)"
        else:
            sub = df[(df.Environment == env) & (df.UseCovLoss == 1)].copy()
            tag = "(cov=1)"
        if sub.empty:
            print(f"[Skip] No eligible rows for {env} {tag}")
            continue

        out_dir = os.path.join(PROJECT_NAME, env)
        os.makedirs(out_dir, exist_ok=True)

        # A) TrainSamples vs error
        gA = (
            sub.groupby("TrainSamples", as_index=False)
               .agg(WeightedError_gmean=("WeightedError", lambda s: gmean(s.values)))
               .sort_values("TrainSamples")
        )
        if not gA.empty:
            fig, ax = plt.subplots(figsize=(8, 6))
            ax.plot(gA["TrainSamples"], gA["WeightedError_gmean"], marker="o")
            ax.set_yscale("log")
            ax.set_xlabel("Train Samples (Ktrain)")
            ax.set_ylabel("Weighted Prediction Error (MSE, geom. mean)")
            ax.set_title(f"{env} — TrainSamples vs Error {tag}")
            ax.grid(True, which="both", ls="--", alpha=0.6)
            fig.tight_layout(); p = os.path.join(out_dir, f"{env}_TrainSamples_vs_Error.png")
            plt.savefig(p, dpi=300); plt.close(fig); print(f"Saved: {p}")

        # B) EncodeDim vs error
        gB = (
            sub.groupby("EncodeDim", as_index=False)
               .agg(WeightedError_gmean=("WeightedError", lambda s: gmean(s.values)))
               .sort_values("EncodeDim")
        )
        if not gB.empty:
            fig, ax = plt.subplots(figsize=(8, 6))
            ax.plot(gB["EncodeDim"], gB["WeightedError_gmean"], marker="s")
            ax.set_xscale("linear"); ax.set_yscale("log")
            ax.set_xlabel("Encode Dimension (z)")
            ax.set_ylabel("Weighted Prediction Error (MSE, geom. mean)")
            ax.set_title(f"{env} — EncodeDim vs Error {tag}")
            ax.grid(True, which="both", ls="--", alpha=0.6)
            fig.tight_layout(); p = os.path.join(out_dir, f"{env}_EncodeDim_vs_Error.png")
            plt.savefig(p, dpi=300); plt.close(fig); print(f"Saved: {p}")

# ============================
# 8) Combined plot (normalized error vs relative multiplier)
# ============================
if not df.empty:
    def nearest_rel_mult(env: str, z_abs: float) -> float:
        st = env_state_dim.get(env, None)
        if not st or st <= 0:
            return float("nan")
        r = float(z_abs) / float(st)
        return float(min(REL_MULT_TARGETS, key=lambda m: abs(m - r)))

    rows = []
    for env in ENVS:
        if env_has_control(env) and ("UseControlLoss" in df.columns):
            sub = df[(df.Environment == env) & (df.UseCovLoss == 1) & (df.UseControlLoss == 1)].copy()
        else:
            sub = df[(df.Environment == env) & (df.UseCovLoss == 1)].copy()
        if sub.empty:
            continue

        sub["RelMult"] = sub["EncodeDim"].apply(lambda z: nearest_rel_mult(env, z))
        sub = sub.replace([np.inf, -np.inf], np.nan).dropna(subset=["RelMult"])
        g = (
            sub.groupby(["Environment", "RelMult"], as_index=False)
               .agg(WeightedError_gmean=("WeightedError", lambda s: gmean(s.values)))
               .sort_values("RelMult")
        )
        if g.empty:
            continue

        # Normalize by error at smallest multiplier
        E0 = g.loc[g["RelMult"].idxmin(), "WeightedError_gmean"]
        g["RelError"] = g["WeightedError_gmean"] / max(E0, 1e-12)

        # Log–log slope
        if g.shape[0] >= 2:
            x = np.log(g["RelMult"].to_numpy(dtype=float))
            y = np.log(np.maximum(g["RelError"].to_numpy(dtype=float), 1e-12))
            b1, _ = np.polyfit(x, y, 1)
            g["Slope"] = float(b1)
        else:
            g["Slope"] = float("nan")

        # Noise proxy: average of top-2 rel errors
        k = min(2, len(g))
        g["NoiseRel"] = float(g.tail(k)["RelError"].mean())
        rows.append(g)

    if rows:
        GG = pd.concat(rows, ignore_index=True)
        fig, ax = plt.subplots(figsize=(8, 6))
        for env in sorted(GG["Environment"].unique()):
            ge = GG[GG.Environment == env].sort_values("RelMult")
            (line,) = ax.plot(ge["RelMult"], ge["RelError"], marker="o", label=f"{env} (slope={ge['Slope'].iloc[0]:.2f})")
            ax.hlines(ge["NoiseRel"].iloc[0], ge["RelMult"].min(), ge["RelMult"].max(), linestyles="dashed", alpha=0.4, colors=[line.get_color()])
        try:
            ax.set_xscale("log", base=2)
        except TypeError:
            ax.set_xscale("log", basex=2)
        ax.set_xlabel("Relative encode multiplier (z / state_dim, log₂)")
        ax.set_ylabel("Relative prediction error (E / E@min multiplier)")
        ax.set_title("Normalized error vs relative encode multiplier (strict)")
        ax.grid(True, which="both", ls="--", alpha=0.6)
        ax.legend(ncol=2, fontsize=9)
        fig.tight_layout()
        p = os.path.join(PROJECT_NAME, "AllEnvs_RelError_vs_RelMultiplier_STRICT.png")
        plt.savefig(p, dpi=300); plt.close(fig); print(f"Saved combined plot: {p}")

# ============================
# 9) Scaling-law fits per env:  E(D) = A * D^{-alpha} + C
# ============================
if not df.empty:
    def strict_subset(dfx: pd.DataFrame, env: str) -> pd.DataFrame:
        if env_has_control(env) and ("UseControlLoss" in dfx.columns):
            return dfx[(dfx.Environment == env) & (dfx.UseCovLoss == 1) & (dfx.UseControlLoss == 1)].copy()
        return dfx[(dfx.Environment == env) & (dfx.UseCovLoss == 1)].copy()

    def agg_over_z(dfe: pd.DataFrame) -> pd.DataFrame:
        return (
            dfe.groupby("EncodeDim", as_index=False)
               .agg(WeightedError_gmean=("WeightedError", lambda s: gmean(s.values)))
               .sort_values("EncodeDim")
        )

    def scaling_model(D, A, alpha, C):
        return A * np.power(D, -alpha) + C

    fits = []
    for env in ENVS:
        sub = strict_subset(df, env)
        if sub.empty:
            print(f"[Skip] Scaling-law: no rows for {env}")
            continue
        G = agg_over_z(sub)
        if len(G) < 3:
            print(f"[Skip] Scaling-law: need >=3 points for {env}, have {len(G)}")
            continue

        D = G["EncodeDim"].astype(float).to_numpy()
        E = G["WeightedError_gmean"].astype(float).to_numpy()

        C0 = float(G.tail(min(2, len(G)))["WeightedError_gmean"].mean())
        A0 = max(float(E.max() - C0), 1e-6)
        alpha0 = 0.7

        try:
            popt, pcov = curve_fit(
                scaling_model, D, E,
                p0=[A0, alpha0, C0],
                bounds=([0.0, 0.0, 0.0], [np.inf, 4.0, np.inf]),
                maxfev=20000,
            )
            A_hat, alpha_hat, C_hat = map(float, popt)
            perr = np.sqrt(np.maximum(np.diag(pcov), 0.0))
            dA, dalpha, dC = map(float, perr)
        except Exception as ex:
            print(f"[Warn] curve_fit failed for {env}: {ex}")
            x = np.log(D); y = np.log(np.maximum(E, 1e-12))
            b1, b0 = np.polyfit(x, y, 1)
            A_hat, alpha_hat, C_hat = float(np.exp(b0)), -float(b1), 0.0
            dA = dalpha = dC = float("nan")

        E_pred = scaling_model(D, A_hat, alpha_hat, C_hat)
        ss_res = float(np.sum((E - E_pred) ** 2))
        ss_tot = float(np.sum((E - np.mean(E)) ** 2))
        R2_lin = 1.0 - ss_res / ss_tot if ss_tot > 0 else float("nan")

        mask = (E > C_hat + 1e-12) & (E_pred > C_hat + 1e-12)
        if np.count_nonzero(mask) >= 2:
            y_log = np.log(E[mask] - C_hat)
            yhat  = np.log(E_pred[mask] - C_hat)
            ss_res_l = float(np.sum((y_log - yhat) ** 2))
            ss_tot_l = float(np.sum((y_log - np.mean(y_log)) ** 2))
            R2_log = 1.0 - ss_res_l / ss_tot_l if ss_tot_l > 0 else float("nan")
        else:
            R2_log = float("nan")

        fits.append({
            "Environment": env,
            "A": A_hat, "A_se": dA,
            "alpha": alpha_hat, "alpha_se": dalpha,
            "C": C_hat, "C_se": dC,
            "R2_linear": R2_lin, "R2_log_after_C": R2_log,
            "n_points": int(len(D)),
            "D_min": float(np.min(D)), "D_max": float(np.max(D)),
        })

        # Plot overlay
        out_dir = os.path.join(PROJECT_NAME, env)
        os.makedirs(out_dir, exist_ok=True)
        Dg = np.geomspace(max(1e-6, D.min()), D.max(), 256)
        Eg = scaling_model(Dg, A_hat, alpha_hat, C_hat)

        fig, ax = plt.subplots(figsize=(7, 5))
        ax.plot(D, E, "o", label="geom. mean (per z)")
        ax.plot(Dg, Eg, "-", label=f"fit: A={A_hat:.3g}, α={alpha_hat:.2f}, C={C_hat:.3g}\nR²={R2_lin:.2f} (lin), {R2_log:.2f} (log−C)")
        ax.hlines(C_hat, D.min(), D.max(), linestyles="dashed", alpha=0.5, label="noise floor C")
        try:
            ax.set_xscale("log", base=2)
        except TypeError:
            ax.set_xscale("log", basex=2)
        ax.set_yscale("log")
        ax.set_xlabel("Encode Dimension (z)")
        ax.set_ylabel("Weighted Prediction Error (MSE)")
        ax.set_title(f"{env} — E(z)=A z^(−α)+C")
        ax.grid(True, which="both", ls="--", alpha=0.6)
        ax.legend(fontsize=9)
        fig.tight_layout()
        p = os.path.join(out_dir, f"{env}_ScalingLawFit.png")
        plt.savefig(p, dpi=300); plt.close(fig); print(f"Saved scaling-law fit plot: {p}")

    if fits:
        df_fits = pd.DataFrame(fits)
        p = os.path.join(PROJECT_NAME, "scaling_law_fits.csv")
        df_fits.to_csv(p, index=False)
        print(f"Saved scaling-law fit table: {p}")
        display(df_fits)
    else:
        print("[Skip] No scaling-law fits produced.")


[Skip] No dataset found for Franka.
[Skip] No dataset found for DoublePendulum.
[Skip] No dataset found for Polynomial.
[Skip] No dataset found for Kinova.
[Skip] No dataset found for G1.
[Skip] No dataset found for Go2.
Saved summary to Sep_21/evaluation_summary.csv


Unnamed: 0,Environment,TrainSamples,EncodeDim,UseCovLoss,UseControlLoss,WeightedError,NormalizedCovLoss
0,DampingPendulum,1000,2,0,0,0.013001,4868574.0


[Skip] No eligible rows for DampingPendulum (cov=1, ctrl=1)
[Skip] No eligible rows for Franka (cov=1, ctrl=1)
[Skip] No eligible rows for DoublePendulum (cov=1, ctrl=1)
[Skip] No eligible rows for Polynomial (cov=1)
[Skip] No eligible rows for Kinova (cov=1, ctrl=1)
[Skip] No eligible rows for G1 (cov=1, ctrl=1)
[Skip] No eligible rows for Go2 (cov=1, ctrl=1)
[Skip] Scaling-law: no rows for DampingPendulum
[Skip] Scaling-law: no rows for Franka
[Skip] Scaling-law: no rows for DoublePendulum
[Skip] Scaling-law: no rows for Polynomial
[Skip] Scaling-law: no rows for Kinova
[Skip] Scaling-law: no rows for G1
[Skip] Scaling-law: no rows for Go2
[Skip] No scaling-law fits produced.


In [None]:
# use all points instead of mean