In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup

from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools
from JPAS_DA.data import crossmatch_tools
from JPAS_DA.data import process_dset_splits
from JPAS_DA.data import wrapper_data_loaders

from JPAS_DA.models import model_building_tools
from JPAS_DA.training import save_load_tools
from JPAS_DA.evaluation import evaluation_tools
from JPAS_DA.wrapper_wandb import wrapper_tools
from JPAS_DA.evaluation import evaluation_tools

import os
import torch
import numpy as np

import fitsio

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib widget

from JPAS_DA.utils import aux_tools
aux_tools.set_seed(42)

In [None]:
path_load_no_DA = "09_no_DA"
path_load_DA = "09_DA"

In [None]:
_, model_encoder_no_DA = save_load_tools.load_model_from_checkpoint(os.path.join(global_setup.path_models, path_load_no_DA, "model_encoder.pt"), model_building_tools.create_mlp)
_, model_downstream_no_DA = save_load_tools.load_model_from_checkpoint(os.path.join(global_setup.path_models, path_load_no_DA, "model_downstream.pt"), model_building_tools.create_mlp)

_, model_encoder_DA = save_load_tools.load_model_from_checkpoint(os.path.join(global_setup.path_models, path_load_DA, "model_encoder.pt"), model_building_tools.create_mlp)
_, model_downstream_DA = save_load_tools.load_model_from_checkpoint(os.path.join(global_setup.path_models, path_load_DA, "model_downstream.pt"), model_building_tools.create_mlp)

_ = evaluation_tools.compare_model_parameters(model_downstream_no_DA, model_downstream_DA, rtol=1e-2, atol=1e-2)

In [None]:
_, config_no_DA = wrapper_tools.load_and_massage_config_file(os.path.join(global_setup.path_models, path_load_no_DA, "config.yaml"), path_load_no_DA)
_, config_DA = wrapper_tools.load_and_massage_config_file(os.path.join(global_setup.path_models, path_load_DA, "config.yaml"), path_load_DA)

In [None]:
config_data = config_DA["data"]

tmp_key = "data_paths"
root_path = config_data[tmp_key]["root_path"]
load_JPAS_data = config_data[tmp_key]["load_JPAS_data"]
load_DESI_data = config_data[tmp_key]["load_DESI_data"]
random_seed_load = config_data[tmp_key]["random_seed_load"]

tmp_key = "dict_clean_data_options"
apply_masks = config_data[tmp_key]["apply_masks"]
mask_indices = config_data[tmp_key]["mask_indices"]
magic_numbers = config_data[tmp_key]["magic_numbers"]
i_band_sn_threshold = config_data[tmp_key]["i_band_sn_threshold"]
magnitude_flux_key = config_data[tmp_key]["magnitude_flux_key"]
magnitude_threshold = config_data[tmp_key]["magnitude_threshold"]
z_lim_QSO_cut = config_data[tmp_key]["z_lim_QSO_cut"]

tmp_key = "dict_split_data_options"
train_ratio_both = config_data[tmp_key]["train_ratio_both"]
val_ratio_both = config_data[tmp_key]["val_ratio_both"]
test_ratio_both = config_data[tmp_key]["test_ratio_both"]
random_seed_split_both = config_data[tmp_key]["random_seed_split_both"]
train_ratio_only_DESI = config_data[tmp_key]["train_ratio_only_DESI"]
val_ratio_only_DESI = config_data[tmp_key]["val_ratio_only_DESI"]
test_ratio_only_DESI = config_data[tmp_key]["test_ratio_only_DESI"]
random_seed_split_only_DESI = config_data[tmp_key]["random_seed_split_only_DESI"]

define_dataset_loaders_keys = ['DESI_only', "JPAS_matched"]
keys_xx = config_data["features_labels_options"]["keys_xx"]
keys_yy = ["SPECTYPE_int", "TARGETID", "DESI_FLUX_R"]
normalize = True
provided_normalization = config_data["provided_normalization"]

