# IMPORTS

In [1]:
import CL_inference as cl_inference
N_threads = cl_inference.train_tools.set_N_threads_(N_threads=1)

import os, sys
import torch
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

torch.set_num_threads(N_threads)
torch.set_num_interop_threads(N_threads)

%load_ext autoreload

%matplotlib notebook
plt.style.use('default')
plt.close('all')

font, rcnew = cl_inference.plot_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%config InlineBackend.figure_format = 'retina'

device = cl_inference.train_tools.set_torch_device_()

N_threads: 1
Device: cuda


# Select Runs

In [2]:
kmax = 0.6
# 0.6, 0.2, -0.2, -0.6, -1.0, -1.4
 
tmp_CL_str = "VICReg"
# Wein, VICReg
 
tmp_dataset_str = "all"
# all
# v1_v2, v1_v3, v2_v3
# f0_f1, f2_f3, f4_f5, f6_f7, f8_f9
# illustris_eagle, bahamas_illustris, eagle_bahamas

In [3]:
models_path = "/cosmos_storage/home/dlopez/Projects/CL_inference/models/"
tmp_str = "_models_" + tmp_dataset_str + "_kmax_" + str(kmax)

main_name = "only" + "_inference_" + "also_baryons"                  + tmp_str
# main_name = "only" + "_inference"                                    + tmp_str
# main_name = "only" + "_CL_"                             + tmp_CL_str + tmp_str
# main_name = "only" + "_inference_CL_"                   + tmp_CL_str + tmp_str
# main_name = "join" + "_inference_CL_"                   + tmp_CL_str + tmp_str
# main_name = "join" + "_reload_inference_CL_"            + tmp_CL_str + tmp_str

models_path = os.path.join(models_path, main_name)

# Check loss of runs & get config files of N_best_runs

In [4]:
select_N_best_runs = 1

In [5]:
configs = cl_inference.plot_utils.get_config_files(models_path, select_N_best_runs=select_N_best_runs)

<IPython.core.display.Javascript object>



<IPython.core.display.Javascript object>

# Reload Models

In [6]:
# evalute_mode = 'eval_CL' # "eval_CL", "eval_CL_and_inference", "eval_inference_supervised"
if "only_CL" in main_name:
    evalute_mode = 'eval_CL'
else:
    evalute_mode = 'eval_CL_and_inference'

In [7]:
models_encoder, models_inference = cl_inference.evaluation_tools.reload_models(
    models_path, evalute_mode, configs, device
)

Loaded model encoder: misty-sweep-21
Loaded model inference: misty-sweep-21


# Load datasets

In [8]:
key0_dset = list(configs.keys())[0]
sweep_name_load_norm_dset = key0_dset
save_root = configs[key0_dset]["path_save"]
config = configs[sweep_name_load_norm_dset]

In [9]:
box = 2000
kf = 2.0 * np.pi / box
kmin=np.log10(4*kf)
N_kk = int((kmax - kmin) / (8*kf))
kk = np.logspace(kmin, kmax, num=N_kk)

In [10]:
try:
    print("include_baryon_params:", config['include_baryon_params'])
except:
    config['include_baryon_params'] = False
include_baryon_params = config['include_baryon_params']

custom_titles, limits_plots_inference, list_range_priors = cl_inference.plot_utils.get_titles_limits_and_priors(include_baryon_params)

include_baryon_params: True


In [11]:
list_model_names = ["Model_vary_all"] + config["list_model_names"]
# Model_vary_all
# Model_vary_1, Model_vary_2, Model_vary_3
# Model_fixed_0, Model_fixed_1, Model_fixed_2, Model_fixed_3, Model_fixed_4, Model_fixed_5, Model_fixed_6, Model_fixed_7, Model_fixed_8, Model_fixed_9
# Model_fixed_eagle, Model_fixed_illustris, Model_fixed_bahamas,

colors = cl_inference.plot_utils.colors_dsets(list_model_names)

