In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup
from JPAS_DA.wrapper_wandb import wrapper_tools
from JPAS_DA.evaluation import evaluation_tools

import os
import torch
import numpy as np

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib inline

from JPAS_DA.utils import aux_tools
aux_tools.set_seed(42)

In [None]:
config_path = os.path.join(global_setup.path_configs, "config_DA.yaml")
run_name = "09_DA"

In [None]:
config_0, config, dset_train, dset_val, config_encoder, model_encoder, config_downstream, model_downstream, min_val_loss = wrapper_tools.wrapper_train_from_config(
    config_path, run_name
)

In [None]:
plotting_utils.plot_training_curves(config["training"]["path_save"])

In [None]:
device = next(model_encoder.parameters()).device

class_names = global_setup.class_names
# class_names = ['QSO_high', 'no_QSO_high']

In [None]:
# === Train ===
xx_train, yy_true_train = dset_train(batch_size=dset_train.NN_xx, seed=0, sampling_strategy="true_random", to_torch=True, device=device)

with torch.no_grad():
        features_train = model_encoder(xx_train)
        logits = model_downstream(features_train)
yy_pred_P_train = torch.nn.functional.softmax(logits, dim=1)
yy_pred_P_train = yy_pred_P_train.cpu().numpy()
yy_pred_train = np.argmax(yy_pred_P_train, axis=1)
yy_true_train = yy_true_train.cpu().numpy()
xx_train = xx_train.cpu().numpy()
features_train = features_train.cpu().numpy()

evaluation_tools.plot_confusion_matrix(
    yy_true_train, yy_pred_P_train,
    class_names=class_names,
    cmap=plt.cm.RdYlGn, title="train"
)


# === Val ===
xx_val, yy_true_val = dset_val(batch_size=dset_val.NN_xx, seed=0, sampling_strategy="true_random", to_torch=True, device=device)

with torch.no_grad():
        features_val = model_encoder(xx_val)
        logits = model_downstream(features_val)
yy_pred_P_val = torch.nn.functional.softmax(logits, dim=1)
yy_pred_P_val = yy_pred_P_val.cpu().numpy()
yy_pred_val = np.argmax(yy_pred_P_val, axis=1)
yy_true_val = yy_true_val.cpu().numpy()
xx_val = xx_val.cpu().numpy()
features_val = features_val.cpu().numpy()

evaluation_tools.plot_confusion_matrix(
    yy_true_val, yy_pred_P_val,
    class_names=class_names,
    cmap=plt.cm.RdYlGn, title="val"
)


# === Compare Train VS Val ===
evaluation_tools.compare_TPR_confusion_matrices(
    yy_true_train,
    yy_pred_P_train,
    yy_true_val,
    yy_pred_P_val,
    class_names=class_names,
    figsize=(10, 7),
    cmap='seismic',
    title='TPR Comparison: Val vs Train'
)

metrics = evaluation_tools.compare_sets_performance(
    yy_true_train, yy_pred_P_train,
    yy_true_val, yy_pred_P_val,
    class_names=class_names,
    name_1="train",
    name_2="val"
)