In [None]:
# ───────────────────────────────────────────────────── #
# 1. Load raw JPAS and DESI datasets
# ───────────────────────────────────────────────────── #
logging.info("\n\n1️⃣: Loading datasets from disk...")
DATA = loading_tools.load_dsets(
    root_path=root_path,
    datasets_jpas=load_JPAS_data,
    datasets_desi=load_DESI_data,
    random_seed=random_seed_load
)

# ───────────────────────────────────────────────────── #
# 2. Apply cleaning and masking procedures
# ───────────────────────────────────────────────────── #
logging.info("\n\n2️⃣: Cleaning and masking data...")
DATA = cleaning_tools.clean_and_mask_data(
    DATA=DATA,
    apply_masks=apply_masks,
    mask_indices=mask_indices,
    magic_numbers=magic_numbers,
    i_band_sn_threshold=i_band_sn_threshold,
    magnitude_flux_key=magnitude_flux_key,
    magnitude_threshold=magnitude_threshold,
    z_lim_QSO_cut=z_lim_QSO_cut
)

# ───────────────────────────────────────────────────── #
# 3. Crossmatch JPAS and DESI using TARGETID
# ───────────────────────────────────────────────────── #
logging.info("\n\n3️⃣: Crossmatching JPAS and DESI TARGETIDs...")
Dict_LoA = {"both": {}, "only": {}}
IDs_only_DESI, IDs_only_JPAS, IDs_both, \
Dict_LoA["only"]["DESI"], Dict_LoA["only"]["JPAS"], \
Dict_LoA["both"]["DESI"], Dict_LoA["both"]["JPAS"] = crossmatch_tools.crossmatch_IDs_two_datasets(
    DATA["DESI"]['TARGETID'], DATA["JPAS"]['TARGETID']
)

# ───────────────────────────────────────────────────── #
# 4. Perform train/val/test splits
# ───────────────────────────────────────────────────── #
logging.info("\n\n4️⃣: Splitting data into train/val/test...")
Dict_LoA_split = {"both": {}, "only": {}}

Dict_LoA_split["both"]["JPAS"] = process_dset_splits.split_LoA(
    Dict_LoA["both"]["JPAS"], train_ratio_both, val_ratio_both, test_ratio_both, seed=random_seed_split_both
)
Dict_LoA_split["both"]["DESI"] = process_dset_splits.split_LoA(
    Dict_LoA["both"]["DESI"], train_ratio_both, val_ratio_both, test_ratio_both, seed=random_seed_split_both
)
Dict_LoA_split["only"]["DESI"] = process_dset_splits.split_LoA(
    Dict_LoA["only"]["DESI"], train_ratio_only_DESI, val_ratio_only_DESI, test_ratio_only_DESI, seed=random_seed_split_only_DESI
)

# ───────────────────────────────────────────────────── #
# 5. Load data
# ───────────────────────────────────────────────────── #
logging.info("\n\n5️⃣: Load and normalize data...")

xx_dict = {}
yy_dict = {}
for key_dset in ["val", "test"]:
    xx_dict[key_dset] = {}
    yy_dict[key_dset] = {}
    logging.info(f"⚙️ Preparing split: {key_dset}")
    for key_loader in define_dataset_loaders_keys:
        logging.info(f"├── {key_loader}")
        if key_loader == "DESI_combined":
            LoA, xx, yy = process_dset_splits.extract_and_combine_DESI_data(
                Dict_LoA_split["only"]["DESI"][key_dset], Dict_LoA_split["both"]["DESI"][key_dset], DATA["DESI"], keys_xx, keys_yy
            )
        elif key_loader == "DESI_only":
            LoA, xx, yy = process_dset_splits.extract_data_using_LoA(
                Dict_LoA_split["only"]["DESI"][key_dset], DATA["DESI"], keys_xx, keys_yy
            )
        elif key_loader == "DESI_matched":
            LoA, xx, yy = process_dset_splits.extract_data_using_LoA(
                Dict_LoA_split["both"]["DESI"][key_dset], DATA["DESI"], keys_xx, keys_yy
            )
        elif key_loader == "JPAS_matched":
            LoA, xx, yy = process_dset_splits.extract_data_using_LoA(
                Dict_LoA_split["both"]["JPAS"][key_dset], DATA["JPAS"], keys_xx, keys_yy
            )
        # Normalize, reshape, and stack all features in one pass
        xx_stacked = np.concatenate([
            np.atleast_2d((xx[kk] - provided_normalization[0][ii]) / provided_normalization[1][ii]).reshape(xx[kk].shape[0], -1)
            for ii, kk in enumerate(xx)
        ], axis=1)

        # Store as torch tensor
        xx_dict[key_dset][key_loader] = torch.tensor(xx_stacked, dtype=torch.float32, device="cpu")
        yy_dict[key_dset][key_loader] = yy


