# IMPORTS

In [None]:
import os, sys
import torch
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import CL_inference as cl_inference

%load_ext autoreload

%matplotlib notebook
plt.style.use('default')
plt.close('all')

font, rcnew = cl_inference.plot_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%config InlineBackend.figure_format = 'retina'

N_threads = cl_inference.train_tools.set_N_threads_(N_threads=1)
device = cl_inference.train_tools.set_torch_device_()

# Check loss of different runs

In [None]:
save_path = "/cosmos_storage/dlopez/Projects/CL_inference/models/"

In [None]:
main_name = "only_inference_models_all_kmax_0.6_box_3000"

In [None]:
# main_name = "only_inference_models_all_kmax_-1.2_box_3000"

# main_name = "only_inference_models_v1v2_kmax_-1.2_box_3000"

# main_name = "only_CL_Wein_models_v1v2_kmax_-1.2_box_3000"
# main_name = "only_inference_CL_Wein_models_v1v2_kmax_-1.2_box_3000"

In [None]:
tmp_dataset_str = "v1v2"
tmp_dataset_str = "v1v3"
tmp_dataset_str = "v2v3"

tmp_dataset_str = "f0f1"
tmp_dataset_str = "f2f3"
tmp_dataset_str = "f4f5"
tmp_dataset_str = "f6f7"
tmp_dataset_str = "f8f9"

tmp_dataset_str = "illustris_eagle"
tmp_dataset_str = "bahamas_illustris"
tmp_dataset_str = "eagle_bahamas"

main_name = "only_inference_models_"+tmp_dataset_str+"_kmax_0.6_box_3000"
# main_name = "only_CL_Wein_models_"+tmp_dataset_str+"_kmax_0.6_box_3000"
main_name = "only_inference_CL_Wein_models_"+tmp_dataset_str+"_kmax_0.6_box_3000"

In [None]:
# evalute_mode = 'eval_CL' # "eval_CL", "eval_CL_and_inference", "eval_inference_supervised"
if "only_CL" in main_name:
    evalute_mode = 'eval_CL'
else:
    evalute_mode = 'eval_CL_and_inference'

In [None]:
select_N_best_runs = 1

In [None]:
listdir_names = os.listdir(os.path.join(save_path, main_name))
sweep_names = []
for ii, listdir_name in enumerate(listdir_names):
    if (os.path.isdir(os.path.join(save_path, main_name, listdir_name))) and ("sweep" in listdir_name):
        sweep_names.append(listdir_name)
sweep_names = np.array(sweep_names)

custom_lines = [
    mpl.lines.Line2D([0], [0], color='grey', ls='-', lw=3, marker=None, markersize=9),
    mpl.lines.Line2D([0], [0], color='grey', ls='--', lw=3, marker=None, markersize=9)
]
fig, ax = cl_inference.plot_utils.simple_plot(
    custom_labels=[r'Train', r'Val'],
    custom_lines=custom_lines,
    x_label='Epoch',
    y_label='Loss'
)
ax.set_title(main_name, fontsize=16)

custom_lines = []
colors = cl_inference.plot_utils.get_N_colors(len(sweep_names), mpl.colormaps['prism'])
min_loss = []
for ii, sweep_name in enumerate(sweep_names):
    print(sweep_name)
    path_to_register = os.path.join(save_path, main_name, sweep_name, "register.txt")
    losses = np.loadtxt(path_to_register)

#     ax.plot(losses[:, 0], c=colors[ii], lw=1, ls='-')
    ax.plot(losses[:, 1], c=colors[ii], lw=1, ls='--')
    
    custom_lines.append(mpl.lines.Line2D([0], [0], color=colors[ii], ls='-', lw=3, marker=None, markersize=8))

    min_loss.append(np.nanmin(losses[:, 1]))
min_loss = np.array(min_loss)
    
# legend = ax.legend(custom_lines, sweep_names, loc='upper left',
#                    fancybox=True, shadow=True, ncol=2,fontsize=7)
# ax.add_artist(legend)

# ------------------------ select_N_best_runs ------------------------ #

sorted_sweeps_indexes = np.argsort(min_loss)
selected_sweeps = sweep_names[sorted_sweeps_indexes][:select_N_best_runs]

