In [1]:
import math
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim


In [2]:
# ==========================================
# 0. CUDA 설정
# ==========================================
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# (선택) matmul precision
if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")

pin_mem = (device.type == "cuda")

Using device: cuda


In [3]:
# ==========================================
# 1. Linear System Dataset (CPU 생성)
# ==========================================
class LinearSystemDataset:
    """
    x_{t+1} = F x_t + w_t,  w ~ N(0,Q)
    y_t     = H x_t + v_t,  v ~ N(0,R)

    여기서는 q_scale, r_scale를 "표준편차"로 보고 Q=(q_scale^2)I, R=(r_scale^2)I로 둠.
    """
    def __init__(self, T=50, N=600, m=2, n=2, q_scale=0.1, r_scale=1.0):
        self.T = T
        self.N = N
        self.m = m
        self.n = n

        theta = np.pi / 20  # rotation angle
        self.F = torch.tensor([[np.cos(theta), -np.sin(theta)],
                               [np.sin(theta),  np.cos(theta)]], dtype=torch.float32)
        self.H = torch.eye(n, m, dtype=torch.float32)

        self.Q = (q_scale**2) * torch.eye(m, dtype=torch.float32)
        self.R = (r_scale**2) * torch.eye(n, dtype=torch.float32)

    def generate_data(self):
        X = torch.zeros(self.N, self.T, self.m, dtype=torch.float32)
        Y = torch.zeros(self.N, self.T, self.n, dtype=torch.float32)

        x_curr = torch.randn(self.N, self.m, dtype=torch.float32)

        LQ = torch.linalg.cholesky(self.Q)
        LR = torch.linalg.cholesky(self.R)

        for t in range(self.T):
            w = torch.randn(self.N, self.m) @ LQ.T
            x_next = (self.F @ x_curr.T).T + w

            v = torch.randn(self.N, self.n) @ LR.T
            y_curr = (self.H @ x_next.T).T + v

            X[:, t, :] = x_next
            Y[:, t, :] = y_curr
            x_curr = x_next

        return X, Y

In [None]:
# ==========================================
# 2. UCKN Model (Unscented + Cholesky PDEL)
# ==========================================
def safe_cholesky(P, jitter=1e-6):
    # 항상 약간의 jitter를 더해 안정화 (분기/try-except 없이 autograd friendly)
    m = P.shape[-1]
    I = torch.eye(m, device=P.device, dtype=P.dtype)
    P = 0.5 * (P + P.transpose(-1, -2))
    return torch.linalg.cholesky(P + jitter * I)