In [None]:
# compute the validation set results for the networks trained without domain adaptation (employing the DESI mock spectra)
key_dset = "val"
key_loader = "DESI_only"
tmp_xx = xx_dict[key_dset][key_loader]
with torch.no_grad():
    tmp_features = model_encoder_no_DA(tmp_xx)
    tmp_logits = model_downstream_no_DA(tmp_features)
tmp_yy_pred_P = torch.nn.functional.softmax(tmp_logits, dim=1).cpu().numpy()
yy_dict[key_dset][key_loader]['no_DA_features'] = tmp_features.cpu().numpy()
yy_dict[key_dset][key_loader]['no_DA_pred_Probabilities'] = tmp_yy_pred_P
yy_dict[key_dset][key_loader]['no_DA_pred_labels'] = np.argmax(tmp_yy_pred_P, axis=1)
yy_true_no_DA_val = yy_dict[key_dset][key_loader]['SPECTYPE_int']
yy_pred_P_no_DA_val = yy_dict[key_dset][key_loader]['no_DA_pred_Probabilities']
yy_pred_no_DA_val = yy_dict[key_dset][key_loader]['no_DA_pred_labels']

# compute the test set results for the networks trained without domain adaptation (employing the JPAS spectra)
key_dset = "test"
key_loader = "JPAS_matched"
tmp_xx = xx_dict[key_dset][key_loader]
with torch.no_grad():
    tmp_features = model_encoder_no_DA(tmp_xx)
    tmp_logits = model_downstream_no_DA(tmp_features)
tmp_yy_pred_P = torch.nn.functional.softmax(tmp_logits, dim=1).cpu().numpy()
yy_dict[key_dset][key_loader]['no_DA_features'] = tmp_features.cpu().numpy()
yy_dict[key_dset][key_loader]['no_DA_pred_Probabilities'] = tmp_yy_pred_P
yy_dict[key_dset][key_loader]['no_DA_pred_labels'] = np.argmax(tmp_yy_pred_P, axis=1)
yy_true_no_DA_test = yy_dict[key_dset][key_loader]['SPECTYPE_int']
yy_pred_P_no_DA_test = yy_dict[key_dset][key_loader]['no_DA_pred_Probabilities']
yy_pred_no_DA_test = yy_dict[key_dset][key_loader]['no_DA_pred_labels']

# compute the validation set results for the networks trained with domain adaptation (employing the JPAS spectra)
key_dset = "val"
key_loader = "JPAS_matched"
tmp_xx = xx_dict[key_dset][key_loader]
with torch.no_grad():
    tmp_features = model_encoder_DA(tmp_xx)
    tmp_logits = model_downstream_DA(tmp_features)
tmp_yy_pred_P = torch.nn.functional.softmax(tmp_logits, dim=1).cpu().numpy()
yy_dict[key_dset][key_loader]['DA_features'] = tmp_features.cpu().numpy()
yy_dict[key_dset][key_loader]['DA_pred_Probabilities'] = tmp_yy_pred_P
yy_dict[key_dset][key_loader]['DA_pred_labels'] = np.argmax(tmp_yy_pred_P, axis=1)
yy_true_DA_val = yy_dict[key_dset][key_loader]['SPECTYPE_int']
yy_pred_P_DA_val = yy_dict[key_dset][key_loader]['DA_pred_Probabilities']
yy_pred_DA_val = yy_dict[key_dset][key_loader]['DA_pred_labels']