custom_lines = []
print("\n SELECTED SWEEPS\n")
for ii, sweep_name in enumerate(selected_sweeps):
    print(sweep_name)
    path_to_register = os.path.join(save_path, main_name, sweep_name, "register.txt")
    losses = np.loadtxt(path_to_register)

#     ax.plot(losses[:, 0], c='k', lw=1, ls='-')
    ax.plot(losses[:, 1], c='k', lw=1, ls='-')
    
    custom_lines.append(mpl.lines.Line2D([0], [0], color='k', ls='-', lw=3, marker=None, markersize=8))

legend = ax.legend(custom_lines, selected_sweeps, loc='upper left',
                   fancybox=True, shadow=True, ncol=2,fontsize=7)
ax.add_artist(legend)
    
if "only_inference" in main_name:
    ax.set_ylim([-7.5, -3.5])
if "only_CL" in main_name:
#     ax.set_ylim([0.05, 1])
    ax.set_ylim([0.1, 20])
    ax.set_yscale('log')
    
fig.set_tight_layout(True)
fig.savefig(os.path.join(save_path, main_name, 'eval_loss.png'))

# Define config and models to load

In [None]:
configs = {}
for ii, sweep_name in enumerate(selected_sweeps):
    print(sweep_name)
    if sweep_name == "manual-sweep-0":
        path_to_config=save_path + main_name + "/"+ sweep_name
        config_file_name = "config.yaml"
        configs[sweep_name] = cl_inference.train_tools.load_config_file(
            path_to_config=path_to_config,
            config_file_name=config_file_name
        )
    else:
        path_to_config=save_path+main_name+"/"+sweep_name
        configs[sweep_name] = cl_inference.evaluation_tools.load_config_file_wandb_format(
            path_to_config=path_to_config,
            config_file_name="config.yaml"
        )
    # print(config)

In [None]:
custom_lines = [
    mpl.lines.Line2D([0], [0], color='k', ls='-', lw=3, marker=None, markersize=9),
    mpl.lines.Line2D([0], [0], color='k', ls='--', lw=3, marker=None, markersize=9)
]

fig, ax = cl_inference.plot_utils.simple_plot(
    custom_labels=[r'Train', r'Val'],
    custom_lines=custom_lines,
    x_label='Epoch',
    y_label='Loss'
)

for ii, sweep_name in enumerate(selected_sweeps):
    path_to_register = os.path.join(save_path, main_name, sweep_name, "register.txt")
    losses = np.loadtxt(path_to_register)

    ax.plot(losses[:, 0], c='k', lw=1, ls='-')
    ax.plot(losses[:, 1], c='k', lw=1, ls='--')

    fig.set_tight_layout(True)
    fig.savefig(configs[sweep_name]["path_save"] + "/eval_loss.png")

### Check compatibility of config files

In [None]:
list_assert_compatible_keys = [
"normalize",
"CL_loss",
"NN_params_out",
"NN_augs_batch",
"add_noise_Pk",
"boxsize_cosmic_variance",
"inference_loss",
"input_encoder",
"kmax",
"list_model_names",
"load_encoder_model_path",
"normalize",
"output_encoder",
"output_projector",
"path_load",
"path_save",
"seed_mode",
"train_mode"
]

for ii, key in enumerate(list_assert_compatible_keys):
    tmp_list = []
    for jj, sweep_name in enumerate(selected_sweeps):
        tmp_list.append(configs[sweep_name][key])
    assert all(x==tmp_list[0] for x in tmp_list), "key: " + key + ". Not all config files share the same value: " + str(tmp_list)

# RELOAD MODELS

In [None]:
models_encoder, models_inference = cl_inference.evaluation_tools.reload_models(
    save_path, main_name, evalute_mode, configs, device
)

# CHECK DATASET EMPLOYED FOR TRAINING

In [None]:
custom_titles=[r'$\Omega_\mathrm{c}$', r'$\Omega_\mathrm{b}$', r'$h$', r'$n_\mathrm{s}$', r'$\sigma_{8,\mathrm{c}}$']
limits_plots = [[0.23, 0.4], [0.038, 0.062], [0.60, 0.80], [0.92, 1.01], [0.73, 0.9]]

