# Imports and setup

In [None]:
import os, sys
import numpy as np
import torch
import CL_SBI_mini as CL_SBI_mini

%load_ext autoreload

N_threads = 1
os.environ["OMP_NUM_THREADS"] = str(N_threads)
os.environ["OPENBLAS_NUM_THREADS"] = str(N_threads)
os.environ["MKL_NUM_THREADS"] = str(N_threads)
os.environ["VECLIB_MAXIMUM_THREADS"] = str(N_threads)
os.environ["NUMEXPR_NUM_THREADS"] = str(N_threads)
torch.set_num_threads(N_threads)
torch.set_num_interop_threads(N_threads)

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')

# Load training dataset

In [None]:
normalize = True
dset_dict = dict(
    add_noise_Pk = "cosmic_var_gauss",
    NN_noise_realizations = 10,
    kmax = 0.6,
    box = 2000,
    factor_kmin_cut = 4
)

theta, xx = CL_SBI_mini.load_datasets.load_and_preprocess_dset(
    path_load_dset = os.path.join("../", "DATASETS", "TRAIN"),
    list_model_names = ["Model_eagle", "Model_illustris"],
    dset_dict = dset_dict,
    normalize = normalize,
    path_save_norm = None,
    path_load_norm = os.path.join("../", "DATASETS", "TEST")
)

In [None]:
kf = 2.0 * np.pi / dset_dict["box"]
kmin=np.log10(dset_dict["factor_kmin_cut"]*kf)
N_kk = int((dset_dict["kmax"] - kmin) / (8*kf))

N_plot_cosmo = 10
indexes = np.random.choice(xx.shape[0], N_plot_cosmo, replace=False)

fig, ax = mpl.pyplot.subplots(1,1,figsize=(9,6))
ax.set_ylabel(r'$\mathrm{Norm}\left(P(k) \left[ \left(h^{-1} \mathrm{Mpc}\right)^{3} \right]\right)$')
ax.set_xlabel('$k - index [adim]$')

ax.plot(np.arange(N_kk), xx[indexes].T, c='limegreen', lw=1.5, marker=None, ms=2, alpha=0.7)

plt.show()

# Train SBI pipeline

In [None]:
from sbi.inference import SNPE
from sbi import utils

In [None]:
num_hidden_features = 64
num_transforms      = 4
num_blocks          = 3

In [None]:
torch.manual_seed(0)

density_estimator_build_fun = utils.get_nn_models.posterior_nn(
    model='maf',
    hidden_features=num_hidden_features,
    num_transforms=num_transforms,
    num_blocks=num_blocks
)

In [None]:
device = "cpu"

dict_bounds_params = {
    'omega_cold': [0.23, 0.40],
    'omega_baryon': [0.04, 0.06],
    'hubble': [0.60, 0.80],
    'ns': [0.92, 1.01],
    'sigma8_cold': [0.73, 0.90]
}

def get_prior(dict_bounds, device="cpu"):

    lower_bound = np.vstack(tuple(dict_bounds[key] for key in dict_bounds))[:,0]
    upper_bound = np.vstack(tuple(dict_bounds[key] for key in dict_bounds))[:,1]

    lower_bound, upper_bound = (
        torch.from_numpy(lower_bound.astype('float32')).to(device), 
        torch.from_numpy(upper_bound.astype('float32')).to(device)
    )
    prior = utils.BoxUniform(lower_bound, upper_bound)
    
    return prior

In [None]:
inference = SNPE(
    prior=get_prior(dict_bounds_params, device),
    density_estimator=density_estimator_build_fun,
    device=device
)

In [None]:
inference.append_simulations(
    torch.from_numpy(theta.astype('float32')).to(device), 
    torch.from_numpy(xx.astype('float32')).to(device)
)

In [None]:
batch_size           = 8
lr                   = 0.001

In [None]:
%%time

density_estimator = inference.train(
    training_batch_size=batch_size,
    validation_fraction=0.2,
    learning_rate=lr,
    show_train_summary=True
)

In [None]:
posterior = inference.build_posterior(
    density_estimator
)