In [1]:
"""
CKN-style UKN (UKF + learned R_t via Cholesky + SPD enforcing layer on P)
- No linearization needed: accepts arbitrary f_fn, h_fn (ODE/FEA/black-box OK)
- Demo: pendulum, dt=0.2
    x = [theta, omega]
    y = [sin(theta), cos(theta)] + noise
    noise is mixture (outlier bursts) so time-varying R_t helps a lot

Baselines:
- UKF (fixed R_base)
- EKF (needs Jacobians; only for demo comparison)
"""

import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


# =========================
# Repro / device
# =========================
def set_seed(seed=0):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def wrap_angle(a: torch.Tensor) -> torch.Tensor:
    return (a + math.pi) % (2 * math.pi) - math.pi

def state_error(x_true: torch.Tensor, x_hat: torch.Tensor) -> torch.Tensor:
    """
    x_true, x_hat: [B,T,2]
    theta error wrapped, omega error raw
    """
    e = x_true - x_hat
    e_theta = wrap_angle(e[..., 0])
    e_omega = e[..., 1]
    return torch.stack([e_theta, e_omega], dim=-1)


# =========================
# CKN-style SPD enforcing
# =========================
def symmetrize(P: torch.Tensor) -> torch.Tensor:
    return 0.5 * (P + P.transpose(-1, -2))

