In [None]:
from tvb.simulator.simulator import Simulator
from tvb.datatypes.connectivity import Connectivity
from tvb.contrib.inversion.pymcInference import pymcModel

import tvb.simulator.models
import tvb.simulator.integrators
import tvb.simulator.coupling
import tvb.simulator.monitors

import matplotlib.pyplot as plt
import numpy as np
import arviz as az
import pymc3 as pm
import scipy
import theano.tensor as tt
import theano
import math
from tqdm import tqdm
import pickle

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# Simulation parameters
with open("../limit-cycle_simulation.pkl", "rb") as f:
    simulation_params = pickle.load(f)

In [None]:
simulation_params["x0"].shape

In [None]:
# Connectivity
connectivity = Connectivity()
connectivity.weights = np.array([[0., 2.], [2., 0.]])
connectivity.region_labels = np.array(["R1", "R2"])
connectivity.centres = np.array([[0.1, 0.1, 0.1], [0.2, 0.1, 0.1]])
connectivity.tract_lengths = np.array([[0., 2.5], [2.5, 0.]])
connectivity.configure()

# Model
oscillator_model = getattr(tvb.simulator.models, simulation_params["model"])(
    a=np.asarray([simulation_params["a_sim"]]),
    b=np.asarray([simulation_params["b_sim"]]),
    c=np.asarray([simulation_params["c_sim"]]),
    d=np.asarray([simulation_params["d_sim"]]),
    I=np.asarray([simulation_params["I_sim"]]),
)
oscillator_model.configure()

# Integrator
integrator = getattr(tvb.simulator.integrators, simulation_params["integrator"])(dt=simulation_params["dt"])
integrator.noise.nsig = np.array([simulation_params["nsig"]])
integrator.configure()

# Global coupling
coupling = getattr(tvb.simulator.coupling, simulation_params["coupling"])()

# Monitor
monitor = getattr(tvb.simulator.monitors, simulation_params["monitor"])()

In [None]:
# Simulator
sim = Simulator(
    model=oscillator_model,
    connectivity=connectivity,
    coupling=coupling,
    integrator=integrator,
    monitors=(monitor,),
    simulation_length=simulation_params["simulation_length"]
)

sim.configure()

In [None]:
X = simulation_params["simulation"]

In [None]:
f1 = plt.figure(figsize=(14,8))
plt.plot(X[:, 0, 0, 0], label="R1")
plt.plot(X[:, 0, 1, 0], label="R2")
plt.ylabel("states", fontsize=18)
plt.xlabel("time (ms)", fontsize=18)
plt.legend(fontsize=18)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.show()

In [None]:
dt = theano.shared(simulation_params["dt"], name="dt")

In [None]:
priors = {
    "model_a": np.array([simulation_params["a_sim"]]),
    "model_b": np.array([simulation_params["b_sim"]]),
    "model_c": np.array([simulation_params["c_sim"]]),
    "model_d": np.array([simulation_params["d_sim"]]),
    "model_I": np.array([simulation_params["I_sim"]]),
    "model_tau": np.array([1.0]),
    "model_e": np.array([3.0]),
    "model_f": np.array([1.0]),
    "model_g": np.array([0.0]),
    "model_alpha": np.array([1.0]),
    "model_beta": np.array([1.0]),
    "model_gamma": np.array([1.0]),
    "coupling_a": np.array([0.1]),
    "local_coupling": 0.0
}

In [None]:
from theano.tensor.random.utils import RandomStream
random_stream = RandomStream(seed=42)

