---<br>
jupyter:<br>
  jupytext:<br>
    cell_metadata_filter: -all<br>
    custom_cell_magics: kql<br>
    text_representation:<br>
      extension: .py<br>
      format_name: percent<br>
      format_version: '1.3'<br>
      jupytext_version: 1.11.2<br>
  kernelspec:<br>
    display_name: vbi_paper<br>
    language: python<br>
    name: python3<br>
---

%%

In [None]:
import torch
import pickle
import numpy as np
from time import time
from tqdm import tqdm
import networkx as nx
import sbi.utils as utils
import scipy.stats as stats
from helpers import plot_mat
import matplotlib.pyplot as plt
from multiprocessing import Pool
from sbi.analysis import pairplot
from vbi.inference import Inference
from vbi.models.numba.ww import WW_sde, ParWW, ParBaloon
from sklearn.preprocessing import StandardScaler
from vbi.feature_extraction.features_utils import get_fc, get_fcd2

%%

In [None]:
import vbi
from vbi import report_cfg
from vbi import extract_features
from vbi import get_features_by_domain, get_features_by_given_names

%%

In [None]:
seed = 2
np.random.seed(seed)
torch.manual_seed(seed);

%%

In [None]:
LABESSIZE = 12
plt.rcParams['axes.labelsize'] = LABESSIZE
plt.rcParams['xtick.labelsize'] = LABESSIZE
plt.rcParams['ytick.labelsize'] = LABESSIZE

%%

In [None]:
weights = vbi.LoadSample(84).get_weights()
nn = weights.shape[0]

%%

In [None]:
par = {
    "G": 0.25,
    "dt": 5.0,
    "t_cut": 2 * 60 * 1000.0,
    "t_end": 5 * 60 * 1000.0,
    "weights": weights,
    "seed": seed,
    "sigma_noise": 0.001,
    "I_o": 0.3 * np.ones(nn),
    "w": np.random.uniform(0.9, 1.0, nn),
    "ts_decimate": 10,
    "fmri_decimate": 10,
    "method": "heun",
    "RECORD_TS": True,
    "RECORD_FMRI": True,
}

%%

In [None]:
from vbi.feature_extraction.features_utils import get_fcd

%%

In [None]:
def simulate(params, par):
    
    params['G'] = par[0]
    params['sigma_noise'] = par[1]
    params['I_o'] = par[2] * np.ones(nn)
    obj = WW_sde(params)
    data = obj.run()
    t = data['t']
    s = data['s']
    t_fmri = data['t_fmri']
    d_fmri = data['d_fmri']
    
    return t, s, t_fmri, d_fmri

%%

In [None]:
def visual(t, s, t_fmri, d_fmri, k=30, **kwargs):
    fc = get_fc(d_fmri)['full']
    fcd = get_fcd2(d_fmri, **kwargs)
    # fcd = get_fcd(d_fmri.T, win_len=win_len)['full']
    
    fc = vbi.utils.set_diag(fc, 0)
    fcd = vbi.utils.set_diag(fcd, k)
    
    mosaic = """
    AACD
    BBCD
    """
    fig = plt.figure(constrained_layout=True, figsize=(12, 3.5))
    ax = fig.subplot_mosaic(mosaic)
    ax['A'].plot(t, s.T, lw=0.1, alpha=1.0)
    ax['B'].plot(t_fmri, d_fmri.T, lw=0.1, alpha=1.0)
    im = ax['C'].imshow(fcd, cmap="viridis"); plt.colorbar(im, ax=ax['C'])
    ax['D'].imshow(fc, cmap="viridis"); plt.colorbar(im, ax=ax['D']);

%%

In [None]:
tic = time()
t, s, t_fmri, d_fmri = simulate(par, [0.2, 0.01, 0.3])
print(f"Elapsed time: {time() - tic:.3f} seconds.")
print(t.shape, s.shape, t_fmri.shape, d_fmri.shape)

%%

In [None]:
tic = time()
t, s, t_fmri, d_fmri = simulate(par, [0.25, 0.008, 0.295])  # [0.2, 0.003, 0.295]
print(f"Elapsed time: {time() - tic:.3f} seconds.")
visual(t, s, t_fmri[:], d_fmri[:, :], k=30, wwidth=200, maxNwindows=250, olap=0.94)

In [None]:
plt.figure(figsize=(15, 3))
plt.plot(t, s[0,:], lw=1);

%%