# BAB Dataset: Linear & Nonlinear Models

Organized training and simulation for linear, Stribeck, and black-box models.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchdiffeq import odeint
import bab_datasets as nod

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1) Protocol + Explicit Train/Test Split


In [None]:
# Fix SSL certificate issue on macOS (Python 3.14)
# import ssl
# import certifi
# ssl._create_default_https_context = ssl._create_unverified_context

velMethod = "central"
PROTOCOL_MODE = "split_50_50"  # "split_50_50" or "classic_train_test"
TRAIN_DATASETS_CLASSIC = ["multisine_05"]
resample_factor = 50

# Load all preprocessed datasets once (explicitly reused for train/test protocol)
datasets_cache = {}
all_datasets = nod.list_experiments()
for ds_name in all_datasets:
    data_ds = nod.load_experiment(
        ds_name,
        preprocess=True,
        plot=False,
        end_idx=None,
        resample_factor=resample_factor,
        zoom_last_n=200,
        y_dot_method=velMethod,
    )
    u_ds, y_ds, y_ref_ds, y_dot_ds = data_ds
    Ts_ds = data_ds.sampling_time
    y_sim_ds = np.column_stack([y_ds, y_dot_ds])
    datasets_cache[ds_name] = {
        "u": u_ds,
        "y": y_ds,
        "y_ref": y_ref_ds,
        "y_dot": y_dot_ds,
        "y_sim": y_sim_ds,
        "Ts": Ts_ds,
        "N": len(u_ds),
    }

# Build explicit train/test index sets per dataset
core_datasets = [d for d in all_datasets if ("multisine" in d or "random_steps" in d)]
external_datasets = [d for d in all_datasets if d not in core_datasets]

split_map = {}
for ds_name in all_datasets:
    N = datasets_cache[ds_name]["N"]
    if PROTOCOL_MODE == "split_50_50":
        if ds_name in core_datasets:
            mid = N // 2
            split_map[ds_name] = {
                "train_idx": np.arange(0, mid, dtype=int),
                "test_idx": np.arange(mid, N, dtype=int),
            }
        else:
            split_map[ds_name] = {
                "train_idx": np.array([], dtype=int),
                "test_idx": np.arange(0, N, dtype=int),
            }
    elif PROTOCOL_MODE == "classic_train_test":
        if ds_name in TRAIN_DATASETS_CLASSIC:
            split_map[ds_name] = {
                "train_idx": np.arange(0, N, dtype=int),
                "test_idx": np.array([], dtype=int),
            }
        else:
            split_map[ds_name] = {
                "train_idx": np.array([], dtype=int),
                "test_idx": np.arange(0, N, dtype=int),
            }
    else:
        raise ValueError(f"Unknown PROTOCOL_MODE: {PROTOCOL_MODE}")

# Build explicit training arrays by concatenating only training partitions
train_blocks_u = []
train_blocks_y = []
train_block_meta = []
Ts_values = []
for ds_name in all_datasets:
    idx = split_map[ds_name]["train_idx"]
    if len(idx) == 0:
        continue
    cache = datasets_cache[ds_name]
    train_blocks_u.append(cache["u"][idx])
    train_blocks_y.append(cache["y_sim"][idx])
    train_block_meta.append((ds_name, len(idx)))
    Ts_values.append(cache["Ts"])

if len(train_blocks_u) == 0:
    raise RuntimeError("No training samples selected by the current protocol.")

# Check sampling time consistency before concatenation
Ts = Ts_values[0]
if not np.allclose(Ts_values, Ts, rtol=0, atol=1e-12):
    raise RuntimeError(f"Inconsistent Ts across training datasets: {Ts_values}")

u = np.concatenate(train_blocks_u, axis=0)
y_sim = np.concatenate(train_blocks_y, axis=0)
y = y_sim[:, 0]
y_dot = y_sim[:, 1]
y_ref = np.full_like(y, np.nan)
t = np.arange(len(u)) * Ts

# Segment boundaries in concatenated training vector (to avoid k-step windows crossing dataset boundaries)
train_segments = []
cursor = 0
for ds_name, seg_len in train_block_meta:
    train_segments.append((ds_name, cursor, cursor + seg_len))
    cursor += seg_len

print(f"Protocol mode: {PROTOCOL_MODE}")
print(f"Core datasets: {core_datasets}")
print(f"External datasets (test-only in split_50_50): {external_datasets}")
print(f"Training datasets/segments: {train_block_meta}")
print(f"Total training samples: {len(u)}")

In [None]:
"""
Plotting script for core and external datasets.
Append this to your existing data-loading code (after datasets_cache, split_map,
core_datasets, and external_datasets are defined).
"""

import matplotlib.pyplot as plt
import numpy as np


def plot_dataset_group(ds_names, datasets_cache, split_map, group_title, velMethod):
    """Plot u, y, y_dot for a list of dataset names, shading train/test regions."""
    n_ds = len(ds_names)
    if n_ds == 0:
        print(f"No datasets in group '{group_title}'.")
        return

    fig, axes = plt.subplots(n_ds, 3, figsize=(18, 3.5 * n_ds), squeeze=False)
    fig.suptitle(f"{group_title}  (velMethod = '{velMethod}')", fontsize=14, fontweight="bold", y=1.01)

    for row, ds_name in enumerate(ds_names):
        cache = datasets_cache[ds_name]
        Ts = cache["Ts"]
        N = cache["N"]
        t = np.arange(N) * Ts

        u = cache["u"]
        y = cache["y"]
        y_dot = cache["y_dot"]

        train_idx = split_map[ds_name]["train_idx"]
        test_idx = split_map[ds_name]["test_idx"]

        signals = [
            ("u  (input)", u, "tab:blue"),
            ("y  (output)", y, "tab:orange"),
            ("ẏ  (velocity)", y_dot, "tab:green"),
        ]

        for col, (label, sig, color) in enumerate(signals):
            ax = axes[row, col]
            ax.plot(t, sig, color=color, linewidth=0.7, alpha=0.85)

            # Shade train / test regions
            if len(train_idx) > 0:
                t_train_start = t[train_idx[0]]
                t_train_end = t[train_idx[-1]]
                ax.axvspan(t_train_start, t_train_end, alpha=0.08, color="blue", label="train")
            if len(test_idx) > 0:
                t_test_start = t[test_idx[0]]
                t_test_end = t[test_idx[-1]]
                ax.axvspan(t_test_start, t_test_end, alpha=0.08, color="red", label="test")

            ax.set_xlabel("Time [s]")
            ax.set_ylabel(label)
            if col == 0:
                ax.set_title(ds_name, fontsize=11, fontweight="bold", loc="left")
            if row == 0:
                ax.legend(fontsize=8, loc="upper right")
            ax.grid(True, alpha=0.3)

    fig.tight_layout()
    plt.show()


# ── Plot core datasets ──────────────────────────────────────────────
plot_dataset_group(core_datasets, datasets_cache, split_map,
                   "Core Datasets (multisine / random_steps)", velMethod)

# ── Plot external datasets ──────────────────────────────────────────
plot_dataset_group(external_datasets, datasets_cache, split_map,
                   "External Datasets (test-only)", velMethod)

## 2) Tensor Prep (From Explicit Training Split)


In [None]:
t_tensor = torch.tensor(t, dtype=torch.float32).to(device)
u_tensor = torch.tensor(u, dtype=torch.float32).reshape(-1, 1).to(device)
y_tensor = torch.tensor(y_sim, dtype=torch.float32).to(device)

# Precompute valid k-step start indices so minibatches never cross concatenated segment boundaries
K_STEPS = 20
valid_train_start_idx = []
for _, s0, s1 in train_segments:
    if (s1 - s0) > K_STEPS:
        valid_train_start_idx.extend(range(s0, s1 - K_STEPS))
valid_train_start_idx = np.asarray(valid_train_start_idx, dtype=int)

if len(valid_train_start_idx) == 0:
    raise RuntimeError("No valid k-step windows found in training split. Decrease K_STEPS or increase train split size.")

print(f"Tensor shapes -> t: {tuple(t_tensor.shape)}, u: {tuple(u_tensor.shape)}, y: {tuple(y_tensor.shape)}")
print(f"Valid train starts: {len(valid_train_start_idx)}")

## 3) Model Definitions

In [None]:
class LinearPhysODE(nn.Module):
    # J*thdd + R*thd + K*(th+delta) = Tau*V
    def __init__(self):
        super().__init__()
        self.log_J = nn.Parameter(torch.tensor(np.log(0.1), dtype=torch.float32))
        self.log_R = nn.Parameter(torch.tensor(np.log(0.1), dtype=torch.float32))
        self.log_K = nn.Parameter(torch.tensor(np.log(1.0), dtype=torch.float32))
        self.delta = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
        self.log_Tau = nn.Parameter(torch.tensor(np.log(1.0), dtype=torch.float32))
        self.u_series = None
        self.t_series = None
        self.batch_start_times = None

    def get_params(self):
        J = torch.exp(self.log_J)
        R = torch.exp(self.log_R)
        K = torch.exp(self.log_K)
        Tau = torch.exp(self.log_Tau)
        return J, R, K, self.delta, Tau

    def forward(self, t, x):
        J, R, K, delta, Tau = self.get_params()
        if self.batch_start_times is not None:
            t_abs = self.batch_start_times + t
        else:
            t_abs = t * torch.ones_like(x[:, 0:1])

        k_idx = torch.searchsorted(self.t_series, t_abs.reshape(-1), right=True)
        k_idx = torch.clamp(k_idx, 1, len(self.t_series) - 1)
        t1, t2 = self.t_series[k_idx - 1].unsqueeze(1), self.t_series[k_idx].unsqueeze(1)
        u1, u2 = self.u_series[k_idx - 1], self.u_series[k_idx]
        denom = (t2 - t1)
        denom[denom < 1e-6] = 1.0
        alpha = (t_abs - t1) / denom
        u_t = u1 + alpha * (u2 - u1)

        th, thd = x[:, 0:1], x[:, 1:2]
        thdd = (Tau * u_t - R * thd - K * (th + delta)) / J
        return torch.cat([thd, thdd], dim=1)


class StribeckPhysODE(nn.Module):
    # J*thdd + R*thd + K*(th+delta) + F_stribeck = Tau*V
    def __init__(self):
        super().__init__()
        self.log_J = nn.Parameter(torch.tensor(np.log(0.1), dtype=torch.float32))
        self.log_R = nn.Parameter(torch.tensor(np.log(0.1), dtype=torch.float32))
        self.log_K = nn.Parameter(torch.tensor(np.log(1.0), dtype=torch.float32))
        self.delta = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
        self.log_Tau = nn.Parameter(torch.tensor(np.log(1.0), dtype=torch.float32))

        self.log_Fc = nn.Parameter(torch.tensor(np.log(0.1), dtype=torch.float32))
        self.log_Fs = nn.Parameter(torch.tensor(np.log(0.2), dtype=torch.float32))
        self.log_vs = nn.Parameter(torch.tensor(np.log(0.1), dtype=torch.float32))
        self.log_b = nn.Parameter(torch.tensor(np.log(0.01), dtype=torch.float32))

        self.u_series = None
        self.t_series = None
        self.batch_start_times = None

    def get_params(self):
        J = torch.exp(self.log_J)
        R = torch.exp(self.log_R)
        K = torch.exp(self.log_K)
        Tau = torch.exp(self.log_Tau)
        Fc = torch.exp(self.log_Fc)
        Fs = torch.exp(self.log_Fs)
        vs = torch.exp(self.log_vs)
        b = torch.exp(self.log_b)
        return J, R, K, self.delta, Tau, Fc, Fs, vs, b

    def forward(self, t, x):
        J, R, K, delta, Tau, Fc, Fs, vs, b = self.get_params()
        if self.batch_start_times is not None:
            t_abs = self.batch_start_times + t
        else:
            t_abs = t * torch.ones_like(x[:, 0:1])

        k_idx = torch.searchsorted(self.t_series, t_abs.reshape(-1), right=True)
        k_idx = torch.clamp(k_idx, 1, len(self.t_series) - 1)
        t1, t2 = self.t_series[k_idx - 1].unsqueeze(1), self.t_series[k_idx].unsqueeze(1)
        u1, u2 = self.u_series[k_idx - 1], self.u_series[k_idx]
        denom = (t2 - t1)
        denom[denom < 1e-6] = 1.0
        alpha = (t_abs - t1) / denom
        u_t = u1 + alpha * (u2 - u1)

        th, thd = x[:, 0:1], x[:, 1:2]
        sgn = torch.tanh(thd / 1e-3)
        F_str = (Fc + (Fs - Fc) * torch.exp(-(thd / vs) ** 2)) * sgn + b * thd
        thdd = (Tau * u_t - R * thd - K * (th + delta) - F_str) / J
        return torch.cat([thd, thdd], dim=1)


