In [None]:
"""
Nonlinear demo (dt=0.2, y=[sin(theta), cos(theta)]) where
UKN = UKF + learned time-varying R_t (Cholesky-param) and constant Q
compared with UKF (fixed R) and EKF (fixed R).

State: x = [theta, omega]^T
Dynamics:
  theta_{t+1} = theta_t + dt * omega_t
  omega_{t+1} = omega_t + dt * (-g/L*sin(theta_t) - c*omega_t) + w_t

Measurement:
  y_t = [sin(theta_t), cos(theta_t)] + v_t
Measurement noise is a mixture (outliers):
  with prob (1-p_out): v ~ N(0, r_nom^2 I)
  with prob p_out     : v ~ N(0, r_out^2 I)

UKF/EKF baselines assume fixed R_base ~ r_nom^2 I (optimistic vs outliers).
UKN learns R_t via Cholesky scheme (NOT scalar*R_base).
"""

import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from typing import Optional, Tuple, List


# -------------------------
# Repro / device
# -------------------------
def set_seed(seed: int = 0):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# -------------------------
# Angle utils
# -------------------------
def wrap_angle(a: torch.Tensor) -> torch.Tensor:
    """Wrap to [-pi, pi]."""
    return (a + math.pi) % (2 * math.pi) - math.pi


def state_error(x_true: torch.Tensor, x_hat: torch.Tensor) -> torch.Tensor:
    """
    x_true, x_hat: [...,2] -> error [...,2]
    theta error wrapped, omega error standard.
    """
    e = x_true - x_hat
    e_theta = wrap_angle(e[..., 0])
    e_omega = e[..., 1]
    return torch.stack([e_theta, e_omega], dim=-1)


# -------------------------
# SPD / numerics (autograd-safe)
# -------------------------
def symmetrize(P: torch.Tensor) -> torch.Tensor:
    return 0.5 * (P + P.transpose(-1, -2))


def ensure_spd_batch(P: torch.Tensor, eps: float = 1e-6, max_shift: float = 1e6) -> torch.Tensor:
    """
    Per-batch diagonal shift so min eigenvalue >= eps.
    Also sanitizes NaN/Inf.
    """
    P = symmetrize(P)

    if not torch.isfinite(P).all():
        P = torch.nan_to_num(P, nan=0.0, posinf=0.0, neginf=0.0)
        P = symmetrize(P)

    B, n, _ = P.shape
    I = torch.eye(n, device=P.device, dtype=P.dtype).unsqueeze(0).expand(B, -1, -1)

    with torch.no_grad():
        eigmin = torch.linalg.eigvalsh(P).min(dim=-1).values  # [B]
        shift = torch.clamp(eps - eigmin, min=0.0, max=max_shift)  # [B]

    return P + shift.view(B, 1, 1) * I


def robust_cholesky(P: torch.Tensor, eps: float = 1e-6, tries: int = 7) -> torch.Tensor:
    """
    Robust batch Cholesky: retries with increasing eps.
    Avoid in-place patching -> safer for autograd.
    """
    last_err = None
    for k in range(tries):
        P2 = ensure_spd_batch(P, eps=eps * (10.0 ** k))
        P2 = symmetrize(P2)
        try:
            return torch.linalg.cholesky(P2)
        except RuntimeError as e:
            last_err = e
    raise last_err


