In [2]:
# 0. imports
import time, numpy as np, pandas as pd, torch
import matplotlib.pyplot as plt
from scipy.integrate import odeint

from scripts.magix.dynamic import nnSTModule, nnMTModule  # MAGI-X NN module
from scripts.magix.inference import FMAGI                 # MAGI-X inference


In [1]:
# 1. UQ metric utilities

# ---------- UQ metric helpers ----------
def interp_truth_to(t_truth, X_truth, t_eval):
    D = X_truth.shape[1]
    X_eval = np.column_stack([np.interp(t_eval, t_truth, X_truth[:, d]) for d in range(D)])
    return X_eval

def percentile_band(samples, alpha=0.10):
    lo = np.percentile(samples, 100*(alpha/2), axis=0)
    hi = np.percentile(samples, 100*(1 - alpha/2), axis=0)
    return lo, hi

def coverage_and_width(samples, t_full, t_truth, X_truth, t_end_fit, alpha=0.10):
    """
    samples: (N, T, D)  stochastic trajectories (e.g., MC-dropout)
    t_full:  (T,)       time grid of samples (initial point included)
    t_truth: (T0,)      truth grid
    X_truth: (T0,D)     truth values
    """
    X_eval = interp_truth_to(t_truth, X_truth, t_full)  # (T,D)
    lo, hi = percentile_band(samples, alpha=alpha)      # (T,D)
    width  = hi - lo
    inside = (X_eval >= lo) & (X_eval <= hi)

    fit_mask = t_full <= t_end_fit
    fct_mask = ~fit_mask

    def summarize(mask):
        cov   = inside[mask].mean()
        w     = width[mask].mean()
        # normalized width: divide per-dim by dynamic range in the region
        Xm    = X_eval[mask]
        rng   = Xm.max(axis=0) - Xm.min(axis=0)
        rng[rng == 0] = 1.0
        w_norm = (width[mask].mean(axis=0) / rng).mean()
        return cov, w, w_norm

    cov_fit, w_fit, wN_fit = summarize(fit_mask)
    cov_fct, w_fct, wN_fct = summarize(fct_mask)

    # per-dim breakdown (handy for appendix tables)
    per_dim = []
    D = X_truth.shape[1]
    for d in range(D):
        pdict = dict(
            cov_fit = inside[fit_mask, d].mean(),
            cov_fct = inside[fct_mask, d].mean(),
            w_fit   = width[fit_mask, d].mean(),
            w_fct   = width[fct_mask, d].mean(),
        )
        per_dim.append(pdict)

    return dict(
        lo=lo, hi=hi, X_eval=X_eval,
        coverage_fit=cov_fit, coverage_fct=cov_fct,
        width_fit=w_fit, width_fct=w_fct,
        width_norm_fit=wN_fit, width_norm_fct=wN_fct,
        per_dim=per_dim
    )


In [9]:
# 2) MC-dropout sampler for MAGI-X

# ---------- MC-dropout sampler ----------
def mc_dropout_samples(model, fOde, trecon, xinfer, N=50):
    """
    Returns:
      t_full: (T,)
      samples: (N, T, D)
    """
    # keep dropout ON
    fOde.train()
    x0 = xinfer[0, :].squeeze()

    # one deterministic call to fix t_full and T
    tr1, xr1 = model.predict(trecon[1:], trecon[:1], x0, random=True)
    t_full = np.concatenate([trecon[:1], tr1])
    D = xr1.shape[1]
    T = t_full.size

    # collect stochastic draws (dropout is active)
    draws = np.empty((N, T, D), dtype=float)
    draws[0] = np.vstack([x0.reshape(1, -1), xr1])

    for i in range(1, N):
        tri, xri = model.predict(trecon[1:], trecon[:1], x0, random=True)
        draws[i] = np.vstack([x0.reshape(1, -1), xri])

    return t_full, draws


