In [1]:
import os
import h5py
import tqdm
import json
import torch
import pickle
import helpers
import warnings
import numpy as np 
from copy import copy
import pylab as plt
from os.path import join
from scipy import signal
from copy import deepcopy
from vbi.utility import kop
from sbi.analysis import pairplot
from multiprocessing import Pool
from scipy.signal import hilbert
from vbi.Models.St_Lan_swig import Stuart_Landau
from sbi.utils.user_input_checks import process_prior

warnings.filterwarnings("ignore")


ModuleNotFoundError: No module named 'vbi'

In [2]:
torch.manual_seed(2)
np.random.seed(2)

NameError: name 'np' is not defined

In [None]:
data_path = "output"
if not os.path.exists(data_path):
    os.makedirs(data_path)
print(os.getcwd())

In [None]:
data_path = "output"
if not os.path.exists(data_path):
    os.makedirs(data_path)

def load_mat(mat_filename, key):
    with h5py.File(mat_filename, 'r') as f:
        return np.array(f[key])

dataset_data = json.load(open("dataset.json"))
group = dataset_data['group']
subject = dataset_data['subject']

# np.savez("data/SC.npz", SC=SC, Dist=Dist)
SC = np.load("data/SC.npz")['SC']
Dist = np.load("data/SC.npz")['Dist']

SC_avg = load_mat("data/test_conn.mat", "test_conn")
SC_avg = SC_avg / SC_avg.max()
np.fill_diagonal(SC_avg, 0.0)
Dist_avg = load_mat("data/test_dist.mat", "test_dist") / 1000.0

np.fill_diagonal(SC, 0.0)
np.fill_diagonal(Dist, 0.0)
SC = SC/np.max(SC)
SC = np.abs(SC)
assert(np.trace(SC) == 0.0)
num_nodes = SC.shape[0]

freq = 40.0
omega = 2*np.pi*freq * np.ones(num_nodes)

# substitute zeros with average
SC = np.where(SC<1e-10, SC_avg, SC)
Dist = np.where(Dist<1e4, Dist_avg, Dist)
SC = SC * 2.0

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
helpers.plot_matrix(np.log10(SC), ax=ax[0], title="SC", xlabel="", ylabel="")
helpers.plot_matrix(np.log10(Dist), ax=ax[1], title="Dist", xlabel="", ylabel="")
[ax[i].set_xticks([]) for i in range(2)]
[ax[i].set_yticks([]) for i in range(2)];

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].hist(SC.flatten(), bins=100, range=(0.01, 1.0), log=True);
ax[1].hist(Dist.flatten(), bins=100, range=(0.01, 0.25), log=True);
ax[0].set_title("SC")
ax[1].set_title("Dist")

In [None]:
parameters = {
    "G": 1000.0,                      # global coupling strength
    "a": -5.0,                     # bifurcation parameter
    "dt": 1e-4,                    # time step [s]
    'sigma_r': 1e-4,               # noise strength
    'sigma_v': 1e-4,               # noise strength
    'omega': omega,                # natural frequency [Hz]
    "fix_seed": 0,
    "velocity": 6.0,               # velocity          [m/s]

    "t_initial": 0.0,              # initial time    [s] 
    "t_transition": 2.0,           # transition time [s]
    "t_final": 10.0,               # end time        [s]   
    "method": "euler",

    "control": ["G", "velocity"],  # control parameters

    "adj": SC,                     # weighted connection matrix
    "distances": Dist,                  # distance matrix
    "record_step": 1,              # sampling every n step from mpr time series
    "data_path": "output",         # output directory
    "RECORD_TS": 0,                # true to store large time series in file
}


In [None]:
import json 
import torch
import sbi.utils as utils
import multiprocessing as mp
from vbi.utility import brute_sample
from vbi.Models.St_Lan_swig import Inference
from vbi.feature_extraction import Features

NUM_SIMULATIONS = 5000
N_JOBS = 10

prior_PAR_min = [800.0, 2.5]
prior_PAR_max = [1700.0, 20.0]

fs = 1.0/(parameters["dt"]*parameters["record_step"])

prior_dist = utils.torchutils.BoxUniform(
    low=torch.as_tensor(prior_PAR_min),
    high=torch.as_tensor(prior_PAR_max))

sol = Inference(parameters)
prior, _, _ = process_prior(prior_dist)
theta = prior.sample((NUM_SIMULATIONS,)) 

