In [None]:
from tvb.simulator.simulator import Simulator
from tvb.datatypes.connectivity import Connectivity
from tvb.contrib.inversion.sbiInference import sbiModel, sbiPrior

import tvb.simulator.models
import tvb.simulator.integrators
import tvb.simulator.coupling
import tvb.simulator.monitors

import matplotlib.pyplot as plt
import numpy as np
import torch
import math
import arviz as az
import pickle

%load_ext autoreload
%autoreload 2

In [None]:
#run_ids = [
#    "2022-11-16_09-38-21-545432_instance",
#    "2022-11-16_09-38-21-555473_instance",
#    "2022-11-16_09-38-21-642074_instance",
#    "2022-11-16_09-38-21-707618_instance"
#]

#run_ids = [
#    "2022-11-16_09-40-42-780023_instance",
#    "2022-11-16_09-40-42-786705_instance",
#    "2022-11-16_09-40-42-813028_instance",
#    "2022-11-16_09-40-42-814468_instance"
#]

run_ids = [
    "2022-11-22_16-45-06-670747_instance",
    "2022-11-22_16-45-44-588078_instance",
    "2022-11-23_09-50-33-332108_instance",
    "2022-11-23_10-1-59-508195_instance",
    "2022-11-23_10-39-43-303435_instance",
    "2022-11-23_22-43-43-610772_instance",
    "2022-11-23_23-19-27-526335_instance",
    "2022-11-24_09-30-46-725224_instance",
    "2022-11-24_10-42-57-333240_instance",
    "2022-11-24_14-25-03-153868_instance",
    "2022-11-24_14-25-03-154492_instance",
    "2022-11-24_16-46-25-876713_instance",
    "2022-11-24_16-48-21-754821_instance",
    "2022-11-25_10-1-29-944905_instance",
    "2022-11-25_10-2-04-719592_instance",
    "2022-11-25_16-33-24-503044_instance",
    "2022-11-25_16-34-14-974959_instance",
    "2022-11-25_20-34-23-718078_instance",
    "2022-11-25_20-36-32-281698_instance",
    "2022-11-27_13-47-55-965479_instance",
    "2022-11-27_13-48-56-280657_instance",
    "2022-11-28_14-20-45-301958_instance",
    "2022-11-28_14-24-33-141601_instance"
]

In [None]:
with open(f"sbi_data/inference_data/{run_ids[-2]}.pkl", "rb") as f:
    instance_params = pickle.load(f)
    simulation_params = instance_params["simulation_params"]

In [None]:
# Connectivity
if simulation_params["connectivity"] == "Own":
    connectivity = Connectivity()
    connectivity.weights = np.array([[0., 2.], [2., 0.]])
    connectivity.region_labels = np.array(["R1", "R2"])
    connectivity.centres = np.array([[0.1, 0.1, 0.1], [0.2, 0.1, 0.1]])
    connectivity.tract_lengths = np.array([[0., 2.5], [2.5, 0.]])
    # connectivity.configure()

# Model
oscillator_model = getattr(tvb.simulator.models, simulation_params["model"])(
    a=np.asarray([simulation_params["a_sim"]]),
    b=np.asarray([simulation_params["b_sim"]]),
    c=np.asarray([simulation_params["c_sim"]]),
    d=np.asarray([simulation_params["d_sim"]]),
    I=np.asarray([simulation_params["I_sim"]]),
)
# oscillator_model.configure()

# Integrator
integrator = getattr(tvb.simulator.integrators, simulation_params["integrator"])(dt=simulation_params["dt"])
integrator.noise.nsig = np.array([simulation_params["nsig"]])
# integrator.configure()

# Global coupling
coupling = getattr(tvb.simulator.coupling, simulation_params["coupling"])()

# Monitor
monitor = getattr(tvb.simulator.monitors, simulation_params["monitor"])()

In [None]:
# Simulator
sim = Simulator(
    model=oscillator_model,
    connectivity=connectivity,
    coupling=coupling,
    integrator=integrator,
    monitors=(monitor,),
    simulation_length=simulation_params["simulation_length"]
)

# sim.configure();

In [None]:
X = instance_params["obs"]

