# IMPORTS

In [None]:
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

import os, sys
import torch
import numpy as np

import CL_inference as cl_inference
N_threads = cl_inference.train_tools.set_N_threads_(N_threads=1)
torch.set_num_threads(N_threads)
torch.set_num_interop_threads(N_threads)
device = cl_inference.train_tools.set_torch_device_()

%load_ext autoreload

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib notebook
plt.style.use('default')
plt.close('all')
font, rcnew = cl_inference.plot_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%config InlineBackend.figure_format = 'retina'

# Select Runs

In [None]:
kmax = 0.6
# 0.6, 0.2, -0.2, -0.6, -1.0, -1.4
 
tmp_CL_str = "Wein"
# Wein, VICReg
 
tmp_dataset_str = "illustris_eagle"
# all
# v1_v2, v1_v3, v2_v3
# f0_f1, f2_f3, f4_f5, f6_f7, f8_f9
# illustris_eagle, bahamas_illustris, eagle_bahamas

In [None]:
models_path = "/cosmos_storage/home/dlopez/Projects/CL_inference/models_box_5000/"
tmp_str = "_models_" + tmp_dataset_str + "_kmax_" + str(kmax)

# main_name = "only" + "_inference_" + "also_baryons"                  + tmp_str
main_name = "only" + "_inference"                                    + tmp_str
# main_name = "only" + "_CL_"                             + tmp_CL_str + tmp_str
# main_name = "only" + "_inference_CL_"                   + tmp_CL_str + tmp_str
# main_name = "join" + "_inference_CL_"                   + tmp_CL_str + tmp_str
# main_name = "join" + "_reload_inference_CL_"            + tmp_CL_str + tmp_str

models_path = os.path.join(models_path, main_name)

# Check loss of runs & get config files of N_best_runs

In [None]:
select_N_best_runs = 1

In [None]:
configs = cl_inference.plot_utils.get_config_files(
    models_path, select_N_best_runs=select_N_best_runs,
    wandb_entity="daniellopezcano"
)

# Reload Models

In [None]:
# evalute_mode = 'eval_CL' # "eval_CL", "eval_CL_and_inference", "eval_inference_supervised"
if "only_CL" in main_name:
    evalute_mode = 'eval_CL'
else:
    evalute_mode = 'eval_CL_and_inference'

In [None]:
models_encoder, models_inference = cl_inference.evaluation_tools.reload_models(
    models_path, evalute_mode, configs, device
)

# Load datasets

In [None]:
key0_dset = list(configs.keys())[0]
sweep_name_load_norm_dset = key0_dset
save_root = configs[key0_dset]["path_save"]
config = configs[sweep_name_load_norm_dset]

### Complete some config default values (be carefull with this)

In [None]:
try:
    print("include_baryon_params:", config['include_baryon_params'])
except:
    config['include_baryon_params'] = False
include_baryon_params = config['include_baryon_params']

try:
    print("box:", config['box'])
except:
    config['box'] = 2000
box = config['box']

try:
    print("factor_kmin_cut:", config['factor_kmin_cut'])
except:
    config['factor_kmin_cut'] = 4
factor_kmin_cut = config['factor_kmin_cut']

In [None]:
kf = 2.0 * np.pi / box
kmin=np.log10(factor_kmin_cut*kf)
N_kk = int((kmax - kmin) / (8*kf))
kk = np.logspace(kmin, kmax, num=N_kk)

In [None]:
custom_titles, limits_plots_inference, list_range_priors = cl_inference.plot_utils.get_titles_limits_and_priors(include_baryon_params)

In [None]:
list_model_names = ["Model_vary_all"] + config["list_model_names"]
# Model_vary_all
# Model_vary_1, Model_vary_2, Model_vary_3
# Model_fixed_0, Model_fixed_1, Model_fixed_2, Model_fixed_3, Model_fixed_4, Model_fixed_5, Model_fixed_6, Model_fixed_7, Model_fixed_8, Model_fixed_9
# Model_fixed_eagle, Model_fixed_illustris, Model_fixed_bahamas,