In [None]:
config = configs[list(configs.keys())[0]]

list_model_names = config["list_model_names"]
# list_model_names = ["Model_vary_1", "Model_vary_2", "Model_vary_3"]
# list_model_names = ["Model_fixed_0","Model_fixed_1","Model_fixed_2","Model_fixed_3","Model_fixed_4","Model_fixed_5","Model_fixed_6","Model_fixed_7","Model_fixed_8","Model_fixed_9"]
# list_model_names = ["Model_fixed_eagle", "Model_fixed_illustris", "Model_fixed_bahamas"]

kmax = config['kmax']

dset_name = "TEST"
loaded_theta, loaded_xx, len_models = cl_inference.data_tools.load_stored_data(
    path_load=os.path.join(config['path_load'], dset_name),
    list_model_names=list_model_names,
    return_len_models=True
)
dsets = {}
dsets[dset_name] = cl_inference.data_tools.data_loader(
    loaded_theta,
    loaded_xx,
    normalize=config['normalize'],
    path_load_norm = os.path.join(config['path_save'], sweep_name),
    NN_augs_batch = np.sum(len_models),
    add_noise_Pk=config['add_noise_Pk'],
    kmax=kmax,
    boxsize_cosmic_variance=config['boxsize_cosmic_variance'], # Mpc/h
    )

In [None]:
fig, axs = cl_inference.plot_utils.theta_distrib_plot(
    dsets=dsets,
    custom_titles=[r'$\Omega_\mathrm{c}$', r'$\Omega_\mathrm{b}$', r'$h$', r'$n_\mathrm{s}$', r'$\sigma_{8,\mathrm{c}}$']
)
fig.savefig(config["path_save"] + "/theta_distrib.png")

In [None]:
NN_plot = 2
plot_as_Pk = True
dset_key = list(dsets.keys())[0]
np.random.seed(config["seed"])
# indexes = np.random.choice(dsets[dset_key].NN_cosmos, NN_plot, replace=False)
indexes = np.array([137, 1024])

xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
    config,
    sweep_name,
    list_model_names,
    models_encoder,
    models_inference,
    device,
    dset_key=dset_key,
    indexes_cosmo=indexes,
    use_all_dataset_augs_ordered=False
)

In [None]:
if plot_as_Pk:
    xx_plot = 10**(xx*dsets[dset_key].norm_std + dsets[dset_key].norm_mean)
else:
    xx_plot = xx

fig, axs = mpl.pyplot.subplots(2,1,figsize=(9,9), gridspec_kw={'height_ratios': [1.5, 1]})
axs[0].set_ylabel(r'$P(k) \left[ \left(h^{-1} \mathrm{Mpc}\right)^{3} \right]$')
axs[1].set_ylabel(r'$P_{Model}(k) / P_{mean}(k)$')
axs[1].set_xlabel(r'$\mathrm{Wavenumber}\, k \left[ h\, \mathrm{Mpc}^{-1} \right]$')

fig1, ax1 = cl_inference.plot_utils.simple_plot(x_label=r'Latent x [adim]', y_label=r'Latent y [adim]')
fig2, axs2 = plt.subplots(1, theta_pred.shape[-1], figsize=(5.2*theta_pred.shape[-1], 5.2))

if plot_as_Pk:
    kmin=-2.3
    N_kk = int(((kmax-kmin)/(0.6+2.3))*100)
    kk = np.logspace(kmin, kmax, num=N_kk)
    axs[0].axvline(10**kmax, c='k', ls=':', lw=1.)
    axs[1].axvline(10**kmax, c='k', ls=':', lw=1.)
else:
    kk = np.arange(xx_plot.shape[-1])
    kmin=-2.3
    N_kk = int(((kmax-kmin)/(0.6+2.3))*100)-1
    axs[0].axvline(N_kk, c='k', ls=':', lw=1.)
    axs[1].axvline(N_kk, c='k', ls=':', lw=1.)

# colors = cl_inference.plot_utils.get_N_colors(NN_plot, mpl.colormaps['prism'])
# colors = ['#1F77B4', '#FF7F0E', '#2CA02C']                                                        # vary
colors = list(cl_inference.plot_utils.get_N_colors(len(list_model_names), mpl.colormaps['cool'])) # fixed
# colors = ['#D62728', '#9467BD', '#8C564B']                                                        # hydro
# colors = ['grey']                                                                                 # all
colors = ['#1F77B4', '#FF7F0E']
# colors = ['#D62728', '#9467BD']