In [None]:
f1 = plt.figure(figsize=(14,8))
plt.plot(X[:, 0, 0, 0], label="Region 1")
plt.plot(X[:, 0, 1, 0], label="Region 2")
plt.ylabel("states", fontsize=16)
plt.xlabel("time (ms)", fontsize=16)
plt.legend(fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.show()

### SNPE inference

In [None]:
snpe_model = sbiModel(
    method="SNPE",
    obs=X,
    simulator_instance=sim
)

In [None]:
# run_id = "2022-11-15_15-15-37-401909_instance"
# run_id = "2022-11-15_14-47-44-012410_instance"
# run_id = "2022-11-15_16-57-04-711747_instance"
# run_id = "2022-11-15_17-11-07-890014_instance"

snpe_model.load(f"{run_ids[-2]}.pkl")

In [None]:
print("Priors\n------")
for iprior in snpe_model.prior.identifier:
    print(f"{iprior[2]}.{iprior[1]}:", "\t", "\t",
          "mean:", np.array(snpe_model.prior.location)[iprior[0]], "\t",
          "std:", np.array(snpe_model.prior.scale)[iprior[0]])

In [None]:
snpe_model.plot_posterior_samples(
    init_params={"a_model": simulation_params["a_sim"], 
                 "b_model": simulation_params["b_sim"],
                 "c_model": simulation_params["c_sim"],
                 #"d_model": simulation_params["d_sim"],
                 "I_model": simulation_params["I_sim"],
                 "a_coupling": 0.1,
                 "nsig_integrator.noise": 0.003,
                 "noise_global": 0.0}
)

In [None]:
print("posterior std dev:", snpe_model.posterior_samples.std(dim=0).numpy())
print("prior std dev:", torch.diag(snpe_model.priors.scale_tril).numpy())
print("shrinkages:", snpe_model.posterior_shrinkage().numpy())

In [None]:
posterior_zscores = snpe_model.posterior_zscore(
    init_params={"a_model": simulation_params["a_sim"],
                 #"b_model": simulation_params["b_sim"],
                 # "c_model": simulation_params["c_sim"],
                 # "d_model": simulation_params["d_sim"],
                 # "I_model": simulation_params["I_sim"],
                 "a_coupling": 0.1,
                 "nsig_integrator.noise": 0.003,
                 "epsilon_global": 0.0}
)
posterior_zscores

In [None]:
posterior_shrinkages = snpe_model.posterior_shrinkage()
posterior_shrinkages

In [None]:
f2 = plt.figure(figsize=(12,8))
plt.plot(posterior_shrinkages, posterior_zscores,
         color="blue", linewidth=0, marker="*", markersize=12)
plt.xlabel("posterior shrinkage")
plt.ylabel("posterior zscore")
plt.xlim([-1.6, 1.1])
plt.ylim([0.0, 2])
plt.show();

In [None]:
posterior_sample = snpe_model.get_sample()[0]
posterior_sample

In [None]:
posterior_sample = torch.tensor([2.0, -10.0, 0.0, 0.0, 0.1, 0.003, 0.00001])
posterior_sample

In [None]:
posterior_obs = snpe_model.simulation_wrapper(params=posterior_sample)
posterior_obs = posterior_obs.numpy().reshape(X.shape, order="F")

In [None]:
f3, axes3 = plt.subplots(nrows=2, ncols=1, figsize=(18,15))
axes3[0].plot(snpe_model.inference_data.observed_data.x_obs.values[:, 0, 0, 0], label="R1_observed", color="blue")
axes3[0].plot(posterior_obs[:, 0, 0, 0], label="R1_posterior", color="cyan")
axes3[0].legend(fontsize=16)
axes3[0].set_xlabel("time (ms)", fontsize=16)
axes3[0].tick_params(axis="both", labelsize=16)

axes3[1].plot(snpe_model.inference_data.observed_data.x_obs.values[:, 0, 1, 0], label="R2_observed", color="red")
axes3[1].plot(posterior_obs[:, 0, 1, 0], label="R2_posterior", color="orange")
axes3[1].legend(fontsize=16)
axes3[1].set_xlabel("time (ms)", fontsize=16)
axes3[1].tick_params(axis="both", labelsize=16)

plt.show()

In [None]:
#snpe_model.save()