def l2_normalize(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    n = torch.norm(x, dim=-1, keepdim=True)
    return x / (n + eps)


# -------------------------
# Nonlinear data generation (pendulum + outliers) with y=[sin,cos]
# -------------------------
def sample_pendulum_sequences_sincos(
    num_seq: int,
    T: int,
    dt: float,
    g_over_L: float,
    damping: float,
    Q_true: torch.Tensor,  # [2,2]
    r_nom: float,
    r_out: float,
    p_out: float,
    device: str = "cpu",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Returns:
      x: [N,T,2]  (theta, omega)
      y: [N,T,2]  ([sin(theta), cos(theta)] + noise)
      out_mask: [N,T,1]  (1 if outlier used)
    """
    N = num_seq
    n = 2
    m = 2

    LQ = torch.linalg.cholesky(Q_true)

    x = torch.zeros(N, T, n, device=device)
    y = torch.zeros(N, T, m, device=device)
    out_mask = torch.zeros(N, T, 1, device=device)

    # init (조금 더 움직이게)
    x[:, 0, 0] = (torch.rand(N, device=device) * 2.0 - 1.0) * math.pi  # theta in [-pi,pi]
    x[:, 0, 1] = torch.randn(N, device=device) * 1.0                   # omega

    for t in range(T):
        theta = x[:, t, 0]
        omega = x[:, t, 1]

        # measurement noise mixture (same sigma for both channels)
        is_out = (torch.rand(N, device=device) < p_out).float().view(N, 1)
        sigma = r_nom + is_out * (r_out - r_nom)  # [N,1]
        v = torch.randn(N, m, device=device) * sigma  # [N,2]
        out_mask[:, t, :] = is_out

        y[:, t, 0] = torch.sin(theta) + v[:, 0]
        y[:, t, 1] = torch.cos(theta) + v[:, 1]

        if t < T - 1:
            theta_next = theta + dt * omega
            omega_next = omega + dt * (-g_over_L * torch.sin(theta) - damping * omega)

            w = torch.randn(N, n, device=device) @ LQ.T
            x[:, t + 1, 0] = wrap_angle(theta_next + w[:, 0])
            x[:, t + 1, 1] = omega_next + w[:, 1]

    return x, y, out_mask


class SeqDataset(torch.utils.data.Dataset):
    def __init__(self, x: torch.Tensor, y: torch.Tensor, out_mask: torch.Tensor):
        self.x = x
        self.y = y
        self.out_mask = out_mask

    def __len__(self) -> int:
        return self.x.shape[0]

    def __getitem__(self, idx: int):
        return self.x[idx], self.y[idx], self.out_mask[idx]


# -------------------------
# Unscented Transform
# -------------------------
def sigma_points(x, P, alpha=0.2, beta=2.0, kappa=0.0):
    """
    x: [B,n], P:[B,n,n]
    returns Xi:[B,2n+1,n], Wm,Wc:[2n+1]
    """
    B, n = x.shape
    lam = alpha**2 * (n + kappa) - n
    c = n + lam
    gamma = math.sqrt(max(c, 1e-12))

    Wm = x.new_zeros(2 * n + 1)
    Wc = x.new_zeros(2 * n + 1)
    Wm[0] = lam / c
    Wc[0] = lam / c + (1 - alpha**2 + beta)
    Wm[1:] = 1.0 / (2 * c)
    Wc[1:] = 1.0 / (2 * c)

    P = ensure_spd_batch(P, eps=1e-6)
    S = robust_cholesky(P, eps=1e-6)  # [B,n,n]

    S_scaled = gamma * S
    U = S_scaled.transpose(1, 2)  # [B,n,n]
    x0 = x.unsqueeze(1)           # [B,1,n]
    Xi = torch.cat([x0, x0 + U, x0 - U], dim=1)  # [B,2n+1,n]
    return Xi, Wm, Wc


def unscented_mean_cov(X, Wm, Wc, noise=None, eps=1e-6):
    """
    X: [B,L,d]
    return mean:[B,d], cov:[B,d,d]
    noise can be [d,d] or [B,d,d]
    """
    B, L, d = X.shape
    mean = torch.sum(Wm.view(1, L, 1) * X, dim=1)

    Xm = X - mean.unsqueeze(1)
    cov = torch.zeros(B, d, d, device=X.device, dtype=X.dtype)
    for i in range(L):
        wi = Wc[i]
        vi = Xm[:, i, :].unsqueeze(-1)
        cov = cov + wi * (vi @ vi.transpose(-1, -2))

    if noise is not None:
        if noise.dim() == 2:
            cov = cov + noise.unsqueeze(0)
        else:
            cov = cov + noise

    cov = symmetrize(cov)
    cov = ensure_spd_batch(cov, eps=eps)
    return mean, cov


def cross_cov(X, Y, x_mean, y_mean, Wc):
    """
    X:[B,L,n], Y:[B,L,m] -> Pxy:[B,n,m]
    """
    B, L, n = X.shape
    m = Y.shape[-1]
    Xc = X - x_mean.unsqueeze(1)
    Yc = Y - y_mean.unsqueeze(1)
    Pxy = torch.zeros(B, n, m, device=X.device, dtype=X.dtype)
    for i in range(L):
        wi = Wc[i]
        xi = Xc[:, i, :].unsqueeze(-1)   # [B,n,1]
        yi = Yc[:, i, :].unsqueeze(-1)   # [B,m,1]
        Pxy = Pxy + wi * (xi @ yi.transpose(-1, -2))
    return Pxy


# -------------------------
# Baseline UKF (fixed Q, fixed R)
# -------------------------
@torch.no_grad()
def batch_ukf_filter(
    y: torch.Tensor,          # [B,T,m]
    f_fn,
    h_fn,
    Q: torch.Tensor,          # [2,2]
    R: torch.Tensor,          # [m,m]
    x0: torch.Tensor,         # [B,2]
    P0: torch.Tensor,         # [B,2,2]
    ut_params=(0.2, 2.0, 0.0),
) -> Tuple[torch.Tensor, torch.Tensor]:
    alpha, beta, kappa = ut_params
    B, T, _ = y.shape

    x = x0
    P = ensure_spd_batch(P0, eps=1e-6)

    x_list, P_list = [], []
    for t in range(T):
        P = ensure_spd_batch(P, eps=1e-6)
        Xi, Wm, Wc = sigma_points(x, P, alpha, beta, kappa)

        X_pred = f_fn(Xi)
        x_pred, P_pred = unscented_mean_cov(X_pred, Wm, Wc, noise=Q)

        Y_pred = h_fn(X_pred)
        y_pred, S = unscented_mean_cov(Y_pred, Wm, Wc, noise=R)

        Pxy = cross_cov(X_pred, Y_pred, x_pred, y_pred, Wc)

        S = ensure_spd_batch(S, eps=1e-9)
        Ls = robust_cholesky(S, eps=1e-9)
        KT = torch.cholesky_solve(Pxy.transpose(1, 2), Ls)  # [B,m,n]
        K = KT.transpose(1, 2)  # [B,n,m]

        e = y[:, t, :] - y_pred
        x = x_pred + torch.bmm(K, e.unsqueeze(-1)).squeeze(-1)
        x = torch.stack([wrap_angle(x[:, 0]), x[:, 1]], dim=-1)

        P = P_pred - torch.bmm(torch.bmm(K, S), K.transpose(-1, -2))
        P = ensure_spd_batch(symmetrize(P), eps=1e-6)

        x_list.append(x)
        P_list.append(P)

    return torch.stack(x_list, dim=1), torch.stack(P_list, dim=1)


# -------------------------
# Baseline EKF (fixed Q, fixed R) for y=[sin,cos]
# -------------------------
@torch.no_grad()
def batch_ekf_filter_sincos(
    y: torch.Tensor,          # [B,T,2]
    dt: float,
    g_over_L: float,
    damping: float,
    Q: torch.Tensor,          # [2,2]
    R: torch.Tensor,          # [2,2]
    x0: torch.Tensor,         # [B,2]
    P0: torch.Tensor,         # [B,2,2]
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    EKF for pendulum with measurement h(x)=[sin(theta), cos(theta)].
    Joseph form covariance update for stability.
    """
    device = y.device
    B, T, m = y.shape
    n = 2
    I = torch.eye(n, device=device, dtype=y.dtype).unsqueeze(0).expand(B, -1, -1)

    x = x0
    P = ensure_spd_batch(P0, eps=1e-6)

    x_list, P_list = [], []

    Rb = R.unsqueeze(0).expand(B, -1, -1)  # [B,2,2]
    Qb = Q.unsqueeze(0).expand(B, -1, -1)  # [B,2,2]

    for t in range(T):
        theta = x[:, 0]
        omega = x[:, 1]

        # ---- Predict ----
        theta_pred = wrap_angle(theta + dt * omega)
        omega_pred = omega + dt * (-g_over_L * torch.sin(theta) - damping * omega)
        x_pred = torch.stack([theta_pred, omega_pred], dim=-1)

        # Jacobian F=df/dx at x
        F = torch.zeros(B, 2, 2, device=device, dtype=y.dtype)
        F[:, 0, 0] = 1.0
        F[:, 0, 1] = dt
        F[:, 1, 0] = -dt * g_over_L * torch.cos(theta)
        F[:, 1, 1] = 1.0 - dt * damping

        P_pred = F @ P @ F.transpose(-1, -2) + Qb
        P_pred = ensure_spd_batch(P_pred, eps=1e-6)

        # ---- Update ----
        y_pred = torch.stack([torch.sin(theta_pred), torch.cos(theta_pred)], dim=-1)  # [B,2]
        e = y[:, t, :] - y_pred  # [B,2]

        # H = dh/dx at x_pred
        # h1=sin(theta) -> [cos(theta), 0]
        # h2=cos(theta) -> [-sin(theta),0]
        H = torch.zeros(B, 2, 2, device=device, dtype=y.dtype)
        H[:, 0, 0] = torch.cos(theta_pred)
        H[:, 0, 1] = 0.0
        H[:, 1, 0] = -torch.sin(theta_pred)
        H[:, 1, 1] = 0.0

        S = H @ P_pred @ H.transpose(-1, -2) + Rb
        S = ensure_spd_batch(S, eps=1e-9)
        Ls = robust_cholesky(S, eps=1e-9)

        # K^T = S^{-1} (H P_pred)
        HP = H @ P_pred  # [B,2,2]
        KT = torch.cholesky_solve(HP, Ls)  # [B,2,2]
        K = KT.transpose(1, 2)             # [B,2,2]

        x = x_pred + torch.bmm(K, e.unsqueeze(-1)).squeeze(-1)
        x = torch.stack([wrap_angle(x[:, 0]), x[:, 1]], dim=-1)

        # Joseph: P=(I-KH)P_pred(I-KH)^T + K R K^T
        KH = torch.bmm(K, H)
        A = I - KH
        P = A @ P_pred @ A.transpose(-1, -2) + torch.bmm(torch.bmm(K, Rb), K.transpose(-1, -2))
        P = ensure_spd_batch(symmetrize(P), eps=1e-6)

        x_list.append(x)
        P_list.append(P)

    return torch.stack(x_list, dim=1), torch.stack(P_list, dim=1)


# -------------------------
# Cholesky parameterization helper for R_t
# -------------------------
def build_L_from_params(p: torch.Tensor, m: int, eps_diag: float) -> torch.Tensor:
    """
    p: [B, pdim] where pdim = m + m(m-1)/2
    Returns L: [B,m,m] lower-triangular with positive diag via softplus + eps.
    Order: diag then strict-lower (row-wise).
    """
    B = p.shape[0]
    L = p.new_zeros(B, m, m)

    # diag part
    diag_raw = p[:, :m]
    diag = torch.nn.functional.softplus(diag_raw) + eps_diag
    for i in range(m):
        L[:, i, i] = diag[:, i]

    # off-diag part
    off = p[:, m:]
    k = 0
    for i in range(1, m):
        for j in range(i):
            L[:, i, j] = off[:, k]
            k += 1

    return L


# -------------------------
# UKNet: learn time-varying R_t via Cholesky; Q is constant
# (GRU with pre/post FC layers)
# -------------------------
class UKNet_RCholesky(nn.Module):
    def __init__(
        self,
        n: int,
        m: int,
        hidden_size: int = 64,
        emb_size: int = 64,
        ut_params=(0.2, 2.0, 0.0),
        clip: float = 2.5,
        delta_max: float = 0.15,
        eps_diag: float = 1e-4,
        r_min: float = 1e-6,          # <<< FIX: allow r_min as in your call
        dropout: float = 0.0,
    ):
        super().__init__()
        self.n = n
        self.m = m
        self.alpha, self.beta, self.kappa = ut_params

        self.clip = clip
        self.delta_max = delta_max
        self.eps_diag = eps_diag
        self.r_min = r_min

        self.pdim = m + (m * (m - 1)) // 2

        # features: e(m), de(m), dy(m), |e|(m), diagP(n), diagS(m) -> 5m + n
        in_dim = 5 * m + n

        # Pre-MLP (GRU input)
        self.pre = nn.Sequential(
            nn.Linear(in_dim, emb_size),
            nn.LayerNorm(emb_size),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(emb_size, emb_size),
            nn.SiLU(),
        )

        self.gru = nn.GRUCell(emb_size, hidden_size)

        # Post-MLP (GRU output)
        self.post = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size),
            nn.SiLU(),
        )

        self.fc_dp = nn.Linear(hidden_size, self.pdim)

        nn.init.zeros_(self.fc_dp.weight)
        nn.init.zeros_(self.fc_dp.bias)

    def forward(self, y, f_fn, h_fn, x0, P0, Q_const: torch.Tensor):
        """
        y: [B,T,m], x0:[B,n], P0:[B,n,n], Q_const:[n,n]
        returns xhat:[B,T,n], P_hist:[B,T,n,n], p_hist:[B,T,pdim], R_hist:[B,T,m,m]
        """
        device = y.device
        B, T, m = y.shape
        n = self.n

        x = x0
        P = ensure_spd_batch(P0, eps=1e-6)

        h = torch.zeros(B, self.gru.hidden_size, device=device, dtype=y.dtype)

        e_prev = torch.zeros(B, m, device=device, dtype=y.dtype)
        y_prev = torch.zeros(B, m, device=device, dtype=y.dtype)

        # p is the Cholesky-parameter vector for R
        p = torch.zeros(B, self.pdim, device=device, dtype=y.dtype)

        x_list, P_list, p_list, R_list = [], [], [], []

        # tiny covs for feature preview stability
        Q_feat = torch.eye(n, device=device, dtype=y.dtype) * 1e-6
        R_feat = torch.eye(m, device=device, dtype=y.dtype) * 1e-6

        Qb = Q_const.unsqueeze(0).expand(B, -1, -1)  # [B,n,n]

        for t in range(T):
            P = ensure_spd_batch(P, eps=1e-6)

            Xi, Wm, Wc = sigma_points(x, P, self.alpha, self.beta, self.kappa)
            X_pred = f_fn(Xi)

            # preview for features
            x_pred0, P_pred0 = unscented_mean_cov(X_pred, Wm, Wc, noise=Q_feat)
            Y_pred0 = h_fn(X_pred)
            y_pred0, S0 = unscented_mean_cov(Y_pred0, Wm, Wc, noise=R_feat)

            e0 = y[:, t, :] - y_pred0
            de = e0 - e_prev
            dy = y[:, t, :] - y_prev
            ae = torch.abs(e0)

            diagP = torch.diagonal(P_pred0, dim1=-2, dim2=-1)
            diagS = torch.diagonal(S0, dim1=-2, dim2=-1)

            z = torch.cat([
                l2_normalize(e0),
                l2_normalize(de),
                l2_normalize(dy),
                l2_normalize(ae),
                l2_normalize(diagP.detach()),
                l2_normalize(diagS.detach()),
            ], dim=-1)

            z_emb = self.pre(z)
            h = self.gru(z_emb, h)
            h2 = self.post(h)

            dp = torch.tanh(self.fc_dp(h2)) * self.delta_max
            p = torch.clamp(p + dp, -self.clip, self.clip)

            L = build_L_from_params(p, m=self.m, eps_diag=self.eps_diag)
            R_t = torch.bmm(L, L.transpose(-1, -2))
            R_t = ensure_spd_batch(R_t, eps=self.r_min)  # <<< enforce min eig

            # UKF predict/update with Q_const and R_t
            x_pred, P_pred = unscented_mean_cov(X_pred, Wm, Wc, noise=Qb)

            Y_pred = h_fn(X_pred)
            y_pred, S = unscented_mean_cov(Y_pred, Wm, Wc, noise=R_t)

            Pxy = cross_cov(X_pred, Y_pred, x_pred, y_pred, Wc)

            S = ensure_spd_batch(S, eps=1e-9)
            Ls = robust_cholesky(S, eps=1e-9)
            KT = torch.cholesky_solve(Pxy.transpose(1, 2), Ls)  # [B,m,n]
            K = KT.transpose(1, 2)  # [B,n,m]

            innov = y[:, t, :] - y_pred
            x = x_pred + torch.bmm(K, innov.unsqueeze(-1)).squeeze(-1)
            x = torch.stack([wrap_angle(x[:, 0]), x[:, 1]], dim=-1)

            P = P_pred - torch.bmm(torch.bmm(K, S), K.transpose(-1, -2))
            P = ensure_spd_batch(symmetrize(P), eps=1e-6)

            x_list.append(x)
            P_list.append(P)
            p_list.append(p)
            R_list.append(R_t)

            e_prev = innov.detach()
            y_prev = y[:, t, :].detach()

        return (
            torch.stack(x_list, dim=1),
            torch.stack(P_list, dim=1),
            torch.stack(p_list, dim=1),
            torch.stack(R_list, dim=1),
        )


