In [None]:
from tvb.simulator.simulator import Simulator
from tvb.datatypes.connectivity import Connectivity
from tvb.contrib.inversion.sbiInference import sbiModel

import tvb.simulator.models
import tvb.simulator.integrators
import tvb.simulator.coupling
import tvb.simulator.monitors

import matplotlib.pyplot as plt
import numpy as np
import torch
import math
import arviz as az
import pickle

%load_ext autoreload
%autoreload 2

In [None]:
run_ids = [
    ""
]

In [None]:
with open(f"sbi_data/inference_data/{run_ids[0]}.pkl", "rb") as f:
    instance_params = pickle.load(f)
    simulation_params = instance_params["simulation_params"]

In [None]:
# Model
oscillator_model = getattr(tvb.simulator.models, simulation_params["model"])(
    a=np.asarray([simulation_params["a_sim"]]),
    b=np.asarray([simulation_params["b_sim"]]),
    c=np.asarray([simulation_params["c_sim"]]),
    d=np.asarray([simulation_params["d_sim"]]),
    I=np.asarray([simulation_params["I_sim"]]),
)
oscillator_model.configure()

# Integrator
integrator = getattr(tvb.simulator.integrators, simulation_params["integrator"])(dt=simulation_params["dt"])
integrator.noise.nsig = np.array([simulation_params["nsig"]])
integrator.noise.configure()
integrator.noise.configure_white(dt=integrator.dt)
integrator.set_random_state(random_state=None)
integrator.configure()
integrator.configure_boundaries(oscillator_model)

In [None]:
X = instance_params["obs"]

In [None]:
f1 = plt.figure(figsize=(14,8))
plt.plot(X[:, 0, 0, 0], label="Region 1")
plt.plot(X[:, 0, 1, 0], label="Region 2")
plt.ylabel("states", fontsize=16)
plt.xlabel("time (ms)", fontsize=16)
plt.legend(fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.show()

### SNPE inference

In [None]:
snpe_model = sbiModel(
    method="SNPE", 
    obs=X,
    model_instance=oscillator_model,
    integrator_instance=integrator
)

In [None]:
snpe_model.load(f"{run_ids[0]}.pkl")

In [None]:
snpe_model.plot_posterior_samples(
    init_params={"a_model": simulation_params["a_sim"],
                 "nsig_integrator.noise": 0.003,
                 "epsilon_global": 0.0}
)

In [None]:
print("posterior std dev:", snpe_model.posterior_samples.std(dim=0).numpy())
print("prior std dev:", torch.diag(snpe_model.priors.scale_tril).numpy())
print("shrinkages:", snpe_model.posterior_shrinkage().numpy())

In [None]:
f2 = plt.figure(figsize=(12,8))
plt.plot(snpe_model.posterior_shrinkage(), snpe_model.posterior_zscore(), 
         color="blue", linewidth=0, marker="*", markersize=12)

In [None]:
posterior_sample = snpe_model.get_sample()[0]
posterior_sample

In [None]:
posterior_obs = snpe_model.simulation_wrapper(params=posterior_sample)
posterior_obs = posterior_obs.numpy().reshape(X.shape, order="F")

In [None]:
f3, axes3 = plt.subplots(nrows=2, ncols=1, figsize=(18,15))
axes3[0].plot(snpe_model.inference_data.observed_data.x_obs.values[:, 0, 0, 0], label="R1_observed", color="blue")
axes3[0].plot(posterior_obs[:, 0, 0, 0], label="R1_posterior", color="cyan")
axes3[0].legend(fontsize=16)
axes3[0].set_xlabel("time (ms)", fontsize=16)
axes3[0].tick_params(axis="both", labelsize=16)

axes3[1].plot(snpe_model.inference_data.observed_data.x_obs.values[:, 0, 1, 0], label="R2_observed", color="red")
axes3[1].plot(posterior_obs[:, 0, 1, 0], label="R2_posterior", color="orange")
axes3[1].legend(fontsize=16)
axes3[1].set_xlabel("time (ms)", fontsize=16)
axes3[1].tick_params(axis="both", labelsize=16)

plt.show()

In [None]:
#snpe_model.save()

### SNLE inference

In [None]:
snle_model = sbiModel(
    simulator_instance=sim,
    method="SNLE", 
    obs=obs, 
    prior_vars=priors,
    prior_dist="Normal",
)

In [None]:
snle_model.run_inference(
    num_simulations=800,
    num_workers=4,
    num_samples=2000
)

In [None]:
inference_data = snle_model.to_arviz_data(save=True)

In [None]:
snle_model.plot_posterior_samples(
    init_params={"a": a_sim, "b": b_sim, "c": c_sim, "d": d_sim, "I": I_sim, "epsilon": 0.0},
    bins=50
)

In [None]:
map_estimator = snle_model.get_map_estimator()
map_estimator

In [None]:
posterior_sample = snle_model.get_sample()[0]
posterior_sample

In [None]:
posterior_obs = snle_model.simulation_wrapper(params=posterior_sample)
posterior_obs = posterior_obs.numpy().reshape(shape, order="F")

In [None]:
f4 = plt.figure(figsize=(13,8))
plt.plot(snle_model.inference_data.observed_data.x_obs.values[:, 0, 0, 0], label="V_simulated", color="blue")
plt.plot(snle_model.inference_data.observed_data.x_obs.values[:, 1, 0, 0], label="W_simulated", color="red")
plt.plot(posterior_obs[:, 0, 0, 0], label="V_posterior", color="cyan")
plt.plot(posterior_obs[:, 1, 0, 0], label="W_posterior", color="orange")
plt.legend()
plt.xlabel("time (ms)")
plt.ylabel("states")
plt.show()

In [None]:
snle_model.information_criteria()

### SNRE inference

In [None]:
snre_model = sbiModel(
    integrator_instance=integrator, 
    model_instance=oscillator_model, 
    method="SNRE", 
    obs=obs, 
    priors=priors,
    obs_shape=shape
)

In [None]:
snre_model.run_inference(
    num_simulations=800,
    num_workers=1,
    num_samples=2000
)

In [None]:
inference_data = snre_model.to_arviz_data(save=True)

In [None]:
snre_model.plot_posterior(
    init_params={"a": a_sim, "b": b_sim, "c": c_sim, "d": d_sim, "I": I_sim, "epsilon": 0.0}
)

In [None]:
map_estimator = snre_model.get_map_estimator()
map_estimator

In [None]:
posterior_sample = snre_model.get_sample()[0]
posterior_sample

In [None]:
posterior_obs = snre_model.simulation_wrapper(params=map_estimator)
posterior_obs = posterior_obs.numpy().reshape(shape, order="F")

In [None]:
f4 = plt.figure(figsize=(13,8))
plt.plot(snre_model.inference_data.observed_data.x_obs.values[:, 0, 0, 0], label="V_simulated", color="blue")
plt.plot(snre_model.inference_data.observed_data.x_obs.values[:, 1, 0, 0], label="W_simulated", color="red")
plt.plot(posterior_obs[:, 0, 0, 0], label="V_posterior", color="cyan")
plt.plot(posterior_obs[:, 1, 0, 0], label="W_posterior", color="orange")
plt.legend()
plt.xlabel("time (ms)")
plt.ylabel("states")
plt.show()

In [None]:
snre_model.information_criteria()