# compute the test set results for the networks trained with domain adaptation (employing the JPAS spectra)
key_dset = "test"
key_loader = "JPAS_matched"
tmp_xx = xx_dict[key_dset][key_loader]
with torch.no_grad():
    tmp_features = model_encoder_DA(tmp_xx)
    tmp_logits = model_downstream_DA(tmp_features)
tmp_yy_pred_P = torch.nn.functional.softmax(tmp_logits, dim=1).cpu().numpy()
yy_dict[key_dset][key_loader]['DA_features'] = tmp_features.cpu().numpy()
yy_dict[key_dset][key_loader]['DA_pred_Probabilities'] = tmp_yy_pred_P
yy_dict[key_dset][key_loader]['DA_pred_labels'] = np.argmax(tmp_yy_pred_P, axis=1)
yy_true_DA_test = yy_dict[key_dset][key_loader]['SPECTYPE_int']
yy_pred_P_DA_test = yy_dict[key_dset][key_loader]['DA_pred_Probabilities']
yy_pred_DA_test = yy_dict[key_dset][key_loader]['DA_pred_labels']

In [None]:
# plot the confusion matrix validation results for the networks trained without domain adaptation (employing the Validation DESI mock spectra)
confusion_matrix = evaluation_tools.plot_confusion_matrix(
    yy_true_no_DA_val, yy_pred_P_no_DA_val, class_names=global_setup.class_names, cmap=plt.cm.RdYlGn, title="Validation no-DA"
)
# plot the confusion matrix test results for the networks trained without domain adaptation (employing the Test JPAS spectra)
confusion_matrix = evaluation_tools.plot_confusion_matrix(
    yy_true_no_DA_test, yy_pred_P_no_DA_test, class_names=global_setup.class_names, cmap=plt.cm.RdYlGn, title="Test no-DA"
)
# plot the confusion matrix test results for the networks trained with domain adaptation (employing the Test JPAS spectra)
confusion_matrix = evaluation_tools.plot_confusion_matrix(
    yy_true_DA_test, yy_pred_P_DA_test, class_names=global_setup.class_names, cmap=plt.cm.RdYlGn, title="Test DA"
)
# compare performance between the no-DA-validation (employing the Validation DESI mock spectra) and no-DA-test (employing the Test JPAS spectra)
evaluation_tools.compare_TPR_confusion_matrices(
    yy_true_no_DA_val, yy_pred_P_no_DA_val, yy_true_no_DA_test, yy_pred_P_no_DA_test,
    class_names=global_setup.class_names, figsize=(10, 7), cmap='seismic',
    title='Performance lost no-DA -- Validation (mocks) VS Test (JPAS)', name_1 = "Val. Mock", name_2 = "Test JPAS"
)
metrics, F1_1, F1_2 = evaluation_tools.compare_sets_performance(
    yy_true_no_DA_val, yy_pred_P_no_DA_val, yy_true_no_DA_test, yy_pred_P_no_DA_test,
    class_names=global_setup.class_names, name_1="Val. Mock", name_2="Test JPAS"
)
# compare performance between the no-DA-test (employing the Test JPAS spectra) and DA-test (employing the Test JPAS spectra)
evaluation_tools.compare_TPR_confusion_matrices(
    yy_true_no_DA_test, yy_pred_P_no_DA_test, yy_true_DA_test, yy_pred_P_DA_test,
    class_names=global_setup.class_names, figsize=(10, 7), cmap='seismic',
    title='Performance no-DA VS DA (Tests JPAS spectra)', name_1 = "No-DA", name_2 = "With DA"
)
metrics, F1_1, F1_2 = evaluation_tools.compare_sets_performance(
    yy_true_no_DA_test, yy_pred_P_no_DA_test, yy_true_DA_test, yy_pred_P_DA_test,
    class_names=global_setup.class_names, name_1="No-DA", name_2="With DA"
)

In [None]:
JPAS_Ignasi = fitsio.read("/home/dlopez/Documentos/0.profesional/Postdoc/USP/Projects/JPAS_Domain_Adaptation/DATA/jpas_idr_classification_xmatch_desi_dr1.fits.gz")

