In [None]:
from tvb.simulator.simulator import Simulator
from tvb.datatypes.connectivity import Connectivity
from tvb.contrib.inversion.pymcInference import NonCenteredModel

import tvb.simulator.models
import tvb.simulator.integrators
import tvb.simulator.coupling
import tvb.simulator.monitors

import matplotlib.pyplot as plt
import numpy as np
import math
import arviz as az
import pymc3 as pm
import scipy
import theano.tensor as tt
import theano
import pickle

%load_ext autoreload
%autoreload 2

In [None]:
with open('../limit-cycle_simulation.pkl', 'rb') as f:
    simulation_params = pickle.load(f)

In [None]:
# Connectivity
if simulation_params["connectivity"] == "Own":
    connectivity = Connectivity()
    connectivity.weights = np.array([[0., 2/3], [2/3, 0.]])
    connectivity.region_labels = np.array(["R1", "R2"])
    connectivity.centres = np.array([[0.1, 0.1, 0.1], [0.2, 0.1, 0.1]])
    connectivity.tract_lengths = np.array([[0., 0.1], [0.1, 0.]])
    connectivity.configure()

# Model
oscillator_model = getattr(tvb.simulator.models, simulation_params["model"])(
    a=np.asarray([simulation_params["a_sim"]]),
    b=np.asarray([simulation_params["b_sim"]]),
    c=np.asarray([simulation_params["c_sim"]]),
    d=np.asarray([simulation_params["d_sim"]]),
    I=np.asarray([simulation_params["I_sim"]]),
)
oscillator_model.configure()

# Integrator
integrator = getattr(tvb.simulator.integrators, simulation_params["integrator"])(dt=simulation_params["dt"])
integrator.noise.nsig = np.array([simulation_params["nsig"]])
integrator.configure()

# Global coupling
coupling = getattr(tvb.simulator.coupling, simulation_params["coupling"])()

# Monitor
monitor = getattr(tvb.simulator.monitors, simulation_params["monitor"])()

In [None]:
# Simulator
sim = Simulator(
    model=oscillator_model,
    connectivity=connectivity,
    coupling=coupling,
    integrator=integrator,
    monitors=(monitor,),
    simulation_length=simulation_params["simulation_length"]
)

sim.configure()

In [None]:
X = simulation_params["simulation"]

In [None]:
f1 = plt.figure(figsize=(14,8))
plt.plot(X[:, 0, 0, 0], label="R1")
plt.plot(X[:, 0, 1, 0], label="R2")
plt.ylabel("states")
plt.legend()
plt.show()

In [None]:
# global inference parameters
shape = X.shape
draws = 1000
tune = 1000