In [4]:
# 3) ODEs + truth simulators (FN / LV (log) / Hes1)

# ---------- Dynamics ----------
def FN(y, t, a, b, c):
    V, R = y
    dVdt = c * (V - V**3/3.0 + R)
    dRdt = -1.0/c * (V - a + b*R)
    return (dVdt, dRdt)

def LV_log(y, t, a, b, c, d):
    # state in log-space
    x1, x2 = np.exp(y)
    dx1dt = a*x1 - b*x1*x2
    dx2dt = c*x1*x2 - d*x2
    return [dx1dt/x1, dx2dt/x2]  # chain rule

def Hes1(y, t, a, b, c, d, e, f, g):
    P, M, H = y
    dPdt = -a*P*H + b*M - c*P
    dMdt = -d*M + e/(1 + P**2)
    dHdt = -a*P*H + f/(1 + P**2) - g*H
    return (dPdt, dMdt, dHdt)

# ---------- Truth simulators ----------
def simulate_FN():
    a, b, c = 0.2, 0.2, 3.0
    V0, R0 = -1.0, 1.0
    t = np.linspace(0, 40, 1281)
    X = odeint(FN, (V0, R0), t, args=(a, b, c))
    return t, X

def simulate_LV_log():
    a, b, c, d = 1.5, 1.0, 1.0, 3.0
    x1_0, x2_0 = 5.0, 0.2
    y0 = np.log([x1_0, x2_0])
    t = np.linspace(0, 12, 321)
    Y = odeint(LV_log, y0, t, args=(a, b, c, d))  # log-state
    X = np.column_stack([np.exp(Y[:,0]), np.exp(Y[:,1])])  # back to original scale
    return t, X

def simulate_Hes1():
    a, b, c, d, e, f, g = 0.022, 0.3, 0.031, 0.028, 0.5, 20.0, 0.3
    P0, M0, H0 = 1.438575, 2.037488, 17.90385
    t = np.linspace(0, 640, 1281)
    X = odeint(Hes1, (P0, M0, H0), t, args=(a, b, c, d, e, f, g))
    return t, X


In [5]:
# 4) Observation maker (same style you used)