In [None]:
def scheme(*args):
    Nr = sim.connectivity.number_of_regions
    Ncv = sim.history.n_cvar

    x_prev = args[-1]

    x_i = x_prev[sim.model.cvar, :, :]
    x_i = tt.transpose(tt.reshape(tt.tile(x_i, (1, Nr)), (Nr, Nr)))

    x_j = tt.stack(args, axis=0)
    x_j = x_j[:, sim.model.cvar, :, :]
    x_j = tt.flatten(x_j)[-1 * sim.connectivity.idelays - 1]

    # x_j = x_j[-1 * sim.history.nnz_idelays, :, :, :]
    # x_j = x_j[np.arange(sim.history.n_nnzw), :, sim.history.nnz_col_el_idx, :].reshape([Ncv, sim.history.n_nnzw, 1])
    # x_i = x_i[:, sim.history.nnz_row_el_idx, :]

    pre = sim.coupling.pre(x_i, x_j)
    gx = tt.sum(sim.connectivity.weights * pre, axis=-1)
    nc = sim.coupling.post_tensor(gx, priors)

    # weights_col = sim.history.nnz_weights.reshape((sim.history.n_nnzw, 1))
    # sum_ = np.zeros((Ncv, Nr, 1))
    # lri, nzr = sim.coupling._lri(sim.history.nnz_row_el_idx)
    # sum_[:, nzr] = np.add.reduceat(weights_col * pre, lri, axis=1)
    # node_coupling = sim.coupling.post(sum_)
    # node_coupling = np.zeros((1, 2, 1))

    # try:
    #     sum_[:, nzr] = np.add.reduceat(weights_col * pre, lri, axis=1)
    #     node_coupling = sim.coupling.post(sum_)
    # except:
    #     node_coupling = sim.coupling.post(sum_)

    noise = random_stream.normal(size=(2, 2, 1))
    noise_gfun = sim.integrator.noise.gfun(x_prev)
    noise *= noise_gfun

    m_dx_tn = sim.model.dfun_tensor(x_prev, priors, nc)
    inter = x_prev + dt * m_dx_tn + noise
    # sim.integrator.integration_bound_and_clamp(inter)
    x_next = x_prev + (m_dx_tn + sim.model.dfun_tensor(inter, priors, nc)) * dt / 2.0 + noise
    # sim.integrator.integration_bound_and_clamp(x_next)
    # x_next = x_prev + dt * sim.model.dfun_tensor(x_prev, priors, nc) + noise

    return x_next

In [None]:
Nt = int(sim.simulation_length)
Nsv = len(sim.model.state_variables)
Nr = sim.connectivity.number_of_regions
Ncv = sim.history.n_cvar
Nc = 1
idmax = sim.connectivity.idelays.max()
cvars = sim.history.cvars

# series_init = theano.shared(np.random.rand(*shape))
# x_init = theano.shared(np.random.rand(idmax+1, Nsv, Nr, 1))

In [None]:
x0_init = np.zeros((Nsv, Nr, 1))
for i, (_, value) in enumerate(sim.model.state_variable_range.items()):
    loc = (value[0] + value[1]) / 2
    scale = (value[1] - value[0]) / 2
    x0_init[i, :, :] = np.random.normal(loc=loc, scale=scale, size=(1, Nr, 1))

x_init = np.zeros((idmax + 1, Nsv, Nr, 1))
x_init = theano.shared(x_init, name="x_init")
x_init = tt.set_subtensor(x_init[-1], x0_init)

In [None]:
taps = list(-1 * np.arange(np.unique(sim.history.nnz_idelays).max() + 1) - 1)[::-1]

In [None]:
x_sim, updates = theano.scan(
    fn=scheme,
    outputs_info=[dict(initial=x_init, taps=taps)],
    n_steps=X.shape[0]
)

In [None]:
x_sim_np = x_sim.eval()

In [None]:
plt.figure(figsize=(18,10))
plt.plot(x_sim_np[:, 0, 0, 0])
plt.plot(x_sim_np[:, 0, 1, 0])

### Inference using non-centered model

In [None]:
# global inference parameters
shape = X.shape
draws = 500
tune = 500
num_cores = 2

In [None]:
pymc_model = pymcModel(sim)