colors = cl_inference.plot_utils.colors_dsets(list_model_names)

In [None]:
dsets = {}
dset_name = "TEST"
dsets[dset_name] = cl_inference.data_tools.def_data_loader(
    path_load               = os.path.join(config['path_load'], dset_name),
    list_model_names        = list_model_names,
    normalize               = config['normalize'],
    path_load_norm          = os.path.join(config['path_save'], sweep_name_load_norm_dset),
    NN_augs_batch           = config['NN_augs_batch'],
    add_noise_Pk            = config['add_noise_Pk'],
    kmax                    = config['kmax'],
    include_baryon_params   = include_baryon_params
)

In [None]:
# fig, axs = cl_inference.plot_utils.theta_distrib_plot(dsets=dsets, custom_titles=custom_titles)
# fig.savefig(save_root + "/theta_distrib.png")

In [None]:
NN_plot = 5
np.random.seed(config["seed"])
indexes = np.random.choice(dsets[dset_name].NN_cosmos, NN_plot, replace=False)

xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
    config,
    sweep_name_load_norm_dset,
    list_model_names,
    models_encoder,
    models_inference,
    dset_key=dset_name,
    indexes_cosmo=indexes,
    use_all_dataset_augs_ordered=False
)

In [None]:
# --------------------- plot_dataset_Pk --------------------- #

fig, axs = cl_inference.plot_utils.plot_dataset_Pk(
    dsets[dset_name].norm_mean, dsets[dset_name].norm_std, xx, list_model_names, len_models, colors, kk, plot_as_Pk=True)
fig.set_tight_layout(True)
plt.show()
fig.savefig(save_root  + "/example_Pk.png")

# --------------------- plot_dataset_latents --------------------- #

fig, ax = cl_inference.plot_utils.plot_dataset_latents(hh, list_model_names, len_models, colors)
fig.set_tight_layout(True)
plt.show()
fig.savefig(save_root + "/example_latent.png")

# --------------------- plot_dataset_predictions --------------------- #

fig, axs = cl_inference.plot_utils.plot_dataset_predictions(
    theta_true, theta_pred, Cov, list_model_names, len_models, colors, custom_titles, limits_plots_inference)
fig.set_tight_layout(True)
plt.show()
fig.savefig(save_root + "/example_inference.png")

In [None]:
xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
    config,
    sweep_name_load_norm_dset,
    list_model_names=list_model_names,
    models_encoder=models_encoder,
    models_inference=models_inference,
    dset_key="TEST",
    use_all_dataset_augs_ordered=True
)

fig, axs = cl_inference.plot_utils.plot_inference_split_models(
    list_model_names,
    len_models,
    theta_true,
    theta_pred,
    Cov,
    custom_titles=custom_titles,
    limits_plots=limits_plots_inference,
    colors=colors
)
fig.suptitle(main_name, size=18)
fig.set_tight_layout(True)
fig.savefig(save_root + "/eval_inference_train.png")

# Most biased ploints

In [None]:
NN_plot_biased = 5 # obtain the NN_plot_biased most ans less biased points
ii_cosmo_param = -1 # in the cosmo param of index ii_cosmo_param
biased_mode = "most"

abs_bias = np.abs((theta_pred - theta_true) / np.sqrt(np.diagonal(Cov, axis1=2, axis2=3)))
tmp_abs_bias = np.reshape(abs_bias, ((np.prod(abs_bias.shape[0:2]),) + (abs_bias.shape[-1],)))
if biased_mode == "most":
    idxs = np.argpartition(tmp_abs_bias[:, ii_cosmo_param], -NN_plot_biased)[-NN_plot_biased:]
    idxs_cosmo, idxs_augs = np.where(abs_bias[..., ii_cosmo_param] >= np.min(tmp_abs_bias[idxs, ii_cosmo_param]))