class UnscentedCholeskyKalmanNet(nn.Module):
    """
    - Unscented Transform로 prior 예측
    - GRU 기반으로 Kalman Gain K_t 예측
    - CKN 스타일 PDEL: Cholesky 파라미터화로 P_post SPD 보장
    """
    def __init__(self, m, n, F, H, Q, R,
                 alpha=1, beta=0.0, kappa=0.0,
                 hidden_dim=None,
                 pdel_clip=5.0,
                 diag_min=1e-4):
        super().__init__()
        self.m = m
        self.n = n

        # system matrices (buffers so .to(device) moves them)
        self.register_buffer("F_sys", F.clone())
        self.register_buffer("H_sys", H.clone())
        self.register_buffer("Q_sys", Q.clone())
        self.register_buffer("R_sys", R.clone())

        # UT params
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.kappa = float(kappa)

        lam = self.alpha**2 * (m + self.kappa) - m
        c = m + lam
        if c <= 0:
            raise ValueError(f"Unscented params invalid: m+lambda must be > 0, got {c}. "
                             f"Try smaller alpha (e.g., 1e-3~0.2).")
        self.lam = lam
        self.c = c
        self.scale = math.sqrt(c)

        # weights (register as buffers)
        Wm = torch.full((2*m + 1,), 1.0/(2.0*c), dtype=torch.float32)
        Wc = torch.full((2*m + 1,), 1.0/(2.0*c), dtype=torch.float32)
        Wm[0] = lam / c
        Wc[0] = lam / c + (1.0 - self.alpha**2 + self.beta)
        self.register_buffer("Wm", Wm)
        self.register_buffer("Wc", Wc)

        # features: innovation(n) + state_diff(m)
        feat_dim = n + m

        if hidden_dim is None:
            hidden_dim = m * 10  # user style

        # "GRU 양쪽 FC" 구조: pre_fc -> GRUCell -> post_fc
        self.pre_fc = nn.Sequential(
            nn.Linear(feat_dim, hidden_dim),
            nn.ReLU(),
        )
        self.gru = nn.GRUCell(hidden_dim, hidden_dim)
        self.post_fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )

        self.fc_kg = nn.Linear(hidden_dim, m * n)

        self.pdel_dim = m * (m + 1) // 2
        self.fc_pdel = nn.Linear(hidden_dim, self.pdel_dim)

        # stability knobs
        self.pdel_clip = float(pdel_clip)
        self.diag_min = float(diag_min)

        # init small (안정)
        nn.init.zeros_(self.fc_kg.weight); nn.init.zeros_(self.fc_kg.bias)
        nn.init.zeros_(self.fc_pdel.weight); nn.init.zeros_(self.fc_pdel.bias)

    def pdel_layer(self, out_vec):
        """
        out_vec: [B, pdel_dim] -> P = L L^T, L lower-tri with softplus diag
        """
        B = out_vec.shape[0]
        m = self.m
        out_vec = torch.clamp(out_vec, -self.pdel_clip, self.pdel_clip)

        L = torch.zeros(B, m, m, device=out_vec.device, dtype=out_vec.dtype)
        idx = torch.tril_indices(row=m, col=m, offset=0, device=out_vec.device)
        L[:, idx[0], idx[1]] = out_vec

        # diag positive
        d = torch.arange(m, device=out_vec.device)
        L[:, d, d] = torch.nn.functional.softplus(L[:, d, d]) + self.diag_min

        P = L @ L.transpose(1, 2)
        P = 0.5 * (P + P.transpose(1, 2))
        return P

    def sigma_points(self, x, P):
        """
        x:[B,m], P:[B,m,m] -> Xi:[B,2m+1,m]
        """
        B, m = x.shape
        L = safe_cholesky(P, jitter=1e-6)          # [B,m,m]
        U = self.scale * L                         # [B,m,m]
        Ucols = U.transpose(1, 2)                  # [B,m,m] each row is a column-vector

        x0 = x.unsqueeze(1)                        # [B,1,m]
        Xi_plus = x0 + Ucols                       # [B,m,m]
        Xi_minus = x0 - Ucols                      # [B,m,m]
        Xi = torch.cat([x0, Xi_plus, Xi_minus], dim=1)  # [B,2m+1,m]
        return Xi

    def unscented_predict(self, x_post, P_post):
        """
        prior from UT:
        X_pred = f(Xi). 여기선 linear: f(x)=F x
        P_prior += Q
        """
        Xi = self.sigma_points(x_post, P_post)         # [B,L,m]
        X_pred = Xi @ self.F_sys.T                     # [B,L,m]

        Wm = self.Wm.view(1, -1, 1)                    # [1,L,1]
        Wc = self.Wc.view(1, -1, 1)                    # [1,L,1]

        x_prior = torch.sum(Wm * X_pred, dim=1)        # [B,m]
        diff = X_pred - x_prior.unsqueeze(1)           # [B,L,m]
        P_prior = diff.transpose(1, 2) @ (diff * Wc)   # [B,m,m]
        P_prior = 0.5 * (P_prior + P_prior.transpose(1, 2)) + self.Q_sys.unsqueeze(0)
        return x_prior, P_prior, X_pred

    def unscented_obs_predict(self, X_pred):
        """
        y_sig = h(X_pred). 여기선 linear: h(x)=H x
        y_prior = sum Wm * y_sig
        """
        y_sig = X_pred @ self.H_sys.T                  # [B,L,n]
        Wm = self.Wm.view(1, -1, 1)
        y_prior = torch.sum(Wm * y_sig, dim=1)         # [B,n]
        return y_prior

    def forward(self, Y, x_init, P_init):
        """
        Y:[B,T,n]
        returns x_est:[B,T,m], P_est:[B,T,m,m]
        """
        B, T, _ = Y.shape
        x_post = x_init
        P_post = P_init
        h = torch.zeros(B, self.gru.hidden_size, device=Y.device, dtype=Y.dtype)

        x_hist = []
        P_hist = []

        for t in range(T):
            x_prior, P_prior, X_pred = self.unscented_predict(x_post, P_post)
            y_prior = self.unscented_obs_predict(X_pred)

            innovation = Y[:, t, :] - y_prior          # [B,n]
            state_diff = x_prior - x_post              # [B,m]
            feat = torch.cat([innovation, state_diff], dim=1)  # [B,n+m]

            u = self.pre_fc(feat)
            h = self.gru(u, h)
            h2 = self.post_fc(h)

            K = self.fc_kg(h2).view(B, self.m, self.n)
            P_post = self.pdel_layer(self.fc_pdel(h2))

            correction = (K @ innovation.unsqueeze(-1)).squeeze(-1)  # [B,m]
            x_post = x_prior + correction

            x_hist.append(x_post)
            P_hist.append(P_post)

        return torch.stack(x_hist, dim=1), torch.stack(P_hist, dim=1)