def spd_project_eig(P: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """
    CKN-style PD enforcing layer:
    - symmetrize
    - eigen-decomp
    - clamp eigenvalues >= eps
    """
    P = symmetrize(P)
    # sanitize
    if not torch.isfinite(P).all():
        P = torch.nan_to_num(P, nan=0.0, posinf=0.0, neginf=0.0)
        P = symmetrize(P)

    evals, evecs = torch.linalg.eigh(P)           # [...,n]
    evals = torch.clamp(evals, min=eps)
    return evecs @ torch.diag_embed(evals) @ evecs.transpose(-1, -2)

def robust_cholesky(P: torch.Tensor, eps: float = 1e-6, tries: int = 6) -> torch.Tensor:
    """
    Robust Cholesky for batch matrices:
    repeatedly apply SPD projection with increasing eps if needed.
    """
    last_err = None
    for k in range(tries):
        P2 = spd_project_eig(P, eps=eps * (10.0 ** k))
        try:
            return torch.linalg.cholesky(P2)
        except RuntimeError as e:
            last_err = e
    raise last_err


# =========================
# Data: pendulum + outlier bursts
# =========================
def sample_pendulum(
    N: int, T: int, dt: float,
    g_over_L: float, damping: float,
    Q_true: torch.Tensor,             # [2,2]
    r_nom: float, r_out: float, p_out: float,
    device: str
):
    """
    x: [N,T,2], y: [N,T,2], out_mask: [N,T,1]
    y = [sin(theta), cos(theta)] + v
    v ~ mixture Gaussian:
        prob(1-p_out): N(0, r_nom^2 I)
        prob(p_out)  : N(0, r_out^2 I)
    """
    n = 2
    LQ = torch.linalg.cholesky(Q_true)

    x = torch.zeros(N, T, n, device=device)
    y = torch.zeros(N, T, 2, device=device)
    out_mask = torch.zeros(N, T, 1, device=device)

    # init: make it move faster than "too slow" case
    x[:, 0, 0] = (torch.rand(N, device=device) * 2 - 1) * math.pi
    x[:, 0, 1] = torch.randn(N, device=device) * 1.0  # omega init

    I2 = torch.eye(2, device=device)

    for t in range(T):
        theta = x[:, t, 0]
        omega = x[:, t, 1]

        # outlier mixture
        is_out = (torch.rand(N, device=device) < p_out).float().view(N, 1)
        out_mask[:, t, :] = is_out

        sigma = r_nom + is_out * (r_out - r_nom)  # [N,1]
        v = torch.randn(N, 2, device=device) * sigma  # isotropic

        y[:, t, 0] = torch.sin(theta) + v[:, 0]
        y[:, t, 1] = torch.cos(theta) + v[:, 1]

        if t < T - 1:
            theta_next = theta + dt * omega
            omega_next = omega + dt * (-g_over_L * torch.sin(theta) - damping * omega)

            w = torch.randn(N, n, device=device) @ LQ.T
            x[:, t + 1, 0] = wrap_angle(theta_next + w[:, 0])
            x[:, t + 1, 1] = omega_next + w[:, 1]

    return x, y, out_mask


class SeqDataset(torch.utils.data.Dataset):
    def __init__(self, x, y, out_mask):
        self.x = x
        self.y = y
        self.out = out_mask
    def __len__(self):
        return self.x.shape[0]
    def __getitem__(self, i):
        return self.x[i], self.y[i], self.out[i]


# =========================
# Unscented Transform (UT)
# =========================
def sigma_points(x, P, alpha=0.2, beta=2.0, kappa=0.0):
    """
    x: [B,n], P:[B,n,n]
    Xi: [B,2n+1,n], weights Wm/Wc: [2n+1]
    """
    B, n = x.shape
    lam = alpha**2 * (n + kappa) - n
    c = n + lam
    gamma = math.sqrt(max(c, 1e-12))

    Wm = x.new_zeros(2*n + 1)
    Wc = x.new_zeros(2*n + 1)
    Wm[0] = lam / c
    Wc[0] = lam / c + (1 - alpha**2 + beta)
    Wm[1:] = 1.0 / (2*c)
    Wc[1:] = 1.0 / (2*c)

    L = robust_cholesky(P, eps=1e-6)  # [B,n,n]
    U = (gamma * L).transpose(1, 2)   # [B,n,n] columns as offsets

    x0 = x.unsqueeze(1)               # [B,1,n]
    Xi = torch.cat([x0, x0 + U, x0 - U], dim=1)  # [B,2n+1,n]
    return Xi, Wm, Wc

def ut_mean_cov(X, Wm, Wc, noise=None, eps=1e-6):
    """
    X: [B,L,d]
    mean:[B,d], cov:[B,d,d]
    noise can be [d,d] or [B,d,d]
    """
    B, L, d = X.shape
    mean = torch.sum(Wm.view(1, L, 1) * X, dim=1)
    Xm = X - mean.unsqueeze(1)

    cov = torch.zeros(B, d, d, device=X.device, dtype=X.dtype)
    for i in range(L):
        vi = Xm[:, i, :].unsqueeze(-1)
        cov = cov + Wc[i] * (vi @ vi.transpose(-1, -2))

    if noise is not None:
        cov = cov + (noise if noise.dim() == 3 else noise.unsqueeze(0))

    cov = spd_project_eig(cov, eps=eps)
    return mean, cov

def cross_cov(X, Y, xm, ym, Wc):
    """
    X:[B,L,n], Y:[B,L,m] -> Pxy:[B,n,m]
    """
    B, L, n = X.shape
    m = Y.shape[-1]
    Xc = X - xm.unsqueeze(1)
    Yc = Y - ym.unsqueeze(1)

    Pxy = torch.zeros(B, n, m, device=X.device, dtype=X.dtype)
    for i in range(L):
        xi = Xc[:, i, :].unsqueeze(-1)
        yi = Yc[:, i, :].unsqueeze(-1)
        Pxy = Pxy + Wc[i] * (xi @ yi.transpose(-1, -2))
    return Pxy


# =========================
# Baseline UKF (fixed Q/R)
# =========================
@torch.no_grad()
def batch_ukf(y, f_fn, h_fn, Q, R, x0, P0, ut_params=(0.2,2.0,0.0), DIFF_THROUGH_FH=False):
    alpha, beta, kappa = ut_params
    B, T, m = y.shape
    x = x0
    P = spd_project_eig(P0, eps=1e-6)

    xs, Ps = [], []
    for t in range(T):
        Xi, Wm, Wc = sigma_points(x, P, alpha, beta, kappa)

        # f/h may be black-box
        if DIFF_THROUGH_FH:
            X_pred = f_fn(Xi)
            Y_pred = h_fn(X_pred)
        else:
            X_pred = f_fn(Xi)
            Y_pred = h_fn(X_pred)

        x_pred, P_pred = ut_mean_cov(X_pred, Wm, Wc, noise=Q)
        y_pred, S      = ut_mean_cov(Y_pred, Wm, Wc, noise=R)

        Pxy = cross_cov(X_pred, Y_pred, x_pred, y_pred, Wc)

        Ls = robust_cholesky(S, eps=1e-9)
        KT = torch.cholesky_solve(Pxy.transpose(1,2), Ls)  # [B,m,n]
        K  = KT.transpose(1,2)                              # [B,n,m]

        e = y[:, t, :] - y_pred
        x = x_pred + torch.bmm(K, e.unsqueeze(-1)).squeeze(-1)
        x = torch.stack([wrap_angle(x[:,0]), x[:,1]], dim=-1)

        P = P_pred - torch.bmm(torch.bmm(K, S), K.transpose(-1,-2))
        P = spd_project_eig(P, eps=1e-6)

        xs.append(x); Ps.append(P)

    return torch.stack(xs, dim=1), torch.stack(Ps, dim=1)


# =========================
# EKF baseline (demo only, needs Jacobians)
# =========================
@torch.no_grad()
def batch_ekf(y, dt, g_over_L, damping, Q, R, x0, P0):
    """
    y=[sin(theta), cos(theta)]
    H = [[cos(theta), 0],
         [-sin(theta),0]]
    """
    device = y.device
    B, T, m = y.shape
    n = 2
    I = torch.eye(n, device=device, dtype=y.dtype).unsqueeze(0).expand(B, -1, -1)

    x = x0
    P = spd_project_eig(P0, eps=1e-6)

    xs, Ps = [], []
    for t in range(T):
        theta = x[:,0]
        omega = x[:,1]

        # predict
        theta_p = wrap_angle(theta + dt*omega)
        omega_p = omega + dt*(-g_over_L*torch.sin(theta) - damping*omega)
        x_pred = torch.stack([theta_p, omega_p], dim=-1)

        F = torch.zeros(B,2,2,device=device,dtype=y.dtype)
        F[:,0,0] = 1.0
        F[:,0,1] = dt
        F[:,1,0] = -dt*g_over_L*torch.cos(theta)
        F[:,1,1] = 1.0 - dt*damping

        P_pred = F @ P @ F.transpose(-1,-2) + Q.unsqueeze(0).expand(B,-1,-1)
        P_pred = spd_project_eig(P_pred, eps=1e-6)

        # update
        y_pred = torch.stack([torch.sin(theta_p), torch.cos(theta_p)], dim=-1)  # [B,2]
        e = y[:,t,:] - y_pred

        H = torch.zeros(B,2,2,device=device,dtype=y.dtype)
        H[:,0,0] = torch.cos(theta_p)
        H[:,1,0] = -torch.sin(theta_p)
        H[:,0,1] = 0.0
        H[:,1,1] = 0.0

        S = H @ P_pred @ H.transpose(-1,-2) + R.unsqueeze(0).expand(B,-1,-1)
        S = spd_project_eig(S, eps=1e-9)

        Ls = robust_cholesky(S, eps=1e-9)
        # K = P_pred H^T S^{-1}
        KT = torch.cholesky_solve((P_pred @ H.transpose(-1,-2)).transpose(1,2), Ls)  # [B,2,2]
        K  = KT.transpose(1,2)

        x = x_pred + torch.bmm(K, e.unsqueeze(-1)).squeeze(-1)
        x = torch.stack([wrap_angle(x[:,0]), x[:,1]], dim=-1)

        # Joseph form
        KH = torch.bmm(K, H)
        A = I - KH
        P = A @ P_pred @ A.transpose(-1,-2) + torch.bmm(torch.bmm(K, R.unsqueeze(0).expand(B,-1,-1)), K.transpose(-1,-2))
        P = spd_project_eig(P, eps=1e-6)

        xs.append(x); Ps.append(P)

    return torch.stack(xs, dim=1), torch.stack(Ps, dim=1)


# =========================
# UKN: learn R_t via Cholesky (CKN-style PD)
# =========================
class UKNet_RCholesky(nn.Module):
    """
    UKF core + GRU predicts Cholesky factor of R_t (lower-triangular with positive diag).
    Q is kept constant (user gives Q_base).
    """
    def __init__(
        self,
        n: int, m: int,
        R_base: torch.Tensor,            # [m,m] SPD
        hidden_size: int = 64,
        ut_params=(0.2, 2.0, 0.0),
        diag_clip: float = 3.0,          # clip on log-diag delta
        off_max: float = 0.25,           # tanh scaling for off-diag delta
        r_min: float = 1e-6              # minimum diag magnitude
    ):
        super().__init__()
        self.n = n
        self.m = m
        self.alpha, self.beta, self.kappa = ut_params
        self.diag_clip = diag_clip
        self.off_max = off_max
        self.r_min = r_min

        # Base Cholesky of R_base
        with torch.no_grad():
            Lb = torch.linalg.cholesky(R_base)
            logd = torch.log(torch.diagonal(Lb))
        self.register_buffer("L_base", Lb)         # [m,m]
        self.register_buffer("log_diag_base", logd) # [m]

        # Feature dimension: e, de, dy, |e|, diagP, diagS
        # e/de/dy/|e|: 4m, diagP: n, diagS: m  => 5m + n
        in_dim = 5*m + n

        # "FC - GRU - FC" (you asked for FC on both sides)
        self.fc_pre = nn.Sequential(
            nn.Linear(in_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
        )
        self.gru = nn.GRUCell(hidden_size, hidden_size)

        # outputs: diag deltas (m) + off-diag deltas (m(m-1)/2)
        k = m + (m*(m-1))//2
        self.fc_post = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, k),
        )

        # start near UKF baseline
        nn.init.zeros_(self.fc_post[-1].weight)
        nn.init.zeros_(self.fc_post[-1].bias)

    def build_LR(self, params: torch.Tensor) -> torch.Tensor:
        """
        params: [B, k], k = m + m(m-1)/2
        return L_R: [B,m,m] lower-tri with positive diagonal
        """
        B = params.shape[0]
        m = self.m

        diag_raw = params[:, :m]                   # [B,m]
        off_raw  = params[:, m:]                   # [B, m(m-1)/2]

        # diagonal: exp(log_diag_base + clipped_delta) + r_min
        d = torch.clamp(diag_raw, -self.diag_clip, self.diag_clip)
        diag = torch.exp(self.log_diag_base.view(1,m) + d) + self.r_min  # [B,m]

        # start from base
        L = self.L_base.view(1,m,m).expand(B,-1,-1).clone()

        # set diag (out-of-place-ish on a fresh clone)
        idx = torch.arange(m, device=L.device)
        L[:, idx, idx] = diag

        # set lower off-diagonals
        k = 0
        for i in range(1, m):
            for j in range(i):
                delta = torch.tanh(off_raw[:, k]) * self.off_max
                L[:, i, j] = self.L_base[i, j] + delta
                k += 1

        return L

    def forward(
        self,
        y: torch.Tensor,              # [B,T,m]
        f_fn, h_fn,
        x0: torch.Tensor,             # [B,n]
        P0: torch.Tensor,             # [B,n,n]
        Q_base: torch.Tensor,         # [n,n] (constant)
        DIFF_THROUGH_FH: bool = False
    ):
        device = y.device
        B, T, m = y.shape
        n = self.n

        x = x0
        P = spd_project_eig(P0, eps=1e-6)

        h = torch.zeros(B, self.gru.hidden_size, device=device, dtype=y.dtype)
        e_prev = torch.zeros(B, m, device=device, dtype=y.dtype)
        y_prev = torch.zeros(B, m, device=device, dtype=y.dtype)

        x_list, P_list, R_list, param_list = [], [], [], []

        # feature covariances (tiny) for stable diagP/diagS features
        Q_feat = torch.eye(n, device=device, dtype=y.dtype).unsqueeze(0).expand(B,-1,-1) * 1e-6
        R_feat = torch.eye(m, device=device, dtype=y.dtype).unsqueeze(0).expand(B,-1,-1) * 1e-6

        for t in range(T):
            Xi, Wm, Wc = sigma_points(x, P, self.alpha, self.beta, self.kappa)

            # f/h evaluation: allow black-box by default (no grad through them)
            if DIFF_THROUGH_FH:
                X_pred = f_fn(Xi)
                Y_pred0 = h_fn(X_pred)
            else:
                with torch.no_grad():
                    X_pred = f_fn(Xi)
                    Y_pred0 = h_fn(X_pred)

            # provisional stats for features
            x_pred0, P_pred0 = ut_mean_cov(X_pred, Wm, Wc, noise=Q_feat)
            y_pred0, S0      = ut_mean_cov(Y_pred0, Wm, Wc, noise=R_feat)

            e0 = y[:, t, :] - y_pred0
            de = e0 - e_prev
            dy = y[:, t, :] - y_prev
            ae = torch.abs(e0)

            diagP = torch.diagonal(P_pred0, dim1=-2, dim2=-1)  # [B,n]
            diagS = torch.diagonal(S0,      dim1=-2, dim2=-1)  # [B,m]

            z = torch.cat([e0, de, dy, ae, diagP.detach(), diagS.detach()], dim=-1)  # [B,5m+n]
            z = self.fc_pre(z)
            h = self.gru(z, h)
            params = self.fc_post(h)  # [B,k]

            Lr = self.build_LR(params)                # [B,m,m]
            R_t = Lr @ Lr.transpose(-1,-2)            # [B,m,m] SPD guaranteed
            R_t = spd_project_eig(R_t, eps=1e-9)

            # full UKF predict/update using Q_base & learned R_t
            x_pred, P_pred = ut_mean_cov(X_pred, Wm, Wc, noise=Q_base)  # Q_base broadcast inside ut_mean_cov

            if DIFF_THROUGH_FH:
                Y_pred = h_fn(X_pred)
            else:
                with torch.no_grad():
                    Y_pred = h_fn(X_pred)

            y_pred, S = ut_mean_cov(Y_pred, Wm, Wc, noise=R_t)

            Pxy = cross_cov(X_pred, Y_pred, x_pred, y_pred, Wc)

            Ls = robust_cholesky(S, eps=1e-9)
            KT = torch.cholesky_solve(Pxy.transpose(1,2), Ls)  # [B,m,n]
            K  = KT.transpose(1,2)                              # [B,n,m]

            e = y[:, t, :] - y_pred
            x = x_pred + torch.bmm(K, e.unsqueeze(-1)).squeeze(-1)
            x = torch.stack([wrap_angle(x[:,0]), x[:,1]], dim=-1)

            P = P_pred - torch.bmm(torch.bmm(K, S), K.transpose(-1,-2))
            P = spd_project_eig(P, eps=1e-6)

            x_list.append(x)
            P_list.append(P)
            R_list.append(R_t)
            param_list.append(params)

            e_prev = e.detach()
            y_prev = y[:, t, :].detach()

        return (
            torch.stack(x_list, dim=1),         # xhat [B,T,n]
            torch.stack(P_list, dim=1),         # P_hist [B,T,n,n]
            torch.stack(R_list, dim=1),         # R_hist [B,T,m,m]
            torch.stack(param_list, dim=1),     # raw params [B,T,k]
        )