In [12]:
dsets = {}
dset_name = "TEST"
dsets[dset_name] = cl_inference.data_tools.def_data_loader(
    path_load               = os.path.join(config['path_load'], dset_name),
    list_model_names        = list_model_names,
    normalize               = config['normalize'],
    path_load_norm          = os.path.join(config['path_save'], sweep_name_load_norm_dset),
    NN_augs_batch           = config['NN_augs_batch'],
    add_noise_Pk            = config['add_noise_Pk'],
    kmax                    = config['kmax'],
    include_baryon_params   = include_baryon_params
)

In [13]:
# fig, axs = cl_inference.plot_utils.theta_distrib_plot(dsets=dsets, custom_titles=custom_titles)
# fig.savefig(save_root + "/theta_distrib.png")

In [14]:
NN_plot = 5
np.random.seed(config["seed"])
indexes = np.random.choice(dsets[dset_name].NN_cosmos, NN_plot, replace=False)

xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
    config,
    sweep_name_load_norm_dset,
    list_model_names,
    models_encoder,
    models_inference,
    device,
    dset_key=dset_name,
    indexes_cosmo=indexes,
    use_all_dataset_augs_ordered=False
)

In [15]:
fig, axs, fig1, ax1, fig2, axs2 = cl_inference.plot_utils.plot_dataset_and_prediction_examples(
    dsets, dset_name, xx, hh, theta_true, theta_pred, Cov, list_model_names, len_models, colors, kk,
    custom_titles, limits_plots_inference, plot_as_Pk=True
)
fig.set_tight_layout(True)
fig1.set_tight_layout(True)
fig2.set_tight_layout(True)
plt.show()
fig.savefig(save_root  + "/example_Pk.png")
fig1.savefig(save_root + "/example_latent.png")
fig2.savefig(save_root + "/example_inference.png")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
    config,
    sweep_name_load_norm_dset,
    list_model_names=list_model_names,
    models_encoder=models_encoder,
    models_inference=models_inference,
    device=device,
    dset_key="TEST",
    use_all_dataset_augs_ordered=True
)

fig, axs = cl_inference.plot_utils.plot_inference_split_models(
    list_model_names,
    len_models,
    theta_true,
    theta_pred,
    Cov,
    custom_titles=custom_titles,
    limits_plots=limits_plots_inference,
    colors=colors
)
fig.suptitle(main_name, size=18)
fig.set_tight_layout(True)
fig.savefig(save_root + "/eval_inference_hydros.png")

<IPython.core.display.Javascript object>

In [None]:
list_model_names = ["Model_vary_all"]

xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
    config,
    sweep_name_load_norm_dset,
    list_model_names=list_model_names,
    models_encoder=models_encoder,
    models_inference=models_inference,
    device=device,
    dset_key="TEST",
    use_all_dataset_augs_ordered=True
)
fig, axs = cl_inference.plot_utils.plot_inference_split_models(
    list_model_names,
    len_models,
    theta_true,
    theta_pred,
    Cov,
    custom_titles=custom_titles,
    limits_plots=limits_plots_inference,
    colors=["grey"]
)
fig.suptitle(main_name, size=18)
fig.set_tight_layout(True)
fig.savefig(save_root + "/eval_inference_all.png")

# Generate bias and errorbar figures

In [None]:
threshold_bias = np.linspace(0.5, 6, 20)

list_model_name = ["Model_vary_all"]

NN_avail_cosmo_test = 2048
indexes_cosmos = np.arange(NN_avail_cosmo_test)
np.random.shuffle(indexes_cosmos)
NN_split = 20
indexes_cosmos_groups = np.split(indexes_cosmos, (np.arange(NN_split)+1) * int(NN_avail_cosmo_test/NN_split))[:-1]

NN_bins_hist = 60

NN_bins_hist_err = 60

max_err_hist = [0.05, 0.012, 0.12, 0.042, 0.06, 3.2, 1.5, 1.5, 3., .4, 1., 3.]

In [None]:
ii_aug=0