In [5]:
# ==========================================
# 3. Loss: NLL + (small) MSE
# ==========================================
def nll_plus_mse(x_est, x_true, P_est, w_mse=0.5, eps=1e-5):
    """
    NLL = 0.5 * (e^T P^{-1} e + logdet(P))
    total = NLL + w_mse * MSE
    """
    B, T, m = x_est.shape
    e = (x_est - x_true)                              # [B,T,m]
    P = P_est                                         # [B,T,m,m]

    # flatten
    e2 = e.reshape(B*T, m, 1)
    P2 = P.reshape(B*T, m, m)

    I = torch.eye(m, device=P2.device, dtype=P2.dtype).unsqueeze(0)
    P2 = 0.5 * (P2 + P2.transpose(-1, -2)) + eps * I

    L = torch.linalg.cholesky(P2)                      # [B*T,m,m]
    sol = torch.cholesky_solve(e2, L)                  # [B*T,m,1]
    maha = (e2.transpose(1, 2) @ sol).squeeze(-1).squeeze(-1)  # [B*T]

    logdet = 2.0 * torch.log(torch.diagonal(L, dim1=-2, dim2=-1) + 1e-12).sum(dim=-1)  # [B*T]

    nll = 0.5 * (maha + logdet).mean()
    mse = (e**2).mean()
    loss = nll + w_mse * mse

    # NaN 보호 (혹시라도)
    loss = torch.nan_to_num(loss, nan=1e6, posinf=1e6, neginf=1e6)
    return loss

