In [None]:
import numpy as np
import torch

%run ../../homework/setup/pltstyle
from tqdm.notebook import trange

from magi_psvgd.torch import magi
from magi_psvgd.torch.magi import MAGISolver

from magi_psvgd.tests.models import fitzhugh_nagumo as fn
from magi_psvgd.tests.make_test import ODEmodel
from magi_psvgd.tests import test_helpers

import time

In [None]:
MODEL = fn
model = ODEmodel(MODEL)
model.get_ode_solution(T=model.I.max())

X_names = ["V", "R"]
theta_names = [r"a", r"b", r"c"]
model_name = "FitzHugh-Nagumo"
logmodel = False

rounded = 3

In [None]:
test_helpers.check_gradients(MODEL.ode, MODEL.dfdx, MODEL.dfdtheta, n=10,
                             D=len(MODEL.hyperparameters["X0"]), p=len(MODEL.hyperparameters["theta"]), trials=100)

In [None]:
thetas_errs = []
x_errs = []
dims = []
times = []

n = 321 # test cases: n = 41, 81, 161, 321
I = np.linspace(0, 20, n)

B = 100
for _ in range(B):    
    data = model.generate_sample(random_seed=None)
    data_disc = test_helpers.discretize_data(data, I, rounded=rounded)
    
    start_time = time.time()
    
    magisolver = MAGISolver(
        ode=MODEL.ode, dfdx=MODEL.dfdx, dfdtheta=MODEL.dfdtheta, data=data_disc,
        theta_guess=np.array([1., 1., 1.]), theta_conf=0,
        sigmas=np.array([0.2, 0.2]),
        X_guess=1,
        mu=None, mu_dot=None,
        pos_X=False, pos_theta=False,
        prior_temperature=None,
        bayesian_sigma=True
    )
    magisolver.initialize_particles(k_0=400, dtype=torch.float32, device='cuda', init_sd=0.1, random_seed=None)

    optimizer = torch.optim.Adam
    optimizer_kwargs = {'params':True, 'lr':1e-2}
    
    X_result, theta_result, sigma_result, ss_dim = magisolver.solve(optimizer=optimizer, optimizer_kwargs=optimizer_kwargs,
                    max_iter=10, subspace_updates=1000, alpha=0.01, atol=5, rtol=0, monitor_convergence=True)
    
    end_time = time.time()

    x_err, t_err = model.evaluate(X_result, theta_result, sigma_result, magisolver.I, logmodel=False)
    total_time = end_time - start_time

    thetas_errs.append(t_err)
    x_errs.append(x_err)
    dims.append(ss_dim.item())
    times.append(total_time)

In [None]:
times

In [None]:
# average runtime
np.array(times).mean()

In [None]:
# average subspace dimension
np.array(dims).mean()

In [None]:
# prmse
(np.array(thetas_errs)**2).mean(axis=0)**0.5

In [None]:
# mtrmse
(((np.array(x_errs)**2).sum(axis=1) / magisolver.I.shape[0])**0.5).mean(axis=0)

In [None]:
fig, axes = plt.subplots(1, data.shape[1]-1, figsize=(10, 3))

for d, ax in enumerate(axes):
    traj = model.solution[:,d+1]
    if logmodel:
        traj = np.exp(traj)
    obs = data[:,d+1]
    if logmodel:
        obs = np.exp(obs)
    
    ax.plot(model.solution[:,0], traj, c='orange', zorder=1)
    ax.scatter(data[:,0], obs, c='r', zorder=2)

    ax.set_title(f"Component ${X_names[d]}$")
    ax.set_xlabel("Time")

fig.suptitle(f"Sparse {model_name} Data on Ground Truth")

plt.tight_layout()

In [None]:
plt.plot(magisolver.I, magisolver.x_init if not logmodel else np.exp(magisolver.x_init))
plt.scatter(data[:,0], data[:,1] if not logmodel else np.exp(data[:,1]))
plt.scatter(data[:,0], data[:,2] if not logmodel else np.exp(data[:,2]))

plt.plot(model.solution[:,0], model.solution[:,1:] if not logmodel else np.exp(model.solution[:,1:]))
plt.show()

In [None]:
X_result, theta_result, sigma_result = magisolver.from_svgd_vector(magisolver.particles)

In [None]:
X_true = model.solution.copy()
obs_data = data.copy()

if not logmodel:
    X_preds = X_result.cpu()
else:
    X_preds = torch.exp(X_result).cpu()
    X_true[:,1:] = np.exp(X_true[:,1:])
    obs_data[:,1:] = np.exp(obs_data[:,1:])
    
X_means = torch.mean(X_preds, axis=0)
X_quantiles = np.quantile(X_preds, [0.025, 0.975],axis=0)

ts = magisolver.I.flatten()

In [None]:
_, hist_axes = plt.subplots(1, p:=model.theta.shape[0], figsize=(10, 3))
for i in range(p):
    hist_axes[i].hist(theta_result[:,i].cpu(), bins=14)
    hist_axes[i].set_title(fr'$E[{theta_names[i]}] \approx {torch.mean(theta_result[:,i]):.4f}$')
    # print(f'{theta_labels[i]} ->', np.quantile(theta_result[:,i].cpu(), [0.025, 0.975]))

temp_theta = r"$\boldsymbol{\theta}$"
plt.suptitle(fr"pSVGD Posterior Distributions of {model_name} System ODE Parameters {temp_theta}")
plt.tight_layout()

In [None]:
D = X_means.shape[1]

_, traj_axes = plt.subplots(1, D, figsize=(10, 3))

for d in range(D):
    traj_axes[d].plot(ts, X_means[:,d], color='b', zorder=3, label="Posterior Mean")
    traj_axes[d].fill_between(x=ts, y1=X_quantiles[0,:,d], y2=X_quantiles[1,:,d], alpha=0.2, color='b', zorder=1, label="95% CI")
    traj_axes[d].plot(X_true[:,0], X_true[:,d+1], color='orange', zorder=2, label="Ground Truth")
    # traj_axes[d].scatter(data[:,0], obs_data[:,d+1], color='r', zorder=4, label="Observed Data")
    traj_axes[d].set_title(f'Component ${X_names[d]}$')
    traj_axes[d].set_xlabel('Time')

traj_axes[-1].legend(loc='center left', bbox_to_anchor=(1, 0.5))
temp_x = r"$\boldsymbol{X}$"
plt.suptitle(fr"pSVGD Posterior Distributions of {model_name} System ODE Trajectories {temp_x}")

plt.tight_layout()