In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup
from JPAS_DA.data import wrapper_data_loaders
from JPAS_DA.models import model_building_tools
from JPAS_DA.training import save_load_tools
from JPAS_DA.evaluation import evaluation_tools
from JPAS_DA.wrapper_wandb import wrapper_tools
from JPAS_DA.evaluation import evaluation_tools

import os
import torch
import numpy as np

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib inline

from JPAS_DA.utils import aux_tools
aux_tools.set_seed(42)

In [None]:
path_load_no_DA = "09_no_DA"
path_load_DA = "09_DA"

In [None]:
_, model_encoder_no_DA = save_load_tools.load_model_from_checkpoint(os.path.join(global_setup.path_models, path_load_no_DA, "model_encoder.pt"), model_building_tools.create_mlp)
_, model_downstream_no_DA = save_load_tools.load_model_from_checkpoint(os.path.join(global_setup.path_models, path_load_no_DA, "model_downstream.pt"), model_building_tools.create_mlp)

_, model_encoder_DA = save_load_tools.load_model_from_checkpoint(os.path.join(global_setup.path_models, path_load_DA, "model_encoder.pt"), model_building_tools.create_mlp)
_, model_downstream_DA = save_load_tools.load_model_from_checkpoint(os.path.join(global_setup.path_models, path_load_DA, "model_downstream.pt"), model_building_tools.create_mlp)

_ = evaluation_tools.compare_model_parameters(model_downstream_no_DA, model_downstream_DA, rtol=1e-2, atol=1e-2)

In [None]:
_, config_no_DA = wrapper_tools.load_and_massage_config_file(os.path.join(global_setup.path_models, path_load_no_DA, "config.yaml"), path_load_no_DA)
_, config_DA = wrapper_tools.load_and_massage_config_file(os.path.join(global_setup.path_models, path_load_DA, "config.yaml"), path_load_DA)

In [None]:
dict_data = config_DA["data"]

# Extract configuration sub-dictionaries
data_paths             = dict_data['data_paths']
clean_opts             = dict_data['dict_clean_data_options']
split_opts             = dict_data['dict_split_data_options']
feature_opts           = dict_data['features_labels_options']
provided_normalization = dict_data.get('provided_normalization', None)

logging.info("🔧 Launching wrapper_data_loaders with loaded configuration...")
dset_loaders = wrapper_data_loaders.wrapper_data_loaders(
    root_path                   = data_paths['root_path'],
    load_JPAS_data              = data_paths['load_JPAS_data'],
    load_DESI_data              = data_paths['load_DESI_data'],
    random_seed_load            = data_paths['random_seed_load'],

    apply_masks                 = clean_opts['apply_masks'],
    mask_indices                = clean_opts['mask_indices'],
    magic_numbers               = clean_opts['magic_numbers'],
    i_band_sn_threshold         = clean_opts['i_band_sn_threshold'],
    magnitude_flux_key          = clean_opts['magnitude_flux_key'],
    magnitude_threshold         = clean_opts['magnitude_threshold'],
    z_lim_QSO_cut               = clean_opts['z_lim_QSO_cut'],

    train_ratio_both            = split_opts['train_ratio_both'],
    val_ratio_both              = split_opts['val_ratio_both'],
    test_ratio_both             = split_opts['test_ratio_both'],
    random_seed_split_both      = split_opts['random_seed_split_both'],

    train_ratio_only_DESI       = split_opts['train_ratio_only_DESI'],
    val_ratio_only_DESI         = split_opts['val_ratio_only_DESI'],
    test_ratio_only_DESI        = split_opts['test_ratio_only_DESI'],
    random_seed_split_only_DESI = split_opts['random_seed_split_only_DESI'],

    define_dataset_loaders_keys = feature_opts['define_dataset_loaders_keys'],
    keys_xx                     = feature_opts['keys_xx'],
    keys_yy                     = feature_opts['keys_yy'],
    normalization_source_key    = feature_opts['normalization_source_key'],
    normalize                   = feature_opts['normalize'],

    provided_normalization      = provided_normalization
)

dset_val_no_DA = dset_loaders['DESI_only']["val"]
dset_val_DA = dset_loaders['JPAS_matched']["val"]
dset_test = dset_loaders['JPAS_matched']["test"]

In [None]:
xx, yy_true = dset_val_no_DA(batch_size=dset_val_no_DA.NN_xx, seed=0, sampling_strategy="true_random", to_torch=True, device="cpu")
with torch.no_grad():
        features_ = model_encoder_no_DA(xx)
        logits = model_downstream_no_DA(features_)
yy_pred_P = torch.nn.functional.softmax(logits, dim=1)
yy_true_val_no_DA = yy_true.cpu().numpy()
features_val_no_DA = features_.cpu().numpy()
yy_pred_P_val_no_DA = yy_pred_P.cpu().numpy()
yy_pred_val_no_DA = np.argmax(yy_pred_P_val_no_DA, axis=1)


xx, yy_true = dset_val_DA(batch_size=dset_val_DA.NN_xx, seed=0, sampling_strategy="true_random", to_torch=True, device="cpu")
with torch.no_grad():
        features_ = model_encoder_DA(xx)
        logits = model_downstream_DA(features_)
yy_pred_P = torch.nn.functional.softmax(logits, dim=1)
yy_true_val_DA = yy_true.cpu().numpy()
features_val_DA = features_.cpu().numpy()
yy_pred_P_val_DA = yy_pred_P.cpu().numpy()
yy_pred_val_DA = np.argmax(yy_pred_P_val_DA, axis=1)


