In [None]:
import os
import tqdm
import torch
import pickle
import numpy as np
import networkx as nx
from os.path import join
import sbi.utils as utils
from vbi import report_cfg
import matplotlib.pyplot as plt
from vbi.utils import LoadSample
from sbi.analysis import pairplot
from vbi.models.cpp.vep import VEP
from vbi.inference import Inference
from sklearn.preprocessing import StandardScaler

In [None]:
seed = 2
np.random.seed(seed)
torch.manual_seed(seed)
path = "output/vep84"
os.makedirs(path, exist_ok=True)

In [None]:
weights = np.loadtxt("data/weights1.txt")
nn = weights.shape[0]

healthy zone, propagation zone, epileptic zone eta values

In [None]:
hz_val = -3.65
pz_val = -2.4
ez_val = -1.6

In [None]:
ez_idx = np.array([6, 34], dtype=np.int32)
pz_wplng_idx = np.array([5, 11], dtype=np.int32)
pz_kplng_idx = np.array([27], dtype=np.int32)
pz_idx = np.append(pz_kplng_idx, pz_wplng_idx)

eta_true = np.ones(nn) * hz_val<br>
eta_true[pz_idx] = pz_val<br>
eta_true[ez_idx] = ez_val

In [None]:
initial_state = np.zeros(2 * nn)
initial_state[:nn] = -2.5
initial_state[nn:] = 3.5
# --------------------------------------------------------------------------- #

In [None]:
params = {
    "G": 1.0,
    "seed": seed,
    "initial_state": initial_state,
    "weights": weights,
    "tau": 10.0,
    "eta": -3.5,
    "noise_sigma": 0.0,
    "iext": 3.1,
    "dt": 0.1,
    "tend": 14.0,
    "tcut": 1.0,
    "noise_seed": 0,
    "record_step": 1,
    "method": "heun",
    "output": "output",
}

In [None]:
obj = VEP(params)
g_true = 1.0
eta_true = np.ones(nn) * hz_val
eta_true[pz_idx] = pz_val
eta_true[ez_idx] = ez_val
control_true = {"eta": eta_true, "G": g_true}

data = obj.run(par=control_true)<br>
ts = data["x"]<br>
t = data["t"]

if 0:<br>
    plt.figure(figsize=(10, 16))<br>
    for i in range(0, nn):<br>
        if i in ez_idx:<br>
            plt.plot(t, ts[i, :] + i, "r", lw=3)<br>
        elif i in pz_idx:<br>
            plt.plot(t, ts[i, :] + i, "orange", lw=3)<br>
        else:<br>
            plt.plot(t, ts[i, :] + i, "g")<br>
    plt.yticks(np.r_[0:nn] - 2, np.r_[0:nn], fontsize=10)<br>
    plt.xticks(fontsize=16)<br>
    plt.title("Source brain activity", fontsize=18)<br>
    plt.xlabel("Time", fontsize=22)<br>
    plt.ylabel("Brain Regions#", fontsize=22)<br>
    plt.tight_layout()<br>
    plt.savefig("output/vep_sde.png", dpi=300)<br>
    plt.show()

In [None]:
from vbi.feature_extraction.features_settings import *
from vbi.feature_extraction.calc_features import *

In [None]:
fs = 1 / (params["dt"]) / 1000
cfg = get_features_by_domain(domain="statistical")
# cfg = get_features_by_given_names(cfg, names=["calc_moments"])
cfg = get_features_by_given_names(cfg, names=["auc"])
# report_cfg(cfg)

data = extract_features_df([ts], fs, cfg=cfg, n_workers=1)<br>
print(data.values.shape)

In [None]:
def wrapper(params, control, x0, cfg, verbose=False):
    vep_obj = VEP(params)
    sol = vep_obj.run(control, x0=x0)

    # extract features
    fs = 1.0 / params["dt"] * 1000  # [Hz]
    stat_vec = extract_features(
        ts=[sol["x"]], cfg=cfg, fs=fs, n_workers=1, verbose=verbose
    ).values[0]
    return stat_vec

In [None]:
def batch_run(params, control_list, x0, cfg, n_workers=1):
    n = len(control_list)
    def update_bar(_):
        pbar.update()
    with Pool(processes=n_workers) as pool:
        with tqdm.tqdm(total=n) as pbar:
            async_results = [
                pool.apply_async(
                    wrapper,
                    args=(params, control_list[i], x0, cfg, False),
                    callback=update_bar,
                )
                for i in range(n)
            ]
            stat_vec = [res.get() for res in async_results]
    return stat_vec

x_ = wrapper(params, control_true, cfg)

In [None]:
num_sim = 5000
num_workers = 10
eta_min, eta_max = -5.0, -1.0
gmin, gmax = 0.0, 2.0

In [None]:
prior_min = [gmin] + [eta_min] * nn
prior_max = [gmax] + [eta_max] * nn
prior = utils.BoxUniform(low=torch.tensor(prior_min), high=torch.tensor(prior_max))

In [None]:
obj = Inference()
theta = obj.sample_prior(prior, num_sim)
theta_np = theta.numpy().astype(float)
print(theta_np.shape)

In [None]:
control_list = []
for i in range(num_sim):
    g_ = theta_np[i, 0]
    eta_ = theta_np[i, 1:]
    control_list.append({"eta": eta_, "G": g_})

control_list = [{'eta': theta_np[i, 1:], "G": theta_np[i, 0]} for i in range(num_sim)]

In [None]:
stat_vec = batch_run(params, control_list, initial_state, cfg, num_workers)

In [None]:
scalar = StandardScaler()
stat_vec_st = scalar.fit_transform(np.array(stat_vec))
stat_vec_st = torch.tensor(stat_vec_st, dtype=torch.float32)
torch.save(theta, join(path, "theta.pt"))
torch.save(stat_vec, join(path,"stat_vec.pt"))

In [None]:
print(theta.shape, stat_vec_st.shape)
posterior = obj.train(
    theta, stat_vec_st, prior, method="SNPE", density_estimator="maf", num_threads=8
)

In [None]:
with open(join(path, "posterior.pkl"), "wb") as f:
    pickle.dump(posterior, f)

In [None]:
with open(join(path, "posterior.pkl"), "rb") as f:
    posterior = pickle.load(f)

In [None]:
xo = wrapper(params, control_true, initial_state, cfg)
xo_st = scalar.transform(xo.reshape(1, -1))

In [None]:
samples = obj.sample_posterior(xo_st, 10000, posterior)
torch.save(samples, join(path, "samples.pt"))

In [None]:
prior_min = prior_min[:10]
prior_max = prior_max[:10]
eta_true = eta_true[:10]
limits = [[i, j] for i, j in zip(prior_min, prior_max)]
points = [[g_true] + eta_true]
fig, ax = pairplot(
    samples[:, :10],
    limits=limits,
    figsize=(5, 5),
    points=points,
    labels=["G"]+ [f"eta_{i}" for i in range(9)],
    offdiag="kde",
    diag="kde",
    points_colors="r",
    samples_colors="k",
    points_offdiag={"markersize": 10},
)
ax[0, 0].tick_params(labelsize=14)
ax[0, 0].margins(y=0)
plt.tight_layout()
fig.savefig(join(path, "triangleplot.jpeg"), dpi=300)