---<br>
jupyter:<br>
  jupytext:<br>
    cell_metadata_filter: -all<br>
    custom_cell_magics: kql<br>
    text_representation:<br>
      extension: .py<br>
      format_name: percent<br>
      format_version: '1.3'<br>
      jupytext_version: 1.11.2<br>
  kernelspec:<br>
    display_name: vbi_paper<br>
    language: python<br>
    name: python3<br>
---

%%

In [None]:
import torch
import pickle
import numpy as np
from time import time
from tqdm import tqdm
import sbi.utils as utils
import scipy.stats as stats
from helpers import plot_mat
import matplotlib.pyplot as plt
from multiprocessing import Pool
from vbi.utils import LoadSample
from sbi.analysis import pairplot
from vbi.inference import Inference
from vbi.models.cpp.ww import WW_sde
from vbi.utils import timer
from sklearn.preprocessing import StandardScaler
from vbi.feature_extraction.features_utils import get_fc, get_fcd2

%%

In [None]:
import vbi
from vbi import extract_features
from vbi import get_features_by_domain, get_features_by_given_names

%%

In [None]:
seed = 2
np.random.seed(seed)
torch.manual_seed(seed)

%%

In [None]:
LABESSIZE = 12
plt.rcParams["axes.labelsize"] = LABESSIZE
plt.rcParams["xtick.labelsize"] = LABESSIZE
plt.rcParams["ytick.labelsize"] = LABESSIZE

%%

In [None]:
def visual(t, s, t_fmri, d_fmri, k=30, **kwargs):
    fc = get_fc(d_fmri)['full']
    fcd = get_fcd2(d_fmri, **kwargs)
    
    fc = vbi.utils.set_diag(fc, 0)
    fcd = vbi.utils.set_diag(fcd, k)
    
    mosaic = """
    AACD
    BBCD
    """
    fig = plt.figure(constrained_layout=True, figsize=(12, 3.5))
    ax = fig.subplot_mosaic(mosaic)
    ax['A'].plot(t, s.T, lw=0.1, alpha=1.0)
    ax['B'].plot(t_fmri, d_fmri.T, lw=0.1, alpha=1.0)
    im = ax['C'].imshow(fcd, cmap="viridis"); plt.colorbar(im, ax=ax['C'])
    ax['D'].imshow(fc, cmap="viridis"); plt.colorbar(im, ax=ax['D']);
    
@timer
def run(par):    
    obj = WW_sde(par)
    data = obj.run()
    t = data["t"]
    s = data["s"]
    t_fmri = data["t_fmri"]
    d_fmri = data["d_fmri"]
    
    if np.isnan(s).any() or np.isnan(d_fmri).any():
        print("Nan values detected")
        return None, None, None, None
    
    return t, s, t_fmri, d_fmri

%%

In [None]:
weights = vbi.LoadSample(84).get_weights()
nn = weights.shape[0]

%%

In [None]:
par = {
    "G": 0.2,
    "dt": 2.5,
    "t_cut": 1 * 60 * 1000.0,
    "t_end": 5 * 60 * 1000.0,
    "weights": weights,
    "seed": seed,
    "I_o" : np.ones(nn) * 0.286,
    "w": np.random.uniform(0.9, 1.0, nn),
    "sigma_noise": 0.008,
    "ts_decimate": 20,
    "fmri_decimate": 50,
    "RECORD_TS": 1,
    "RECORD_FMRI": 1,
}

In [None]:
t, s, t_fmri, d_fmri = run(par)
visual(t, s, t_fmri[:], d_fmri[:, :], k=30, wwidth=200, maxNwindows=250, olap=0.94)

%% [markdown]<br>
if 1:

%%

In [None]:
cfg = get_features_by_domain(domain="connectivity")
cfg = get_features_by_given_names(cfg, names=["fc_stat"])
# report_cfg(cfg)

%%

In [None]:
def wrapper(par, control, cfg, verbose=False):
    ode = WW_sde(par)
    sol = ode.run(control)

    # extract features
    fs = 1.0 / par["dt"] * 1000  # [Hz]
    stat_vec = extract_features(
        ts=[sol["d_fmri"].T], cfg=cfg, fs=fs, verbose=verbose
    ).values[0]
    return stat_vec

%%

In [None]:
def batch_run(par, control_list, cfg, n_workers=1):
    stat_vec = []
    n = len(control_list)
    def update_bar(_):
        pbar.update()
    with Pool(processes=n_workers) as pool:
        with tqdm(total=n) as pbar:
            async_results = [
                pool.apply_async(
                    wrapper,
                    args=(par, control_list[i], cfg, False),
                    callback=update_bar,
                )
                for i in range(n)
            ]
            stat_vec = [res.get() for res in async_results]
    return stat_vec

%%

In [None]:
theta_true = {
    "G": {"value": 0.65},
}
# tic = time()
# x_ = wrapper(par, theta_true, cfg)
# print(f"Elapsed time: {time() - tic:.2f} s")
# print(x_)

%%

In [None]:
num_sim = 200
num_workers = 10
G_min, G_max = 0.0, 1.5
prior_min = [G_min]
prior_max = [G_max]
prior = utils.BoxUniform(low=torch.tensor(prior_min), high=torch.tensor(prior_max))

%%

In [None]:
obj = Inference()
theta = obj.sample_prior(prior, num_sim)
theta_np = theta.numpy().astype(float)
control_list = [{"G": {"value": theta_np[i, 0]}} for i in range(num_sim)]

%%

In [None]:
stat_vec = batch_run(par, control_list, cfg, num_workers)

%%

In [None]:
scaler = StandardScaler()
stat_vec_st = scaler.fit_transform(np.array(stat_vec))
stat_vec_st = torch.tensor(stat_vec_st, dtype=torch.float32)
torch.save(theta, "output/theta.pt")
torch.save(stat_vec, "output/stat_vec.pt")

%%

In [None]:
print(theta.shape, stat_vec_st.shape)

%%

In [None]:
posterior = obj.train(theta, stat_vec_st, prior, method="SNPE", density_estimator="maf")

%%

In [None]:
with open("output/posterior.pkl", "wb") as f:
    pickle.dump(posterior, f)

%%

In [None]:
with open("output/posterior.pkl", "rb") as f:
    posterior = pickle.load(f)

%%

In [None]:
xo = wrapper(par, theta_true, cfg)
xo_st = scaler.transform(xo.reshape(1, -1))

%%

In [None]:
samples = obj.sample_posterior(xo_st, 10000, posterior)
torch.save(samples, "output/samples.pt")

%%

In [None]:
limits = [[i, j] for i, j in zip(prior_min, prior_max)]
points = [[theta_true["G"]["value"]]]
fig, ax = pairplot(
    samples,
    limits=limits,
    figsize=(4, 4),
    points=points,
    labels=["G"],
    offdiag="kde",
    diag="kde",
    points_colors="r",
    samples_colors="k",
    points_offdiag={"markersize": 10},
)
ax[0, 0].tick_params(labelsize=14)
ax[0, 0].margins(y=0)
plt.tight_layout()
fig.savefig("output/tri.jpeg", dpi=300)