class BlackBoxODE(nn.Module):
    def __init__(self, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3, hidden_dim),
            nn.SELU(),
            nn.AlphaDropout(0.05),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SELU(),
            nn.AlphaDropout(0.05),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SELU(),
            nn.Linear(hidden_dim // 2, 2)
        )
        self.u_series = None
        self.t_series = None
        self.batch_start_times = None

    def forward(self, t, x):
        if self.batch_start_times is not None:
            t_abs = self.batch_start_times + t
        else:
            t_abs = t * torch.ones_like(x[:, 0:1])

        k_idx = torch.searchsorted(self.t_series, t_abs.reshape(-1), right=True)
        k_idx = torch.clamp(k_idx, 1, len(self.t_series) - 1)
        t1, t2 = self.t_series[k_idx - 1].unsqueeze(1), self.t_series[k_idx].unsqueeze(1)
        u1, u2 = self.u_series[k_idx - 1], self.u_series[k_idx]
        denom = (t2 - t1)
        denom[denom < 1e-6] = 1.0
        alpha = (t_abs - t1) / denom
        u_t = u1 + alpha * (u2 - u1)

        nn_input = torch.cat([x, u_t], dim=1)
        return self.net(nn_input)


class HybridJointODE(nn.Module):
    # thdd = physics(theta, theta_dot, u) + NN residual
    def __init__(self, hidden_dim=128):
        super().__init__()
        self.log_J = nn.Parameter(torch.tensor(np.log(0.1), dtype=torch.float32))
        self.log_R = nn.Parameter(torch.tensor(np.log(0.1), dtype=torch.float32))
        self.log_K = nn.Parameter(torch.tensor(np.log(1.0), dtype=torch.float32))
        self.delta = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
        self.log_Tau = nn.Parameter(torch.tensor(np.log(1.0), dtype=torch.float32))

        self.net = nn.Sequential(
            nn.Linear(3, hidden_dim),
            nn.SELU(),
            nn.AlphaDropout(0.05),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SELU(),
            nn.AlphaDropout(0.05),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SELU(),
            nn.Linear(hidden_dim // 2, 1)
        )

        self.u_series = None
        self.t_series = None
        self.batch_start_times = None

    def _interp_u(self, t, x):
        if self.batch_start_times is not None:
            t_abs = self.batch_start_times + t
        else:
            t_abs = t * torch.ones_like(x[:, 0:1])

        k_idx = torch.searchsorted(self.t_series, t_abs.reshape(-1), right=True)
        k_idx = torch.clamp(k_idx, 1, len(self.t_series) - 1)

        t1, t2 = self.t_series[k_idx - 1].unsqueeze(1), self.t_series[k_idx].unsqueeze(1)
        u1, u2 = self.u_series[k_idx - 1], self.u_series[k_idx]
        denom = (t2 - t1)
        denom[denom < 1e-6] = 1.0
        alpha = (t_abs - t1) / denom
        return u1 + alpha * (u2 - u1)

    def forward(self, t, x):
        J = torch.exp(self.log_J)
        R = torch.exp(self.log_R)
        K = torch.exp(self.log_K)
        Tau = torch.exp(self.log_Tau)

        u_t = self._interp_u(t, x)
        th, thd = x[:, 0:1], x[:, 1:2]

        thdd_phys = (Tau * u_t - R * thd - K * (th + self.delta)) / J
        thdd_res = self.net(torch.cat([th, thd, u_t], dim=1))
        thdd = thdd_phys + thdd_res
        return torch.cat([thd, thdd], dim=1)


class HybridFrozenPhysODE(nn.Module):
    # thdd = frozen physics(theta, theta_dot, u) + NN residual
    def __init__(self, phys_model, hidden_dim=128):
        super().__init__()
        J0, R0, K0, delta0, Tau0 = phys_model.get_params()
        self.register_buffer('J0', J0.detach().clone())
        self.register_buffer('R0', R0.detach().clone())
        self.register_buffer('K0', K0.detach().clone())
        self.register_buffer('delta0', delta0.detach().clone())
        self.register_buffer('Tau0', Tau0.detach().clone())

        self.net = nn.Sequential(
            nn.Linear(3, hidden_dim),
            nn.SELU(),
            nn.AlphaDropout(0.05),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SELU(),
            nn.AlphaDropout(0.05),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SELU(),
            nn.Linear(hidden_dim // 2, 1)
        )

        self.u_series = None
        self.t_series = None
        self.batch_start_times = None

    def _interp_u(self, t, x):
        if self.batch_start_times is not None:
            t_abs = self.batch_start_times + t
        else:
            t_abs = t * torch.ones_like(x[:, 0:1])

        k_idx = torch.searchsorted(self.t_series, t_abs.reshape(-1), right=True)
        k_idx = torch.clamp(k_idx, 1, len(self.t_series) - 1)

        t1, t2 = self.t_series[k_idx - 1].unsqueeze(1), self.t_series[k_idx].unsqueeze(1)
        u1, u2 = self.u_series[k_idx - 1], self.u_series[k_idx]
        denom = (t2 - t1)
        denom[denom < 1e-6] = 1.0
        alpha = (t_abs - t1) / denom
        return u1 + alpha * (u2 - u1)

    def forward(self, t, x):
        u_t = self._interp_u(t, x)
        th, thd = x[:, 0:1], x[:, 1:2]

        thdd_phys = (self.Tau0 * u_t - self.R0 * thd - self.K0 * (th + self.delta0)) / self.J0
        thdd_res = self.net(torch.cat([th, thd, u_t], dim=1))
        thdd = thdd_phys + thdd_res
        return torch.cat([thd, thdd], dim=1)

## 4) Training Helper

In [None]:
def train_model_obs(model, name, epochs=500, lr=0.02, obs_dim=2):
    print(f"--- Training {name} ---")
    model.to(device)
    model.u_series = u_tensor
    model.t_series = t_tensor

    optimizer = optim.Adam(model.parameters(), lr=lr)
    BATCH_SIZE = 128
    dt_local = (t_tensor[1] - t_tensor[0]).item()
    t_eval = torch.arange(0, K_STEPS * dt_local, dt_local, device=device)

    for epoch in range(epochs + 1):
        optimizer.zero_grad()

        # Sample only from valid indices to keep each k-step rollout inside one training segment
        start_idx = np.random.choice(valid_train_start_idx, size=BATCH_SIZE, replace=True)
        x0 = y_tensor[start_idx]
        model.batch_start_times = t_tensor[start_idx].reshape(-1, 1)

        pred_state = odeint(model, x0, t_eval, method='rk4')
        pred_obs = pred_state[..., :obs_dim]

        batch_targets = []
        for i in start_idx:
            batch_targets.append(y_tensor[i:i + K_STEPS])
        y_target = torch.stack(batch_targets, dim=1)  # [T, B, 2]

        # Position-only loss (theta), velocity is used only as IC
        pred_pos = pred_obs[..., 0:1]
        target_pos = y_target[..., 0:1]
        loss = torch.mean((pred_pos - target_pos) ** 2)

        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            print(f"Epoch {epoch} | Loss: {loss.item():.6f}")

    return model

## 5) Train Linear Model

In [None]:
lin_model = LinearPhysODE()
lin_model = train_model_obs(lin_model, "Linear Model", epochs=1000, lr=0.01)

## 6) Train Stribeck Model

In [None]:
str_model = StribeckPhysODE()
str_model = train_model_obs(str_model, "Stribeck Model", epochs=1000, lr=0.01)

## 7) Train Black-box Model

In [None]:
bb_model = BlackBoxODE(hidden_dim=128)
bb_model = train_model_obs(bb_model, "Black-Box Model", epochs=1000, lr=0.01)

## 8) Train Hybrid Joint Model

In [None]:
hjoint_model = HybridJointODE(hidden_dim=128)
hjoint_model = train_model_obs(hjoint_model, "Hybrid-Joint Model", epochs=1000, lr=0.01)

## 9) Train Hybrid Frozen-Physics Model

In [None]:
hfrozen_model = HybridFrozenPhysODE(lin_model, hidden_dim=128)
hfrozen_model = train_model_obs(hfrozen_model, "Hybrid-Frozen Model", epochs=1000, lr=0.01)

## 8) Simulate Linear Model

In [None]:
with torch.no_grad():
    lin_model.batch_start_times = torch.zeros(1, 1).to(device)
    x0 = y_tensor[0].unsqueeze(0)
    pred_lin = odeint(lin_model, x0, t_tensor, method='rk4').squeeze(1).cpu().numpy()


## 9) Simulate Stribeck Model

In [None]:
with torch.no_grad():
    str_model.batch_start_times = torch.zeros(1, 1).to(device)
    x0 = y_tensor[0].unsqueeze(0)
    pred_str = odeint(str_model, x0, t_tensor, method='rk4').squeeze(1).cpu().numpy()


## 10) Simulate Black-box Model

In [None]:
with torch.no_grad():
    bb_model.batch_start_times = torch.zeros(1, 1).to(device)
    x0 = y_tensor[0].unsqueeze(0)
    pred_bb = odeint(bb_model, x0, t_tensor, method='rk4').squeeze(1).cpu().numpy()


## 12) Simulate Hybrid Joint Model

In [None]:
with torch.no_grad():
    hjoint_model.batch_start_times = torch.zeros(1, 1).to(device)
    x0 = y_tensor[0].unsqueeze(0)
    pred_hjoint = odeint(hjoint_model, x0, t_tensor, method='rk4').squeeze(1).cpu().numpy()


## 13) Simulate Hybrid Frozen-Physics Model

In [None]:
with torch.no_grad():
    hfrozen_model.batch_start_times = torch.zeros(1, 1).to(device)
    x0 = y_tensor[0].unsqueeze(0)
    pred_hfrozen = odeint(hfrozen_model, x0, t_tensor, method='rk4').squeeze(1).cpu().numpy()


## 11) Comparison Plots

In [None]:
# Metrics and residuals for all models
models = {
    "Linear": pred_lin,
    "Stribeck": pred_str,
    "Black-box": pred_bb,
    "Hybrid-Joint": pred_hjoint,
    "Hybrid-Frozen": pred_hfrozen,
}
colors = {
    "Linear": "tab:red",
    "Stribeck": "tab:blue",
    "Black-box": "tab:green",
    "Hybrid-Joint": "tab:orange",
    "Hybrid-Frozen": "tab:purple",
}
styles = {
    "Linear": "--",
    "Stribeck": "-.",
    "Black-box": ":",
    "Hybrid-Joint": "-",
    "Hybrid-Frozen": (0, (3, 1, 1, 1)),
}

residuals = {}
metrics = {}
for name, pred in models.items():
    res_pos = y_sim[:, 0] - pred[:, 0]
    res_vel = y_sim[:, 1] - pred[:, 1]
    rmse_pos = np.sqrt(np.mean(res_pos**2))
    rmse_vel = np.sqrt(np.mean(res_vel**2))
    ss_res_pos = np.sum(res_pos**2)
    ss_tot_pos = np.sum((y_sim[:, 0] - np.mean(y_sim[:, 0]))**2)
    ss_res_vel = np.sum(res_vel**2)
    ss_tot_vel = np.sum((y_sim[:, 1] - np.mean(y_sim[:, 1]))**2)
    r2_pos = 1 - ss_res_pos / ss_tot_pos if ss_tot_pos > 0 else np.nan
    r2_vel = 1 - ss_res_vel / ss_tot_vel if ss_tot_vel > 0 else np.nan
    fit_pos = 100 * (1 - np.linalg.norm(res_pos) / np.linalg.norm(y_sim[:, 0] - np.mean(y_sim[:, 0])))
    fit_vel = 100 * (1 - np.linalg.norm(res_vel) / np.linalg.norm(y_sim[:, 1] - np.mean(y_sim[:, 1])))
    residuals[name] = {"pos": res_pos, "vel": res_vel}
    metrics[name] = {"rmse_pos": rmse_pos, "rmse_vel": rmse_vel, "r2_pos": r2_pos, "r2_vel": r2_vel, "fit_pos": fit_pos, "fit_vel": fit_vel}

# Metrics tables
print("Metrics - Position")
print("Model        RMSE      R2        FIT%")
for name in models.keys():
    m = metrics[name]
    print(f"{name:<12} {m['rmse_pos']:<9.4f} {m['r2_pos']:<8.4f} {m['fit_pos']:<8.2f}")

print("Metrics - Velocity")
print("Model        RMSE      R2        FIT%")
for name in models.keys():
    m = metrics[name]
    print(f"{name:<12} {m['rmse_vel']:<9.4f} {m['r2_vel']:<8.4f} {m['fit_vel']:<8.2f}")

# Comparison plots (no metrics in titles)
plt.figure(figsize=(12, 8))

plt.subplot(2, 1, 1)
plt.plot(t, y_sim[:, 0], 'k-', alpha=0.6, linewidth=2, label='Measured')
for name, pred in models.items():
    plt.plot(t, pred[:, 0], linestyle=styles[name], color=colors[name], linewidth=1.5, label=name)
plt.ylabel("Position")
plt.legend()
plt.grid(True)
plt.title("Position: Measured vs Predicted")

plt.subplot(2, 1, 2)
plt.plot(t, y_sim[:, 1], 'k-', alpha=0.6, linewidth=2, label='Measured')
for name, pred in models.items():
    plt.plot(t, pred[:, 1], linestyle=styles[name], color=colors[name], linewidth=1.5, label=name)
plt.xlabel("Time (s)")
plt.ylabel("Velocity")
plt.legend()
plt.grid(True)
plt.title("Velocity: Measured vs Predicted")

plt.tight_layout()
plt.show()

# Zooms (position only)
win_sec = 5.0
win_n = int(win_sec / Ts)
starts = [0, max(0, (len(t) - win_n) // 2), max(0, len(t) - win_n)]
labels = ["Start", "Middle", "End"]

plt.figure(figsize=(12, 8))
for i, s in enumerate(starts):
    e = min(len(t), s + win_n)
    plt.subplot(3, 1, i + 1)
    plt.plot(t[s:e], y_sim[s:e, 0], 'k-', alpha=0.6, linewidth=2, label='Measured')
    for name, pred in models.items():
        plt.plot(t[s:e], pred[s:e, 0], linestyle=styles[name], color=colors[name], linewidth=1.5, label=name)
    plt.title(f"Zoom ({labels[i]})")
    plt.ylabel('Position')
    plt.grid(True)
    if i == 0:
        plt.legend()
    if i == 2:
        plt.xlabel('Time (s)')

plt.tight_layout()
plt.show()

# Residuals (all models)
plt.figure(figsize=(12, 8))
plt.subplot(2, 1, 1)
for name in models.keys():
    plt.plot(t, residuals[name]['pos'], color=colors[name], linewidth=1.2, label=name)
plt.axhline(0, color='gray', linewidth=1)
plt.title('Position Residuals')
plt.ylabel('Residual')
plt.grid(True)
plt.legend()

plt.subplot(2, 1, 2)
for name in models.keys():
    plt.plot(t, residuals[name]['vel'], color=colors[name], linewidth=1.2, label=name)
plt.axhline(0, color='gray', linewidth=1)
plt.title('Velocity Residuals')
plt.ylabel('Residual')
plt.xlabel('Time (s)')
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()

# Raincloud plot (position residuals)
from scipy.stats import gaussian_kde

labels = list(models.keys())
res_list = [residuals[name]['pos'] for name in labels]

plt.figure(figsize=(10, 5))
for i, (label, res) in enumerate(zip(labels, res_list)):
    res = res[np.isfinite(res)]
    kde = gaussian_kde(res)
    xs = np.linspace(np.min(res), np.max(res), 200)
    ys = kde(xs)
    ys = ys / ys.max() * 0.3
    plt.fill_between(xs, i + ys, i - ys, color=colors[label], alpha=0.3)
    q1, q2, q3 = np.percentile(res, [25, 50, 75])
    plt.plot([q1, q3], [i, i], color=colors[label], linewidth=6)
    plt.plot([q2, q2], [i-0.1, i+0.1], color='k', linewidth=1)
    jitter = (np.random.rand(len(res)) - 0.5) * 0.2
    plt.scatter(res, i + jitter, s=3, color=colors[label], alpha=0.3)

plt.yticks(range(len(labels)), labels)
plt.xlabel('Residual (Position)')
plt.title('Residual Raincloud Plot (Position)')
plt.grid(True, axis='x')
plt.tight_layout()
plt.show()

# y vs yhat (position)
plt.figure(figsize=(6, 6))
plt.plot([y_sim[:,0].min(), y_sim[:,0].max()], [y_sim[:,0].min(), y_sim[:,0].max()], 'k--', label='Ideal')
for name, pred in models.items():
    plt.scatter(y_sim[:,0], pred[:,0], s=6, alpha=0.4, color=colors[name], label=name)
plt.xlabel('Measured y')
plt.ylabel('Predicted y')
plt.title('y vs y_hat (Position)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

## 12) Diagnostics (Position)

In [None]:
# Residual ACF (position) for all models
N = len(t)
max_lag = min(2000, N - 1)
conf = 1.96 / np.sqrt(N)

plt.figure(figsize=(10, 4))
for name in models.keys():
    res = residuals[name]['pos'] - np.mean(residuals[name]['pos'])
    acf = np.correlate(res, res, mode='full')
    acf = acf[N-1:N+max_lag] / acf[N-1]
    plt.plot(np.arange(0, max_lag+1), acf, color=colors[name], linewidth=1.2, label=name)

plt.axhline(conf, color='red', linestyle='--', linewidth=1)
plt.axhline(-conf, color='red', linestyle='--', linewidth=1)
plt.title('Residual ACF (Position)')
plt.xlabel('Lag')
plt.ylabel('Autocorrelation')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# Spectrum: measured vs predictions (position)
freqs = np.fft.rfftfreq(len(t), d=Ts)
Y_meas = np.fft.rfft(y_sim[:, 0])

plt.figure(figsize=(10, 4))
plt.semilogy(freqs, np.abs(Y_meas), color='k', label='Measured')
for name, pred in models.items():
    Y_pred = np.fft.rfft(pred[:, 0])
    plt.semilogy(freqs, np.abs(Y_pred), color=colors[name], linewidth=1.2, label=name)
plt.title('Spectrum (Position)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# Spectrum of position residuals for all models
plt.figure(figsize=(10, 4))
plt.semilogy(freqs, np.abs(Y_meas), color='k', label='Measured')
for name in models.keys():
    # Calculate FFT of position residuals
    res_fft = np.fft.rfft(residuals[name]['pos'])
    # Plot magnitude spectrum on a semilogy scale
    plt.semilogy(freqs, np.abs(res_fft), color=colors[name], linewidth=1.2, label=name)

plt.title('Spectrum of Position Residuals')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude of FFT')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

## 13) Load Test Data and run prediction



In [None]:
# Reuse protocol from training cell
print("Using PROTOCOL_MODE:", PROTOCOL_MODE)
print("Using TRAIN_DATASETS_CLASSIC:", TRAIN_DATASETS_CLASSIC)

# Evaluate all datasets with protocol, print R2 tables, and plot rainclouds:
# - rows = models
# - columns = train/test
# - each subplot y-axis = datasets

# =========================
# 1) Evaluation config
# =========================
PROTOCOL_MODE = "split_50_50"  # "split_50_50" or "classic_train_test"
TRAIN_DATASETS_CLASSIC = ["multisine_05"]

model_objects = {
    "Linear": lin_model,
    "Stribeck": str_model,
    "Black-box": bb_model,
    "Hybrid-Joint": hjoint_model,
    "Hybrid-Frozen": hfrozen_model,
}
model_names = list(model_objects.keys())

all_datasets = nod.list_experiments()
core_datasets = [d for d in all_datasets if ("multisine" in d or "random_steps" in d)]
external_datasets = [d for d in all_datasets if d not in core_datasets]

print(f"PROTOCOL_MODE: {PROTOCOL_MODE}")
print(f"Core datasets: {core_datasets}")
print(f"External datasets: {external_datasets}")

eval_store = {}

# =========================
# 2) Simulate all datasets
# =========================
for ds_name in all_datasets:
    data_ds = nod.load_experiment(
        ds_name,
        preprocess=True,
        plot=False,
        end_idx=None,
        resample_factor=50,
        zoom_last_n=200,
        y_dot_method=velMethod,
    )

    u_ds, y_ds, y_ref_ds, y_dot_ds = data_ds
    Ts_ds = data_ds.sampling_time
    t_ds = np.arange(len(u_ds)) * Ts_ds
    y_sim_ds = np.column_stack([y_ds, y_dot_ds])

    n = len(t_ds)
    split_i = int(0.5 * n)

    if PROTOCOL_MODE == "split_50_50":
        if ds_name in core_datasets:
            train_idx = np.arange(0, split_i)
            test_idx = np.arange(split_i, n)
        else:
            train_idx = None
            test_idx = np.arange(0, n)
    elif PROTOCOL_MODE == "classic_train_test":
        if ds_name in TRAIN_DATASETS_CLASSIC:
            train_idx = np.arange(0, n)
            test_idx = None
        else:
            train_idx = None
            test_idx = np.arange(0, n)
    else:
        raise ValueError(f"Unknown PROTOCOL_MODE: {PROTOCOL_MODE}")

    t_ds_tensor = torch.tensor(t_ds, dtype=torch.float32).to(device)
    u_ds_tensor = torch.tensor(u_ds, dtype=torch.float32).reshape(-1, 1).to(device)
    y_ds_tensor = torch.tensor(y_sim_ds, dtype=torch.float32).to(device)

    preds_ds = {}
    with torch.no_grad():
        x0_ds = y_ds_tensor[0].unsqueeze(0)
        for model_name, model_obj in model_objects.items():
            model_obj.u_series = u_ds_tensor
            model_obj.t_series = t_ds_tensor
            model_obj.batch_start_times = torch.zeros(1, 1).to(device)
            pred = odeint(model_obj, x0_ds, t_ds_tensor, method="rk4").squeeze(1).cpu().numpy()
            preds_ds[model_name] = pred

    metrics_ds = {m: {"train": None, "test": None} for m in model_names}
    residuals_ds = {m: {"train": None, "test": None} for m in model_names}

    for mname, pred in preds_ds.items():
        for split_name, idx in [("train", train_idx), ("test", test_idx)]:
            if idx is None or len(idx) < 2:
                continue

            y_true = y_sim_ds[idx]
            y_hat = pred[idx]
            res_pos = y_true[:, 0] - y_hat[:, 0]
            res_vel = y_true[:, 1] - y_hat[:, 1]

            rmse_pos = np.sqrt(np.mean(res_pos**2))
            rmse_vel = np.sqrt(np.mean(res_vel**2))

            ss_res_pos = np.sum(res_pos**2)
            ss_tot_pos = np.sum((y_true[:, 0] - np.mean(y_true[:, 0]))**2)
            ss_res_vel = np.sum(res_vel**2)
            ss_tot_vel = np.sum((y_true[:, 1] - np.mean(y_true[:, 1]))**2)

            r2_pos = 1 - ss_res_pos / ss_tot_pos if ss_tot_pos > 0 else np.nan
            r2_vel = 1 - ss_res_vel / ss_tot_vel if ss_tot_vel > 0 else np.nan

            fit_pos = 100 * (1 - np.linalg.norm(res_pos) / np.linalg.norm(y_true[:, 0] - np.mean(y_true[:, 0])))
            fit_vel = 100 * (1 - np.linalg.norm(res_vel) / np.linalg.norm(y_true[:, 1] - np.mean(y_true[:, 1])))

            residuals_ds[mname][split_name] = {"pos": res_pos, "vel": res_vel}
            metrics_ds[mname][split_name] = {
                "rmse_pos": rmse_pos,
                "rmse_vel": rmse_vel,
                "r2_pos": r2_pos,
                "r2_vel": r2_vel,
                "fit_pos": fit_pos,
                "fit_vel": fit_vel,
            }

    eval_store[ds_name] = {
        "t": t_ds,
        "Ts": Ts_ds,
        "y_sim": y_sim_ds,
        "preds": preds_ds,
        "train_idx": train_idx,
        "test_idx": test_idx,
        "metrics": metrics_ds,
        "residuals": residuals_ds,
    }

# =========================
# 3) R2 tables
# =========================
def print_r2_table(split_name, component):
    key = "r2_pos" if component == "pos" else "r2_vel"
    print(f"\nR2 table ({component}, {split_name})")
    header = "model".ljust(14) + " " + " ".join([d[:12].ljust(12) for d in all_datasets])
    print(header)
    for mname in model_names:
        vals = []
        for ds in all_datasets:
            m = eval_store[ds]["metrics"][mname][split_name]
            vals.append(np.nan if m is None else m[key])
        row = mname.ljust(14) + " " + " ".join(
            [("nan" if np.isnan(v) else f"{v:.4f}").ljust(12) for v in vals]
        )
        print(row)

print_r2_table("train", "pos")
print_r2_table("test", "pos")
print_r2_table("train", "vel")
print_r2_table("test", "vel")




In [None]:
# =========================
# 4) Rainclouds
# rows = models, cols = [train, test], y-axis = datasets
# x-axis shared and centered at zero for all subplots
# =========================
splits = ["train", "test"]
n_models = len(model_names)
n_datasets = len(all_datasets)

# Global symmetric x-limits around zero across all models/datasets/splits
all_residuals = []
for mname in model_names:
    for split_name in splits:
        for ds in all_datasets:
            rpack = eval_store[ds]["residuals"][mname][split_name]
            if rpack is None:
                continue
            res = np.asarray(rpack["pos"])
            res = res[np.isfinite(res)]
            if len(res) > 0:
                all_residuals.append(res)

if len(all_residuals) == 0:
    raise RuntimeError("No residuals found for raincloud plotting.")

res_concat = np.concatenate(all_residuals)
x_abs = np.max(np.abs(res_concat))
pad = 0.05 * (x_abs + 1e-12)
xlim_shared = (-(x_abs + pad), (x_abs + pad))

fig, axes = plt.subplots(
    n_models, 2,
    figsize=(14, max(3.0 * n_models, 8)),
    sharex=True,
    sharey=True
)

if n_models == 1:
    axes = np.array([axes])

for r, mname in enumerate(model_names):
    model_label = mname.replace(" Dataset", "").replace("_Dataset", "")

    for c, split_name in enumerate(splits):
        ax = axes[r, c]

        for i, ds in enumerate(all_datasets):
            rpack = eval_store[ds]["residuals"][mname][split_name]
            if rpack is None:
                continue

            res = np.asarray(rpack["pos"])
            res = res[np.isfinite(res)]
            if len(res) < 5:
                continue

            # cloud
            xs = np.linspace(np.min(res), np.max(res), 200)
            kde = gaussian_kde(res)
            ys = kde(xs)
            ys = ys / ys.max() * 0.25
            ax.fill_between(xs, i + ys, i - ys, color="gray", alpha=0.25)

            # box-like summary
            q1, q2, q3 = np.percentile(res, [25, 50, 75])
            ax.plot([q1, q3], [i, i], color="black", linewidth=5)
            ax.plot([q2, q2], [i - 0.12, i + 0.12], color="black", linewidth=1)

            # rain
            jitter = (np.random.rand(len(res)) - 0.5) * 0.18
            ax.scatter(res, i + jitter, s=2, color="gray", alpha=0.25)

        if r == 0:
            ax.set_title(f"Residual Raincloud ({split_name})")

        if c == 0:
            ax.set_ylabel(model_label)
        else:
            ax.set_ylabel("")

        ax.set_xlabel("Position residual")
        ax.grid(True, axis="x", alpha=0.35)
        ax.axvline(0.0, color="k", linewidth=0.8, alpha=0.6)

        # show dataset labels on every subplot
        ax.set_yticks(range(n_datasets))
        ax.set_yticklabels(all_datasets)

        # enforce shared symmetric limits
        ax.set_xlim(*xlim_shared)

plt.tight_layout()
plt.show()


## 14) Simulate All Models on Test Data

In [None]:
# Choose which test dataset to plot
plot_dataset = "swept_sine"  # change this to any dataset in test_datasets

if plot_dataset not in eval_store:
    raise ValueError(f"Dataset '{plot_dataset}' not found in eval_store. Available: {list(eval_store.keys())}")

t_test = eval_store[plot_dataset]["t"]
Ts_test = eval_store[plot_dataset]["Ts"]
y_sim_test = eval_store[plot_dataset]["y_sim"]
models_test = eval_store[plot_dataset]["preds"]
residuals_test = eval_store[plot_dataset]["residuals"]
metrics_test = eval_store[plot_dataset]["metrics"]

print(f"Selected dataset for plots: {plot_dataset}")


## 15) Test Comparison Plots and Metrics

In [None]:
# Choose which dataset and split to plot
plot_dataset = "random_steps_01"  # any key from eval_store
plot_split = "test"               # "train" or "test"

if plot_dataset not in eval_store:
    raise ValueError(f"Dataset '{plot_dataset}' not found. Available: {list(eval_store.keys())}")
if plot_split not in ["train", "test"]:
    raise ValueError("plot_split must be 'train' or 'test'")

idx = eval_store[plot_dataset][f"{plot_split}_idx"]
if idx is None or len(idx) < 2:
    raise ValueError(f"No '{plot_split}' samples for dataset '{plot_dataset}'")

base = eval_store[plot_dataset]
t_all = base["t"]
y_all = base["y_sim"]
models_all = base["preds"]
metrics_all = base["metrics"]
residuals_all = base["residuals"]
Ts_sel = base["Ts"]

t_sel = t_all[idx]
y_sel = y_all[idx]
models_sel = {name: pred[idx] for name, pred in models_all.items()}
residuals_sel = {name: residuals_all[name][plot_split] for name in models_all.keys()}
metrics_sel = {name: metrics_all[name][plot_split] for name in models_all.keys()}

print(f"Selected dataset: {plot_dataset} | split: {plot_split} | samples: {len(idx)}")


In [None]:
# Metrics tables for selected dataset/split
print(f"Metrics - Position ({plot_dataset}, {plot_split})")
print("Model          RMSE      R2        FIT%")
for name in models_sel.keys():
    m = metrics_sel[name]
    print(f"{name:<14} {m['rmse_pos']:<9.4f} {m['r2_pos']:<8.4f} {m['fit_pos']:<8.2f}")

print(f"\nMetrics - Velocity ({plot_dataset}, {plot_split})")
print("Model          RMSE      R2        FIT%")
for name in models_sel.keys():
    m = metrics_sel[name]
    print(f"{name:<14} {m['rmse_vel']:<9.4f} {m['r2_vel']:<8.4f} {m['fit_vel']:<8.2f}")

# Predictions vs measured
plt.figure(figsize=(12, 8))
plt.subplot(2, 1, 1)
plt.plot(t_sel, y_sel[:, 0], 'k-', alpha=0.6, linewidth=2, label='Measured')
for name, pred in models_sel.items():
    plt.plot(t_sel, pred[:, 0], linestyle=styles[name], color=colors[name], linewidth=1.5, label=name)
plt.ylabel("Position")
plt.legend()
plt.grid(True)
plt.title(f"Position: Measured vs Predicted ({plot_dataset}, {plot_split})")

plt.subplot(2, 1, 2)
plt.plot(t_sel, y_sel[:, 1], 'k-', alpha=0.6, linewidth=2, label='Measured')
for name, pred in models_sel.items():
    plt.plot(t_sel, pred[:, 1], linestyle=styles[name], color=colors[name], linewidth=1.5, label=name)
plt.xlabel("Time (s)")
plt.ylabel("Velocity")
plt.legend()
plt.grid(True)
plt.title(f"Velocity: Measured vs Predicted ({plot_dataset}, {plot_split})")
plt.tight_layout()
plt.show()

# Zooms
win_sec = 5.0
win_n = int(win_sec / Ts_sel)
starts = [0, max(0, (len(t_sel) - win_n) // 2), max(0, len(t_sel) - win_n)]
labels = ["Start", "Middle", "End"]

plt.figure(figsize=(12, 8))
for i, s in enumerate(starts):
    e = min(len(t_sel), s + win_n)
    plt.subplot(3, 1, i + 1)
    plt.plot(t_sel[s:e], y_sel[s:e, 0], 'k-', alpha=0.6, linewidth=2, label='Measured')
    for name, pred in models_sel.items():
        plt.plot(t_sel[s:e], pred[s:e, 0], linestyle=styles[name], color=colors[name], linewidth=1.5, label=name)
    plt.title(f"Zoom ({labels[i]}) - {plot_dataset} ({plot_split})")
    plt.ylabel('Position')
    plt.grid(True)
    if i == 0:
        plt.legend()
    if i == 2:
        plt.xlabel('Time (s)')
plt.tight_layout()
plt.show()

# Residuals
plt.figure(figsize=(12, 8))
plt.subplot(2, 1, 1)
for name in models_sel.keys():
    plt.plot(t_sel, residuals_sel[name]['pos'], color=colors[name], linewidth=1.2, label=name)
plt.axhline(0, color='gray', linewidth=1)
plt.title(f'Position Residuals ({plot_dataset}, {plot_split})')
plt.ylabel('Residual')
plt.grid(True)
plt.legend()

plt.subplot(2, 1, 2)
for name in models_sel.keys():
    plt.plot(t_sel, residuals_sel[name]['vel'], color=colors[name], linewidth=1.2, label=name)
plt.axhline(0, color='gray', linewidth=1)
plt.title(f'Velocity Residuals ({plot_dataset}, {plot_split})')
plt.ylabel('Residual')
plt.xlabel('Time (s)')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# Raincloud (position residuals)
from scipy.stats import gaussian_kde
labels_sel = list(models_sel.keys())
res_list_sel = [residuals_sel[name]['pos'] for name in labels_sel]

plt.figure(figsize=(10, 5))
for i, (label, res) in enumerate(zip(labels_sel, res_list_sel)):
    res = res[np.isfinite(res)]
    kde = gaussian_kde(res)
    xs = np.linspace(np.min(res), np.max(res), 200)
    ys = kde(xs)
    ys = ys / ys.max() * 0.3
    plt.fill_between(xs, i + ys, i - ys, color=colors[label], alpha=0.3)
    q1, q2, q3 = np.percentile(res, [25, 50, 75])
    plt.plot([q1, q3], [i, i], color=colors[label], linewidth=6)
    plt.plot([q2, q2], [i-0.1, i+0.1], color='k', linewidth=1)
    jitter = (np.random.rand(len(res)) - 0.5) * 0.2
    plt.scatter(res, i + jitter, s=3, color=colors[label], alpha=0.3)

plt.yticks(range(len(labels_sel)), labels_sel)
plt.xlabel('Residual (Position)')
plt.title(f'Residual Raincloud Plot (Position, {plot_dataset}, {plot_split})')
plt.grid(True, axis='x')
plt.tight_layout()
plt.show()

# y vs yhat
plt.figure(figsize=(6, 6))
plt.plot([y_sel[:,0].min(), y_sel[:,0].max()], [y_sel[:,0].min(), y_sel[:,0].max()], 'k--', label='Ideal')
for name, pred in models_sel.items():
    plt.scatter(y_sel[:,0], pred[:,0], s=6, alpha=0.4, color=colors[name], label=name)
plt.xlabel('Measured y')
plt.ylabel('Predicted y')
plt.title(f'y vs y_hat (Position, {plot_dataset}, {plot_split})')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


## 16) Test Diagnostics (ACF and Spectra)

In [None]:
# ACF (position) for selected dataset/split
N_sel = len(t_sel)
max_lag_sel = min(2000, N_sel - 1)
conf_sel = 1.96 / np.sqrt(N_sel)

plt.figure(figsize=(10, 4))
for name in models_sel.keys():
    res = residuals_sel[name]['pos'] - np.mean(residuals_sel[name]['pos'])
    acf = np.correlate(res, res, mode='full')
    acf = acf[N_sel-1:N_sel+max_lag_sel] / acf[N_sel-1]
    plt.plot(np.arange(0, max_lag_sel + 1), acf, color=colors[name], linewidth=1.2, label=name)

plt.axhline(conf_sel, color='red', linestyle='--', linewidth=1)
plt.axhline(-conf_sel, color='red', linestyle='--', linewidth=1)
plt.title(f'Residual ACF (Position, {plot_dataset}, {plot_split})')
plt.xlabel('Lag')
plt.ylabel('Autocorrelation')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# Spectrum: measured vs predictions
freqs_sel = np.fft.rfftfreq(len(t_sel), d=Ts_sel)
Y_meas_sel = np.fft.rfft(y_sel[:, 0])

plt.figure(figsize=(10, 4))
plt.semilogy(freqs_sel, np.abs(Y_meas_sel), color='k', label='Measured')
for name, pred in models_sel.items():
    Y_pred = np.fft.rfft(pred[:, 0])
    plt.semilogy(freqs_sel, np.abs(Y_pred), color=colors[name], linewidth=1.2, label=name)
plt.title(f'Spectrum (Position, {plot_dataset}, {plot_split})')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# Residual spectrum
plt.figure(figsize=(10, 4))
plt.semilogy(freqs_sel, np.abs(Y_meas_sel), color='k', label='Measured')
for name in models_sel.keys():
    res_fft = np.fft.rfft(residuals_sel[name]['pos'])
    plt.semilogy(freqs_sel, np.abs(res_fft), color=colors[name], linewidth=1.2, label=name)
plt.title(f'Spectrum of Position Residuals ({plot_dataset}, {plot_split})')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude of FFT')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()


## Iterate and Plot All Datasets

# HYCO — Hybrid Cooperative Model

**Idea:** Keep a *physical* ODE and a *black-box* neural ODE as two completely separate sub-models.  
The composite loss for each mini-batch is:

$$
\mathcal{L} = \underbrace{\mathcal{L}_{\text{phys}}}_{\text{physical model vs data}}
            + \lambda_{\text{bb}} \, \underbrace{\mathcal{L}_{\text{bb}}}_{\text{black-box vs data}}
            + \lambda_{\text{cons}} \, \underbrace{\|\mathbf{x}_{\text{phys}} - \mathbf{x}_{\text{bb}}\|^2}_{\text{consistency}}
$$

**Alternating optimisation:**
- **Odd epochs** → only the *physical* parameters are updated  
- **Even epochs** → only the *black-box* NN weights are updated

This encourages both sub-models to (a) fit the data individually and (b) agree with each other, without one dominating the gradient flow.

In [None]:
# =====================================================================
# HYCO Model Definition
# =====================================================================

class HYCOModel(nn.Module):
    """
    Hybrid Cooperative (HYCO) model.

    Wraps a physical ODE and a black-box neural ODE side-by-side.
    Both share the same input interpolation machinery but maintain
    completely independent parameters.  The forward pass returns
    predictions from *both* sub-models so the training loop can
    build the composite loss.
    """

    def __init__(self, hidden_dim=128):
        super().__init__()

        # ---------- Physical sub-model (Linear ODE) ----------
        self.phys = LinearPhysODE()

        # ---------- Black-box sub-model ----------
        self.bb = BlackBoxODE(hidden_dim=hidden_dim)

        # Shared time-series pointers (set before each forward)
        self.u_series = None
        self.t_series = None
        self.batch_start_times = None

    # ---- helpers to propagate series pointers ----
    def _sync_series(self):
        for sub in (self.phys, self.bb):
            sub.u_series = self.u_series
            sub.t_series = self.t_series
            sub.batch_start_times = self.batch_start_times

    def phys_params(self):
        """Iterator over physical-model parameters only."""
        return self.phys.parameters()

    def bb_params(self):
        """Iterator over black-box-model parameters only."""
        return self.bb.parameters()


def train_hyco(
    hyco_model,
    epochs=1000,
    lr_phys=0.01,
    lr_bb=0.01,
    lambda_bb=1.0,
    lambda_cons=0.5,
    batch_size=128,
    obs_dim=2,
    log_every=50,
):
    """
    Alternating-epoch training for the HYCO model.

    Odd epochs  → update physical parameters only
    Even epochs → update black-box NN weights only

    Returns
    -------
    hyco_model : trained model
    history    : dict with per-epoch losses
    """

    hyco_model.to(device)
    hyco_model.u_series = u_tensor
    hyco_model.t_series = t_tensor

    opt_phys = optim.Adam(hyco_model.phys_params(), lr=lr_phys)
    opt_bb   = optim.Adam(hyco_model.bb_params(),   lr=lr_bb)

    dt_local = (t_tensor[1] - t_tensor[0]).item()
    t_eval = torch.arange(0, K_STEPS * dt_local, dt_local, device=device)

    history = {
        "loss_total": [],
        "loss_phys": [],
        "loss_bb": [],
        "loss_consistency": [],
        "active_branch": [],       # "phys" or "bb"
    }

    print(f"--- Training HYCO  (λ_bb={lambda_bb}, λ_cons={lambda_cons}) ---")
    for epoch in range(epochs + 1):

        # ---------- choose which optimiser is active ----------
        if epoch % 2 == 1:
            active_opt = opt_phys
            branch = "phys"
        else:
            active_opt = opt_bb
            branch = "bb"

        active_opt.zero_grad()
        # also zero the other optimizer so grads don't accumulate
        if branch == "phys":
            opt_bb.zero_grad()
        else:
            opt_phys.zero_grad()

        # ---------- sample mini-batch ----------
        start_idx = np.random.choice(valid_train_start_idx,
                                     size=batch_size, replace=True)
        x0 = y_tensor[start_idx]
        hyco_model.batch_start_times = t_tensor[start_idx].reshape(-1, 1)
        hyco_model._sync_series()

        # ---------- forward both sub-models ----------
        pred_phys = odeint(hyco_model.phys, x0, t_eval, method='rk4')  # [T, B, 2]
        pred_bb   = odeint(hyco_model.bb,   x0, t_eval, method='rk4')  # [T, B, 2]

        # ---------- target ----------
        batch_targets = []
        for i in start_idx:
            batch_targets.append(y_tensor[i:i + K_STEPS])
        y_target = torch.stack(batch_targets, dim=1)  # [T, B, 2]

        # position-only data losses (consistent with other models)
        loss_phys = torch.mean((pred_phys[..., 0:1] - y_target[..., 0:1]) ** 2)
        loss_bb   = torch.mean((pred_bb[...,   0:1] - y_target[..., 0:1]) ** 2)

        # consistency term: both sub-models should agree on position
        loss_cons = torch.mean((pred_phys[..., 0:1] - pred_bb[..., 0:1]) ** 2)

        loss_total = loss_phys + lambda_bb * loss_bb + lambda_cons * loss_cons

        # ---------- backward + step (only active branch) ----------
        loss_total.backward()

        # Only step the active optimizer; the other one had its grads zeroed
        active_opt.step()

        # ---------- bookkeeping ----------
        history["loss_total"].append(loss_total.item())
        history["loss_phys"].append(loss_phys.item())
        history["loss_bb"].append(loss_bb.item())
        history["loss_consistency"].append(loss_cons.item())
        history["active_branch"].append(branch)

        if epoch % log_every == 0:
            print(
                f"Epoch {epoch:>5d} [{branch:>4s}] | "
                f"L_total={loss_total.item():.6f}  "
                f"L_phys={loss_phys.item():.6f}  "
                f"L_bb={loss_bb.item():.6f}  "
                f"L_cons={loss_cons.item():.6f}"
            )

    return hyco_model, history


# ---------- instantiate & train ----------
hyco_model = HYCOModel(hidden_dim=128)
hyco_model, hyco_history = train_hyco(
    hyco_model,
    epochs=2000,
    lr_phys=0.01,
    lr_bb=0.01,
    lambda_bb=1.0,
    lambda_cons=0.5,
    batch_size=128,
    log_every=100,
)

### HYCO Training Diagnostics

Loss curves split by component and by active branch, plus convergence analysis.

In [None]:
# =====================================================================
# HYCO Diagnostic Plots — Training Loss Curves
# =====================================================================

epochs_arr = np.arange(len(hyco_history["loss_total"]))
phys_mask = np.array([b == "phys" for b in hyco_history["active_branch"]])
bb_mask   = ~phys_mask

# ── 1) All loss components over epochs ───────────────────────────────
fig, axes = plt.subplots(2, 2, figsize=(14, 9), sharex=True)

ax = axes[0, 0]
ax.semilogy(epochs_arr, hyco_history["loss_total"], linewidth=0.8, color="black")
ax.set_title("Total loss  $\\mathcal{L}$")
ax.set_ylabel("Loss (log)")
ax.grid(True, alpha=0.3)

ax = axes[0, 1]
ax.semilogy(epochs_arr, hyco_history["loss_phys"], linewidth=0.8, color="tab:red")
ax.set_title("Physical model loss  $\\mathcal{L}_{\\mathrm{phys}}$")
ax.set_ylabel("Loss (log)")
ax.grid(True, alpha=0.3)

ax = axes[1, 0]
ax.semilogy(epochs_arr, hyco_history["loss_bb"], linewidth=0.8, color="tab:blue")
ax.set_title("Black-box loss  $\\mathcal{L}_{\\mathrm{bb}}$")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss (log)")
ax.grid(True, alpha=0.3)

ax = axes[1, 1]
ax.semilogy(epochs_arr, hyco_history["loss_consistency"], linewidth=0.8, color="tab:green")
ax.set_title("Consistency loss  $\\|x_{\\mathrm{phys}} - x_{\\mathrm{bb}}\\|^2$")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss (log)")
ax.grid(True, alpha=0.3)

fig.suptitle("HYCO — Training Loss Components", fontsize=14, fontweight="bold", y=1.01)
fig.tight_layout()
plt.show()

# ── 2) Losses split by active branch ────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

for ax, key, label, color in [
    (axes[0], "loss_phys", "$\\mathcal{L}_{\\mathrm{phys}}$", "tab:red"),
    (axes[1], "loss_bb",   "$\\mathcal{L}_{\\mathrm{bb}}$",   "tab:blue"),
    (axes[2], "loss_consistency", "Consistency", "tab:green"),
]:
    vals = np.array(hyco_history[key])
    ax.semilogy(epochs_arr[phys_mask], vals[phys_mask], '.', markersize=2,
                alpha=0.5, color="tab:red",  label="phys epoch")
    ax.semilogy(epochs_arr[bb_mask],   vals[bb_mask],   '.', markersize=2,
                alpha=0.5, color="tab:blue", label="bb epoch")
    ax.set_title(f"{label} by active branch")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss (log)")
    ax.legend(fontsize=8, markerscale=4)
    ax.grid(True, alpha=0.3)

fig.suptitle("HYCO — Loss per Active Branch", fontsize=13, fontweight="bold", y=1.02)
fig.tight_layout()
plt.show()

# ── 3) Running average of losses (smoothed) ─────────────────────────
def smooth(arr, window=51):
    kernel = np.ones(window) / window
    return np.convolve(arr, kernel, mode='valid')

fig, ax = plt.subplots(figsize=(12, 5))
w = min(51, len(epochs_arr) // 4)
for key, label, color in [
    ("loss_total",       "Total",       "black"),
    ("loss_phys",        "Physical",    "tab:red"),
    ("loss_bb",          "Black-box",   "tab:blue"),
    ("loss_consistency", "Consistency", "tab:green"),
]:
    s = smooth(hyco_history[key], window=w)
    ax.semilogy(np.arange(len(s)) + w // 2, s, linewidth=1.8, color=color, label=label)
ax.set_xlabel("Epoch")
ax.set_ylabel("Smoothed loss (log)")
ax.set_title(f"HYCO — Running-Average Losses (window={w})")
ax.legend()
ax.grid(True, alpha=0.3)
fig.tight_layout()
plt.show()

# ── 4) Ratio  L_bb / L_phys  and  L_cons / L_phys ──────────────────
loss_p = np.array(hyco_history["loss_phys"])
loss_b = np.array(hyco_history["loss_bb"])
loss_c = np.array(hyco_history["loss_consistency"])

ratio_bb   = loss_b / (loss_p + 1e-15)
ratio_cons = loss_c / (loss_p + 1e-15)

fig, ax = plt.subplots(figsize=(12, 4))
s_bb   = smooth(ratio_bb,   window=w)
s_cons = smooth(ratio_cons, window=w)
ax.plot(np.arange(len(s_bb))   + w // 2, s_bb,   linewidth=1.5, color="tab:blue",  label="$L_{bb}/L_{phys}$")
ax.plot(np.arange(len(s_cons)) + w // 2, s_cons, linewidth=1.5, color="tab:green", label="$L_{cons}/L_{phys}$")
ax.axhline(1.0, color="gray", linestyle="--", linewidth=0.8)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss ratio")
ax.set_title("HYCO — Loss-Component Ratios (smoothed)")
ax.legend()
ax.grid(True, alpha=0.3)
fig.tight_layout()
plt.show()

print("Diagnostic plots complete.")

### HYCO — Simulate & Compare with Other Models

Run both HYCO sub-models on the full training signal, then compare against the previously trained models (Linear, Stribeck, Black-box, Hybrid-Joint, Hybrid-Frozen).

In [None]:
# =====================================================================
# HYCO — Simulate on training signal & compare with all other models
# =====================================================================

# --- 1) Simulate HYCO sub-models on the full training time vector ---
hyco_model.u_series = u_tensor
hyco_model.t_series = t_tensor
hyco_model._sync_series()

with torch.no_grad():
    x0 = y_tensor[0].unsqueeze(0)

    hyco_model.phys.batch_start_times = torch.zeros(1, 1).to(device)
    hyco_model.bb.batch_start_times   = torch.zeros(1, 1).to(device)

    pred_hyco_phys = odeint(hyco_model.phys, x0, t_tensor, method='rk4').squeeze(1).cpu().numpy()
    pred_hyco_bb   = odeint(hyco_model.bb,   x0, t_tensor, method='rk4').squeeze(1).cpu().numpy()

# Simple ensemble: average of both sub-model predictions
pred_hyco_ens = 0.5 * (pred_hyco_phys + pred_hyco_bb)

# --- 2) Build augmented model dict (includes HYCO variants) ---------
models_all = {
    "Linear":        pred_lin,
    "Stribeck":      pred_str,
    "Black-box":     pred_bb,
    "Hybrid-Joint":  pred_hjoint,
    "Hybrid-Frozen": pred_hfrozen,
    "HYCO-Phys":     pred_hyco_phys,
    "HYCO-BB":       pred_hyco_bb,
    "HYCO-Ens":      pred_hyco_ens,
}

colors_all = {
    "Linear":        "tab:red",
    "Stribeck":      "tab:blue",
    "Black-box":     "tab:green",
    "Hybrid-Joint":  "tab:orange",
    "Hybrid-Frozen": "tab:purple",
    "HYCO-Phys":     "tab:cyan",
    "HYCO-BB":       "darkviolet",
    "HYCO-Ens":      "gold",
}
styles_all = {
    "Linear":        "--",
    "Stribeck":      "-.",
    "Black-box":     ":",
    "Hybrid-Joint":  "-",
    "Hybrid-Frozen": (0, (3, 1, 1, 1)),
    "HYCO-Phys":     (0, (5, 2)),
    "HYCO-BB":       (0, (1, 1)),
    "HYCO-Ens":      "-",
}

model_names_all = list(colors_all.keys())

# --- 3) Metrics table (training signal) ------------------------------
print("=" * 85)
print(f"{'Model':<16} {'RMSE pos':>10} {'R² pos':>10} {'FIT% pos':>10}"
      f" {'RMSE vel':>10} {'R² vel':>10} {'FIT% vel':>10}")
print("-" * 85)

metrics_all_dict = {}
for name, pred in models_all.items():
    res_pos = y_sim[:, 0] - pred[:, 0]
    res_vel = y_sim[:, 1] - pred[:, 1]

    rmse_pos = np.sqrt(np.mean(res_pos**2))
    rmse_vel = np.sqrt(np.mean(res_vel**2))

    ss_res_pos = np.sum(res_pos**2)
    ss_tot_pos = np.sum((y_sim[:, 0] - np.mean(y_sim[:, 0]))**2)
    ss_res_vel = np.sum(res_vel**2)
    ss_tot_vel = np.sum((y_sim[:, 1] - np.mean(y_sim[:, 1]))**2)

    r2_pos = 1 - ss_res_pos / ss_tot_pos if ss_tot_pos > 0 else np.nan
    r2_vel = 1 - ss_res_vel / ss_tot_vel if ss_tot_vel > 0 else np.nan
    fit_pos = 100 * (1 - np.linalg.norm(res_pos) / np.linalg.norm(y_sim[:, 0] - np.mean(y_sim[:, 0])))
    fit_vel = 100 * (1 - np.linalg.norm(res_vel) / np.linalg.norm(y_sim[:, 1] - np.mean(y_sim[:, 1])))

    metrics_all_dict[name] = dict(
        rmse_pos=rmse_pos, r2_pos=r2_pos, fit_pos=fit_pos,
        rmse_vel=rmse_vel, r2_vel=r2_vel, fit_vel=fit_vel,
        res_pos=res_pos, res_vel=res_vel,
    )
    print(f"{name:<16} {rmse_pos:10.4f} {r2_pos:10.4f} {fit_pos:10.2f}"
          f" {rmse_vel:10.4f} {r2_vel:10.4f} {fit_vel:10.2f}")
print("=" * 85)

# --- 4) Position & Velocity comparison plot --------------------------
fig, axes = plt.subplots(2, 1, figsize=(14, 8))

axes[0].plot(t, y_sim[:, 0], 'k-', alpha=0.6, linewidth=2, label='Measured')
for name, pred in models_all.items():
    axes[0].plot(t, pred[:, 0], linestyle=styles_all[name],
                 color=colors_all[name], linewidth=1.3, label=name)
axes[0].set_ylabel("Position")
axes[0].legend(fontsize=7, ncol=3)
axes[0].grid(True, alpha=0.3)
axes[0].set_title("Position: Measured vs Predicted (all models incl. HYCO)")

axes[1].plot(t, y_sim[:, 1], 'k-', alpha=0.6, linewidth=2, label='Measured')
for name, pred in models_all.items():
    axes[1].plot(t, pred[:, 1], linestyle=styles_all[name],
                 color=colors_all[name], linewidth=1.3, label=name)
axes[1].set_xlabel("Time (s)")
axes[1].set_ylabel("Velocity")
axes[1].legend(fontsize=7, ncol=3)
axes[1].grid(True, alpha=0.3)
axes[1].set_title("Velocity: Measured vs Predicted (all models incl. HYCO)")

fig.tight_layout()
plt.show()

# --- 5) Zoom plots (position only) -----------------------------------
win_sec = 5.0
win_n = int(win_sec / Ts)
starts = [0, max(0, (len(t) - win_n) // 2), max(0, len(t) - win_n)]
zoom_labels = ["Start", "Middle", "End"]

fig, axes = plt.subplots(3, 1, figsize=(14, 9))
for i, s in enumerate(starts):
    e = min(len(t), s + win_n)
    ax = axes[i]
    ax.plot(t[s:e], y_sim[s:e, 0], 'k-', alpha=0.6, linewidth=2, label='Measured')
    for name, pred in models_all.items():
        ax.plot(t[s:e], pred[s:e, 0], linestyle=styles_all[name],
                color=colors_all[name], linewidth=1.3, label=name)
    ax.set_title(f"Zoom ({zoom_labels[i]})")
    ax.set_ylabel("Position")
    ax.grid(True, alpha=0.3)
    if i == 0:
        ax.legend(fontsize=7, ncol=3)
    if i == 2:
        ax.set_xlabel("Time (s)")
fig.tight_layout()
plt.show()

# --- 6) Residuals (position + velocity) ------------------------------
fig, axes = plt.subplots(2, 1, figsize=(14, 7))

for name in models_all:
    axes[0].plot(t, metrics_all_dict[name]["res_pos"],
                 color=colors_all[name], linewidth=0.9, label=name)
axes[0].axhline(0, color='gray', linewidth=0.8)
axes[0].set_title("Position Residuals (all models incl. HYCO)")
axes[0].set_ylabel("Residual")
axes[0].legend(fontsize=7, ncol=3)
axes[0].grid(True, alpha=0.3)

for name in models_all:
    axes[1].plot(t, metrics_all_dict[name]["res_vel"],
                 color=colors_all[name], linewidth=0.9, label=name)
axes[1].axhline(0, color='gray', linewidth=0.8)
axes[1].set_title("Velocity Residuals (all models incl. HYCO)")
axes[1].set_ylabel("Residual")
axes[1].set_xlabel("Time (s)")
axes[1].legend(fontsize=7, ncol=3)
axes[1].grid(True, alpha=0.3)

fig.tight_layout()
plt.show()

# --- 7) Raincloud plot (position residuals) ---------------------------
from scipy.stats import gaussian_kde

fig, ax = plt.subplots(figsize=(12, 6))
for i, name in enumerate(model_names_all):
    res = metrics_all_dict[name]["res_pos"]
    res = res[np.isfinite(res)]
    if len(res) < 5:
        continue
    kde = gaussian_kde(res)
    xs = np.linspace(np.min(res), np.max(res), 200)
    ys = kde(xs)
    ys = ys / ys.max() * 0.3
    ax.fill_between(xs, i + ys, i - ys, color=colors_all[name], alpha=0.3)
    q1, q2, q3 = np.percentile(res, [25, 50, 75])
    ax.plot([q1, q3], [i, i], color=colors_all[name], linewidth=6)
    ax.plot([q2, q2], [i - 0.1, i + 0.1], color='k', linewidth=1)
    jitter = (np.random.rand(len(res)) - 0.5) * 0.2
    ax.scatter(res, i + jitter, s=2, color=colors_all[name], alpha=0.25)

ax.set_yticks(range(len(model_names_all)))
ax.set_yticklabels(model_names_all)
ax.set_xlabel("Residual (Position)")
ax.set_title("Residual Raincloud — Position (all models incl. HYCO)")
ax.axvline(0, color='gray', linewidth=0.8, alpha=0.5)
ax.grid(True, axis='x', alpha=0.3)
fig.tight_layout()
plt.show()

# --- 8) y vs ŷ scatter (position) ------------------------------------
fig, ax = plt.subplots(figsize=(7, 7))
ymin, ymax = y_sim[:, 0].min(), y_sim[:, 0].max()
ax.plot([ymin, ymax], [ymin, ymax], 'k--', linewidth=1, label='Ideal')
for name, pred in models_all.items():
    ax.scatter(y_sim[:, 0], pred[:, 0], s=4, alpha=0.35,
               color=colors_all[name], label=name)
ax.set_xlabel("Measured y")
ax.set_ylabel("Predicted y")
ax.set_title("y vs ŷ — Position (all models incl. HYCO)")
ax.legend(fontsize=7, markerscale=3)
ax.grid(True, alpha=0.3)
fig.tight_layout()
plt.show()

# --- 9) HYCO-specific: Phys vs BB agreement plot ---------------------
fig, axes = plt.subplots(2, 1, figsize=(14, 7))

axes[0].plot(t, y_sim[:, 0], 'k-', alpha=0.5, linewidth=2, label='Measured')
axes[0].plot(t, pred_hyco_phys[:, 0], color='tab:cyan', linewidth=1.3, label='HYCO-Phys')
axes[0].plot(t, pred_hyco_bb[:, 0],   color='darkviolet', linewidth=1.3, label='HYCO-BB')
axes[0].plot(t, pred_hyco_ens[:, 0],  color='gold', linewidth=1.5, label='HYCO-Ens')
axes[0].fill_between(t,
                     np.minimum(pred_hyco_phys[:, 0], pred_hyco_bb[:, 0]),
                     np.maximum(pred_hyco_phys[:, 0], pred_hyco_bb[:, 0]),
                     alpha=0.15, color='gray', label='Phys–BB envelope')
axes[0].set_ylabel("Position")
axes[0].set_title("HYCO Sub-Model Agreement — Position")
axes[0].legend(fontsize=8)
axes[0].grid(True, alpha=0.3)

diff = pred_hyco_phys[:, 0] - pred_hyco_bb[:, 0]
axes[1].plot(t, diff, color='tab:olive', linewidth=0.9)
axes[1].axhline(0, color='gray', linewidth=0.8)
axes[1].set_xlabel("Time (s)")
axes[1].set_ylabel("Phys − BB")
axes[1].set_title("HYCO Position Disagreement (Phys − BB)")
axes[1].grid(True, alpha=0.3)

fig.tight_layout()
plt.show()

# --- 10) ACF of HYCO-Ens position residuals vs others ----------------
N_acf = len(t)
max_lag_acf = min(2000, N_acf - 1)
conf_acf = 1.96 / np.sqrt(N_acf)

fig, ax = plt.subplots(figsize=(12, 4))
for name in models_all:
    res = metrics_all_dict[name]["res_pos"]
    res = res - np.mean(res)
    acf = np.correlate(res, res, mode='full')
    acf = acf[N_acf - 1 : N_acf + max_lag_acf] / acf[N_acf - 1]
    ax.plot(np.arange(max_lag_acf + 1), acf, color=colors_all[name],
            linewidth=1.0, label=name)
ax.axhline( conf_acf, color='red', linestyle='--', linewidth=0.8)
ax.axhline(-conf_acf, color='red', linestyle='--', linewidth=0.8)
ax.set_title("Residual ACF — Position (all models incl. HYCO)")
ax.set_xlabel("Lag")
ax.set_ylabel("Autocorrelation")
ax.legend(fontsize=7, ncol=3)
ax.grid(True, alpha=0.3)
fig.tight_layout()
plt.show()

# --- 11) Spectrum of position residuals ──────────────────────────────
freqs_all = np.fft.rfftfreq(len(t), d=Ts)
Y_meas_fft = np.fft.rfft(y_sim[:, 0])

fig, ax = plt.subplots(figsize=(12, 4))
ax.semilogy(freqs_all, np.abs(Y_meas_fft), color='k', linewidth=1.2, label='Measured')
for name in models_all:
    res_fft = np.fft.rfft(metrics_all_dict[name]["res_pos"])
    ax.semilogy(freqs_all, np.abs(res_fft), color=colors_all[name],
                linewidth=0.9, label=name)
ax.set_title("Spectrum of Position Residuals (all models incl. HYCO)")
ax.set_xlabel("Frequency (Hz)")
ax.set_ylabel("Magnitude")
ax.legend(fontsize=7, ncol=3)
ax.grid(True, alpha=0.3)
fig.tight_layout()
plt.show()

print("HYCO comparison on training signal complete.\n")


# =====================================================================
# PER-DATASET COMPARISON (all models including HYCO)
# =====================================================================

# Augment eval_store with HYCO predictions for each dataset
hyco_model.eval()
for ds_name in all_datasets:
    base = eval_store[ds_name]
    t_ds = base["t"]
    Ts_ds = base["Ts"]
    y_sim_ds = base["y_sim"]

    t_ds_tensor = torch.tensor(t_ds, dtype=torch.float32).to(device)
    u_ds_arr = datasets_cache[ds_name]["u"]
    u_ds_tensor = torch.tensor(u_ds_arr, dtype=torch.float32).reshape(-1, 1).to(device)
    y_ds_tensor = torch.tensor(y_sim_ds, dtype=torch.float32).to(device)

    # Simulate HYCO sub-models on this dataset
    with torch.no_grad():
        x0_ds = y_ds_tensor[0].unsqueeze(0)

        hyco_model.phys.u_series = u_ds_tensor
        hyco_model.phys.t_series = t_ds_tensor
        hyco_model.phys.batch_start_times = torch.zeros(1, 1).to(device)
        pred_hp = odeint(hyco_model.phys, x0_ds, t_ds_tensor, method="rk4").squeeze(1).cpu().numpy()

        hyco_model.bb.u_series = u_ds_tensor
        hyco_model.bb.t_series = t_ds_tensor
        hyco_model.bb.batch_start_times = torch.zeros(1, 1).to(device)
        pred_hb = odeint(hyco_model.bb, x0_ds, t_ds_tensor, method="rk4").squeeze(1).cpu().numpy()

    pred_he = 0.5 * (pred_hp + pred_hb)

    # Add HYCO predictions to eval_store
    base["preds"]["HYCO-Phys"] = pred_hp
    base["preds"]["HYCO-BB"]   = pred_hb
    base["preds"]["HYCO-Ens"]  = pred_he

    # Compute HYCO metrics and residuals per split
    for hyco_name, hyco_pred in [("HYCO-Phys", pred_hp), ("HYCO-BB", pred_hb), ("HYCO-Ens", pred_he)]:
        base["metrics"][hyco_name] = {"train": None, "test": None}
        base["residuals"][hyco_name] = {"train": None, "test": None}

        for split_name in ["train", "test"]:
            idx = base[f"{split_name}_idx"]
            if idx is None or len(idx) < 2:
                continue

            y_true = y_sim_ds[idx]
            y_hat = hyco_pred[idx]
            res_pos = y_true[:, 0] - y_hat[:, 0]
            res_vel = y_true[:, 1] - y_hat[:, 1]

            rmse_pos = np.sqrt(np.mean(res_pos**2))
            rmse_vel = np.sqrt(np.mean(res_vel**2))

            ss_res_pos = np.sum(res_pos**2)
            ss_tot_pos = np.sum((y_true[:, 0] - np.mean(y_true[:, 0]))**2)
            ss_res_vel = np.sum(res_vel**2)
            ss_tot_vel = np.sum((y_true[:, 1] - np.mean(y_true[:, 1]))**2)

            r2_pos = 1 - ss_res_pos / ss_tot_pos if ss_tot_pos > 0 else np.nan
            r2_vel = 1 - ss_res_vel / ss_tot_vel if ss_tot_vel > 0 else np.nan

            fit_pos = 100 * (1 - np.linalg.norm(res_pos) / np.linalg.norm(y_true[:, 0] - np.mean(y_true[:, 0])))
            fit_vel = 100 * (1 - np.linalg.norm(res_vel) / np.linalg.norm(y_true[:, 1] - np.mean(y_true[:, 1])))

            base["residuals"][hyco_name][split_name] = {"pos": res_pos, "vel": res_vel}
            base["metrics"][hyco_name][split_name] = {
                "rmse_pos": rmse_pos, "rmse_vel": rmse_vel,
                "r2_pos": r2_pos, "r2_vel": r2_vel,
                "fit_pos": fit_pos, "fit_vel": fit_vel,
            }

print("eval_store augmented with HYCO predictions for all datasets.\n")

# --- Per-dataset plotting loop ----------------------------------------
for plot_dataset_name in all_datasets:
    for plot_split in splits:
        base = eval_store[plot_dataset_name]
        idx = base[f"{plot_split}_idx"]
        if idx is None or len(idx) < 2:
            continue

        t_all = base["t"]
        y_all = base["y_sim"]
        preds_all = base["preds"]
        metrics_ds = base["metrics"]
        residuals_ds = base["residuals"]
        Ts_sel = base["Ts"]

        t_sel = t_all[idx]
        y_sel = y_all[idx]
        models_sel = {name: pred[idx] for name, pred in preds_all.items()}
        residuals_sel = {name: residuals_ds[name][plot_split] for name in preds_all.keys()}
        metrics_sel = {name: metrics_ds[name][plot_split] for name in preds_all.keys()}

        print(f"\n{'=' * 85}")
        print(f"  Dataset: {plot_dataset_name} | Split: {plot_split} | Samples: {len(idx)}")
        print(f"{'=' * 85}")

        # --- Metrics tables ---
        print(f"\nMetrics — Position ({plot_dataset_name}, {plot_split})")
        print(f"{'Model':<16} {'RMSE':>10} {'R²':>10} {'FIT%':>10}")
        print("-" * 50)
        for name in model_names_all:
            m = metrics_sel.get(name)
            if m is None:
                continue
            print(f"{name:<16} {m['rmse_pos']:10.4f} {m['r2_pos']:10.4f} {m['fit_pos']:10.2f}")

        print(f"\nMetrics — Velocity ({plot_dataset_name}, {plot_split})")
        print(f"{'Model':<16} {'RMSE':>10} {'R²':>10} {'FIT%':>10}")
        print("-" * 50)
        for name in model_names_all:
            m = metrics_sel.get(name)
            if m is None:
                continue
            print(f"{name:<16} {m['rmse_vel']:10.4f} {m['r2_vel']:10.4f} {m['fit_vel']:10.2f}")

        # --- (A) Measured vs Predicted: position & velocity ---
        fig, axes = plt.subplots(2, 1, figsize=(14, 8))
        axes[0].plot(t_sel, y_sel[:, 0], 'k-', alpha=0.6, linewidth=2, label='Measured')
        for name, pred in models_sel.items():
            axes[0].plot(t_sel, pred[:, 0], linestyle=styles_all[name],
                         color=colors_all[name], linewidth=1.3, label=name)
        axes[0].set_ylabel("Position")
        axes[0].legend(fontsize=7, ncol=3)
        axes[0].grid(True, alpha=0.3)
        axes[0].set_title(f"Position: Measured vs Predicted ({plot_dataset_name}, {plot_split})")

        axes[1].plot(t_sel, y_sel[:, 1], 'k-', alpha=0.6, linewidth=2, label='Measured')
        for name, pred in models_sel.items():
            axes[1].plot(t_sel, pred[:, 1], linestyle=styles_all[name],
                         color=colors_all[name], linewidth=1.3, label=name)
        axes[1].set_xlabel("Time (s)")
        axes[1].set_ylabel("Velocity")
        axes[1].legend(fontsize=7, ncol=3)
        axes[1].grid(True, alpha=0.3)
        axes[1].set_title(f"Velocity: Measured vs Predicted ({plot_dataset_name}, {plot_split})")
        fig.tight_layout()
        plt.show()

        # --- (B) Zoom plots (position) ---
        win_sec_ds = 5.0
        win_n_ds = int(win_sec_ds / Ts_sel)
        starts_ds = [0, max(0, (len(t_sel) - win_n_ds) // 2), max(0, len(t_sel) - win_n_ds)]
        zoom_labs = ["Start", "Middle", "End"]

        fig, axes = plt.subplots(3, 1, figsize=(14, 9))
        for i, s in enumerate(starts_ds):
            e = min(len(t_sel), s + win_n_ds)
            if s >= e:
                continue
            ax = axes[i]
            ax.plot(t_sel[s:e], y_sel[s:e, 0], 'k-', alpha=0.6, linewidth=2, label='Measured')
            for name, pred in models_sel.items():
                ax.plot(t_sel[s:e], pred[s:e, 0], linestyle=styles_all[name],
                        color=colors_all[name], linewidth=1.3, label=name)
            ax.set_title(f"Zoom ({zoom_labs[i]}) — {plot_dataset_name} ({plot_split})")
            ax.set_ylabel("Position")
            ax.grid(True, alpha=0.3)
            if i == 0:
                ax.legend(fontsize=7, ncol=3)
            if i == 2:
                ax.set_xlabel("Time (s)")
        fig.tight_layout()
        plt.show()

        # --- (C) Residuals (position + velocity) ---
        fig, axes = plt.subplots(2, 1, figsize=(14, 7))
        for name in model_names_all:
            r = residuals_sel.get(name)
            if r is None or r.get("pos") is None:
                continue
            axes[0].plot(t_sel, r["pos"], color=colors_all[name], linewidth=0.9, label=name)
        axes[0].axhline(0, color='gray', linewidth=0.8)
        axes[0].set_title(f"Position Residuals ({plot_dataset_name}, {plot_split})")
        axes[0].set_ylabel("Residual")
        axes[0].legend(fontsize=7, ncol=3)
        axes[0].grid(True, alpha=0.3)

        for name in model_names_all:
            r = residuals_sel.get(name)
            if r is None or r.get("vel") is None:
                continue
            axes[1].plot(t_sel, r["vel"], color=colors_all[name], linewidth=0.9, label=name)
        axes[1].axhline(0, color='gray', linewidth=0.8)
        axes[1].set_title(f"Velocity Residuals ({plot_dataset_name}, {plot_split})")
        axes[1].set_ylabel("Residual")
        axes[1].set_xlabel("Time (s)")
        axes[1].legend(fontsize=7, ncol=3)
        axes[1].grid(True, alpha=0.3)
        fig.tight_layout()
        plt.show()

        # --- (D) Raincloud plot (position residuals) ---
        fig, ax = plt.subplots(figsize=(12, 6))
        plot_idx = 0
        plotted_names = []
        for name in model_names_all:
            r = residuals_sel.get(name)
            if r is None or r.get("pos") is None:
                continue
            res = r["pos"]
            res = res[np.isfinite(res)]
            if len(res) < 5:
                continue
            kde = gaussian_kde(res)
            xs = np.linspace(np.min(res), np.max(res), 200)
            ys_kde = kde(xs)
            ys_kde = ys_kde / ys_kde.max() * 0.3
            ax.fill_between(xs, plot_idx + ys_kde, plot_idx - ys_kde,
                            color=colors_all[name], alpha=0.3)
            q1, q2, q3 = np.percentile(res, [25, 50, 75])
            ax.plot([q1, q3], [plot_idx, plot_idx], color=colors_all[name], linewidth=6)
            ax.plot([q2, q2], [plot_idx - 0.1, plot_idx + 0.1], color='k', linewidth=1)
            jitter = (np.random.rand(len(res)) - 0.5) * 0.2
            ax.scatter(res, plot_idx + jitter, s=2, color=colors_all[name], alpha=0.25)
            plotted_names.append(name)
            plot_idx += 1

        ax.set_yticks(range(len(plotted_names)))
        ax.set_yticklabels(plotted_names)
        ax.set_xlabel("Residual (Position)")
        ax.set_title(f"Residual Raincloud — Position ({plot_dataset_name}, {plot_split})")
        ax.axvline(0, color='gray', linewidth=0.8, alpha=0.5)
        ax.grid(True, axis='x', alpha=0.3)
        fig.tight_layout()
        plt.show()

        # --- (E) y vs ŷ scatter (position) ---
        fig, ax = plt.subplots(figsize=(7, 7))
        y_min_ds = y_sel[:, 0].min() if len(y_sel) > 0 else 0
        y_max_ds = y_sel[:, 0].max() if len(y_sel) > 0 else 1
        ax.plot([y_min_ds, y_max_ds], [y_min_ds, y_max_ds], 'k--', linewidth=1, label='Ideal')
        for name, pred in models_sel.items():
            ax.scatter(y_sel[:, 0], pred[:, 0], s=4, alpha=0.35,
                       color=colors_all[name], label=name)
        ax.set_xlabel("Measured y")
        ax.set_ylabel("Predicted y")
        ax.set_title(f"y vs ŷ — Position ({plot_dataset_name}, {plot_split})")
        ax.legend(fontsize=7, markerscale=3)
        ax.grid(True, alpha=0.3)
        fig.tight_layout()
        plt.show()

print("\nPer-dataset comparison (all models incl. HYCO) complete.")

## HYCO Hyperparameter Analysis — λ_bb × λ_cons Grid Search

In [None]:
import itertools, time as _time

# =====================================================================
# Hyperparameter grid search: lambda_bb x lambda_cons
# =====================================================================

HP_EPOCHS = 2000
HP_LR_PHYS = 0.01
HP_LR_BB = 0.01
HP_BATCH = 128
HP_HIDDEN = 128

lambda_bb_vals   = [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
lambda_cons_vals = [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]

splits = ["train", "test"]

def evaluate_hyco_hp(lam_bb, lam_cons):
    """Train a fresh HYCO with given lambdas; return per-dataset test RMSE."""
    m = HYCOModel(hidden_dim=HP_HIDDEN)
    m, _ = train_hyco(
        m,
        epochs=HP_EPOCHS,
        lr_phys=HP_LR_PHYS,
        lr_bb=HP_LR_BB,
        lambda_bb=lam_bb,
        lambda_cons=lam_cons,
        batch_size=HP_BATCH,
        log_every=HP_EPOCHS + 1,   # silent
    )
    m.eval()

    per_ds = {}
    for ds_name in all_datasets:
        base = eval_store[ds_name]
        t_ds = base["t"]
        y_sim_ds = base["y_sim"]

        t_ds_t = torch.tensor(t_ds, dtype=torch.float32).to(device)
        u_ds_t = torch.tensor(datasets_cache[ds_name]["u"],
                              dtype=torch.float32).reshape(-1, 1).to(device)
        y_ds_t = torch.tensor(y_sim_ds, dtype=torch.float32).to(device)

        with torch.no_grad():
            x0_ds = y_ds_t[0].unsqueeze(0)

            m.phys.u_series = u_ds_t
            m.phys.t_series = t_ds_t
            m.phys.batch_start_times = torch.zeros(1, 1).to(device)
            p_phys = odeint(m.phys, x0_ds, t_ds_t, method="rk4").squeeze(1).cpu().numpy()

            m.bb.u_series = u_ds_t
            m.bb.t_series = t_ds_t
            m.bb.batch_start_times = torch.zeros(1, 1).to(device)
            p_bb = odeint(m.bb, x0_ds, t_ds_t, method="rk4").squeeze(1).cpu().numpy()

        p_ens = 0.5 * (p_phys + p_bb)

        # test-split RMSE for each sub-model
        ds_metrics = {}
        for sub_name, sub_pred in [("HYCO-Phys", p_phys),
                                    ("HYCO-BB", p_bb),
                                    ("HYCO-Ens", p_ens)]:
            idx = base["test_idx"]
            if idx is None or len(idx) < 2:
                ds_metrics[sub_name] = np.nan
                continue
            res = y_sim_ds[idx, 0] - sub_pred[idx, 0]
            ds_metrics[sub_name] = np.sqrt(np.mean(res**2))

        per_ds[ds_name] = ds_metrics

    return per_ds


# --- Run grid search ---
print(f"Grid: {len(lambda_bb_vals)} x {len(lambda_cons_vals)} = "
      f"{len(lambda_bb_vals)*len(lambda_cons_vals)} combos, "
      f"{HP_EPOCHS} epochs each\n")

results = []
t0 = _time.time()

for i, (lam_bb, lam_cons) in enumerate(
        itertools.product(lambda_bb_vals, lambda_cons_vals)):
    print(f"[{i+1:>3d}/{len(lambda_bb_vals)*len(lambda_cons_vals)}] "
          f"λ_bb={lam_bb:.1f}, λ_cons={lam_cons:.1f} ... ", end="", flush=True)
    ts = _time.time()
    per_ds = evaluate_hyco_hp(lam_bb, lam_cons)
    elapsed = _time.time() - ts

    # Average across sub-models => pick Ens for ranking
    avg_ens_rmse = np.nanmean([per_ds[d]["HYCO-Ens"] for d in all_datasets])
    avg_phys_rmse = np.nanmean([per_ds[d]["HYCO-Phys"] for d in all_datasets])
    avg_bb_rmse = np.nanmean([per_ds[d]["HYCO-BB"] for d in all_datasets])

    # Scenario 2: exclude rampa_positiva & rampa_negativa
    excl = {"rampa_positiva", "rampa_negativa"}
    ds_filtered = [d for d in all_datasets if d not in excl]
    avg_ens_rmse_no_ramp = np.nanmean([per_ds[d]["HYCO-Ens"] for d in ds_filtered])

    results.append({
        "lam_bb": lam_bb,
        "lam_cons": lam_cons,
        "avg_ens_rmse": avg_ens_rmse,
        "avg_phys_rmse": avg_phys_rmse,
        "avg_bb_rmse": avg_bb_rmse,
        "avg_ens_rmse_no_ramp": avg_ens_rmse_no_ramp,
        "per_ds": per_ds,
    })
    print(f"RMSE(Ens)={avg_ens_rmse:.5f}  (no ramp={avg_ens_rmse_no_ramp:.5f})  [{elapsed:.1f}s]")

total_time = _time.time() - t0
print(f"\nGrid search finished in {total_time/60:.1f} min.\n")


# =====================================================================
# Scenario 1 — All datasets
# =====================================================================
print("=" * 80)
print("  SCENARIO 1: All datasets")
print("=" * 80)
sorted_all = sorted(results, key=lambda r: r["avg_ens_rmse"])
print(f"\n{'Rank':<5} {'λ_bb':>6} {'λ_cons':>7} {'RMSE Ens':>10} {'RMSE Phys':>11} {'RMSE BB':>9}")
print("-" * 55)
for rank, r in enumerate(sorted_all[:15], 1):
    marker = " <-- best" if rank == 1 else ""
    print(f"{rank:<5} {r['lam_bb']:6.1f} {r['lam_cons']:7.1f} "
          f"{r['avg_ens_rmse']:10.5f} {r['avg_phys_rmse']:11.5f} {r['avg_bb_rmse']:9.5f}{marker}")

best_all = sorted_all[0]
print(f"\n★ Best (all datasets): λ_bb={best_all['lam_bb']:.1f}, "
      f"λ_cons={best_all['lam_cons']:.1f}  ⇒  RMSE={best_all['avg_ens_rmse']:.5f}")

# Per-dataset breakdown for best
print(f"\n  Per-dataset breakdown (best combo):")
print(f"  {'Dataset':<25} {'HYCO-Phys':>11} {'HYCO-BB':>11} {'HYCO-Ens':>11}")
print("  " + "-" * 60)
for ds_name in all_datasets:
    m = best_all["per_ds"][ds_name]
    print(f"  {ds_name:<25} {m['HYCO-Phys']:11.5f} {m['HYCO-BB']:11.5f} {m['HYCO-Ens']:11.5f}")


# =====================================================================
# Scenario 2 — Without rampa_positiva & rampa_negativa
# =====================================================================
print(f"\n{'=' * 80}")
print("  SCENARIO 2: Excluding rampa_positiva & rampa_negativa")
print("=" * 80)
sorted_no_ramp = sorted(results, key=lambda r: r["avg_ens_rmse_no_ramp"])
print(f"\n{'Rank':<5} {'λ_bb':>6} {'λ_cons':>7} {'RMSE Ens':>10}")
print("-" * 32)
for rank, r in enumerate(sorted_no_ramp[:15], 1):
    marker = " <-- best" if rank == 1 else ""
    print(f"{rank:<5} {r['lam_bb']:6.1f} {r['lam_cons']:7.1f} "
          f"{r['avg_ens_rmse_no_ramp']:10.5f}{marker}")

best_nr = sorted_no_ramp[0]
print(f"\n★ Best (excl. ramps): λ_bb={best_nr['lam_bb']:.1f}, "
      f"λ_cons={best_nr['lam_cons']:.1f}  ⇒  RMSE={best_nr['avg_ens_rmse_no_ramp']:.5f}")

# Per-dataset breakdown
ds_filt = [d for d in all_datasets if d not in excl]
print(f"\n  Per-dataset breakdown (best combo, excl. ramps):")
print(f"  {'Dataset':<25} {'HYCO-Phys':>11} {'HYCO-BB':>11} {'HYCO-Ens':>11}")
print("  " + "-" * 60)
for ds_name in ds_filt:
    m = best_nr["per_ds"][ds_name]
    print(f"  {ds_name:<25} {m['HYCO-Phys']:11.5f} {m['HYCO-BB']:11.5f} {m['HYCO-Ens']:11.5f}")


# =====================================================================
# Heatmaps
# =====================================================================
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

for ax, key, title in [
    (axes[0], "avg_ens_rmse",         "RMSE (Ens) — All datasets"),
    (axes[1], "avg_ens_rmse_no_ramp", "RMSE (Ens) — Excl. ramps"),
]:
    grid = np.full((len(lambda_cons_vals), len(lambda_bb_vals)), np.nan)
    for r in results:
        i = lambda_cons_vals.index(r["lam_cons"])
        j = lambda_bb_vals.index(r["lam_bb"])
        grid[i, j] = r[key]

    im = ax.imshow(grid, origin="lower", aspect="auto", cmap="viridis_r")
    ax.set_xticks(range(len(lambda_bb_vals)))
    ax.set_xticklabels([f"{v:.1f}" for v in lambda_bb_vals])
    ax.set_yticks(range(len(lambda_cons_vals)))
    ax.set_yticklabels([f"{v:.1f}" for v in lambda_cons_vals])
    ax.set_xlabel("λ_bb")
    ax.set_ylabel("λ_cons")
    ax.set_title(title)
    fig.colorbar(im, ax=ax, shrink=0.8, label="RMSE")

    # annotate values
    for ii in range(grid.shape[0]):
        for jj in range(grid.shape[1]):
            val = grid[ii, jj]
            if np.isfinite(val):
                ax.text(jj, ii, f"{val:.4f}", ha="center", va="center",
                        fontsize=7, color="white" if val > np.nanmedian(grid) else "black")

fig.suptitle("HYCO Hyperparameter Grid Search", fontsize=14, fontweight="bold")
fig.tight_layout()
plt.show()


# =====================================================================
# Bar chart: top-5 combos for each scenario
# =====================================================================
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for ax, sorted_list, scenario in [
    (axes[0], sorted_all,     "All datasets"),
    (axes[1], sorted_no_ramp, "Excl. ramps"),
]:
    top5 = sorted_list[:5]
    labels = [f"bb={r['lam_bb']:.1f}\ncons={r['lam_cons']:.1f}" for r in top5]
    rmses  = [r["avg_ens_rmse"] if scenario == "All datasets"
              else r["avg_ens_rmse_no_ramp"] for r in top5]
    bars = ax.bar(labels, rmses, color=["gold", "silver", "#cd7f32", "skyblue", "lightgray"])
    ax.set_ylabel("Avg Test RMSE (Ens)")
    ax.set_title(f"Top-5 HP combos — {scenario}")
    ax.grid(True, axis="y", alpha=0.3)
    for b, v in zip(bars, rmses):
        ax.text(b.get_x() + b.get_width()/2, b.get_height(),
                f"{v:.5f}", ha="center", va="bottom", fontsize=8)

fig.tight_layout()
plt.show()

print("\nHyperparameter analysis complete.")