# =========================
# Losses + plots
# =========================
def mse_state(xhat, xtrue):
    e = state_error(xtrue, xhat)
    return (e**2).mean()

def nll_from_P(xtrue, xhat, P_hist, eps=1e-9):
    """
    Gaussian NLL (up to constant):
    0.5*(e^T P^{-1} e + logdet(P))
    Can be negative if logdet(P) < 0 and errors small -> that's OK.
    """
    e = state_error(xtrue, xhat)                # [B,T,2]
    B, T, n = e.shape
    e2 = e.reshape(B*T, n, 1)
    P2 = P_hist.reshape(B*T, n, n)
    P2 = spd_project_eig(P2, eps=1e-6)
    L  = robust_cholesky(P2, eps=1e-6)

    sol = torch.cholesky_solve(e2, L)
    maha = (e2.transpose(1,2) @ sol).reshape(B, T)

    diag = torch.diagonal(L, dim1=-2, dim2=-1)
    logdet = 2.0 * torch.sum(torch.log(diag + eps), dim=-1).reshape(B, T)

    return 0.5 * (maha + logdet).mean()

@torch.no_grad()
def plot_error_ci_99(x_true_1, x_hat_1, P_1, title_prefix=""):
    z99 = 2.5758293035489004
    T = x_true_1.shape[0]
    t = np.arange(T)

    err_theta = ((x_true_1[:,0] - x_hat_1[:,0] + np.pi) % (2*np.pi)) - np.pi
    err_omega = (x_true_1[:,1] - x_hat_1[:,1])
    err = np.stack([err_theta, err_omega], axis=1)

    var = np.stack([np.diag(P_1[k]) for k in range(T)], axis=0)
    std = np.sqrt(np.maximum(var, 1e-12))
    band = z99 * std

    names = ["theta", "omega"]
    for i in range(2):
        plt.figure()
        plt.plot(t, err[:,i], label=f"error ({names[i]})")
        plt.fill_between(t, -band[:,i], band[:,i], alpha=0.2, label="99% CI from P")
        plt.axhline(0.0, linewidth=1)
        plt.xlabel("Time step")
        plt.ylabel("Error")
        plt.title(f"{title_prefix} Error + 99% CI ({names[i]})")
        plt.legend()
        plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
        plt.show()