fraction_biased_batches = np.zeros((len(indexes_cosmos_groups), len(custom_titles), len(threshold_bias)))
NN_points_batches = np.zeros((len(indexes_cosmos_groups), len(custom_titles)))
bin_centers_batches = np.zeros((len(indexes_cosmos_groups), len(custom_titles), NN_bins_hist+2))
y_hists_batches = np.zeros((len(indexes_cosmos_groups), len(custom_titles), NN_bins_hist+2))
bin_centers_err_batches = np.zeros((len(indexes_cosmos_groups), len(custom_titles), NN_bins_hist+2))
y_hists_err_batches = np.zeros((len(indexes_cosmos_groups), len(custom_titles), NN_bins_hist+2))
median_err_batches = np.zeros((len(indexes_cosmos_groups), len(custom_titles)))
std_err_batches = np.zeros((len(indexes_cosmos_groups), len(custom_titles)))

for ii, indexes_cosmos in enumerate(indexes_cosmos_groups):
    xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
        config,
        sweep_name_load_norm_dset,
        list_model_names=list_model_names,
        models_encoder=models_encoder,
        models_inference=models_inference,
        device=device,
        dset_key="TEST",
        use_all_dataset_augs_ordered=False,
        indexes_cosmo=indexes_cosmos
    )
    
    theta_true = np.reshape(theta_true, (theta_true.shape[0]*theta_true.shape[1], theta_true.shape[-1]))[:,np.newaxis]
    theta_pred = np.reshape(theta_pred, (theta_pred.shape[0]*theta_pred.shape[1], theta_pred.shape[-1]))[:,np.newaxis]
    Cov = np.reshape(Cov, (Cov.shape[0]*Cov.shape[1], Cov.shape[-2], Cov.shape[-1]))[:,np.newaxis]
    
    NN_samples = xx.shape[0]*xx.shape[1]
    
    # --------------------------- compute_bias_hist_augs --------------------------- #
    
    bin_edges, bin_centers, y_hists, NN_points = cl_inference.plot_utils.compute_bias_hist_augs(
        theta_true, theta_pred, Cov, min_x=-6, max_x=6, bins=NN_bins_hist
    )
    
    bin_centers = np.insert(bin_centers, 0, bin_edges[..., 0], axis=-1)
    bin_centers = np.insert(bin_centers, bin_edges.shape[-1], bin_edges[..., -1], axis=-1)
    
    fraction_biased_list = np.zeros((y_hists.shape[0], len(threshold_bias)))
    for jj in range(len(threshold_bias)):
        mask = np.abs(bin_centers) > threshold_bias[jj]
        for kk in range(y_hists.shape[0]):
            fraction_biased_list[kk, jj] = np.sum(y_hists[kk,ii_aug][mask[kk,ii_aug]]) / NN_points[kk,ii_aug]
    fraction_biased_batches[ii] = fraction_biased_list
    
    bin_centers_batches[ii] = bin_centers[:,ii_aug]
    y_hists_batches[ii] = y_hists[:,ii_aug]
    NN_points_batches[ii] = NN_points[:,ii_aug]
    
    # --------------------------- compute_err_hist_augs --------------------------- #
    
    bin_edges_err, bin_centers_err, y_hists_err, median_err, std_err = cl_inference.plot_utils.compute_err_hist_augs(
        Cov, max_x=max_err_hist, bins=NN_bins_hist_err
    )
    
    bin_centers_err = np.insert(bin_centers_err, 0, bin_edges_err[..., 0], axis=-1)
    bin_centers_err = np.insert(bin_centers_err, bin_edges_err.shape[-1], bin_edges_err[..., -1], axis=-1)
    
    bin_centers_err_batches[ii] = bin_centers_err[:,ii_aug]
    y_hists_err_batches[ii] = y_hists_err[:,ii_aug]
    median_err_batches[ii] = median_err[:,ii_aug]
    std_err_batches[ii] = std_err[:,ii_aug]
    
np.save(save_root + "/fraction_biased_batches.npy", fraction_biased_batches)
np.save(save_root + "/NN_points_batches.npy", NN_points_batches)
np.save(save_root + "/bin_centers_batches.npy", bin_centers_batches)
np.save(save_root + "/y_hists_batches.npy", y_hists_batches)
np.save(save_root + "/bin_centers_err_batches.npy", bin_centers_err_batches)
np.save(save_root + "/y_hists_err_batches.npy", y_hists_err_batches)
np.save(save_root + "/median_err_batches.npy", median_err_batches)
np.save(save_root + "/std_err_batches.npy", std_err_batches)

