In [None]:
from tvb.simulator.simulator import Simulator
from tvb.datatypes.connectivity import Connectivity
from tvb.contrib.inversion.sbiInference import sbiModel

import tvb.simulator.models
import tvb.simulator.integrators
import tvb.simulator.coupling
import tvb.simulator.monitors

import matplotlib.pyplot as plt
import numpy as np
import torch
import math
import arviz as az
import pickle

%load_ext autoreload
%autoreload 2

In [None]:
#run_ids = [
#    "2022-09-13_08-20-56-466197_instance",
#    "2022-09-13_08-20-56-525471_instance",
#    "2022-09-13_08-20-56-675430_instance",
#    "2022-09-13_08-20-56-756109_instance"
#]

#run_ids = [
#    "2022-09-13_08-49-00-122871_instance",
#    "2022-09-13_08-49-00-133581_instance",
#    "2022-09-13_08-49-00-218087_instance",
#    "2022-09-13_08-49-00-292717_instance"
#]

#run_ids = [
#    "2022-09-13_09-35-43-404460_instance",
#    "2022-09-13_09-35-43-411359_instance",
#    "2022-09-13_09-35-43-450581_instance",
#    "2022-09-13_09-35-43-452609_instance"
#]

#run_ids = [
#    "2022-09-13_16-37-13-003930_instance",
#    "2022-09-13_16-37-13-010276_instance",
#    "2022-09-13_16-37-13-112678_instance",
#    "2022-09-13_16-37-13-116596_instance"
#]

#run_ids = [
#    "2022-09-14_10-46-45-546821_instance",
#    "2022-09-14_10-46-45-606165_instance",
#    "2022-09-14_10-46-46-077166_instance",
#    "2022-09-14_10-46-46-130627_instance"
#]

#run_ids = [
#    "2022-09-14_12-52-25-374945_instance",
#    "2022-09-14_12-52-25-376720_instance",
#    "2022-09-14_12-52-25-377273_instance",
#    "2022-09-14_12-52-25-381972_instance"
#]

#run_ids = [
#    "2022-09-14_16-3-58-776480_instance",
#    "2022-09-14_16-3-58-777523_instance",
#    "2022-09-14_16-3-58-812755_instance",
#    "2022-09-14_16-3-58-824044_instance"
#]

#run_ids = [
#    "2022-09-20_14-19-51-374717_instance",
#    "2022-09-20_14-19-51-404129_instance",
#    "2022-09-20_14-19-51-416879_instance",
#    "2022-09-20_14-19-51-433541_instance"
#]

#run_ids = [
#    "2022-09-23_16-4-31-962206_instance",
#    "2022-09-23_16-4-31-984204_instance",
#    "2022-09-23_16-4-32-263625_instance",
#    "2022-09-23_16-4-32-330215_instance"
#]

run_ids = [
    "2022-09-27_13-18-07-815553_instance",
    "2022-09-27_13-18-07-815833_instance",
    "2022-09-27_13-18-07-823480_instance",
    "2022-09-27_13-18-07-824198_instance"
]

In [None]:
with open(f"sbi_data/inference_data/{run_ids[0]}.pkl", "rb") as f:
    instance_params = pickle.load(f)
    simulation_params = instance_params["simulation_params"]

In [None]:
# Model
oscillator_model = getattr(tvb.simulator.models, simulation_params["model"])(
    a=np.asarray([simulation_params["a_sim"]]),
    b=np.asarray([simulation_params["b_sim"]]),
    c=np.asarray([simulation_params["c_sim"]]),
    d=np.asarray([simulation_params["d_sim"]]),
    I=np.asarray([simulation_params["I_sim"]]),
)
oscillator_model.configure()

# Integrator
integrator = getattr(tvb.simulator.integrators, simulation_params["integrator"])(dt=simulation_params["dt"])
integrator.noise.nsig = np.array([simulation_params["nsig"]])
integrator.noise.configure()
integrator.noise.configure_white(dt=integrator.dt)
integrator.set_random_state(random_state=None)
integrator.configure()
integrator.configure_boundaries(oscillator_model)