linestyles = cl_inference.plot_utils.get_N_linestyles(NN_plot)
markers = cl_inference.plot_utils.get_N_markers(NN_plot)
ii_aug_column = 0
custom_lines = []
custom_labels = []
custom_lines1 = []
custom_labels1 = []
for ii_model_dataset, len_model in enumerate(len_models):
    custom_lines.append(mpl.lines.Line2D([0],[0],color=colors[ii_model_dataset],ls='-',lw=10,marker=None,markersize=8))
    custom_labels.append(list_model_names[ii_model_dataset])
    for ii_cosmo in range(xx_plot.shape[0]):
        tmp_slice = slice(ii_aug_column, ii_aug_column+len_model)
        axs[0].plot(
            np.array(kk), xx_plot[ii_cosmo, tmp_slice].T,
            c=colors[ii_model_dataset], linestyle=linestyles[ii_cosmo], lw=1., marker=None, ms=2, alpha=0.6
        )
        axs[1].plot(
            np.array(kk), (xx_plot[ii_cosmo, tmp_slice]/np.mean(xx_plot[ii_cosmo], axis=0)).T,
            c=colors[ii_model_dataset], linestyle=linestyles[ii_cosmo], lw=1.1, marker=None, ms=2
        )
        for ii_model_net, sweep_name in enumerate(selected_sweeps):
            ax1.scatter(
                hh[sweep_name][ii_cosmo, tmp_slice][...,0], hh[sweep_name][ii_cosmo, tmp_slice][...,1],
                c=colors[ii_model_dataset], marker=markers[ii_cosmo], s=40
            )
        for ii_cosmo_param in range(theta_true.shape[-1]):
            tmp_theta_true = np.repeat(theta_true[ii_cosmo, ii_cosmo_param], len_model)
            tmp_theta_pred = theta_pred[ii_cosmo, tmp_slice, ii_cosmo_param]
            tmp_Cov = Cov[ii_cosmo, tmp_slice, ii_cosmo_param, ii_cosmo]
            axs2[ii_cosmo_param].scatter(
                tmp_theta_true, tmp_theta_pred,
               color=colors[ii_model_dataset], marker=markers[ii_cosmo], s=40, alpha=1.
            )
            axs2[ii_cosmo_param].errorbar(
                tmp_theta_true, tmp_theta_pred,
                yerr=np.sqrt(tmp_Cov),
                c=colors[ii_model_dataset], ls='', capsize=2, alpha=1., elinewidth=1
            )            
        if ii_model_dataset == 0:
            custom_lines1.append(
                mpl.lines.Line2D([0],[0],color='grey',ls=linestyles[ii_cosmo],lw=3,marker=None,markersize=8)
            )
            custom_labels1.append("Cosmo #" + str(indexes[ii_cosmo]))
            axs2[ii_cosmo].set_title(custom_titles[ii_cosmo], size=26, pad=16)
            axs2[ii_cosmo].set_xlabel(r'True ', size=26)
            ymax = limits_plots[ii_cosmo][0]
            ymin = limits_plots[ii_cosmo][1]
            tmp_xx = np.linspace(ymin, ymax, 2)
            axs2[ii_cosmo].plot(tmp_xx, tmp_xx, c='k', lw=2, ls='-', alpha=1)
            axs2[ii_cosmo].set_xlim([ymin, ymax])
            axs2[ii_cosmo].set_ylim([ymin, ymax])
            
    ii_aug_column += len_model
    
axs2[0].set_ylabel(r'Pred ', size=26)

legend = axs[0].legend(custom_lines, custom_labels, loc='upper right', fancybox=True, shadow=True, ncol=1,fontsize=14)
axs[0].add_artist(legend)
legend = axs[0].legend(custom_lines1, custom_labels1, loc='lower left', fancybox=True, shadow=True, ncol=2,fontsize=14)
axs[0].add_artist(legend)

