In [None]:
from tvb.simulator.simulator import Simulator
from tvb.datatypes.connectivity import Connectivity
from tvb.contrib.inversion.sbiInference import sbiModel

import tvb.simulator.models
import tvb.simulator.integrators
import tvb.simulator.coupling
import tvb.simulator.monitors

import matplotlib.pyplot as plt
import numpy as np
import torch
import math
import arviz as az
import pickle

%load_ext autoreload
%autoreload 2

In [None]:
with open('../limit-cycle_simulation.pkl', 'rb') as f:
    simulation_params = pickle.load(f)

In [None]:
# Connectivity
connectivity = Connectivity()
connectivity.weights = np.array([[0., 2/3], [2/3, 0.]])
connectivity.region_labels = np.array(["R1", "R2"])
connectivity.centres = np.array([[0.1, 0.1, 0.1], [0.2, 0.1, 0.1]])
connectivity.tract_lengths = np.array([[0., 0.1], [0.1, 0.]])
connectivity.configure()

# Model
oscillator_model = getattr(tvb.simulator.models, simulation_params["model"])(
    a=np.asarray([simulation_params["a_sim"]]),
    b=np.asarray([simulation_params["b_sim"]]),
    c=np.asarray([simulation_params["c_sim"]]),
    d=np.asarray([simulation_params["d_sim"]]),
    I=np.asarray([simulation_params["I_sim"]]),
)
oscillator_model.configure()

# Integrator
integrator = getattr(tvb.simulator.integrators, simulation_params["integrator"])(dt=simulation_params["dt"])
integrator.noise.nsig = np.array([simulation_params["nsig"]])
integrator.configure()

# Global coupling
coupling = getattr(tvb.simulator.coupling, simulation_params["coupling"])()

# Monitor
monitor = getattr(tvb.simulator.monitors, simulation_params["monitor"])()

In [None]:
# Simulator
sim = Simulator(
    model=oscillator_model,
    connectivity=connectivity,
    coupling=coupling,
    integrator=integrator,
    monitors=(monitor,),
    simulation_length=simulation_params["simulation_length"]
)

sim.configure()

In [None]:
X = simulation_params["simulation"]

In [None]:
f1 = plt.figure(figsize=(14,8))
plt.plot(X[:, 0, 0, 0], label="R1")
plt.plot(X[:, 0, 1, 0], label="R2")
plt.ylabel("states")
plt.legend()
plt.show()

In [None]:
obs = X

#priors = {
#    "a": [1.8, 2.2, False],
#    "b": [-10.3, -9.7, False],
#    "c": [-0.1, 0.1, False],
#    "d": [0.01, 0.03, False],
#    "I": [-0.1, 0.1, False],
#    "epsilon": [0.0, 0.01, False]
#}

#priors = {
#    "a": [2.0, 0.1, False],
#    "b": [-10, 0.1, False],
#    "c": [0.0, 0.05, False],
#    "d": [0.02, 0.005, False],
#    "I": [0.0, 0.05, False],
#    "epsilon": [0.0, 0.01, False]
#}

#prior_vars = {
#    "a": [2.0, 1.0, "model"],
#    "b": [-10.0, 1.0, "model"],
#    "A": [0.1, 0.05, "coupling"],
#    "epsilon": [0.0, 0.01, "gloabl"]
#}

#prior_vars = {
#    "a": {"mean": 2.0, "sd": 1.0, "for": "model"},
#    "b": {"mean": -10.0, "sd": 1.0, "for": "model"},
#    "A": {"mean": 0.1, "sd": 0.05, "for": "coupling"},
#    "epsilon": {"mean": 0.0, "sd": 0.01, "for": "global"}
#}

prior_vars = {
    "model": {
        "a": [2.0, 0.5],
        "b": [-10.0, 0.5]
    },
    "coupling": {
        "a": [0.1, 0.075]
    },
    "integrator.noise" :{
        "nsig": [0.003, 0.002]
    },
    "global": {
        "epsilon": [0.0, 0.1]
    },
}

shape = X.shape

### SNPE inference

In [None]:
snpe_model = sbiModel(
    simulator_instance=sim,
    method="SNPE", 
    obs=obs
)

In [None]:
snpe_model.run_inference(
    prior_vars=prior_vars,
    prior_dist="Uniform",
    num_simulations=600, 
    num_workers=4, 
    num_samples=1000,
    neural_net="mdn"
)

In [None]:
inference_data = snpe_model.to_arviz_data(num_workers=4, save=True)