yy_true_Ignasi = np.array(JPAS_Ignasi['SPECTYPE'][JPAS_Ignasi["is_in_desi_dr1"]])
yy_true_Ignasi = list(np.array(yy_true_Ignasi).astype(np.str_))
REDSHIFT = np.array(JPAS_Ignasi['z'][JPAS_Ignasi["is_in_desi_dr1"]])
# Split between High and Low redshift quasars
z_lim_QSO_cut = 2.1
for ii in range(len(yy_true_Ignasi)):
    if yy_true_Ignasi[ii] == "QSO":
        if REDSHIFT[ii] < z_lim_QSO_cut:
            yy_true_Ignasi[ii] = "QSO_low"
        else:
            yy_true_Ignasi[ii] = "QSO_high"
yy_true_Ignasi, class_mapping = cleaning_tools.encode_strings_to_integers(yy_true_Ignasi)

classification_keys = {
    "TRANS" : ['conf_gal_TRANS', 'conf_hqso_TRANS', 'conf_lqso_TRANS', 'conf_star_TRANS'],
    "CBM"   : ['conf_gal_CBM', 'conf_hqso_CBM', 'conf_lqso_CBM', 'conf_star_CBM']
}
yy_pred_P = {}
for ii, key in enumerate(classification_keys):
    yy_pred_P[key] = []
    for jj, key_type in enumerate(classification_keys[key]):
        yy_pred_P[key].append(np.array(JPAS_Ignasi[key_type][JPAS_Ignasi["is_in_desi_dr1"]]))
    yy_pred_P[key] = np.array(yy_pred_P[key]).T

IDs_only_1, IDs_only_2, IDs_both, idxs_only_1, idxs_only_2, idxs_both_1, idxs_both_2 = crossmatch_tools.crossmatch_IDs_two_datasets(
    yy_dict["test"]["JPAS_matched"]['TARGETID'],
    np.array(JPAS_Ignasi["TARGETID"][JPAS_Ignasi["is_in_desi_dr1"]])
)
idxs_both_me = np.concatenate(idxs_both_1)
idxs_both_Ignasi = np.concatenate(idxs_both_2)

yy_true_Ignasi_crossmatch = yy_true_Ignasi[idxs_both_Ignasi]
yy_pred_P_Ignasi_crossmatch_CBM = yy_pred_P["CBM"][idxs_both_Ignasi]
yy_pred_P_Ignasi_crossmatch_TRANS = yy_pred_P["TRANS"][idxs_both_Ignasi]

In [None]:
confusion_matrix = evaluation_tools.plot_confusion_matrix(
    yy_true_Ignasi_crossmatch, yy_pred_P_Ignasi_crossmatch_CBM,
    class_names=global_setup.class_names, cmap=plt.cm.RdYlGn, title="CBM"
)

confusion_matrix = evaluation_tools.plot_confusion_matrix(
    yy_true_Ignasi_crossmatch, yy_pred_P_Ignasi_crossmatch_TRANS,
    class_names=global_setup.class_names, cmap=plt.cm.RdYlGn, title="TRANS"
)

In [None]:
# Prepare magnitudes for combinations of interest
key_pairs = [("test", "JPAS_matched"), ("val", "DESI_only")]
mag_dict = {}
for key_dset, key_loader in key_pairs:
    flux_R = yy_dict[key_dset][key_loader]['DESI_FLUX_R']
    magnitude_R = np.full_like(flux_R, np.nan)
    valid_flux = flux_R > 0
    magnitude_R[valid_flux] = 22.5 - 2.5 * np.log10(flux_R[valid_flux])
    mag_dict[(key_dset, key_loader)] = magnitude_R

# Compute global range from all sets
all_mags = np.concatenate([v[np.isfinite(v)] for v in mag_dict.values()])
min_mag, max_mag = np.nanmin(all_mags), np.nanmax(all_mags)
magnitude_ranges = [(17, 19), (19, 21), (21, 22), (22, 22.5)]
colors = ['blue', 'green', 'orange', 'red']
colormaps = [
    plt.cm.Blues,
    plt.cm.Greens,
    plt.cm.YlOrBr,
    plt.cm.Reds
]