if plot_as_Pk:
    axs[0].set_xscale('log')
    axs[0].set_yscale('log')
    axs[0].set_xlim([0.004, 4.5])
    axs[0].set_ylim([40., 70000.])
    axs[1].set_xscale('log')
    axs[1].set_xlim([0.004, 4.5])
    axs[1].set_ylim([0.8, 1.2])
else:
    axs[0].set_xlim([0., 100.])
    axs[0].set_ylim([-2.5, 2.5])
    axs[1].set_xlim([0., 100.])
    axs[1].set_ylim([0.8, 1.2])

axs[0].set_xticklabels([])
    
fig.set_tight_layout(True)
fig1.set_tight_layout(True)
fig2.set_tight_layout(True)
    
plt.show()

fig.savefig(configs[list(configs.keys())[0]]["path_save"] + "/Pk.png")

# EVALUATE vary AUGMENTATIONS

In [None]:
list_model_names = ["Model_vary_1", "Model_vary_2", "Model_vary_3"]
colors = ['#1F77B4', '#FF7F0E', '#2CA02C']

xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
    configs[selected_sweeps[0]],
    sweep_name=sweep_name,
    list_model_names=list_model_names,
    models_encoder=models_encoder,
    models_inference=models_inference,
    device=device,
    dset_key="TEST"
)
fig, axs = cl_inference.plot_utils.plot_inference_split_models(
    list_model_names,
    len_models,
    theta_true,
    theta_pred,
    Cov,
    custom_titles=[r'$\Omega_\mathrm{c}$', r'$\Omega_\mathrm{b}$', r'$h$', r'$n_\mathrm{s}$', r'$\sigma_{8,\mathrm{c}}$'],
    limits_plots = [[0.23, 0.4], [0.038, 0.062], [0.60, 0.80], [0.92, 1.01], [0.73, 0.9]],
    colors=colors
)
fig.suptitle(main_name, size=18)
fig.set_tight_layout(True)
fig.savefig(configs[list(configs.keys())[0]]["path_save"] + "/eval_inference_vary.png")

# EVALUATE fixed AUGMENTATIONS

In [None]:
list_model_names = [
    "Model_fixed_0",
    "Model_fixed_1",
    "Model_fixed_2",
    "Model_fixed_3",
    "Model_fixed_4",
    "Model_fixed_5",
    "Model_fixed_6",
    "Model_fixed_7",
    "Model_fixed_8",
    "Model_fixed_9"
]
colors = list(cl_inference.plot_utils.get_N_colors(len(list_model_names), mpl.colormaps['cool']))

xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
    configs[selected_sweeps[0]],
    sweep_name=sweep_name,
    list_model_names=list_model_names,
    models_encoder=models_encoder,
    models_inference=models_inference,
    device=device,
    dset_key="TEST"
)
fig, axs = cl_inference.plot_utils.plot_inference_split_models(
    list_model_names,
    len_models,
    theta_true,
    theta_pred,
    Cov,
    custom_titles=[r'$\Omega_\mathrm{c}$', r'$\Omega_\mathrm{b}$', r'$h$', r'$n_\mathrm{s}$', r'$\sigma_{8,\mathrm{c}}$'],
    limits_plots = [[0.23, 0.4], [0.038, 0.062], [0.60, 0.80], [0.92, 1.01], [0.73, 0.9]],
    colors=colors
)
fig.suptitle(main_name, size=18)
fig.set_tight_layout(True)
fig.savefig(configs[list(configs.keys())[0]]["path_save"] + "/eval_inference_fixed.png")

# EVALUATE Hydro AUGMENTATIONS

In [None]:
list_model_names = ["Model_fixed_eagle", "Model_fixed_illustris", "Model_fixed_bahamas"]
colors = ['#D62728', '#9467BD', '#8C564B']

xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
    configs[selected_sweeps[0]],
    sweep_name=sweep_name,
    list_model_names=list_model_names,
    models_encoder=models_encoder,
    models_inference=models_inference,
    device=device,
    dset_key="TEST"
)
fig, axs = cl_inference.plot_utils.plot_inference_split_models(
    list_model_names,
    len_models,
    theta_true,
    theta_pred,
    Cov,
    custom_titles=[r'$\Omega_\mathrm{c}$', r'$\Omega_\mathrm{b}$', r'$h$', r'$n_\mathrm{s}$', r'$\sigma_{8,\mathrm{c}}$'],
    limits_plots = [[0.23, 0.4], [0.038, 0.062], [0.60, 0.80], [0.92, 1.01], [0.73, 0.9]],
    colors=colors
)
fig.suptitle(main_name, size=18)
fig.set_tight_layout(True)
fig.savefig(configs[list(configs.keys())[0]]["path_save"] + "/eval_inference_hydros.png")