In [None]:
X = instance_params["obs"]

In [None]:
f1 = plt.figure(figsize=(14,8))
plt.plot(X[:, 0, 0, 0], label="V")
plt.plot(X[:, 1, 0, 0], label="W")
plt.ylabel("states", fontsize=16)
plt.xlabel("time (ms)", fontsize=16)
plt.legend(fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.legend()
plt.show()

### SNPE inference

In [None]:
snpe_model = sbiModel(
    method="SNPE", 
    obs=X,
    model_instance=oscillator_model,
    integrator_instance=integrator
)

In [None]:
snpe_model.load(f"{run_ids[0]}.pkl")

In [None]:
print("Priors")
print("------")
for iprior in snpe_model.prior_keys:
    print(f"{iprior[2]}.{iprior[1]}:", "\t", "\t",
          "mean:", snpe_model.priors.loc.numpy()[iprior[0]], "\t", 
          "std:", np.diag(snpe_model.priors.scale_tril)[iprior[0]])

In [None]:
snpe_model.plot_posterior_samples(
    init_params={"a_model": simulation_params["a_sim"],
                 "b_model": simulation_params["b_sim"],
                 "c_model": simulation_params["c_sim"],
                 "nsig_integrator.noise": 0.003,
                 "epsilon_global": 0.0}
)

In [None]:
print("posterior std dev:", snpe_model.posterior_samples.std(dim=0).numpy())
print("prior std dev:", torch.diag(snpe_model.priors.scale_tril).numpy())
print("shrinkages:", snpe_model.posterior_shrinkage().numpy())

In [None]:
posterior_zscores = snpe_model.posterior_zscore(
    init_params={"a_model": simulation_params["a_sim"],
                 "b_model": simulation_params["b_sim"],
                 "c_model": simulation_params["c_sim"],
                 "nsig_integrator.noise": 0.003,
                 "epsilon_global": 0.0}
)

In [None]:
posterior_zscores

In [None]:
snpe_model.posterior_samples[:, 1].mean() - (-10.0)

In [None]:
snpe_model.posterior_samples[:, 1].std()

In [None]:
f2 = plt.figure(figsize=(12,8))
plt.plot(snpe_model.posterior_shrinkage(), posterior_zscores, 
         color="blue", linewidth=0, marker="*", markersize=12)
plt.xlabel("posterior shrinkage")
plt.ylabel("posterior zscore")
#plt.xlim([0.0, 1.1])
#plt.ylim([0.0, 1.1])
plt.plot();

In [None]:
posterior_sample = snpe_model.get_sample()[0]
posterior_sample

In [None]:
posterior_obs = snpe_model.simulation_wrapper(params=posterior_sample)
posterior_obs = posterior_obs.numpy().reshape(X.shape, order="F")

In [None]:
f3, axes3 = plt.subplots(nrows=2, ncols=1, figsize=(18,15))
axes3[0].plot(snpe_model.inference_data.observed_data.x_obs.values[:, 0, 0, 0], label="V_observed", color="blue")
axes3[0].plot(posterior_obs[:, 0, 0, 0], label="V_posterior", color="cyan")
axes3[0].legend(fontsize=16)
axes3[0].set_xlabel("time (ms)", fontsize=16)
axes3[0].tick_params(axis="both", labelsize=16)

axes3[1].plot(snpe_model.inference_data.observed_data.x_obs.values[:, 1, 0, 0], label="W_observed", color="red")
axes3[1].plot(posterior_obs[:, 1, 0, 0], label="W_posterior", color="orange")
axes3[1].legend(fontsize=16)
axes3[1].set_xlabel("time (ms)", fontsize=16)
axes3[1].tick_params(axis="both", labelsize=16)

plt.show()

In [None]:
#snpe_model.save()