In [None]:
from tvb.simulator.simulator import Simulator
from tvb.datatypes.connectivity import Connectivity
from tvb.contrib.inversion.pymcInference import pymcModel1node

import tvb.simulator.models
import tvb.simulator.integrators
import tvb.simulator.coupling
import tvb.simulator.monitors

import matplotlib.pyplot as plt
import numpy as np
import arviz as az
import pymc3 as pm
import scipy
import theano.tensor as tt
import theano
import math
from tqdm import tqdm
import pickle

%load_ext autoreload
%autoreload 2

In [None]:
run_ids = [
    "2022-09-11_1858_instance",
    "2022-09-12_0858_instance",
    "2022-09-12_2353_instance",
    "2022-09-13_1420_instance",
    "2022-09-14_1038_instance",
    "2022-09-16_0905_instance",
    "2022-09-20_1251_instance",
    "2022-09-23_1555_instance",
    "2022-09-27_1313_instance",
    "2022-09-29_1238_instance", #a, b shifted
    "2022-09-30_2201_instance", #a, b around ground truth
    "2022-10-01_1007_instance", #a, b, c shifted
    "2022-10-03_0751_instance", #a, b around ground truth
    "2022-10-03_1129_instance", #a, b, c around ground truth
    "2022-10-04_0941_instance", #a, b shifted
    "2022-10-08_1034_instance",
    "2022-10-08_1039_instance",
    "2022-10-13_1454_instance",
    "2022-10-14_1402_instance", #a, b, c, I around ground truth
    "2022-10-15_1056_instance", #a, b, c, I shifted
    "2022-10-16_1149_instance",
    "2022-10-17_1226_instance",
    "2022-10-18_1002_instance",
    "2022-10-19_2353_instance",
]

In [None]:
idx = -1

In [None]:
with open(f"pymc_data/inference_data/{run_ids[idx]}.pkl", "rb") as f:
    instance_params = pickle.load(f)
    simulation_params = instance_params["simulation_params"]

In [None]:
# Model
oscillator_model = getattr(tvb.simulator.models, simulation_params["model"])(
    a=np.asarray([simulation_params["a_sim"]]),
    b=np.asarray([simulation_params["b_sim"]]),
    c=np.asarray([simulation_params["c_sim"]]),
    d=np.asarray([simulation_params["d_sim"]]),
    I=np.asarray([simulation_params["I_sim"]]),
)
oscillator_model.configure()

# Integrator
integrator = getattr(tvb.simulator.integrators, simulation_params["integrator"])(dt=simulation_params["dt"])
integrator.noise.nsig = np.array([simulation_params["nsig"]])
integrator.noise.configure()
integrator.noise.configure_white(dt=integrator.dt)
integrator.set_random_state(random_state=None)
integrator.configure()
integrator.configure_boundaries(oscillator_model)

In [None]:
X = instance_params["obs"]

In [None]:
oscillator_model

In [None]:
f1 = plt.figure(figsize=(14,8))
plt.plot(X[:, 0, 0, 0], label="V")
plt.plot(X[:, 1, 0, 0], label="W")
plt.ylabel("state (a.u.)", fontsize=16)
plt.xlabel("time (ms)", fontsize=16)
plt.legend(fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.show()

In [None]:
pymc_model = pymcModel1node(oscillator_model)

In [None]:
pymc_model.load(f"{run_ids[idx]}.pkl")

In [None]:
pymc_model.prior_stats

In [None]:
#pymc_model.plot_posterior_samples(
#    init_params={"a": simulation_params["a_sim"],
#                 #"b": simulation_params["b_sim"],
#                 "epsilon": 0.0,
#                 "noise": np.sqrt(2 * simulation_params["nsig"])
#                }
#)

In [None]:
pymc_model.plot_posterior_samples(
    init_params={"model_a": simulation_params["a_sim"],
                 "model_b": simulation_params["b_sim"],
                 "model_c": simulation_params["c_sim"],
                 #"model_d": simulation_params["d_sim"],
                 "model_I": simulation_params["I_sim"],
                 #"model_tau": oscillator_model.tau[0],
                 "global_noise": 0.0,
                 "noise_gfun": np.sqrt(2 * simulation_params["nsig"])
                }
)

In [None]:
s = pymc_model.posterior_shrinkage()
z = pymc_model.posterior_zscore(init_params={
    "model_a": simulation_params["a_sim"],
    "model_b": simulation_params["b_sim"],
    "model_c": simulation_params["c_sim"],
    #"model_d": simulation_params["d_sim"],
    "model_I": simulation_params["I_sim"],
    #"model_tau": oscillator_model.tau[0],
    "global_noise": 0.0,
    "noise_gfun": np.sqrt(2 * simulation_params["nsig"])
})

In [None]:
#s = pymc_model.posterior_shrinkage()
#z = pymc_model.posterior_zscore(init_params={
#    "a": simulation_params["a_sim"],
#    "epsilon": 0.0,
#    "noise": np.sqrt(2 * simulation_params["nsig"])
#})

In [None]:
f2 = plt.figure(figsize=(12,8))
plt.plot(s, z, color="blue", linewidth=0, marker="*", markersize=12)
plt.xlabel("posterior shrinkage", fontsize=16)
plt.ylabel("posterior z-score", fontsize=16)
plt.xlim([0.0, 1.1])
#plt.ylim([0.0, 5])
plt.tick_params(axis="both", labelsize=16)
plt.plot();

In [None]:
n_draws = len(pymc_model.inference_data.sample_stats.draw)
n_chains = len(pymc_model.inference_data.sample_stats.chain)

In [None]:
posterior_x_obs = pymc_model.inference_data.posterior_predictive.x_obs.values.reshape((n_chains*n_draws, *X.shape))

In [None]:
f3, axes3 = plt.subplots(nrows=2, ncols=1, figsize=(18,15))
axes3[0].plot(np.percentile(posterior_x_obs[:, :, 0, 0, 0], [2.5, 97.5], axis=0).T, 
              "k", label=r"$V_{95\% PP}(t)$")
axes3[0].plot(X[:, 0, 0, 0], label="V_observed")
axes3[0].legend(fontsize=16)
axes3[0].set_xlabel("time (ms)", fontsize=16)
axes3[0].tick_params(axis="both", labelsize=16)

axes3[1].plot(np.percentile(posterior_x_obs[:, :, 1, 0, 0], [2.5, 97.5], axis=0).T, 
         "k", label=r"$W_{95\% PP}(t)$")
axes3[1].plot(X[:, 1, 0, 0], label="W_observed")
axes3[1].legend(fontsize=16)
axes3[1].set_xlabel("time (ms)", fontsize=16)
axes3[1].tick_params(axis="both", labelsize=16)

plt.show()

In [None]:
criteria = pymc_model.model_criteria()

In [None]:
criteria

In [None]:
#pymc_model.summary[pymc_model.summary["r_hat"] >= 1.2]

In [None]:
#pymc_model.summary.loc[["a", "noise", "epsilon"]]

In [None]:
pymc_model.summary.loc[["model_a", "model_b", "model_c", "model_I", "noise_gfun", "global_noise"]]

In [None]:
divergent = pymc_model.trace["diverging"]
print("Number of Divergent %d" % divergent.nonzero()[0].size)
divperc = divergent.nonzero()[0].size / (n_draws +  n_chains) * 100
print("Percentage of Divergent %.1f" % divperc)
print("Mean tree accept %.1f" % pymc_model.trace['mean_tree_accept'].mean())

In [None]:
print("Sampling time in hours:", pymc_model.inference_data.sample_stats.sampling_time / 3600)