# EVALUATE ALL AUGMENTATIONS

In [None]:
list_model_names = ["Model_vary_all"]
colors = ['grey']

xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
    configs[selected_sweeps[0]],
    sweep_name=sweep_name,
    list_model_names=list_model_names,
    models_encoder=models_encoder,
    models_inference=models_inference,
    device=device,
    dset_key="TEST"
)
fig, axs = cl_inference.plot_utils.plot_inference_split_models(
    list_model_names,
    len_models,
    theta_true,
    theta_pred,
    Cov,
    custom_titles=[r'$\Omega_\mathrm{c}$', r'$\Omega_\mathrm{b}$', r'$h$', r'$n_\mathrm{s}$', r'$\sigma_{8,\mathrm{c}}$'],
    limits_plots = [[0.23, 0.4], [0.038, 0.062], [0.60, 0.80], [0.92, 1.01], [0.73, 0.9]],
    colors=colors
)
fig.suptitle(main_name, size=18)
fig.set_tight_layout(True)
fig.savefig(configs[list(configs.keys())[0]]["path_save"] + "/eval_inference_all.png")

# generate bias figure

In [None]:
NN_samples = []
for ii, model_name in enumerate(config["list_model_names"]):
    xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
        configs[selected_sweeps[0]],
        sweep_name=sweep_name,
        list_model_names=[model_name],
        models_encoder=models_encoder,
        models_inference=models_inference,
        device=device,
        dset_key="TEST"
    )
    theta_true = np.repeat(theta_true, theta_pred.shape[1], axis=0)
    theta_pred = np.reshape(theta_pred, (theta_pred.shape[0]*theta_pred.shape[1], theta_pred.shape[-1]))[:,np.newaxis]
    Cov = np.reshape(Cov, (Cov.shape[0]*Cov.shape[1], Cov.shape[-2], Cov.shape[-1]))[:,np.newaxis]
    tmp_bin_edges_self , tmp_bin_centers_self, tmp_y_hists_self, tmp_NN_points_self = cl_inference.plot_utils.compute_bias_hist_augs(
        theta_true, theta_pred, Cov, min_x=-6, max_x=6, bins=60
    )
    if ii == 0:
        bin_edges_self = tmp_bin_edges_self
        bin_centers_self = tmp_bin_centers_self
        y_hists_self = tmp_y_hists_self
        NN_points_self = tmp_NN_points_self
    else:
        bin_edges_self = np.concatenate((bin_edges_self, tmp_bin_edges_self), axis=1)
        bin_centers_self = np.concatenate((bin_centers_self, tmp_bin_centers_self), axis=1)
        y_hists_self = np.concatenate((y_hists_self, tmp_y_hists_self), axis=1)
        NN_points_self = np.concatenate((NN_points_self, tmp_NN_points_self), axis=1)
    NN_samples.append(xx.shape[0]*xx.shape[1])
    
list_model_names = ["Model_vary_all"]
xx, hh, theta_true, theta_pred, Cov, len_models = cl_inference.evaluation_tools.compute_dataset_results(
    configs[selected_sweeps[0]],
    sweep_name=sweep_name,
    list_model_names=list_model_names,
    models_encoder=models_encoder,
    models_inference=models_inference,
    device=device,
    dset_key="TEST"
)
theta_true = np.repeat(theta_true, theta_pred.shape[1], axis=0)
theta_pred = np.reshape(theta_pred, (theta_pred.shape[0]*theta_pred.shape[1], theta_pred.shape[-1]))[:,np.newaxis]
Cov = np.reshape(Cov, (Cov.shape[0]*Cov.shape[1], Cov.shape[-2], Cov.shape[-1]))[:,np.newaxis]
bin_edges_all, bin_centers_all, y_hists_all, NN_points_all = cl_inference.plot_utils.compute_bias_hist_augs(
    theta_true, theta_pred, Cov, min_x=-6, max_x=6, bins=60
)
NN_samples.append(xx.shape[0]*xx.shape[1])
NN_samples=np.array(NN_samples)