# -------------------------
# Loss: state NLL from P (theta wrapped)
# -------------------------
def state_nll_from_P(e: torch.Tensor, P: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
    """
    e:[B,T,2], P:[B,T,2,2]
    NLL = 0.5*(e^T P^{-1} e + logdet(P))
    """
    B, T, n = e.shape
    e2 = e.reshape(B * T, n, 1)
    P2 = P.reshape(B * T, n, n)

    P2 = ensure_spd_batch(P2, eps=1e-6)
    L = robust_cholesky(P2, eps=1e-6)

    sol = torch.cholesky_solve(e2, L)  # [B*T,n,1]
    maha = (e2.transpose(1, 2) @ sol).reshape(B, T)  # [B,T]

    diag = torch.diagonal(L, dim1=-2, dim2=-1)
    logdet = 2.0 * torch.sum(torch.log(diag + eps), dim=-1)  # [B*T]
    logdet = logdet.reshape(B, T)

    return 0.5 * (maha + logdet).mean()


def rmse_over_time(xhat: torch.Tensor, x_true: torch.Tensor) -> torch.Tensor:
    e = state_error(x_true, xhat)  # [N,T,2]
    mse_t = (e ** 2).mean(dim=(0, 2))
    return torch.sqrt(mse_t + 1e-12)


def mse_state(xhat: torch.Tensor, x_true: torch.Tensor) -> torch.Tensor:
    e = state_error(x_true, xhat)
    return (e ** 2).mean()


# -------------------------
# Plots + coverage
# -------------------------
def plot_loss_history(train_hist: List[float], test_hist: List[float], title: str):
    epochs = np.arange(1, len(train_hist) + 1)
    plt.figure()
    plt.plot(epochs, train_hist, label="Train total")
    plt.plot(epochs, test_hist, label="Test MSE")
    plt.xlabel("Epoch")
    plt.ylabel("Loss / MSE")
    plt.title(title)
    plt.legend()
    plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
    plt.show()


@torch.no_grad()
def plot_error_with_ci_99(x_true_1, x_hat_1, P_1, title_prefix=""):
    """
    Plot per-state error with 99% CI band from covariance.
    x_true_1: [T,2], x_hat_1: [T,2], P_1: [T,2,2]
    """
    z99 = 2.5758293035489004
    T = x_true_1.shape[0]
    t = np.arange(T)

    # error (wrap theta)
    err_theta = ((x_true_1[:, 0] - x_hat_1[:, 0] + np.pi) % (2*np.pi)) - np.pi
    err_omega = (x_true_1[:, 1] - x_hat_1[:, 1])
    err = np.stack([err_theta, err_omega], axis=1)

    var = np.stack([np.diag(P_1[k]) for k in range(T)], axis=0)  # [T,2]
    std = np.sqrt(np.maximum(var, 1e-12))
    band = z99 * std

    names = ["theta", "omega"]
    for i in range(2):
        plt.figure()
        plt.plot(t, err[:, i], label=f"error ({names[i]})")
        plt.fill_between(t, -band[:, i], band[:, i], alpha=0.2, label="99% CI from P")
        plt.axhline(0.0, linewidth=1)
        plt.xlabel("Time step")
        plt.ylabel("Error")
        plt.title(f"{title_prefix} Error + 99% CI ({names[i]})")
        plt.legend()
        plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
        plt.show()


@torch.no_grad()
def coverage_99(x_true: torch.Tensor, x_hat: torch.Tensor, P: torch.Tensor) -> np.ndarray:
    z99 = 2.5758293035489004
    e = state_error(x_true, x_hat)  # [N,T,2]
    var = torch.diagonal(P, dim1=-2, dim2=-1)  # [N,T,2]
    band = z99 * torch.sqrt(torch.clamp(var, min=1e-12))
    inside = (torch.abs(e) <= band).float().mean(dim=(0, 1))  # [2]
    return inside.cpu().numpy()


# -------------------------
# Train + compare (dt=0.2, y=[sin,cos])
# -------------------------
def train_pendulum_demo_dt02_y_sincos():
    set_seed(0)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float32

    # --- Nonlinear problem params ---
    dt = 0.2
    g_over_L = 1.5
    damping = 0.03

    # UT params
    ut_params = (0.25, 2.0, 0.0)

    # True process noise
    Q_true = torch.diag(torch.tensor([5e-4, 2e-3], device=device, dtype=dtype))

    # Measurement outlier mixture
    r_nom = 0.08
    r_out = 1.00
    p_out = 0.06

    # Baseline assumed covariances (fixed)
    gamma_Q = 8.0
    Q_base = gamma_Q * Q_true
    R_base = torch.eye(2, device=device, dtype=dtype) * (r_nom ** 2)

    print("=== Adopted Nonlinear Problem (dt=0.2): Pendulum + Outlier Measurements + y=[sin,cos] ===")
    print(f"device={device}")
    print(f"dt={dt}, g/L={g_over_L}, damping={damping}")
    print("Q_true diag =", torch.diag(Q_true).detach().cpu().numpy())
    print("Q_base diag =", torch.diag(Q_base).detach().cpu().numpy(), f"(gamma_Q={gamma_Q})")
    print("R_base diag =", torch.diag(R_base).detach().cpu().numpy(), "(~ r_nom^2 I)")
    print(f"Measurement noise: nominal std={r_nom}, outlier std={r_out}, p_out={p_out}")
    print()

    # f/h for sigma points
    def f_fn(X):
        # X:[B,L,2]
        theta = X[..., 0]
        omega = X[..., 1]
        theta_next = wrap_angle(theta + dt * omega)
        omega_next = omega + dt * (-g_over_L * torch.sin(theta) - damping * omega)
        return torch.stack([theta_next, omega_next], dim=-1)

    def h_fn(X):
        theta = X[..., 0]
        return torch.stack([torch.sin(theta), torch.cos(theta)], dim=-1)  # [B,L,2]

    # --- Data ---
    T = 80
    N_train, N_test = 7000, 1400

    x_train, y_train, out_train = sample_pendulum_sequences_sincos(
        N_train, T, dt, g_over_L, damping, Q_true, r_nom, r_out, p_out, device=device
    )
    x_test, y_test, out_test = sample_pendulum_sequences_sincos(
        N_test, T, dt, g_over_L, damping, Q_true, r_nom, r_out, p_out, device=device
    )

    train_loader = DataLoader(SeqDataset(x_train, y_train, out_train), batch_size=128, shuffle=True, num_workers=0)
    test_loader = DataLoader(SeqDataset(x_test, y_test, out_test), batch_size=256, shuffle=False, num_workers=0)

    # --- Init x0, P0 ---
    x0 = torch.zeros(1, 2, device=device, dtype=dtype)
    P0 = torch.diag(torch.tensor([1.0, 1.0], device=device, dtype=dtype)).unsqueeze(0)

    # --- Model (m=2) ---
    model = UKNet_RCholesky(
        n=2, m=2,
        hidden_size=64,
        emb_size=64,
        ut_params=ut_params,
        clip=2.5,
        delta_max=0.10,    # 튐 억제하려면 0.05~0.10 추천
        eps_diag=1e-4,
        r_min=1e-6,        # <<< 이제 에러 없이 받음
        dropout=0.05,
    ).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)

    # losses
    beta_cov = 0.10         # MSE vs NLL mixing
    lam_dp = 5e-4           # dp penalty (R_t 너무 튀는 것 억제)
    num_epochs = 40         # <<< epoch 늘리려면 여기

    train_hist_total: List[float] = []
    test_hist_mse: List[float] = []

    # --------- epoch loop starts ----------
    for ep in range(1, num_epochs + 1):
        model.train()
        tot_loss = 0.0
        tot_mse = 0.0

        for xb, yb, _ in train_loader:
            B = xb.shape[0]
            x0b = x0.expand(B, -1).contiguous()
            P0b = P0.expand(B, -1, -1).contiguous()

            opt.zero_grad()

            xhat, P_hist, p_hist, _ = model(yb, f_fn, h_fn, x0b, P0b, Q_const=Q_base)

            loss_state = mse_state(xhat, xb)
            e = state_error(xb, xhat)
            loss_cov = state_nll_from_P(e, P_hist)

            # dp penalty: p_hist[t]-p_hist[t-1]
            dp = p_hist[:, 1:, :] - p_hist[:, :-1, :]
            loss_dp = (dp ** 2).mean()

            loss = (1.0 - beta_cov) * loss_state + beta_cov * loss_cov + lam_dp * loss_dp

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            tot_loss += loss.item() * B
            tot_mse += loss_state.item() * B

        train_total = tot_loss / len(train_loader.dataset)
        train_mse = tot_mse / len(train_loader.dataset)

        # test MSE
        model.eval()
        with torch.no_grad():
            tot = 0.0
            for xb, yb, _ in test_loader:
                B = xb.shape[0]
                x0b = x0.expand(B, -1).contiguous()
                P0b = P0.expand(B, -1, -1).contiguous()
                xhat, _, _, _ = model(yb, f_fn, h_fn, x0b, P0b, Q_const=Q_base)
                tot += mse_state(xhat, xb).item() * B
            test_mse = tot / len(test_loader.dataset)

        train_hist_total.append(train_total)
        test_hist_mse.append(test_mse)

        print(f"Epoch {ep:02d} | Train total: {train_total:.6f} (MSE={train_mse:.6f}) | Test MSE: {test_mse:.6f}")
    # --------- epoch loop ends ----------

    plot_loss_history(train_hist_total, test_hist_mse, title="Pendulum(dt=0.2, sincos): UKN(R-Cholesky) - Loss History")

    # --- Final comparison ---
    model.eval()
    with torch.no_grad():
        B = y_test.shape[0]
        x0b = x0.expand(B, -1).contiguous()
        P0b = P0.expand(B, -1, -1).contiguous()

        # UKN
        xhat_ukn, P_ukn, p_hist, R_hist = model(y_test, f_fn, h_fn, x0b, P0b, Q_const=Q_base)
        ukn_mse = mse_state(xhat_ukn, x_test).item()

        # UKF fixed
        xhat_ukf, P_ukf = batch_ukf_filter(
            y=y_test, f_fn=f_fn, h_fn=h_fn, Q=Q_base, R=R_base, x0=x0b, P0=P0b, ut_params=ut_params
        )
        ukf_mse = mse_state(xhat_ukf, x_test).item()

        # EKF fixed
        xhat_ekf, P_ekf = batch_ekf_filter_sincos(
            y=y_test, dt=dt, g_over_L=g_over_L, damping=damping,
            Q=Q_base, R=R_base, x0=x0b, P0=P0b
        )
        ekf_mse = mse_state(xhat_ekf, x_test).item()

        print("\n===== Final (Test Set) =====")
        print(f"UKN (R-Chol) MSE: {ukn_mse:.6e}")
        print(f"UKF (fixed)  MSE: {ukf_mse:.6e}")
        print(f"EKF (fixed)  MSE: {ekf_mse:.6e}")

        # RMSE(t)
        rmse_ukn = rmse_over_time(xhat_ukn, x_test).cpu().numpy()
        rmse_ukf = rmse_over_time(xhat_ukf, x_test).cpu().numpy()
        rmse_ekf = rmse_over_time(xhat_ekf, x_test).cpu().numpy()

        t = np.arange(len(rmse_ukn))
        plt.figure()
        plt.plot(t, rmse_ukn, label="UKN RMSE(t)")
        plt.plot(t, rmse_ukf, label="UKF RMSE(t)")
        plt.plot(t, rmse_ekf, label="EKF RMSE(t)")
        plt.xlabel("Time step")
        plt.ylabel("RMSE")
        plt.title("RMSE over time (test avg)")
        plt.legend()
        plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
        plt.show()

        # 99% coverage
        cov_ukn = coverage_99(x_test, xhat_ukn, P_ukn)
        cov_ukf = coverage_99(x_test, xhat_ukf, P_ukf)
        cov_ekf = coverage_99(x_test, xhat_ekf, P_ekf)
        print("\n99% CI Coverage (fraction inside band)")
        print(f"UKN: theta={cov_ukn[0]:.3f}, omega={cov_ukn[1]:.3f}")
        print(f"UKF: theta={cov_ukf[0]:.3f}, omega={cov_ukf[1]:.3f}")
        print(f"EKF: theta={cov_ekf[0]:.3f}, omega={cov_ekf[1]:.3f}")

        # Sample trajectories + learned R diag
        sample_idx = 0
        x_true_1 = x_test[sample_idx].cpu().numpy()
        x_ukn_1 = xhat_ukn[sample_idx].cpu().numpy()
        x_ukf_1 = xhat_ukf[sample_idx].cpu().numpy()
        x_ekf_1 = xhat_ekf[sample_idx].cpu().numpy()

        names = ["theta", "omega"]
        for i in range(2):
            plt.figure()
            plt.plot(t, x_true_1[:, i], label=f"True {names[i]}")
            plt.plot(t, x_ukn_1[:, i], label=f"UKN  {names[i]}")
            plt.plot(t, x_ukf_1[:, i], label=f"UKF  {names[i]}")
            plt.plot(t, x_ekf_1[:, i], label=f"EKF  {names[i]}")
            plt.xlabel("Time step")
            plt.ylabel(names[i])
            plt.title(f"Sample trajectory (idx={sample_idx}) - {names[i]}")
            plt.legend()
            plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
            plt.show()

        # outlier mask vs learned R diag
        out_1 = out_test[sample_idx].cpu().numpy().squeeze(-1)  # [T]
        Rdiag_1 = torch.diagonal(R_hist[sample_idx], dim1=-2, dim2=-1).cpu().numpy()  # [T,2]
        plt.figure()
        plt.plot(t, Rdiag_1[:, 0], label="R_t[0,0]")
        plt.plot(t, Rdiag_1[:, 1], label="R_t[1,1]")
        plt.plot(t, out_1 * max(Rdiag_1.max(), 1e-6), label="outlier mask (scaled)")
        plt.xlabel("Time step")
        plt.title("Sample: learned R diag vs outlier mask")
        plt.legend()
        plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
        plt.show()

        # LAST PLOT: error + 99% CI band from P (UKN)
        P_ukn_1 = P_ukn[sample_idx].cpu().numpy()
        plot_error_with_ci_99(x_true_1, x_ukn_1, P_ukn_1, title_prefix="UKN")

    return model


