# Imports

In [None]:
N_threads = 6

In [None]:
%load_ext autoreload

import sys, os
os.environ["OMP_NUM_THREADS"] = str(N_threads)
os.environ["OPENBLAS_NUM_THREADS"] = str(N_threads)
os.environ["MKL_NUM_THREADS"] = str(N_threads)
os.environ["VECLIB_MAXIMUM_THREADS"] = str(N_threads)
os.environ["NUMEXPR_NUM_THREADS"] = str(N_threads)

import numpy as np
import pickle

from sbi.utils.posterior_ensemble import NeuralPosteriorEnsemble

import SBIBE as sbibe

%matplotlib notebook
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')

np.random.seed(seed=0)

# Get posteriors from N best trained networks

In [None]:
path_wandb_sweep = "/dipc_storage/dlopez/Projects/SBI-baccoemu/wandb_models"
N_best_combine = 2

In [None]:
sweep_names, losses = sbibe.sbi_wandb_utils.load_wandb_sweep_register(path_wandb_sweep)

In [None]:
custom_lines = [
    mpl.lines.Line2D([0], [0], color='royalblue', ls='-', lw=2, marker=None, markersize=9),
    mpl.lines.Line2D([], [], color='k', marker='|', linestyle='None', markersize=10, markeredgewidth=2)
]    
fig, ax = sbibe.plot_utils.simple_plot(
    x_label=r'Sweep ID',
    y_label=r'Loss',
    custom_labels=[r'loss', r'cut'],
    custom_lines=custom_lines
)
ax.plot(losses, color='royalblue', lw=2)
ax.axvline(N_best_combine-0.5, color='k', lw=2)

custom_lines = []
custom_labels = []
colors = sbibe.plot_utils.get_N_colors(N_best_combine, mpl.colormaps['jet'])
for ii in range(N_best_combine):
    ax.scatter([ii], losses[ii], c=colors[ii])
    custom_lines.append(mpl.lines.Line2D([0], [0], color=colors[ii], ls='-', lw=0, marker='o', markersize=9))
    custom_labels.append(sweep_names[ii])
legend = ax.legend(custom_lines, custom_labels, loc='lower left',
                   fancybox=True, shadow=True, ncol=1,fontsize=14)
ax.add_artist(legend)

plt.tight_layout()
plt.show()

In [None]:
posteriors =  sbibe.sbi_wandb_utils.load_posteriors(path_wandb_sweep, sweep_names[:N_best_combine])
posterior = NeuralPosteriorEnsemble(posteriors=posteriors)

# Extract test - posterior inference - rank stats

In [None]:
# ------------------ extract test ------------------ #

dict_bounds_test = {
    'omega_cold'    :  [0.25, 0.38],
    'omega_baryon'  :  [0.042, 0.058],
    'hubble'        :  [0.62, 0.78],
    'ns'            :  [0.93, 1.00],
    'sigma8_cold'   :  [0.75, 0.88],
}
theta_test = sbibe.sbi_data_utils.sample_latin_hypercube(dict_bounds_test, 987)
xx_test, kk = sbibe.sbi_data_utils.get_xx(dict_bounds_test, theta_test)

# ------------------ posterior inference ------------------ #

# norm_xx_test = scaler.transform(xx_test)

inferred_theta_test = sbibe.sbi_utils.sample_posteriors_theta_test(
    posterior,
    xx_test,
    dict_bounds_test
)

# ------------------ rank stats ------------------ #

ranks = sbibe.sbi_utils.compute_ranks(theta_test, inferred_theta_test)

# Visualizations

In [None]:
custom_titles = [
    r'$\Omega_\mathrm{c}$',
    r'$\Omega_\mathrm{b}$',
    r'$h$',
    r'$n_\mathrm{s}$',
    r'$\sigma_{8,\mathrm{c}}$'
]

N_examples = 5
indexes = np.random.choice(inferred_theta_test.shape[0], N_examples, replace=False)

## Visualize test examples

In [None]:
custom_lines = [
    mpl.lines.Line2D([0], [0], color='k', ls='-', lw=2, marker=None, markersize=9),
]

fig, ax = sbibe.plot_utils.simple_plot(
    x_label=r'Wavenumber $k \left[ h\, \mathrm{Mpc}^{-1} \right]$',
    y_label=r'$P(k) \left[ \left(h^{-1} \mathrm{Mpc}\right)^{3} \right]$',
    custom_labels=[r'Test'],
    custom_lines=custom_lines
)

tmp_xx_plot = xx_test
tmp_xx_plot = tmp_xx_plot[np.random.choice(tmp_xx_plot.shape[0], tmp_xx_plot.shape[0], replace=False)].T
ax.plot(np.log10(kk), tmp_xx_plot, c='k', alpha=0.1, lw=0.5)

plt.tight_layout()
plt.show()

## (Optional) Visualize xx associatted with inferred posteriors

In [None]:
tmp_inferred_xx_test = sbibe.sbi_data_utils.compute_baccoemu_predictions_batch(
    inferred_theta_test[indexes],
    list(dict_bounds_test.keys())
)

fig, ax, ax_res = sbibe.plot_utils.plot_xx_from_sampled_posteriors(xx_test[indexes], tmp_inferred_xx_test, kk)
plt.show()

## Posteriors examples

In [None]:
colors = sbibe.plot_utils.get_N_colors(N_examples, mpl.colormaps['prism'])
for ii_sample in range(N_examples):
    fig, axs = sbibe.plot_utils.corner_plot(
        theta_test[ii_sample],
        inferred_theta_test[ii_sample],
        custom_titles,
        color_infer=colors[ii_sample]
    )
    plt.show()

## Rank statistics

In [None]:
fig, axs = sbibe.plot_utils.plot_rank_statistcis(ranks, inferred_theta_test.shape[1], custom_titles)
plt.tight_layout()
plt.subplots_adjust(wspace=0.05)
plt.show()

## Inference errors

In [None]:
fig, axs = sbibe.plot_utils.plot_parameter_prediction_vs_truth(inferred_theta_test, theta_test, custom_titles)
plt.tight_layout()
plt.subplots_adjust(wspace=0.6)
plt.show()