In [None]:
from tvb.simulator.models.oscillator import Generic2dOscillator
from tvb.simulator.integrators import HeunStochastic
import matplotlib.pyplot as plt
import numpy as np
import torch
import math
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer
import arviz as az
from datetime import datetime
%load_ext autoreload
%autoreload 2

In [None]:
# Simulation parameters
a_sim = 2.0
b_sim = -10.0
c_sim = 0.0
d_sim = 0.02
I_sim = 0.0
nsig = 0.003
dt = 0.1
simulation_length = 300

In [None]:
# TVB model and integrator setup
oscillator_model = Generic2dOscillator(
    a=np.asarray([a_sim]),
    b=np.asarray([b_sim]),
    c=np.asarray([c_sim]),
    d=np.asarray([d_sim]),
    I=np.asarray([I_sim]),
)
oscillator_model.configure()

integrator = HeunStochastic(dt=dt)
integrator.noise.nsig = np.array([nsig])
integrator.noise.configure()
integrator.noise.configure_white(dt=integrator.dt)
integrator.set_random_state(random_state=None)
integrator.configure()
integrator.configure_boundaries(oscillator_model)

In [None]:
X = np.load("limit-cycle_simulation.npy")

In [None]:
f1 = plt.figure(figsize=(14,8))
plt.plot(X[:, 0, 0, 0], label="V")
plt.plot(X[:, 1, 0, 0], label="W")
plt.ylabel("states")
plt.legend()
plt.show()

In [None]:
obs = X

priors = {
    "a": [1.8, 2.2, False],
    "b": [-10.3, -9.7, False],
    "c": [-0.1, 0.1, False],
    "d": [0.01, 0.03, False],
    "I": [-0.1, 0.1, False],
    "epsilon": [0.0, 0.01, False]
}

shape = X.shape

In [None]:
from tvb.contrib.inversion.sbiInference import sbiModel

### SNPE inference

In [None]:
snpe_model = sbiModel(
    integrator_instance=integrator, 
    model_instance=oscillator_model, 
    method="SNPE", 
    obs=obs, 
    priors=priors,
    obs_shape=shape
)

In [None]:
snpe_model.run_inference(
    num_simulations=800, 
    num_workers=1, 
    num_samples=2000
)

In [None]:
inference_data = snpe_model.to_arviz_data(save=True)

In [None]:
snpe_model.plot_posterior_samples(
    init_params={"a": a_sim, "b": b_sim, "c": c_sim, "d": d_sim, "I": I_sim, "epsilon": 0.0}
)

In [None]:
map_estimator = snpe_model.get_map_estimator()
map_estimator

In [None]:
posterior_sample = snpe_model.get_sample()[0]
posterior_sample

In [None]:
posterior_obs = snpe_model.simulation_wrapper(params=posterior_sample)
posterior_obs = posterior_obs.numpy().reshape(shape, order="F")

In [None]:
f3 = plt.figure(figsize=(13,8))
plt.plot(snpe_model.inference_data.observed_data.x_obs.values[:, 0, 0, 0], label="V_simulated", color="blue")
plt.plot(snpe_model.inference_data.observed_data.x_obs.values[:, 1, 0, 0], label="W_simulated", color="red")
plt.plot(posterior_obs[:, 0, 0, 0], label="V_posterior", color="cyan")
plt.plot(posterior_obs[:, 1, 0, 0], label="W_posterior", color="orange")
plt.legend()
plt.xlabel("time (ms)")
plt.ylabel("states")
plt.show()

In [None]:
snpe_model.information_criteria()

### SNLE inference

In [None]:
snle_model = sbiModel(
    integrator_instance=integrator, 
    model_instance=oscillator_model, 
    method="SNLE", 
    obs=obs, 
    priors=priors,
    obs_shape=shape
)

In [None]:
snle_model.run_inference(
    num_simulations=800,
    num_workers=1,
    num_samples=2000
)

In [None]:
inference_data = snle_model.to_arviz_data(save=True)

In [None]:
snle_model.plot_posterior_samples(
    init_params={"a": a_sim, "b": b_sim, "c": c_sim, "d": d_sim, "I": I_sim, "epsilon": 0.0},
    bins=50
)

In [None]:
map_estimator = snle_model.get_map_estimator()
map_estimator

In [None]:
posterior_sample = snle_model.get_sample()[0]
posterior_sample

In [None]:
posterior_obs = snle_model.simulation_wrapper(params=posterior_sample)
posterior_obs = posterior_obs.numpy().reshape(shape, order="F")

In [None]:
f4 = plt.figure(figsize=(13,8))
plt.plot(snle_model.inference_data.observed_data.x_obs.values[:, 0, 0, 0], label="V_simulated", color="blue")
plt.plot(snle_model.inference_data.observed_data.x_obs.values[:, 1, 0, 0], label="W_simulated", color="red")
plt.plot(posterior_obs[:, 0, 0, 0], label="V_posterior", color="cyan")
plt.plot(posterior_obs[:, 1, 0, 0], label="W_posterior", color="orange")
plt.legend()
plt.xlabel("time (ms)")
plt.ylabel("states")
plt.show()

In [None]:
snle_model.information_criteria()

### SNRE inference

In [None]:
snre_model = sbiModel(
    integrator_instance=integrator, 
    model_instance=oscillator_model, 
    method="SNRE", 
    obs=obs, 
    priors=priors,
    obs_shape=shape
)

In [None]:
snre_model.run_inference(
    num_simulations=800,
    num_workers=1,
    num_samples=2000
)

In [None]:
inference_data = snre_model.to_arviz_data(save=True)

In [None]:
snre_model.plot_posterior(
    init_params={"a": a_sim, "b": b_sim, "c": c_sim, "d": d_sim, "I": I_sim, "epsilon": 0.0}
)

In [None]:
map_estimator = snre_model.get_map_estimator()
map_estimator

In [None]:
posterior_sample = snre_model.get_sample()[0]
posterior_sample

In [None]:
posterior_obs = snre_model.simulation_wrapper(params=map_estimator)
posterior_obs = posterior_obs.numpy().reshape(shape, order="F")

In [None]:
f4 = plt.figure(figsize=(13,8))
plt.plot(snre_model.inference_data.observed_data.x_obs.values[:, 0, 0, 0], label="V_simulated", color="blue")
plt.plot(snre_model.inference_data.observed_data.x_obs.values[:, 1, 0, 0], label="W_simulated", color="red")
plt.plot(posterior_obs[:, 0, 0, 0], label="V_posterior", color="cyan")
plt.plot(posterior_obs[:, 1, 0, 0], label="W_posterior", color="orange")
plt.legend()
plt.xlabel("time (ms)")
plt.ylabel("states")
plt.show()

In [None]:
snre_model.information_criteria()