#!Warning: to avoid overwritting theta 
#!clean output directory before running the notebook
if not os.path.exists(join(data_path, "theta.pt")):
    torch.save(theta, join(data_path, "theta.pt"))
    torch.save(prior, join(data_path, "prior.pt"))

In [None]:
theta.shape

In [None]:
def simulation_wrapper(params, par):

    data_path = params['data_path']
    _par = deepcopy(params)

    fmin = 6.0
    fmax = 13.0

    sol = Stuart_Landau(_par)
    data = sol.simulate(par)

    fs = 1/(parameters['dt']*parameters['record_step'])

    freq, pxx = signal.welch(data['x'], fs=fs, nperseg=data['x'].shape[1]//2)
    idx = np.where((freq >= 0.0) & (freq <= 50.0))[0]
    pxx = pxx[:, idx]
    freq = freq[idx]
    pxx_avg = np.mean(pxx, axis=0)
    ind = np.argmax(pxx_avg)

    pwr = np.median(pxx, axis=0)
    ind_m = np.argmax(pwr)
    
    ## integrate under area between fmin and fmax
    idx = np.logical_and(freq >= fmin, freq <= fmax).tolist()
    area_avg = np.trapz(pxx_avg[idx], freq[idx])
    area_med = np.trapz(pwr[idx], freq[idx])

    return freq[ind_m], pwr[ind_m], area_med, freq[ind], pxx_avg[ind], area_avg

In [None]:
def batch_run(parameters, par_list):
    
    num_simulations = len(par_list)

    def update_bar(_):
        pbar.update()

    with Pool(processes=N_JOBS) as pool:
        with tqdm.tqdm(total=num_simulations) as pbar:
            async_results = [pool.apply_async(simulation_wrapper, args=(parameters,
                                                                 par_list[i]),
                                              callback=update_bar)
                             for i in range(num_simulations)]
            data = [async_result.get() for async_result in async_results]

    return data


In [None]:
par_list = []
for i in range(theta.shape[0]):
    par = [theta[i, 0].item(), theta[i, 1].item()]
    par_list.append(par)
num_simulations = len(theta)
print(len(par_list))

In [None]:
# x = torch.as_tensor(batch_run(parameters, par_list), dtype=torch.float32)[:, :3]

In [None]:
# torch.save(x, join(data_path, "data_x.pt"))
x = torch.load(join(data_path, "data_x.pt"))

In [None]:
x.shape, x.mean(axis=0).tolist(), x.std(axis=0).tolist(), x.max(axis=0).values.tolist()

In [None]:
x_ = x.numpy()
fig, ax = plt.subplots(1,6, figsize=(18, 3))
ax[0].plot(theta[:, 0].numpy(), x_[:, 0], 'o', alpha=0.5, color='r', ms=1)
ax[1].plot(theta[:, 0].numpy(), np.log10(x_[:, 1]), 'o', alpha=0.5, color='r', ms=1)
ax[2].plot(theta[:, 0].numpy(), x_[:, 2], 'o', alpha=0.5, color='r', ms=1)
ax[3].plot(theta[:, 1].numpy(), x_[:, 0], 'o', alpha=0.5, color="b", ms=1)
ax[4].plot(theta[:, 1].numpy(), np.log10(x_[:, 1]), 'o', alpha=0.5, color="b", ms=1)
ax[5].plot(theta[:, 1].numpy(), x_[:, 2], 'o', alpha=0.5, color="b", ms=1)

for i in range(3):
    ax[i].set_xlabel("G", fontsize=12)
for i in range(3,6):    
    ax[i].set_xlabel("V", fontsize=12)
for i in range(6):
    ax[i].tick_params(labelsize=12)


[ax[i].set_title("frequency") for i in [0,3]]
[ax[i].set_title("power") for i in [1,4]]
[ax[i].set_title("area") for i in [2,5]]
plt.tight_layout()

In [None]:
# print(f_features(parameters, par_list[0]))

In [None]:
# Load observation data features

def obs_wrapper(par, opts):
    group, subject = par
    DS = P_Dataset("/home/ziaee/projects/02_Peirpaolo/dataset/data")
    t, ts = DS.load_TS(group, subject)
    ts = DS.moving_average(ts, 10)
    fmin = opts['fmin']
    fmax = opts['fmax']

    freq, pxx = DS.welch(ts, 1024, nperseg=4096)
    pxx = np.mean(pxx, axis=0)
    idxs = np.logical_and(freq >= fmin, freq <= fmax).tolist()
    pxx = pxx[idxs]
    freq = freq[idxs]
    idx = np.argmax(pxx)
    area = helpers.PSD_under_area(freq, pxx.reshape(1, -1), opt=opts)
    
    return freq[idx], pxx[idx], area[0]


In [None]:
# opts = {"fmax":13.0, "fmin":6.0, "normalize": False}
# x_obs_ = obs_wrapper([group, subject], opts)
# x_obs_ = torch.as_tensor(x_obs_, dtype=torch.float32).reshape(1, -1)

In [None]:
# x_obs_.shape, x_obs_

In [None]:
# torch.save(x_obs_, join(data_path, "data_x_obs.pt"))
x_obs_ = torch.load(join(data_path, "data_x_obs.pt"))

In [None]:
x_ = torch.load(join(data_path, "data_x.pt"))
theta = torch.load(join(data_path, "theta.pt"))
print(x_.shape, theta.shape, x_[0,:])

In [None]:
# scaling features
coefficient = 1e12 if (group == "control") else 2e12

x = x_.clone().detach()
x[:, 1:] = x[:, 1:] * coefficient
x[0, :]


In [None]:
obj = Inference(parameters)
# posterior = obj.train(num_simulations, prior, x[:, :], theta, 8)

In [None]:
# with open(join(data_path, f"posterior_.pickle"), "wb") as cf:
#     pickle.dump({"posterior": posterior}, cf)
posterior = pickle.load(open(join(data_path, f"posterior_.pickle"), "rb"))['posterior']

In [None]:
samples = obj.sample_posterior(x_obs_[:, :], 10_000, posterior)

In [None]:
torch.save(samples, join(data_path, "samples.pt"))

In [None]:
limits = [[i, j] for i, j in zip(prior_PAR_min, prior_PAR_max)]

fig, ax = pairplot(
    samples,
    # points=theta_obs,
    figsize=(5, 5),
    limits=limits,
    labels=["G", "V"],
    upper='kde',
    diag='kde',
    title=f"n = {len(theta)}",
    points_colors="r",
    samples_colors="k",
    points_offdiag={'markersize': 10})

ax[0,0].tick_params(labelsize=14)
ax[0,0].margins(y=0)
ax[0,0].set_xlabel(r"$G$", fontsize=16)
ax[1,1].set_xlabel(r"$V$", fontsize=16)
fig.savefig(join(data_path, "triangleplot_.jpeg"), dpi=150)


In [None]:
from vbi.utility import posterior_peaks
theta_peak = posterior_peaks(samples, labels=["G", "V"])
print(theta_peak)

In [None]:
type(theta_peak), type(theta_peak[0])

In [None]:
from sbi.analysis import ActiveSubspace

In [None]:
posterior_active = posterior.set_default_x(x_obs_)
sensitivity = ActiveSubspace(posterior_active)

e_vals, e_vecs = sensitivity.find_directions(posterior_log_prob_as_property=True)

In [None]:
print("Eigenvalues: \n", e_vals, "\n")
print("Eigenvectors: \n", e_vecs)

In [None]:
import seaborn as sns
sns.set_theme(style="darkgrid")
def plot_bar(ax, labels, y, ylabel=None):

    x = np.arange(len(labels))
    sns.barplot(x=x, y=y, palette="Reds_d")
    # ax.legend(loc='best', frameon=False)
    ax.set_xticks(x, labels)
    if ylabel is not None:
        ax.set_ylabel(ylabel, fontsize=18)
    ax.grid(False)
    sns.despine(bottom = False, left = False)

In [None]:
fig, ax = plt.subplots(1, figsize=(4, 4))
plot_bar(ax, [r"$G$", r"$V$"], e_vals.numpy().squeeze(), ylabel="sensitivity")
ax.set_yscale('log')
plt.savefig(join(data_path, "sensitivity_.jpeg"), dpi=150);

In [None]:
# store in json file
results = {"G": theta_peak[0], "V": theta_peak[1], "sensitivity": (1.0/e_vals.numpy().squeeze()).tolist()}
with open(join(data_path, "results.json"), "w") as cf:
    json.dump(results, cf)