In [None]:
snpe_model.plot_posterior_samples(
    init_params={"a_model": 2.0, 
                 "b_model": -10.0, 
                 "a_coupling": 0.1, 
                 "epsilon_global": 0.0,
                 "amplitude_global": 0.0,
                 "offset_global": 0.0,
                 "nsig_integrator.noise": 0.003}
)

In [None]:
map_estimator = snpe_model.get_map_estimator()
map_estimator

In [None]:
posterior_sample = snpe_model.get_sample()[0]
posterior_sample

In [None]:
posterior_obs = snpe_model.simulation_wrapper(params=posterior_sample)
posterior_obs = posterior_obs.numpy().reshape(shape, order="F")

In [None]:
f3 = plt.figure(figsize=(13,8))
plt.plot(snpe_model.inference_data.observed_data.x_obs.values[:, 0, 0, 0], label="R1_observed", color="blue")
plt.plot(snpe_model.inference_data.observed_data.x_obs.values[:, 0, 1, 0], label="R2_observed", color="red")
plt.plot(posterior_obs[:, 0, 0, 0], label="R1_posterior", color="cyan")
plt.plot(posterior_obs[:, 0, 1, 0], label="R2_posterior", color="orange")
plt.legend()
plt.xlabel("time (ms)")
plt.ylabel("states")
plt.show()

In [None]:
snpe_model.information_criteria()

In [None]:
snpe_model.save(simulation_params=simulation_params.copy())

### SNLE inference

In [None]:
snle_model = sbiModel(
    simulator_instance=sim,
    method="SNLE", 
    obs=obs, 
    prior_vars=priors,
    prior_dist="Normal",
)

In [None]:
snle_model.run_inference(
    num_simulations=800,
    num_workers=4,
    num_samples=2000
)

In [None]:
inference_data = snle_model.to_arviz_data(save=True)

In [None]:
snle_model.plot_posterior_samples(
    init_params={"a": a_sim, "b": b_sim, "c": c_sim, "d": d_sim, "I": I_sim, "epsilon": 0.0},
    bins=50
)

In [None]:
map_estimator = snle_model.get_map_estimator()
map_estimator

In [None]:
posterior_sample = snle_model.get_sample()[0]
posterior_sample

In [None]:
posterior_obs = snle_model.simulation_wrapper(params=posterior_sample)
posterior_obs = posterior_obs.numpy().reshape(shape, order="F")

In [None]:
f4 = plt.figure(figsize=(13,8))
plt.plot(snle_model.inference_data.observed_data.x_obs.values[:, 0, 0, 0], label="V_simulated", color="blue")
plt.plot(snle_model.inference_data.observed_data.x_obs.values[:, 1, 0, 0], label="W_simulated", color="red")
plt.plot(posterior_obs[:, 0, 0, 0], label="V_posterior", color="cyan")
plt.plot(posterior_obs[:, 1, 0, 0], label="W_posterior", color="orange")
plt.legend()
plt.xlabel("time (ms)")
plt.ylabel("states")
plt.show()

In [None]:
snle_model.information_criteria()

### SNRE inference

In [None]:
snre_model = sbiModel(
    integrator_instance=integrator, 
    model_instance=oscillator_model, 
    method="SNRE", 
    obs=obs, 
    priors=priors,
    obs_shape=shape
)

In [None]:
snre_model.run_inference(
    num_simulations=800,
    num_workers=1,
    num_samples=2000
)

In [None]:
inference_data = snre_model.to_arviz_data(save=True)

In [None]:
snre_model.plot_posterior(
    init_params={"a": a_sim, "b": b_sim, "c": c_sim, "d": d_sim, "I": I_sim, "epsilon": 0.0}
)

In [None]:
map_estimator = snre_model.get_map_estimator()
map_estimator

In [None]:
posterior_sample = snre_model.get_sample()[0]
posterior_sample

In [None]:
posterior_obs = snre_model.simulation_wrapper(params=map_estimator)
posterior_obs = posterior_obs.numpy().reshape(shape, order="F")

In [None]:
f4 = plt.figure(figsize=(13,8))
plt.plot(snre_model.inference_data.observed_data.x_obs.values[:, 0, 0, 0], label="V_simulated", color="blue")
plt.plot(snre_model.inference_data.observed_data.x_obs.values[:, 1, 0, 0], label="W_simulated", color="red")
plt.plot(posterior_obs[:, 0, 0, 0], label="V_posterior", color="cyan")
plt.plot(posterior_obs[:, 1, 0, 0], label="W_posterior", color="orange")
plt.legend()
plt.xlabel("time (ms)")
plt.ylabel("states")
plt.show()

In [None]:
snre_model.information_criteria()