In [None]:
bin_centers_self = np.insert(bin_centers_self, 0, bin_edges_self[..., 0], axis=-1)
bin_centers_self = np.insert(bin_centers_self, bin_centers_self.shape[-1], bin_edges_self[..., -1], axis=-1)

bin_centers_all = np.insert(bin_centers_all, 0, bin_edges_all[..., 0], axis=-1)
bin_centers_all = np.insert(bin_centers_all, bin_centers_all.shape[-1], bin_edges_all[..., -1], axis=-1)

In [None]:
y_final = np.concatenate((y_hists_self, y_hists_all), axis=1)
bin_centers_final = np.concatenate((bin_centers_self, bin_centers_all), axis=1)
NN_points_final = np.concatenate((NN_points_self, NN_points_all), axis=1)

In [None]:
fontsize=26
fontsize1=18
# colors = cl_inference.plot_utils.get_N_colors(y_final.shape[1], mpl.colormaps['prism'])
colors = ['#1F77B4', '#FF7F0E', 'grey']
# colors = ['#D62728', '#9467BD', 'grey']
custom_titles=[r'$\Omega_\mathrm{c}$', r'$\Omega_\mathrm{b}$', r'$h$', r'$n_\mathrm{s}$', r'$\sigma_{8,\mathrm{c}}$']

fig, axs = plt.subplots(
    1, len(custom_titles), figsize=(5.2*len(custom_titles), 5.2)
)
axs[0].set_ylabel(r'Normalized Counts ', size=fontsize)

for ii_cosmo_param in range(len(custom_titles)):
    ax = axs[ii_cosmo_param]
    ax.set_title(custom_titles[ii_cosmo_param], size=fontsize+8, pad=16)
    ax.set_xlabel(r'Bias ', size=fontsize)
    ax.axvline(0, c='k', ls=':', lw=1)
#     ax.axvspan(np.min(bin_centers_final[ii_cosmo_param]), -1, alpha=0.2, facecolor ='red')
#     ax.axvspan(1, np.max(bin_centers_final[ii_cosmo_param]), alpha=0.2, facecolor ='red')
    for ii_aug in range(y_final.shape[1]):
        if ii_aug+1 == y_final.shape[1]:
            color = 'k'
        else:
            color = colors[ii_aug]
        ax.plot(
            bin_centers_final[ii_cosmo_param, ii_aug], y_final[ii_cosmo_param, ii_aug]/NN_points_final[ii_cosmo_param, ii_aug],
            color=color, lw=3, alpha=0.9
        )
        
#         tmp_hist = np.random.normal(loc=0,scale=1,size=NN_samples[ii_aug])
#         counts, bin_edges = np.histogram(tmp_hist, bins=50, range=(-6, 6))
#         bin_centers = (bin_edges[1:] + bin_edges[:-1])/2
#         y_hist = np.array(counts)/NN_samples[ii_aug]
#         ax.plot(
#             bin_centers, y_hist, color=color, lw=3, alpha=0.9, ls='--'
#         )
        
fig.set_tight_layout(True)
fig.savefig(configs[list(configs.keys())[0]]["path_save"] + "/bias.png")

In [None]:
mask =np.abs(bin_centers_final) > 2
fraction_biased = np.zeros((y_final.shape[0], y_final.shape[1]))
for ii in range(y_final.shape[0]):
    for jj in range(y_final.shape[1]):
        fraction_biased[ii,jj] = np.sum(y_final[ii, jj][mask[ii, jj]])/NN_points_final[ii,jj]

In [None]:
np.save(configs[list(configs.keys())[0]]["path_save"] + "/y_final.npy", y_final)
np.save(configs[list(configs.keys())[0]]["path_save"] + "/bin_centers_final.npy", bin_centers_final)
np.save(configs[list(configs.keys())[0]]["path_save"] + "/fraction_biased.npy", fraction_biased)
np.save(configs[list(configs.keys())[0]]["path_save"] + "/NN_points_final.npy", NN_points_final)