def make_observations(t, X, no_train, noise, seed=0):
    """
    t: (T,), X: (T,D)
    noise: list-like of length D (std dev per component)
    returns: list of arrays [(n_i, 2) ...] with columns (t, y_obs)
    and the last observation time (fit boundary).
    """
    np.random.seed(seed)
    T, D = X.shape
    # observe first half of the horizon unless you change obs_idx
    obs_idx = np.linspace(0, (T-1)//2, no_train).astype(int)
    obs = []
    for d in range(D):
        tobs = t[obs_idx].copy()
        yobs = X[obs_idx, d].copy() + np.random.normal(0, noise[d], size=no_train)
        obs.append(np.c_[tobs, yobs])
    t_end_fit = max(o[:,0].max() for o in obs)
    return obs, t_end_fit


In [6]:
# 5) One-system pipeline (fit → samples → metrics)

def run_system(system_name,
               simulate_fn,
               nn_hidden=512,
               grid_size=161,
               max_epoch=2500,
               lr=1e-3,
               dropout_p=0.1,
               N_train=41,
               noise=None,
               seed=188714368,
               N_samples=50):
    """
    Returns a dict with metrics + a small table row for the paper.
    """
    # 1) truth & observations
    t_truth, X_truth = simulate_fn()
    D = X_truth.shape[1]
    if noise is None: noise = [0.1]*D
    obs, t_end_fit = make_observations(t_truth, X_truth, N_train, noise, seed=seed)

    # 2) fit MAGI-X (MT = multi-task)
    torch.manual_seed(seed)
    fOde = nnMTModule(D, [nn_hidden], dp=dropout_p)
    model = FMAGI(obs, fOde, grid_size=grid_size, interpolation_orders=3)

    # inference (MAP trajectory)
    trecon = t_truth[np.linspace(0, t_truth.size-1, 321).astype(int)]
    _, xinfer = model.map(max_epoch=max_epoch,
                          learning_rate=lr, decay_learning_rate=True,
                          hyperparams_update=False, dynamic_standardization=True,
                          verbose=False, returnX=True)

    # 3) MC-dropout samples
    t_full, samples = mc_dropout_samples(model, fOde, trecon, xinfer, N=N_samples)

    # 4) metrics
    m = coverage_and_width(samples, t_full, t_truth, X_truth, t_end_fit, alpha=0.10)

    # 5) compact rows for paper table
    row_fit = dict(system=system_name, region='fit',
                   coverage_90=m['coverage_fit'], width=m['width_fit'], width_norm=m['width_norm_fit'],
                   N_samples=N_samples, T=t_full.size, D=D, N_train=N_train)
    row_fct = dict(system=system_name, region='forecast',
                   coverage_90=m['coverage_fct'], width=m['width_fct'], width_norm=m['width_norm_fct'],
                   N_samples=N_samples, T=t_full.size, D=D, N_train=N_train)

    return dict(metrics=m, t_full=t_full, samples=samples, obs=obs,
                t_truth=t_truth, X_truth=X_truth,
                rows=[row_fit, row_fct])


In [7]:
# 6) Orchestrator: run FN / LV / Hes1 and print a paper-ready table
def run_all_systems(save_csv_path=None):
    # You can tweak N_train and noise per system to match your experiments
    jobs = [
        dict(name='FN',   sim=simulate_FN,     N_train=41,  noise=[0.1, 0.1]),
        dict(name='LV',   sim=simulate_LV_log, N_train=41,  noise=[0.1, 0.1]),
        dict(name='Hes1', sim=simulate_Hes1,   N_train=81,  noise=[0.1, 0.1, 0.1]),
    ]
    rows = []
    per_system = {}
    for jb in jobs:
        out = run_system(jb['name'], jb['sim'],
                         N_train=jb['N_train'], noise=jb['noise'],
                         dropout_p=0.1, N_samples=50,
                         nn_hidden=512, grid_size=161, max_epoch=2500, lr=1e-3)
        rows.extend(out['rows'])
        per_system[jb['name']] = out  # keep everything (for plots)

    df = pd.DataFrame(rows)
    # nice formatting for paper
    df_display = df.copy()
    df_display['coverage_90'] = (100*df_display['coverage_90']).map(lambda x: f"{x:5.1f}%")
    df_display['width']       = df_display['width'].map(lambda x: f"{x:.3f}")
    df_display['width_norm']  = df_display['width_norm'].map(lambda x: f"{x:.3f}")

    # order columns
    df_display = df_display[['system','region','coverage_90','width','width_norm','N_train','N_samples','T','D']]

    print("\n=== Coverage@90% and Band Width (Fit vs Forecast) ===")
    print(df_display.to_string(index=False))

    if save_csv_path:
        df.to_csv(save_csv_path, index=False)
        print(f"\n[Saved raw metrics to {save_csv_path}]")

    return df, per_system


In [10]:
df_metrics, out = run_all_systems(save_csv_path="uq_metrics_fn_lv_hes1.csv")



=== Coverage@90% and Band Width (Fit vs Forecast) ===
system   region coverage_90              width         width_norm  N_train  N_samples   T  D
    FN      fit       66.7%              0.166              0.054       41         50 322  2
    FN forecast       92.5%              0.288              0.097       41         50 322  2
    LV      fit       94.8%           2044.380            298.722       41         50 322  2
    LV forecast       70.3% 96930221511194.078 14135303168782.176       41         50 322  2
  Hes1      fit       90.3%              3.403              0.419       81         50 322  3
  Hes1 forecast       95.6%              8.297              0.990       81         50 322  3

[Saved raw metrics to uq_metrics_fn_lv_hes1.csv]