# =========================
# Training demo
# =========================
def train_demo():
    set_seed(0)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float32

    # -------------------------
    # Problem settings (you can tune)
    # -------------------------
    dt = 0.2
    g_over_L = 4.0       # make motion faster (avoid "too slow")
    damping = 0.05

    T = 100              # longer horizon -> easier to separate filters
    N_train, N_test = 6000, 1200

    # True process noise
    Q_true = torch.diag(torch.tensor([2e-4, 8e-4], device=device, dtype=dtype))

    # Outlier mixture (measurement)
    r_nom = 0.05
    r_out = 0.8
    p_out = 0.08

    # Baseline assumed covs (intentionally optimistic for measurement)
    Q_base = 5.0 * Q_true
    R_base = (r_nom**2) * torch.eye(2, device=device, dtype=dtype)  # UKF baseline uses this fixed R

    ut_params = (0.2, 2.0, 0.0)

    print("device =", device)
    print(f"dt={dt}, g/L={g_over_L}, damping={damping}, T={T}")
    print("Q_true diag =", torch.diag(Q_true).detach().cpu().numpy())
    print("Q_base diag =", torch.diag(Q_base).detach().cpu().numpy())
    print("R_base diag =", torch.diag(R_base).detach().cpu().numpy())
    print(f"outliers: p_out={p_out}, r_nom={r_nom}, r_out={r_out}\n")

    # f/h (vectorized over sigma points)
    def f_fn(X):
        th = X[...,0]
        om = X[...,1]
        th2 = wrap_angle(th + dt*om)
        om2 = om + dt*(-g_over_L*torch.sin(th) - damping*om)
        return torch.stack([th2, om2], dim=-1)

    def h_fn(X):
        th = X[...,0]
        return torch.stack([torch.sin(th), torch.cos(th)], dim=-1)

    # data
    x_tr, y_tr, out_tr = sample_pendulum(N_train, T, dt, g_over_L, damping, Q_true, r_nom, r_out, p_out, device)
    x_te, y_te, out_te = sample_pendulum(N_test,  T, dt, g_over_L, damping, Q_true, r_nom, r_out, p_out, device)

    train_loader = DataLoader(SeqDataset(x_tr, y_tr, out_tr), batch_size=128, shuffle=True, num_workers=0)
    test_loader  = DataLoader(SeqDataset(x_te, y_te, out_te), batch_size=256, shuffle=False, num_workers=0)

    # init prior
    x0 = torch.zeros(1,2, device=device, dtype=dtype)
    P0 = torch.diag(torch.tensor([1.0, 1.0], device=device, dtype=dtype)).unsqueeze(0)

    # model
    model = UKNet_RCholesky(
        n=2, m=2,
        R_base=R_base,
        hidden_size=64,
        ut_params=ut_params,
        diag_clip=3.0,
        off_max=0.25,
        r_min=1e-6
    ).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)

    # -------------------------
    # Epoch control: 늘리고 싶으면 여기!
    # -------------------------
    num_epochs = 60   # <<<<<<<<<<<<<<<<<<<<<<<< 여기서 epoch 늘리면 됨

    beta_nll = 0.10   # total = (1-beta)*MSE + beta*NLL
    lam_smooth = 1e-3 # R-parameter smoothness penalty (optional but helpful)

    DIFF_THROUGH_FH = False  # black-box/FEA/ODE면 False 추천

    train_hist, test_hist = [], []

    for ep in range(1, num_epochs + 1):
        model.train()
        tot_loss = 0.0
        tot_mse  = 0.0

        for xb, yb, _ in train_loader:
            B = xb.shape[0]
            x0b = x0.expand(B,-1).contiguous()
            P0b = P0.expand(B,-1,-1).contiguous()

            opt.zero_grad()

            xhat, P_hist, _, param_hist = model(yb, f_fn, h_fn, x0b, P0b, Q_base=Q_base, DIFF_THROUGH_FH=DIFF_THROUGH_FH)

            loss_mse = mse_state(xhat, xb)
            loss_nll = nll_from_P(xb, xhat, P_hist)

            # smoothness penalty on params (discourage "Rt jumping")
            dp = param_hist[:,1:,:] - param_hist[:,:-1,:]
            loss_smooth = (dp**2).mean()

            loss = (1.0 - beta_nll)*loss_mse + beta_nll*loss_nll + lam_smooth*loss_smooth

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            tot_loss += loss.item() * B
            tot_mse  += loss_mse.item() * B

        train_loss = tot_loss / len(train_loader.dataset)
        train_mse  = tot_mse  / len(train_loader.dataset)

        # test MSE
        model.eval()
        with torch.no_grad():
            tot = 0.0
            for xb, yb, _ in test_loader:
                B = xb.shape[0]
                x0b = x0.expand(B,-1).contiguous()
                P0b = P0.expand(B,-1,-1).contiguous()
                xhat, _, _, _ = model(yb, f_fn, h_fn, x0b, P0b, Q_base=Q_base, DIFF_THROUGH_FH=DIFF_THROUGH_FH)
                tot += mse_state(xhat, xb).item() * B
            test_mse = tot / len(test_loader.dataset)

        train_hist.append(train_loss)
        test_hist.append(test_mse)

        if ep % 5 == 0 or ep == 1:
            print(f"Epoch {ep:03d} | Train total={train_loss:.6f} (MSE={train_mse:.6f}) | Test MSE={test_mse:.6f}")

    # plot losses
    plt.figure()
    plt.plot(np.arange(1, num_epochs+1), train_hist, label="Train total")
    plt.plot(np.arange(1, num_epochs+1), test_hist, label="Test MSE")
    plt.xlabel("Epoch"); plt.ylabel("Loss/MSE")
    plt.title("Training history (UKN: R_t via Cholesky + SPD layer)")
    plt.legend(); plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
    plt.show()

    # final comparison
    model.eval()
    with torch.no_grad():
        B = y_te.shape[0]
        x0b = x0.expand(B,-1).contiguous()
        P0b = P0.expand(B,-1,-1).contiguous()

        # UKN
        x_ukn, P_ukn, _, _ = model(y_te, f_fn, h_fn, x0b, P0b, Q_base=Q_base, DIFF_THROUGH_FH=DIFF_THROUGH_FH)
        mse_ukn = mse_state(x_ukn, x_te).item()

        # UKF fixed
        x_ukf, P_ukf = batch_ukf(y_te, f_fn, h_fn, Q_base, R_base, x0b, P0b, ut_params=ut_params, DIFF_THROUGH_FH=False)
        mse_ukf = mse_state(x_ukf, x_te).item()

        # EKF fixed (demo)
        x_ekf, P_ekf = batch_ekf(y_te, dt, g_over_L, damping, Q_base, R_base, x0b, P0b)
        mse_ekf = mse_state(x_ekf, x_te).item()

        print("\n===== Final Test MSE =====")
        print(f"UKN : {mse_ukn:.6e}")
        print(f"UKF : {mse_ukf:.6e}")
        print(f"EKF : {mse_ekf:.6e}")

        # sample trajectory
        idx = 0
        t = np.arange(T)

        x_true_1 = x_te[idx].cpu().numpy()
        x_ukn_1  = x_ukn[idx].cpu().numpy()
        x_ukf_1  = x_ukf[idx].cpu().numpy()
        x_ekf_1  = x_ekf[idx].cpu().numpy()

        names = ["theta", "omega"]
        for i in range(2):
            plt.figure()
            plt.plot(t, x_true_1[:,i], label=f"True {names[i]}")
            plt.plot(t, x_ukn_1[:,i],  label=f"UKN  {names[i]}")
            plt.plot(t, x_ukf_1[:,i],  label=f"UKF  {names[i]}")
            plt.plot(t, x_ekf_1[:,i],  label=f"EKF  {names[i]}")
            plt.xlabel("Time step"); plt.ylabel(names[i])
            plt.title(f"Sample trajectory (idx={idx}) - {names[i]}")
            plt.legend(); plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
            plt.show()

        # LAST PLOT: per-state error + 99% CI from P (UKN)
        P_ukn_1 = P_ukn[idx].cpu().numpy()
        plot_error_ci_99(x_true_1, x_ukn_1, P_ukn_1, title_prefix="UKN")

    return model


if __name__ == "__main__":
    train_demo()


device = cuda
dt=0.2, g/L=4.0, damping=0.05, T=100
Q_true diag = [0.0002 0.0008]
Q_base diag = [0.001 0.004]
R_base diag = [0.0025 0.0025]
outliers: p_out=0.08, r_nom=0.05, r_out=0.8

Epoch 001 | Train total=nan (MSE=0.752423) | Test MSE=0.538479


KeyboardInterrupt: 