In [None]:
with pymc_model.stat_model:
    a_star = pm.Normal(name="a_star", mu=0.0, sd=1.0)
    a = pm.Deterministic(name="a", var=2.0 + a_star)
    
    a_coupling_star = pm.Normal(name="a_coupling_star", mu=0.0, sd=1.0)
    a_coupling = pm.Deterministic(name="coupling", var=0.1 + 0.05 * a_coupling_star)
    
    #x_init = pm.Normal("x_init", mu=0.0, sd=1.0, shape=shape[1:])
    #x_init = theano.shared(X[0], name="x_init")
    
    BoundedNormal = pm.Bound(pm.Normal, lower=0.0)
    
    # sd should be in the range of sqrt(2*nsig)
    noise_gfun_star = BoundedNormal(name="noise_gfun_star", mu=0.0, sd=1.0)
    noise_gfun = pm.Deterministic(name="noise_gfun", var=0.05 + 0.1 * noise_gfun_star)
    
    noise_star = pm.Normal(name="noise_star", mu=0.0, sd=1.0, shape=tuple(shape))
    noise = pm.Deterministic(name="noise", var=noise_gfun * noise_star)
    
    epsilon = BoundedNormal(name="epsilon", mu=0.0, sd=1.0)
    
    # Passing the prior distributions as dictionary. Also including fixed model parameters.
    priors = {
        "model.a": a,
        "model.b": np.array([simulation_params["b_sim"]]),
        "model.c": np.array([simulation_params["c_sim"]]),
        "model.d": np.array([simulation_params["d_sim"]]),
        "model.I": np.array([simulation_params["I_sim"]]),
        "model.tau": np.array([1.0]),
        "model.e": np.array([3.0]),
        "model.f": np.array([1.0]),
        "model.g": np.array([0.0]),
        "model.alpha": np.array([1.0]),
        "model.beta": np.array([1.0]),
        "model.gamma": np.array([1.0]),
        "coupling.a": a_coupling,
        "integrator.noise": noise,
        "global.noise": epsilon,
        "local_coupling": 0.0
    }

In [None]:
pymc_model.prior_stats = {
            "model.a": {"mean": 2.0, "sd": 1.0},
            "coupling.a": {"mean": 0.1, "sd": 0.05},
            "noise_gfun": {"mean": 0.05, "sd": 0.1},
            "global.epsilon": {"mean": 0.0, "sd": 1.0}
        }

In [None]:
pymc_model.set_model(
    priors=priors,
    obs=X, 
    time_step=simulation_params["dt"],
)

In [None]:
inference_data = pymc_model.run_inference(
    draws=draws,
    tune=tune,
    cores=num_cores,
    target_accept=0.9,
    max_treedepth=20,
    save=True
)

In [None]:
pymc_model.inference_data

In [None]:
pymc_model.plot_posterior_samples(
    init_params={"a": simulation_params["a_sim"], 
                 "epsilon": 0.0, 
                 "noise_gfun": np.sqrt(2 * simulation_params["nsig"])
                }
)

In [None]:
posterior_x_obs = pymc_model.inference_data.posterior_predictive.x_obs.values.reshape(
    (num_cores*draws, *shape))

In [None]:
f3, axes3 = plt.subplots(nrows=2, ncols=1, figsize=(18,15))
axes3[0].plot(np.percentile(posterior_x_obs[:, :, 0, 0, 0], [2.5, 97.5], axis=0).T, 
              "k", label=r"$V_{95\% PP}(t)$")
axes3[0].plot(X[:, 0, 0, 0], label="V_observed")
#axes3[0].plot(posterior_x_obs[0, :, 0, 0, 0])
axes3[0].legend(fontsize=16)
axes3[0].set_xlabel("time (ms)", fontsize=16)
axes3[0].tick_params(axis="both", labelsize=16)

axes3[1].plot(np.percentile(posterior_x_obs[:, :, 0, 1, 0], [2.5, 97.5], axis=0).T, 
         "k", label=r"$W_{95\% PP}(t)$")
axes3[1].plot(X[:, 0, 1, 0], label="W_observed")
#axes3[1].plot(posterior_x_obs[0, :, 1, 0, 0])
axes3[1].legend(fontsize=16)
axes3[1].set_xlabel("time (ms)", fontsize=16)
axes3[1].tick_params(axis="both", labelsize=16)

plt.show()

In [None]:
criteria = ncModel.model_criteria(["WAIC", "LOO"])

In [None]:
print("WAIC: ", criteria["WAIC"])
print("LOO: ", criteria["LOO"])

In [None]:
pymc_model.summary[pymc_model.summary["r_hat"] > 1.2]

In [None]:
pymc_model.summary.loc[["a", "noise_gfun"]] #, "b", "c", "d", "I"]]

In [None]:
divergent = pymc_model.trace["diverging"]
print("Number of Divergent %d" % divergent.nonzero()[0].size)
divperc = divergent.nonzero()[0].size / len(pymc_model.trace) * 100
print("Percentage of Divergent %.1f" % divperc)
print("Mean tree accept %.1f" % pymc_model.trace['mean_tree_accept'].mean())

In [None]:
ncModel.save(simulation_params=simulation_params.copy())