masks_all = plotting_utils.plot_histogram_with_ranges_multiple(
    mag_dict, ranges=magnitude_ranges, colors=colors, bins=42,
    x_label="DESI Magnitude (R)",
    title="DESI R-band Magnitudes by Dataset Split and Loader"
)

# massage masks_all to a dictionary with mask like bin indices
bin_index_dict = {}
for key in masks_all.keys():
    n_samples = len(next(iter(masks_all[key].values())))  # length from first mask
    bin_indices = np.full(n_samples, -1, dtype=int)  # default: -1 means "unassigned"
    for bin_id, mag_range in enumerate(magnitude_ranges):
        mask = masks_all[key][mag_range]
        bin_indices[mask] = bin_id
    bin_index_dict[key] = bin_indices

# include this mask as a new feature in the yy_dicts
for key in bin_index_dict:
    key_dset, key_loader = key
    yy_dict[key_dset][key_loader]['MAG_BIN_ID'] = bin_index_dict[key]

In [None]:
y_min_Delta_F1 = -0.6
y_max_Delta_F1 = 0.6

# Storage for F1-scores per magnitude bin and comparison
F1_scores_per_bin = {
    "Val. Mock.": [],
    "DA": [],
    "TRANS": [],
    "CBM": []
}

for ii in range(len(magnitude_ranges)):
    mask_val_DESI = yy_dict["val"]["DESI_only"]['MAG_BIN_ID'] == ii
    mask_test_JPAS = yy_dict["test"]["JPAS_matched"]['MAG_BIN_ID'] == ii

    # plot the confusion matrix validation results for the networks trained without domain adaptation (employing the Validation DESI mock spectra)
    confusion_matrix = evaluation_tools.plot_confusion_matrix(
        yy_true_no_DA_val[mask_val_DESI], yy_pred_P_no_DA_val[mask_val_DESI],
        class_names=global_setup.class_names, cmap=colormaps[ii],
        title="Validation no-DA. Mag: (" + str(magnitude_ranges[ii][0]) + ", " + str(magnitude_ranges[ii][1]) + "). #Obj.: " + str(np.sum(mask_val_DESI))
    )

    # # plot the confusion matrix test results for the networks trained without domain adaptation (employing the Test JPAS spectra)
    # confusion_matrix = evaluation_tools.plot_confusion_matrix(
    #     yy_true_no_DA_test[mask_test_JPAS], yy_pred_P_no_DA_test[mask_test_JPAS],
    #     class_names=global_setup.class_names, cmap=colormaps[ii],
    #     title="Test no-DA. Mag: (" + str(magnitude_ranges[ii][0]) + ", " + str(magnitude_ranges[ii][1]) + "). #Obj.: " + str(np.sum(mask_test_JPAS))
    # )

    # plot the confusion matrix test results for the networks trained with domain adaptation (employing the Test JPAS spectra)
    confusion_matrix = evaluation_tools.plot_confusion_matrix(
        yy_true_DA_test[mask_test_JPAS], yy_pred_P_DA_test[mask_test_JPAS],
        class_names=global_setup.class_names, cmap=colormaps[ii],
        title="Test DA. Mag: (" + str(magnitude_ranges[ii][0]) + ", " + str(magnitude_ranges[ii][1]) + "). #Obj.: " + str(np.sum(mask_test_JPAS))
    )

    # compare performance between the no-DA-validation (employing the Validation DESI mock spectra) and no-DA-test (employing the Test JPAS spectra)
    evaluation_tools.compare_TPR_confusion_matrices(
        yy_true_no_DA_val[mask_val_DESI], yy_pred_P_no_DA_val[mask_val_DESI], yy_true_no_DA_test[mask_test_JPAS], yy_pred_P_no_DA_test[mask_test_JPAS],
        class_names=global_setup.class_names, figsize=(10, 7), cmap='seismic',
        title="Performance lost no-DA -- Validation (mocks) VS Test (JPAS). Mag: (" + str(magnitude_ranges[ii][0]) + ", " + str(magnitude_ranges[ii][1]) + ")",
        name_1 = "Val. Mock", name_2 = "Test JPAS"
    )

    metrics, F1_1, F1_2 = evaluation_tools.compare_sets_performance(
        yy_true_no_DA_val[mask_val_DESI], yy_pred_P_no_DA_val[mask_val_DESI], yy_true_no_DA_test[mask_test_JPAS], yy_pred_P_no_DA_test[mask_test_JPAS],
        class_names=global_setup.class_names, name_1="Val. Mock.", name_2="Test JPAS", plot_ROC_curves=False, y_min_Delta_F1=y_min_Delta_F1, y_max_Delta_F1=y_max_Delta_F1
    )
    F1_scores_per_bin["Val. Mock."].append(F1_1)

    # compare performance between the no-DA-test (employing the Test JPAS spectra) and DA-test (employing the Test JPAS spectra)
    evaluation_tools.compare_TPR_confusion_matrices(
        yy_true_no_DA_test[mask_test_JPAS], yy_pred_P_no_DA_test[mask_test_JPAS], yy_true_DA_test[mask_test_JPAS], yy_pred_P_DA_test[mask_test_JPAS],
        class_names=global_setup.class_names, figsize=(10, 7), cmap='seismic',
        title="Performance no-DA VS DA (Tests JPAS spectra). Mag: (" + str(magnitude_ranges[ii][0]) + ", " + str(magnitude_ranges[ii][1]) + ")",
        name_1 = "No-DA", name_2 = "With DA"
    )
    
    metrics, F1_1, F1_2 = evaluation_tools.compare_sets_performance(
        yy_true_no_DA_test[mask_test_JPAS], yy_pred_P_no_DA_test[mask_test_JPAS], yy_true_DA_test[mask_test_JPAS], yy_pred_P_DA_test[mask_test_JPAS],
        class_names=global_setup.class_names, name_1="No-DA", name_2="With DA", plot_ROC_curves=False, y_min_Delta_F1=y_min_Delta_F1, y_max_Delta_F1=y_max_Delta_F1
    )
    F1_scores_per_bin["DA"].append(F1_2)

    IDs_only_1, IDs_only_2, IDs_both, idxs_only_1, idxs_only_2, idxs_both_1, idxs_both_2 = crossmatch_tools.crossmatch_IDs_two_datasets(
        yy_dict["test"]["JPAS_matched"]['TARGETID'][mask_test_JPAS],
        np.array(JPAS_Ignasi["TARGETID"][JPAS_Ignasi["is_in_desi_dr1"]])
    )
    idxs_both_Ignasi = np.concatenate(idxs_both_2)
    yy_true_Ignasi_crossmatch = yy_true_Ignasi[idxs_both_Ignasi]
    yy_pred_P_Ignasi_crossmatch_CBM = yy_pred_P["CBM"][idxs_both_Ignasi]
    yy_pred_P_Ignasi_crossmatch_TRANS = yy_pred_P["TRANS"][idxs_both_Ignasi]

    confusion_matrix = evaluation_tools.plot_confusion_matrix(
        yy_true_Ignasi_crossmatch, yy_pred_P_Ignasi_crossmatch_CBM,
        class_names=global_setup.class_names, cmap=colormaps[ii],
        title="CBM. Mag: (" + str(magnitude_ranges[ii][0]) + ", " + str(magnitude_ranges[ii][1]) + "). #Obj.: " + str(len(yy_pred_P_Ignasi_crossmatch_CBM))
    )

    metrics, F1_1, F1_2 = evaluation_tools.compare_sets_performance(
        yy_true_no_DA_val[mask_val_DESI], yy_pred_P_no_DA_val[mask_val_DESI], yy_true_Ignasi_crossmatch, yy_pred_P_Ignasi_crossmatch_CBM,
        class_names=global_setup.class_names, name_1="Val. Mock.", name_2="CBM", plot_ROC_curves=False, y_min_Delta_F1=y_min_Delta_F1, y_max_Delta_F1=y_max_Delta_F1
    )
    F1_scores_per_bin["CBM"].append(F1_2)

    confusion_matrix = evaluation_tools.plot_confusion_matrix(
        yy_true_Ignasi_crossmatch, yy_pred_P_Ignasi_crossmatch_TRANS,
        class_names=global_setup.class_names, cmap=colormaps[ii],
        title="TRANS. Mag: (" + str(magnitude_ranges[ii][0]) + ", " + str(magnitude_ranges[ii][1]) + "). #Obj.: " + str(len(yy_pred_P_Ignasi_crossmatch_TRANS))
    )

    metrics, F1_1, F1_2 = evaluation_tools.compare_sets_performance(
        yy_true_no_DA_val[mask_val_DESI], yy_pred_P_no_DA_val[mask_val_DESI], yy_true_Ignasi_crossmatch, yy_pred_P_Ignasi_crossmatch_TRANS,
        class_names=global_setup.class_names, name_1="Val. Mock.", name_2="TRANS", plot_ROC_curves=False, y_min_Delta_F1=y_min_Delta_F1, y_max_Delta_F1=y_max_Delta_F1
    )
    F1_scores_per_bin["TRANS"].append(F1_2)