In [None]:
fontsize=26
fontsize1=18
fig, axs = plt.subplots(1, len(custom_titles), figsize=(5.2*len(custom_titles), 5.2))
axs[0].set_ylabel(r'Normalized Counts ', size=fontsize)
for ii_cosmo_param in range(len(custom_titles)):
    ax = axs[ii_cosmo_param]
    ax.set_title(custom_titles[ii_cosmo_param], size=fontsize+8, pad=16)
    ax.set_xlabel(r'Bias ', size=fontsize)
    ax.axvline(0, c='k', ls=':', lw=1)
    
    for ii_batch in range(y_hists_batches.shape[0]):
        ax.plot(
            bin_centers_batches[ii_batch, ii_cosmo_param],
            y_hists_batches[ii_batch, ii_cosmo_param]/NN_points_batches[ii_batch, ii_cosmo_param],
            color='grey', lw=0.7, alpha=0.9
        )
fig.set_tight_layout(True)
fig.savefig(save_root + "/hist_bias.png")

In [None]:
fontsize=26
fontsize1=18
fig, axs = plt.subplots(1, len(custom_titles), figsize=(5.2*len(custom_titles), 5.2))
axs[0].set_ylabel(r'Fraction biased points ', size=fontsize)
for ii_cosmo_param in range(len(custom_titles)):
    for jj in range(fraction_biased_batches.shape[0]):
        ax = axs[ii_cosmo_param]
        ax.set_xlim([1.5, 5.])
        ax.set_ylim([0, 1])
        ax.set_title(custom_titles[ii_cosmo_param], size=fontsize+8, pad=16)
        ax.set_xlabel(r'$\sigma_\mathrm{thr.}$', size=fontsize)
        ax.plot(threshold_bias, fraction_biased_batches[jj, ii_cosmo_param], color='grey', lw=0.5, alpha=0.9)
        
    tmp_mean = np.nanmean(fraction_biased_batches[:, ii_cosmo_param], axis=0)
    tmp_std = np.nanstd(fraction_biased_batches[:, ii_cosmo_param], axis=0)
    ax.scatter(threshold_bias, tmp_mean, c='k', s=20)
    ax.errorbar(
        threshold_bias, tmp_mean,
        yerr=tmp_std,
        c='k', ls='', capsize=2, alpha=1., elinewidth=1.5
    )   
fig.set_tight_layout(True)
fig.savefig(save_root + "/bias_vs_threshold.png")

In [None]:
fontsize=26
fontsize1=18
fig, axs = plt.subplots(1, len(custom_titles), figsize=(5.2*len(custom_titles), 5.2))
axs[0].set_ylabel(r'Normalized Counts ', size=fontsize)
for ii_cosmo_param in range(len(custom_titles)):
    ax = axs[ii_cosmo_param]
    ax.set_title(custom_titles[ii_cosmo_param], size=fontsize+8, pad=16)
    ax.set_xlabel(r'$\frac{2\sigma}{\Delta \mathrm{Prior}}$', size=fontsize)
    for ii_batch in range(y_hists_batches.shape[0]):
        tmp_prior_range = (list_range_priors[ii_cosmo_param][1]-list_range_priors[ii_cosmo_param][0])
        ax.plot(
            bin_centers_err_batches[ii_batch, ii_cosmo_param]/tmp_prior_range,
            y_hists_err_batches[ii_batch, ii_cosmo_param]/NN_points_batches[ii_batch, ii_cosmo_param],
            color='grey', lw=0.5, alpha=0.9
        )
        ax.axvline(
            median_err_batches[ii_batch, ii_cosmo_param]/tmp_prior_range,
            color='k', lw=0.5, alpha=0.9
        )
        ax.axvspan(
            (median_err_batches[ii_batch, ii_cosmo_param]-std_err_batches[ii_batch, ii_cosmo_param])/tmp_prior_range,
            (median_err_batches[ii_batch, ii_cosmo_param]+std_err_batches[ii_batch, ii_cosmo_param])/tmp_prior_range,
            color='k', alpha=0.01
        )
        
fig.set_tight_layout(True)
fig.savefig(save_root + "/hist_error.png")