xx, yy_true = dset_test(batch_size=dset_test.NN_xx, seed=0, sampling_strategy="true_random", to_torch=True, device="cpu")
with torch.no_grad():
        features_ = model_encoder_no_DA(xx)
        logits = model_downstream_no_DA(features_)
yy_pred_P = torch.nn.functional.softmax(logits, dim=1)
yy_true_test = yy_true.cpu().numpy()
features_test_no_DA = features_.cpu().numpy()
yy_pred_P_test_no_DA = yy_pred_P.cpu().numpy()
yy_pred_test_no_DA = np.argmax(yy_pred_P_test_no_DA, axis=1)


xx, yy_true = dset_test(batch_size=dset_test.NN_xx, seed=0, sampling_strategy="true_random", to_torch=True, device="cpu")
with torch.no_grad():
        features_ = model_encoder_DA(xx)
        logits = model_downstream_DA(features_)
yy_pred_P = torch.nn.functional.softmax(logits, dim=1)
yy_true_test = yy_true.cpu().numpy()
features_test_DA = features_.cpu().numpy()
yy_pred_P_test_DA = yy_pred_P.cpu().numpy()
yy_pred_test_DA = np.argmax(yy_pred_P_test_DA, axis=1)

In [None]:
print(dset_test.xx.keys())
print(dset_test.xx['OBS'].shape)
print(dset_test.xx['MORPHTYPE_int'].shape)

print(len(dset_test.means))
print(dset_test.means[0].shape)
print(dset_test.means[-1].shape)

print(xx.shape)
print(yy_true.shape)

In [None]:
n_classes = len(dset_test.class_labels)

In [None]:
evaluation_tools.plot_confusion_matrix(
    yy_true_val_no_DA, yy_pred_P_val_no_DA,
    class_names=global_setup.class_names,
    cmap=plt.cm.RdYlGn, title="val no DA"
)

evaluation_tools.plot_confusion_matrix(
    yy_true_test, yy_pred_P_test_no_DA,
    class_names=global_setup.class_names,
    cmap=plt.cm.RdYlGn, title="test no DA"
)

evaluation_tools.plot_confusion_matrix(
    yy_true_val_DA, yy_pred_P_val_DA,
    class_names=global_setup.class_names,
    cmap=plt.cm.RdYlGn, title="val DA"
)

evaluation_tools.plot_confusion_matrix(
    yy_true_test, yy_pred_P_test_DA,
    class_names=global_setup.class_names,
    cmap=plt.cm.RdYlGn, title="test DA"
)

evaluation_tools.compare_TPR_confusion_matrices(
    yy_true_test,
    yy_pred_P_test_no_DA,
    yy_true_test,
    yy_pred_P_test_DA,
    class_names=global_setup.class_names,
    figsize=(10, 7),
    cmap='seismic',
    title='TPR Comparison: DA vs no DA',
    name_1 = "no DA",
    name_2 = "DA"
)

In [None]:
metrics = evaluation_tools.compare_sets_performance(
    yy_true_test, yy_pred_P_test_no_DA,
    yy_true_test, yy_pred_P_test_DA,
    class_names=global_setup.class_names,
    name_1="no DA",
    name_2="DA"
)

In [None]:
# from sklearn.manifold import TSNE

# # === Stack all feature representations together ===
# n_val_no_DA = features_val_no_DA.shape[0]
# n_val_DA = features_val_DA.shape[0]
# n_test_no_DA = features_test_no_DA.shape[0]
# n_test_DA = features_test_DA.shape[0]

# X_all = np.vstack([
#     features_val_no_DA,
#     features_val_DA,
#     features_test_no_DA,
#     features_test_DA
# ])

# # === Perform shared t-SNE projection ===
# tsne = TSNE(n_components=2, perplexity=30, init='pca', random_state=42)
# X_all_tsne = tsne.fit_transform(X_all)

# # === Split back to original domains ===
# i0 = 0
# i1 = i0 + n_val_no_DA
# i2 = i1 + n_val_DA
# i3 = i2 + n_test_no_DA
# i4 = i3 + n_test_DA

# X_val_no_DA_tsne   = X_all_tsne[i0:i1]
# X_val_DA_tsne      = X_all_tsne[i1:i2]
# X_test_no_DA_tsne  = X_all_tsne[i2:i3]
# X_test_DA_tsne     = X_all_tsne[i3:i4]

In [None]:
# evaluation_tools.plot_tsne_comparison_single_pair(
#     X_val_no_DA_tsne, yy_true_val_no_DA,
#     X_test_no_DA_tsne, yy_true_test,
#     dset_test.class_counts,
#     class_names=global_setup.class_names,
#     title_set1="No DA - Validation",
#     title_set2="No DA - Test",
#     n_bins=128,
#     sigma=2.0,
#     scatter_size=1,
#     scatter_alpha=1.0
# )

# evaluation_tools.plot_tsne_comparison_single_pair(
#     X_val_DA_tsne, yy_true_val_DA,
#     X_test_DA_tsne, yy_true_test,
#     dset_test.class_counts,
#     class_names=global_setup.class_names,
#     title_set1="DA - Validation",
#     title_set2="DA - Test",
#     n_bins=128,
#     sigma=2.0,
#     scatter_size=1,
#     scatter_alpha=1.0
# )