In [None]:
import os
import numpy as np
import networkx as nx
from vbi import report_cfg
import matplotlib.pyplot as plt
from vbi.models.cpp.mpr import MPR_sde

In [None]:
seed = 2
np.random.seed(seed)

In [None]:
LABESSIZE = 14
plt.rcParams['axes.labelsize'] = LABESSIZE
plt.rcParams['xtick.labelsize'] = LABESSIZE
plt.rcParams['ytick.labelsize'] = LABESSIZE

In [None]:
nn = 6
weights = nx.to_numpy_array(nx.complete_graph(nn))

In [None]:
parameters = {
    "G": 0.55,                          # global coupling strength
    "dt": 0.01,                         # for mpr model [ms]
    "dt_bold": 0.001,                   # for Balloon model [s]
    "J": 14.5,                          # model parameter
    "eta": -4.6,                        # model parameter
    "tau": 1.0,                         # model parameter
    "delta": 0.7,                       # model parameter
    "decimate": 500,                    # sampling from mpr time series
    "noise_amp": 0.037,                 # amplitude of noise
    "iapp": 0.0,                        # constant applyed current
    "t_cut": 0.5 * 60* 1000.0,   # transition time * 10 [ms]
    "t_end": 2 * 60 * 1000,             # end time * 10 [ms]
    "weights": weights,                 # weighted connection matrix
    "seed": seed,                       # seed for random number generator
    "noise_seed": True,                 # fix seed for noise
    "record_step": 10,                  # sampling every n step from mpr time series
    "output": "output",                 # output directory
    "RECORD_AVG": 0                     # true to store large time series in file
}

In [None]:
control_dict = {"G": 0.5}
obj = MPR_sde(parameters)
# print(obj())
sol = obj.run(par=control_dict)
print(obj.eta)

In [None]:
t = sol["t"]
x = sol["x"]

In [None]:
print(f"t.shape = {t.shape}")
print(f"x.shape = {x.shape}")

In [None]:
if x.ndim == 2:
    pass
    fig, ax = plt.subplots(1, figsize=(10, 3))
    ax.set_xlabel("Time [s]")
    ax.set_ylabel("BOLD")
    plt.plot(t/1000, x.T, alpha=0.8, lw=2)
    plt.margins(0,0.1)
    plt.tight_layout()
    os.makedirs("output", exist_ok=True)
    plt.savefig("output/mpr_sde_ts.png", dpi=300)
    plt.close()
else:
    exit(0)

Feature extraction ------------------------------------------------

In [None]:
from vbi.feature_extraction.features_settings import *
from vbi.feature_extraction.calc_features import *

In [None]:
fs = 1/(parameters["dt_bold"]) / 1000
cfg = get_features_by_domain(domain="connectivity")
report_cfg(cfg)
data = extract_features_df([x], fs, cfg=cfg, n_workers=1)
print(data.values.shape)