# Convert lists to arrays
F1_dict = {k: np.array(v) for k, v in F1_scores_per_bin.items()}

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

# Define structure
class_names = ["GALAXY", "QSO_low", "QSO_high", "STAR"]
num_classes = len(class_names)
magnitude_bins = ["17–19", "19–21", "21–22", "22–22.5"]
models = ["Val. Mock.", "DA", "TRANS", "CBM"]

angles = np.linspace(0, 2 * np.pi, num_classes, endpoint=False).tolist()
angles += angles[:1]

# Define styles
model_colors = {
    "Val. Mock.": "royalblue",
    "DA": "royalblue",
    "TRANS": "crimson",
    "CBM": "limegreen"
}
model_styles = {
    "Val. Mock.": "dashed",
    "DA": "solid",
    "TRANS": "dotted",
    "CBM": "dotted"
}

# Generate dummy F1 scores
np.random.seed(1)
F1_dict = {
    model: np.random.rand(4, num_classes) * 0.4 + 0.5 for model in models
}

# Plot
fig, axes = plt.subplots(2, 2, figsize=(12, 10), subplot_kw=dict(polar=True))
axes = axes.flatten()

for i, ax in enumerate(axes):
    for model in models:
        f1_vals = F1_dict[model][i].tolist()
        f1_vals += f1_vals[:1]
        ax.plot(angles, f1_vals, color=model_colors[model], linestyle=model_styles[model], linewidth=2)

        macro_f1 = np.mean(F1_dict[model][i])
        angle_pos = np.pi / 4 + (2*np.pi/num_classes)*(models.index(model))
        r_pos = 1.2
        ax.text(
            angle_pos, r_pos,
            f"{model}\nF1={macro_f1:.2f}",
            color=model_colors[model],
            fontsize=9,
            ha="center", va="center",
            bbox=dict(facecolor='white', edgecolor=model_colors[model], boxstyle='round,pad=0.4', lw=1.5, ls=model_styles[model])
        )

    ax.set_title(f"Magnitude Bin: {magnitude_bins[i]}", fontsize=14, pad=15)
    ax.set_theta_offset(np.pi / 2)
    ax.set_theta_direction(-1)
    ax.set_thetagrids(np.degrees(angles[:-1]), class_names, fontsize=10)
    ax.set_ylim(0, 1)
    ax.tick_params(labelsize=8)  # Make radial ticks (F1) smaller here

legend_lines = [Line2D([0], [0], color=model_colors[m], lw=2, linestyle=model_styles[m], label=m) for m in models]
fig.legend(handles=legend_lines, loc="center right", title="Model", fontsize=14, title_fontsize=13)

plt.suptitle("F1-score Radar Plot per Class for Each Magnitude Bin", fontsize=16)
plt.tight_layout(rect=[0, 0, 0.88, 0.95])
plt.show()