# IMPORTS

In [None]:
import CL_inference as cl_inference
N_threads = cl_inference.train_tools.set_N_threads_(N_threads=1)
    
import os, sys
import torch
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

torch.set_num_threads(N_threads)
torch.set_num_interop_threads(N_threads)

%load_ext autoreload

%matplotlib notebook
plt.style.use('default')
plt.close('all')

font, rcnew = cl_inference.plot_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%config InlineBackend.figure_format = 'retina'

device = cl_inference.train_tools.set_torch_device_()

# SETUP - config file

In [None]:
config_file_name = "conf_only_inference_also_baryons_models_all_kmax_0.6.yaml"

In [None]:
config = cl_inference.train_tools.load_config_file(
    path_to_config="../config_files",
    config_file_name=config_file_name
)

# LOAD DATSETS

In [None]:
path_save = os.path.join(config['path_save'], "manual-sweep-0")

In [None]:
path_load               = config['path_load']
list_model_names        = config['list_model_names']
normalize               = config['normalize']

NN_augs_batch           = config['NN_augs_batch']
add_noise_Pk            = config['add_noise_Pk']
kmax                    = config['kmax']
include_baryon_params   = config['include_baryon_params']

In [None]:
dsets = {}

dset_name = "TRAIN"
dsets[dset_name] = cl_inference.data_tools.def_data_loader(
    path_load               = os.path.join(path_load, dset_name),
    list_model_names        = list_model_names,
    normalize               = normalize,
    path_save_norm          = path_save,
    path_load_norm          = None,
    NN_augs_batch           = NN_augs_batch,
    add_noise_Pk            = add_noise_Pk,
    kmax                    = kmax,
    include_baryon_params   = include_baryon_params
)

dset_name = "VAL"
dsets[dset_name] = cl_inference.data_tools.def_data_loader(
    path_load               = os.path.join(path_load, dset_name),
    list_model_names        = list_model_names,
    normalize               = normalize,
    path_save_norm          = None,
    path_load_norm          = path_save,
    NN_augs_batch           = NN_augs_batch,
    add_noise_Pk            = add_noise_Pk,
    kmax                    = kmax,
    include_baryon_params   = include_baryon_params
)

In [None]:
NN_plot = 5
plot_as_Pk = False
dset_plot = dsets["TRAIN"]
_, xx_plot, _ = dset_plot(batch_size=NN_plot, seed=0)

if plot_as_Pk:
    xx_plot = 10**(xx_plot*dset_plot.norm_std + dset_plot.norm_mean)
    kmin=-2.3
    N_kk = int(((kmax-kmin)/(0.6+2.3))*100)
    kk = np.logspace(kmin, kmax, num=N_kk)
else:
    kk = np.arange(xx_plot.shape[-1])

fig, ax = mpl.pyplot.subplots(1,1,figsize=(8,6))
ax.set_ylabel(r'$P(k) \left[ \left(h^{-1} \mathrm{Mpc}\right)^{3} \right]$')
ax.set_xlabel(r'$\mathrm{Wavenumber}\, k \left[ h\, \mathrm{Mpc}^{-1} \right]$')
colors = cl_inference.plot_utils.get_N_colors(NN_plot, mpl.colormaps['prism'])
for ii_cosmo in range(xx_plot.shape[0]):
    for ii_aug in range(xx_plot.shape[1]):
        ax.plot(np.array(kk), xx_plot[ii_cosmo, ii_aug], c=colors[ii_cosmo], linestyle='-', lw=2, marker=None, ms=2)
if plot_as_Pk:
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlim([0.004, 4.5])
    ax.set_ylim([40., 70000])
    ax.axvline(10**kmax, c='k', ls=':', lw=1)
else:
    ax.set_xlim([-1., 100.])
    ax.set_ylim([-2.5, 2.5])
    ax.axvline(len(kk)-1, c='k', ls=':', lw=1)

plt.tight_layout()
plt.show()

# MODEL ARCHITECTURE

In [None]:
train_mode              = config['train_mode']
inference_loss          = config['inference_loss']

load_encoder_model_path = config['load_encoder_model_path']
input_encoder           = config['input_encoder']
hidden_layers_encoder   = config['hidden_layers_encoder']
output_encoder          = config['output_encoder']