if biased_mode == "less":
    idxs = np.argpartition(tmp_abs_bias[:, ii_cosmo_param], NN_plot_biased)[:NN_plot_biased]
    idxs_cosmo, idxs_augs = np.where(abs_bias[..., ii_cosmo_param] <= np.max(tmp_abs_bias[idxs, ii_cosmo_param]))

xx_biased, hh_biased, theta_true_biased, theta_pred_biased, Cov_biased, _ = cl_inference.evaluation_tools.compute_dataset_results(
    config,
    sweep_name_load_norm_dset,
    list_model_names=list_model_names,
    models_encoder=models_encoder,
    models_inference=models_inference,
    dset_key="TEST",
    use_all_dataset_augs_ordered=False,
    indexes_cosmo=idxs_cosmo,
    indexes_augs=idxs_augs[np.newaxis].T
)

xx_biased_from_train, hh_biased_from_train, theta_true_biased_from_train, theta_pred_biased_from_train, Cov_biased_from_train, _ = cl_inference.evaluation_tools.compute_dataset_results(
    config,
    sweep_name_load_norm_dset,
    list_model_names=config["list_model_names"],
    models_encoder=models_encoder,
    models_inference=models_inference,
    dset_key="TEST",
    use_all_dataset_augs_ordered=False,
    indexes_cosmo=idxs_cosmo,
    indexes_augs=None
)

In [None]:
# --------------------- plot_dataset_Pk --------------------- #

fig, axs = cl_inference.plot_utils.plot_dataset_biased_Pk(
    dsets[dset_name].norm_mean, dsets[dset_name].norm_std, xx_biased, xx_biased_from_train, kk, plot_as_Pk=False
)
fig.set_tight_layout(True)
plt.show()
fig.savefig(save_root  + "/example_Pk_most_bias.png")

# --------------------- plot_dataset_latents --------------------- #

fig, ax = cl_inference.plot_utils.plot_dataset_biased_latents(hh_biased, hh_biased_from_train)
fig.set_tight_layout(True)
plt.show()
fig.savefig(save_root + "/example_latent_most_bias.png")

# --------------------- plot_dataset_biased_predictions --------------------- #

fig, axs = cl_inference.plot_utils.plot_dataset_biased_predictions(
    theta_true_biased, theta_pred_biased, Cov_biased,
    theta_true_biased_from_train, theta_pred_biased_from_train, Cov_biased_from_train,
    custom_titles, limits_plots_inference
)
fig.set_tight_layout(True)
plt.show()
fig.savefig(save_root + "/example_inference_most_bias.png")

# Generate bias and errorbar figures

In [None]:
thresholds_bias=np.linspace(0.5, 6, 20)

fraction_biased_batches, NN_points_batches, bin_centers_batches, y_hists_batches, bin_centers_err_batches, y_hists_err_batches, median_err_batches, std_err_batches = cl_inference.evaluation_tools.compute_bias_and_errorbar_stats(
    config,
    sweep_name_load_norm_dset,
    ["Model_vary_all"]+config["list_model_names"],
    models_encoder,
    models_inference,
    len(custom_titles),
    save_root=save_root,
    thresholds_bias=thresholds_bias,
    NN_bins_hist = 60,
    NN_bins_hist_err = 60,
    NN_avail_cosmo_test = 2048,
    NN_split = 20,
    max_err_hist=[0.05, 0.012, 0.12, 0.042, 0.06, 3.2, 1.5, 1.5, 3., .4, 1., 3.]
)

In [None]:
fontsize=26
fontsize1=18
fig, axs = plt.subplots(1, len(custom_titles), figsize=(5.2*len(custom_titles), 5.2))
axs[0].set_ylabel(r'Normalized Counts ', size=fontsize)
for ii_cosmo_param in range(len(custom_titles)):
    ax = axs[ii_cosmo_param]
    ax.set_title(custom_titles[ii_cosmo_param], size=fontsize+8, pad=16)
    ax.set_xlabel(r'Bias ', size=fontsize)
    ax.axvline(0, c='k', ls=':', lw=1)
    for ii_model_name, list_model_name in enumerate(list_model_names):
        for ii_batch in range(y_hists_batches.shape[0]):
            ax.plot(
                bin_centers_batches[ii_batch, ii_model_name, ii_cosmo_param],
                y_hists_batches[ii_batch, ii_model_name, ii_cosmo_param]/NN_points_batches[ii_batch, ii_model_name, ii_cosmo_param],
                color=colors[ii_model_name], lw=0.2, alpha=0.9
            )
