In [None]:
from tvb.simulator.simulator import Simulator
from tvb.datatypes.connectivity import Connectivity
from tvb.contrib.inversion.pymcInference import pymcModel

import tvb.simulator.models
import tvb.simulator.integrators
import tvb.simulator.coupling
import tvb.simulator.monitors

import matplotlib.pyplot as plt
import numpy as np
import arviz as az
import pymc3 as pm
import scipy
import theano.tensor as tt
import theano
import math
from tqdm import tqdm
import pickle

%load_ext autoreload
%autoreload 2

In [None]:
run_ids = [
    "2022-10-01_1542_instance",
    "2022-10-07_1112_instance",
    "2022-10-07_1821_instance",
    "2022-10-08_1052_instance",
    "2022-10-11_1324_instance",
    "2022-10-15_2126_instance",
    "2022-10-20_2106_instance",
    "2022-11-02_1223_instance",
    "2022-11-02_1224_instance",
    "2022-11-03_1048_instance",
    "2022-11-06_1430_instance"
]

In [None]:
idx = -1

In [None]:
with open(f"pymc_data/inference_data/{run_ids[idx]}.pkl", "rb") as f:
    instance_params = pickle.load(f)
    simulation_params = instance_params["simulation_params"]

In [None]:
# Connectivity
connectivity = Connectivity()
connectivity.weights = np.array([[0., 2.], [2., 0.]])
connectivity.region_labels = np.array(["R1", "R2"])
connectivity.centres = np.array([[0.1, 0.1, 0.1], [0.2, 0.1, 0.1]])
connectivity.tract_lengths = np.array([[0., 2.5], [2.5, 0.]])
connectivity.configure()

# Model
oscillator_model = getattr(tvb.simulator.models, simulation_params["model"])(
    a=np.asarray([simulation_params["a_sim"]]),
    b=np.asarray([simulation_params["b_sim"]]),
    c=np.asarray([simulation_params["c_sim"]]),
    d=np.asarray([simulation_params["d_sim"]]),
    I=np.asarray([simulation_params["I_sim"]]),
)
oscillator_model.configure()

# Integrator
integrator = getattr(tvb.simulator.integrators, simulation_params["integrator"])(dt=simulation_params["dt"])
integrator.noise.nsig = np.array([simulation_params["nsig"]])
integrator.configure()

# Global coupling
coupling = getattr(tvb.simulator.coupling, simulation_params["coupling"])()

# Monitor
monitor = getattr(tvb.simulator.monitors, simulation_params["monitor"])()

In [None]:
# Simulator
sim = Simulator(
    model=oscillator_model,
    connectivity=connectivity,
    coupling=coupling,
    integrator=integrator,
    monitors=(monitor,),
    simulation_length=simulation_params["simulation_length"]
)

sim.configure()

In [None]:
X = instance_params["obs"]

In [None]:
f1 = plt.figure(figsize=(14,8))
plt.plot(X[:, 0, 0, 0], label="Region 1")
plt.plot(X[:, 0, 1, 0], label="Region 2")
plt.ylabel("V (a.u.)", fontsize=16)
plt.xlabel("time (ms)", fontsize=16)
plt.legend(fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.show()

In [None]:
pymc_model = pymcModel(sim)

In [None]:
pymc_model.load(f"{run_ids[idx]}.pkl")

In [None]:
pymc_model.prior_stats

In [None]:
#plt.hist(pymc_model.inference_data.posterior.global_noise.values.reshape(700,), bins=100);

In [None]:
pymc_model.plot_posterior_samples(
    init_params={"coupling_a": sim.coupling.a[0],
                 "model_a": simulation_params["a_sim"],
                 #"model_b": simulation_params["b_sim"],
                 "noise_gfun": np.sqrt(2 * simulation_params["nsig"]), 
                 "global_noise": 0.0
                }
)

In [None]:
s = pymc_model.posterior_shrinkage()
z = pymc_model.posterior_zscore(init_params={
    "coupling_a": sim.coupling.a[0],
    "model_a": simulation_params["a_sim"],
    "global_noise": 0.0,
    "noise_gfun": np.sqrt(2 * simulation_params["nsig"])
})

In [None]:
f2 = plt.figure(figsize=(12,8))
plt.plot(s, z, color="blue", linewidth=0, marker="*", markersize=12)
plt.xlabel("posterior shrinkage")
plt.ylabel("posterior zscore")
#plt.xlim([0.0, 1.1])
#plt.ylim([0.0, 5])
plt.plot();

In [None]:
n_draws = len(pymc_model.inference_data.sample_stats.draw)
n_chains = len(pymc_model.inference_data.sample_stats.chain)

In [None]:
posterior_x_obs = pymc_model.inference_data.posterior_predictive.x_obs.values.reshape((n_chains*n_draws, *X.shape))

In [None]:
f3, axes3 = plt.subplots(nrows=2, ncols=1, figsize=(18,15))
axes3[0].plot(np.percentile(posterior_x_obs[:, :, 0, 0, 0], [2.5, 97.5], axis=0).T, 
              "k", label=r"$V_{95\% PP}(t)$")
axes3[0].plot(X[:, 0, 0, 0], label="V_observed")
#axes3[0].plot(posterior_x_obs[0, :, 0, 0, 0])
axes3[0].legend(fontsize=16)
axes3[0].set_xlabel("time (ms)", fontsize=16)
axes3[0].tick_params(axis="both", labelsize=16)

axes3[1].plot(np.percentile(posterior_x_obs[:, :, 0, 1, 0], [2.5, 97.5], axis=0).T, 
         "k", label=r"$W_{95\% PP}(t)$")
axes3[1].plot(X[:, 0, 1, 0], label="W_observed")
#axes3[1].plot(posterior_x_obs[0, :, 1, 0, 0])
axes3[1].legend(fontsize=16)
axes3[1].set_xlabel("time (ms)", fontsize=16)
axes3[1].tick_params(axis="both", labelsize=16)

plt.show()

In [None]:
criteria = pymc_model.model_criteria(["WAIC", "LOO"])

In [None]:
criteria

In [None]:
pymc_model.summary.loc[["coupling_a", "model_a", "noise_gfun", "global_noise"]] #, "b", "c", "d", "I"]]

In [None]:
divergent = pymc_model.trace["diverging"]
print("Number of Divergent %d" % divergent.nonzero()[0].size)
divperc = divergent.nonzero()[0].size / len(pymc_model.trace) * 100
print("Percentage of Divergent %.1f" % divperc)
print("Mean tree accept %.1f" % pymc_model.trace['mean_tree_accept'].mean())

In [None]:
print("Sampling time in hours:", pymc_model.inference_data.sample_stats.sampling_time / 3600)