hidden_layers_projector = config['hidden_layers_projector']
output_projector        = config['output_projector']

hidden_layers_inference = config['hidden_layers_inference']
NN_params_out           = config['NN_params_out']

load_inference_model_path = config['load_inference_model_path']
load_projector_model_path = config['load_projector_model_path']

In [None]:
# ----------------------- define model encoder ----------------------- #

assert input_encoder == dsets["TRAIN"].xx.shape[-1], "input_encoder from config file must coincide with xx size"

if train_mode == "train_inference_fully_supervised":
    model_encoder = cl_inference.nn_tools.define_MLP_model(
        hidden_layers_encoder+[output_encoder], input_encoder, bn=True, last_bias=True
    ).to(device)
else:
    model_encoder = cl_inference.nn_tools.define_MLP_model(
        hidden_layers_encoder+[output_encoder], input_encoder, bn=True
    ).to(device)
if load_encoder_model_path != 'None':
    model_encoder.load_state_dict(torch.load(load_encoder_model_path))
    model_encoder.eval();
    
# ----------------------- define model projector ----------------------- #

if len(hidden_layers_projector) != 0:
    model_projector = cl_inference.nn_tools.define_MLP_model(
        hidden_layers_projector+[output_projector], output_encoder, bn=True
    ).to(device)
    if load_projector_model_path != 'None':
        model_projector.load_state_dict(torch.load(load_projector_model_path))
        model_projector.eval();
else:
    model_projector=None

# ----------------------- define model inference ----------------------- #

if len(hidden_layers_inference) != 0:
    if inference_loss == "MSE":
        output_dim_inference = NN_params_out
    else:
        n_tril = int(NN_params_out * (NN_params_out + 1) / 2)  # Number of parameters in lower triangular matrix, for symmetric matrix
        output_dim_inference = NN_params_out + n_tril  # Dummy output of neural network

    model_inference = cl_inference.nn_tools.define_MLP_model(
        hidden_layers_inference+[output_dim_inference], output_encoder, bn=True
    ).to(device)        
    if load_inference_model_path != 'None':
        model_inference.load_state_dict(torch.load(load_inference_model_path))
        model_inference.eval();
else:
    model_inference = None

# TRAIN

In [None]:
NN_epochs            = config['NN_epochs']
NN_batches_per_epoch = config['NN_batches_per_epoch']
batch_size           = config['batch_size']
lr                   = config['lr']
weight_decay         = config['weight_decay']
clip_grad_norm       = config['clip_grad_norm']
seed_mode            = config['seed_mode']
seed                 = config['seed']

dict_loss = dict(
    CL_loss              = config['CL_loss'],
    loss_hyperparameters = config['loss_hyperparameters'],
    inference_loss       = config['inference_loss']
)
kwargs = dict(train=dict_loss,  val=dict_loss)

In [None]:
min_val_loss = cl_inference.train_tools.train_model(
    dset_train=dsets["TRAIN"],
    train_mode=train_mode,
    model_encoder=model_encoder, model_projector=model_projector, model_inference=model_inference,
    NN_epochs=NN_epochs, NN_batches_per_epoch=NN_batches_per_epoch,
    lr=lr, weight_decay=weight_decay,clip_grad_norm=clip_grad_norm,
    batch_size=batch_size,
    dset_val=dsets["VAL"], batch_size_val=int(dsets["VAL"].theta.shape[0]/6),
    seed_mode=seed_mode, # 'random', 'deterministic' or 'overfit'
    seed=seed, # only relevant if mode is 'overfit'
    path_save=path_save,
    **kwargs
)

# PLOT LOSS

In [None]:
losses = np.loadtxt(os.path.join(path_save, 'register.txt'))

custom_lines = [
    mpl.lines.Line2D([0], [0], color='k', ls='-', lw=3, marker=None, markersize=9),
    mpl.lines.Line2D([0], [0], color='k', ls='--', lw=3, marker=None, markersize=9)
]

fig, ax = cl_inference.plot_utils.simple_plot(
    custom_labels=[r'Train', r'Val'],
    custom_lines=custom_lines,
    x_label='Epoch',
    y_label='Loss'
)

ax.plot(losses[:, 0], c='k', lw=3, ls='-')
ax.plot(losses[:, 1], c='k', lw=3, ls='--')

plt.tight_layout()
plt.show()