fig.set_tight_layout(True)
fig.savefig(save_root + "/hist_bias.png")

In [None]:
fontsize=26
fontsize1=18
fig, axs = plt.subplots(1, len(custom_titles), figsize=(5.2*len(custom_titles), 5.2))
axs[0].set_ylabel(r'Fraction biased points ', size=fontsize)
for ii_cosmo_param in range(len(custom_titles)):
    ax = axs[ii_cosmo_param]
    ax.set_xlim([1.5, 5.])
    ax.set_ylim([0, 1])
    ax.set_title(custom_titles[ii_cosmo_param], size=fontsize+8, pad=16)
    ax.set_xlabel(r'$\sigma_\mathrm{thr.}$', size=fontsize)
    if ii_cosmo_param != 0:
        ax.set_yticks([])
    for ii_model_name, list_model_name in enumerate(list_model_names):
        color = colors[ii_model_name]
        tmp_mean = np.nanmean(fraction_biased_batches[:, ii_model_name, ii_cosmo_param], axis=0)
        tmp_std = np.nanstd(fraction_biased_batches[:, ii_model_name, ii_cosmo_param], axis=0)
        ax.scatter(thresholds_bias, tmp_mean, c=color, s=20)
        ax.errorbar(thresholds_bias, tmp_mean, yerr=tmp_std, c=color, ls='', capsize=2, alpha=1., elinewidth=1.5)
        ax.fill_between(thresholds_bias, tmp_mean-tmp_std, tmp_mean+tmp_std, color=color, alpha=0.3)

fig.set_tight_layout(True)
fig.savefig(save_root + "/bias_vs_threshold.png")

In [None]:
fontsize=26
fontsize1=18
fig, axs = plt.subplots(1, len(custom_titles), figsize=(5.2*len(custom_titles), 5.2))
axs[0].set_ylabel(r'Normalized Counts ', size=fontsize)
for ii_cosmo_param in range(len(custom_titles)):
    ax = axs[ii_cosmo_param]
    ax.set_title(custom_titles[ii_cosmo_param], size=fontsize+8, pad=16)
    ax.set_xlabel(r'$\frac{2\sigma}{\Delta \mathrm{Prior}}$', size=fontsize)
    for ii_model_name, list_model_name in enumerate(list_model_names):
        color = colors[ii_model_name]
        for ii_batch in range(y_hists_batches.shape[0]):
            tmp_prior_range = (list_range_priors[ii_cosmo_param][1]-list_range_priors[ii_cosmo_param][0])
            ax.plot(
                bin_centers_err_batches[ii_batch, ii_model_name, ii_cosmo_param]/tmp_prior_range,
                y_hists_err_batches[ii_batch, ii_model_name, ii_cosmo_param]/NN_points_batches[ii_batch, ii_model_name, ii_cosmo_param],
                color=color, lw=0.2, alpha=0.9
            )
            ax.axvline(median_err_batches[ii_batch, ii_model_name, ii_cosmo_param]/tmp_prior_range, color=color, lw=0.1, alpha=0.9)
        tmp_mean = np.nanmean(median_err_batches[:, ii_model_name, ii_cosmo_param]/tmp_prior_range, axis=0)
        tmp_std = np.nanstd(median_err_batches[:, ii_model_name, ii_cosmo_param]/tmp_prior_range, axis=0)
        ax.axvline(tmp_mean, color=color, ls='-', lw=2)
        ax.axvspan(tmp_mean-tmp_std, tmp_mean+tmp_std, alpha=0.3, color=color)

fig.set_tight_layout(True)
fig.savefig(save_root + "/hist_error.png")