if __name__ == "__main__":
    train_pendulum_demo_dt02_y_sincos()


=== Adopted Nonlinear Problem (dt=0.2): Pendulum + Outlier Measurements + y=[sin,cos] ===
device=cpu
dt=0.2, g/L=1.5, damping=0.03
Q_true diag = [0.0005 0.002 ]
Q_base diag = [0.004 0.016] (gamma_Q=8.0)
R_base diag = [0.0064 0.0064] (~ r_nom^2 I)
Measurement noise: nominal std=0.08, outlier std=1.0, p_out=0.06

Epoch 01 | Train total: 1.738306 (MSE=1.060322) | Test MSE: 1.104195
Epoch 02 | Train total: 1.754052 (MSE=1.067480) | Test MSE: 1.126536
Epoch 03 | Train total: 1.829426 (MSE=1.123686) | Test MSE: 1.221302
Epoch 04 | Train total: 1.888822 (MSE=1.164145) | Test MSE: 1.253752
Epoch 05 | Train total: 1.861185 (MSE=1.193488) | Test MSE: 1.186606
Epoch 06 | Train total: 1.695645 (MSE=1.333348) | Test MSE: 1.458809
Epoch 07 | Train total: 1.726726 (MSE=1.449187) | Test MSE: 1.444406
Epoch 08 | Train total: 1.705366 (MSE=1.377525) | Test MSE: 1.391967
Epoch 09 | Train total: 1.659776 (MSE=1.271194) | Test MSE: 1.182407
Epoch 10 | Train total: 1.731441 (MSE=1.184647) | Test MSE: 1.1729