In [6]:
# ==========================================
# 4. Train / Eval (mini-batch + GPU friendly)
# ==========================================
def train_linear_demo(
    m=2, n=2,
    T=50,
    N_train=500,
    N_test=100,
    q_scale=0.1,
    r_scale=1.0,
    batch_size=128,
    test_batch_size=256,
    num_epochs=50,
    lr=1e-3,
    ut_alpha=0.5,
    ut_beta=2.0,
    ut_kappa=0.0,
    w_mse=0.05,
):
    # ---- data (CPU) ----
    dataset = LinearSystemDataset(T=T, N=N_train+N_test, m=m, n=n, q_scale=q_scale, r_scale=r_scale)
    X, Y = dataset.generate_data()   # CPU tensors
    F_sys, H_sys = dataset.F, dataset.H
    Q_sys, R_sys = dataset.Q, dataset.R

    X_train = X[:N_train]
    Y_train = Y[:N_train]
    X_test  = X[N_train:]
    Y_test  = Y[N_train:]

    train_loader = DataLoader(
        TensorDataset(Y_train, X_train),
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,          # Windows 안정
        pin_memory=pin_mem
    )
    test_loader = DataLoader(
        TensorDataset(Y_test, X_test),
        batch_size=test_batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=pin_mem
    )

    # ---- model ----
    model = UnscentedCholeskyKalmanNet(
        m, n, F_sys, H_sys, Q_sys, R_sys,
        alpha=ut_alpha, beta=ut_beta, kappa=ut_kappa,
        hidden_dim=m*10,
        pdel_clip=5.0,
        diag_min=1e-4,
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    I_m = torch.eye(m, device=device).unsqueeze(0)  # [1,m,m]

    train_loss_hist = []
    test_loss_hist  = []

    print("Training...")
    for epoch in range(1, num_epochs + 1):
        # -----------------
        # Train
        # -----------------
        model.train()
        total_train = 0.0
        n_train = 0

        for yb_cpu, xb_cpu in train_loader:
            yb = yb_cpu.to(device, non_blocking=True)
            xb = xb_cpu.to(device, non_blocking=True)
            B = yb.size(0)

            x_init = torch.zeros(B, m, device=device)
            P_init = I_m.expand(B, -1, -1).contiguous()

            optimizer.zero_grad(set_to_none=True)
            x_est, P_est = model(yb, x_init, P_init)

            loss = nll_plus_mse(x_est, xb, P_est, w_mse=w_mse, eps=1e-5)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_train += loss.item() * B
            n_train += B

        train_loss = total_train / max(n_train, 1)
        train_loss_hist.append(train_loss)

        # -----------------
        # Test
        # -----------------
        model.eval()
        total_test = 0.0
        n_test = 0
        with torch.no_grad():
            for yb_cpu, xb_cpu in test_loader:
                yb = yb_cpu.to(device, non_blocking=True)
                xb = xb_cpu.to(device, non_blocking=True)
                B = yb.size(0)

                x_init = torch.zeros(B, m, device=device)
                P_init = I_m.expand(B, -1, -1).contiguous()

                x_est, P_est = model(yb, x_init, P_init)
                loss_t = nll_plus_mse(x_est, xb, P_est, w_mse=w_mse, eps=1e-5)

                total_test += loss_t.item() * B
                n_test += B

        test_loss = total_test / max(n_test, 1)
        test_loss_hist.append(test_loss)

        if epoch % 10 == 0 or epoch == 1:
            print(f"Epoch {epoch:03d} | Train Loss: {train_loss:.6f} | Test Loss: {test_loss:.6f}")

    # ==========================================
    # 5. Plots: Training curves
    # ==========================================
    epochs = np.arange(1, num_epochs + 1)
    plt.figure()
    plt.plot(epochs, train_loss_hist, label="Train Loss")
    plt.plot(epochs, test_loss_hist, label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss (NLL + w*MSE)")
    plt.title("Training Curve")
    plt.legend()
    plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
    plt.show()

    # ==========================================
    # 6. Final eval on full test set (one pass)
    # ==========================================
    model.eval()
    with torch.no_grad():
        y_all = Y_test.to(device)
        x_all = X_test.to(device)
        B = y_all.size(0)

        x_init = torch.zeros(B, m, device=device)
        P_init = I_m.expand(B, -1, -1).contiguous()

        x_est_all, P_est_all = model(y_all, x_init, P_init)

    # CPU로 내림
    x_est_np = x_est_all.cpu().numpy()      # [N_test,T,m]
    x_true_np = X_test.numpy()
    P_np = P_est_all.cpu().numpy()          # [N_test,T,m,m]

    err = (x_true_np - x_est_np)            # [N_test,T,m]
    mean_err = err.mean(axis=0)             # [T,m]

    # CI band: time별로 std를 샘플 평균내서 사용 (user가 쓰던 방식 유지)
    var = np.diagonal(P_np, axis1=2, axis2=3)          # [N_test,T,m]
    std = np.sqrt(np.maximum(var, 1e-12))              # [N_test,T,m]
    mean_std = std.mean(axis=0)                        # [T,m]
    z99 = 2.5758293035489004
    ci = z99 * mean_std                                # [T,m]

    t = np.arange(T)

    # ==========================================
    # 7. Plot: all errors (light gray) + mean + 99% CI
    # ==========================================
    plt.figure(figsize=(12, 5))
    for i in range(m):
        plt.subplot(1, m, i+1)

        # 모든 샘플 error를 연한 회색으로
        for s in range(err.shape[0]):
            plt.plot(t, err[s, :, i], color="0.85", linewidth=0.7, alpha=0.5)

        # 95% CI band (0 중심)
        plt.fill_between(t, -ci[:, i], ci[:, i], alpha=0.2, label="95% CI (from P)")

        # mean error
        plt.plot(t, mean_err[:, i], label="Mean Error")

        plt.axhline(0.0, color="k", linestyle="--", alpha=0.5)
        plt.title(f"Error + 99% CI : state x{i+1}")
        plt.xlabel("Time step")
        plt.ylabel("Error")
        plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
        plt.legend()

    plt.tight_layout()
    plt.show()

    # ==========================================
    # 8. Plot: one sample trajectory + sample error/CI
    # ==========================================
    sample_idx = 0
    xs_true = x_true_np[sample_idx]     # [T,m]
    xs_est  = x_est_np[sample_idx]      # [T,m]
    Ps      = P_np[sample_idx]          # [T,m,m]
    err_s   = xs_true - xs_est          # [T,m]
    std_s   = np.sqrt(np.maximum(np.diagonal(Ps, axis1=1, axis2=2), 1e-12))  # [T,m]
    ci_s    = z99 * std_s               # [T,m]

    # trajectory
    for i in range(m):
        plt.figure()
        plt.plot(t, xs_true[:, i], label=f"True x{i+1}")
        plt.plot(t, xs_est[:, i], label=f"Est  x{i+1}")
        plt.xlabel("Time step")
        plt.ylabel(f"x{i+1}")
        plt.title(f"Sample Trajectory (idx={sample_idx})")
        plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
        plt.legend()
        plt.show()

    # sample error + CI
    for i in range(m):
        plt.figure()
        plt.plot(t, err_s[:, i], label=f"Error x{i+1}")
        plt.fill_between(t, -ci_s[:, i], ci_s[:, i], alpha=0.2, label="95% CI (from P)")
        plt.axhline(0.0, color="k", linestyle="--", alpha=0.5)
        plt.xlabel("Time step")
        plt.ylabel("Error")
        plt.title(f"Sample Error + 95% CI (idx={sample_idx}, x{i+1})")
        plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
        plt.legend()
        plt.show()

    return model, (train_loss_hist, test_loss_hist)

In [None]:
# ==========================================
# 9. Run
# ==========================================
if __name__ == "__main__":
    model, (train_hist, test_hist) = train_linear_demo(
        m=2, n=2,
        T=50,
        N_train=500,
        N_test=100,
        q_scale=0.1,
        r_scale=1.0,
        batch_size=128,
        test_batch_size=256,
        num_epochs=200,      # <-- epoch 늘리려면 여기
        lr=1e-3,
        ut_alpha=0.5,       # <-- 불안정하면 0.2 또는 1e-3로 낮춰보세요
        ut_beta=2.0,
        ut_kappa=0.0,
        w_mse=1.0
    )

Training...
Epoch 001 | Train Loss: 2.234144 | Test Loss: 0.868780
Epoch 010 | Train Loss: -0.420100 | Test Loss: -0.450373
Epoch 020 | Train Loss: -0.813008 | Test Loss: -0.826002
Epoch 030 | Train Loss: -0.988051 | Test Loss: -0.937741
Epoch 040 | Train Loss: -1.017907 | Test Loss: -0.969115
Epoch 050 | Train Loss: -1.037259 | Test Loss: -0.978675
Epoch 060 | Train Loss: -1.043826 | Test Loss: -0.984680
Epoch 070 | Train Loss: -1.049619 | Test Loss: -0.991547
Epoch 080 | Train Loss: -1.053043 | Test Loss: -0.996571
Epoch 090 | Train Loss: -1.055289 | Test Loss: -0.994144
