In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup
from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools
from JPAS_DA.data import crossmatch_tools
from JPAS_DA.data import process_dset_splits
from JPAS_DA.data import data_loaders

from JPAS_DA.models import model_building_tools
from JPAS_DA.training import save_load_tools
from JPAS_DA.evaluation import evaluation_tools
from JPAS_DA.wrapper_wandb import wrapper_tools
from JPAS_DA.evaluation import wandb_evaluation_tools

import os
import torch
import numpy as np
from sklearn.manifold import TSNE

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib widget

from JPAS_DA.utils import aux_tools
aux_tools.set_seed(42)

In [None]:
sweeps_direct_path_name = "wandb_supervised"
sweeps_no_DA_path_name = "wandb_no_DA"
sweeps_DA_path_name = "wandb_DA"

In [None]:
path_wandb_sweep_direct = os.path.join(global_setup.path_models, sweeps_direct_path_name)
N_selected_sweeps_direct = 1
sorted_list_sweep_names_direct, sorted_losses_direct = wandb_evaluation_tools.load_and_plot_sorted_sweeps(
    path_wandb_sweep_direct, max_runs_to_plot=N_selected_sweeps_direct
)
path_wandb_sweep_no_DA = os.path.join(global_setup.path_models, sweeps_no_DA_path_name)
N_selected_sweeps_no_DA = 1
sorted_list_sweep_names_no_DA, sorted_losses_no_DA = wandb_evaluation_tools.load_and_plot_sorted_sweeps(
    path_wandb_sweep_no_DA, max_runs_to_plot=N_selected_sweeps_no_DA
)
path_wandb_sweep_DA = os.path.join(global_setup.path_models, sweeps_DA_path_name)
N_selected_sweeps_DA = 1
sorted_list_sweep_names_DA, sorted_losses_DA = wandb_evaluation_tools.load_and_plot_sorted_sweeps(
    path_wandb_sweep_DA, max_runs_to_plot=N_selected_sweeps_DA
)

# Load the config files

In [None]:
paths_load_direct = [os.path.join(global_setup.path_models, sweeps_direct_path_name, sorted_list_sweep_names_direct[0])]
logging.info("🔍 Validating model configs...")
configs_direct = []
for path in paths_load_direct:
    _, config = wrapper_tools.load_and_massage_config_file(
        os.path.join(path, "config.yaml"), path
    )
    configs_direct.append(config)
config_ref_no_DA = configs_direct[0]
for i, cfg in enumerate(configs_direct[1:], 1):
    logging.debug(f"🔍 Comparing config 0 and config {i}")
    if not evaluation_tools.safe_compare(cfg['data'], config_ref_no_DA['data']):
        raise ValueError(f"🚫 Data config mismatch between model 0 and model {i}")
    

paths_load_no_DA = [os.path.join(global_setup.path_models, sweeps_no_DA_path_name, sorted_list_sweep_names_no_DA[0])]
logging.info("🔍 Validating model configs...")
configs_no_DA = []
for path in paths_load_no_DA:
    _, config = wrapper_tools.load_and_massage_config_file(
        os.path.join(path, "config.yaml"), path
    )
    configs_no_DA.append(config)
config_ref_no_DA = configs_no_DA[0]
for i, cfg in enumerate(configs_no_DA[1:], 1):
    logging.debug(f"🔍 Comparing config 0 and config {i}")
    if not evaluation_tools.safe_compare(cfg['data'], config_ref_no_DA['data']):
        raise ValueError(f"🚫 Data config mismatch between model 0 and model {i}")


paths_load_DA = [os.path.join(global_setup.path_models, sweeps_DA_path_name, sorted_list_sweep_names_DA[0])]
logging.info("🔍 Validating model configs...")
configs_DA = []
for path in paths_load_DA:
    _, config = wrapper_tools.load_and_massage_config_file(
        os.path.join(path, "config.yaml"), path
    )
    configs_DA.append(config)
config_ref_DA = configs_DA[0]
for i, cfg in enumerate(configs_no_DA[1:], 1):
    logging.debug(f"🔍 Comparing config 0 and config {i}")
    if not evaluation_tools.safe_compare(cfg['data'], config_ref_DA['data']):
        raise ValueError(f"🚫 Data config mismatch between model 0 and model {i}")


config_data = config_ref_DA["data"]

# Load the data

In [None]:
root_path = global_setup.DATA_path
load_JPAS_x_DESI_Raul   = global_setup.load_JPAS_x_DESI_Raul
load_DESI_mocks_Raul    = global_setup.load_DESI_mocks_Raul
load_Ignasi             = global_setup.load_Ignasi

random_seed_load = global_setup.default_seed

list_of_datasets_to_load = ["JPAS_x_DESI_Raul", "DESI_mocks_Raul", "Ignasi"]

config_dict_cleaning = config_data['cleaning_config']

dict_split_data_options = global_setup.dict_split_data_options

keys_xx = config_data['keys_xx']
keys_yy = ["SPECTYPE_int", "TARGETID", "DESI_FLUX_R"]

device = 'cpu'

In [None]:
DATA = loading_tools.load_data_bundle(
    root_path=root_path,
    include=list_of_datasets_to_load,
    JPAS_x_DESI_Raul={"datasets": load_JPAS_x_DESI_Raul},
    DESI_mocks_Raul={"datasets": load_DESI_mocks_Raul},
    Ignasi={"datasets": load_Ignasi},
    random_seed=random_seed_load,
)
DATA = cleaning_tools.clean_data_pipeline(DATA, config=config_dict_cleaning, in_place=True)

Dict_LoA = {"intersection": {}, "outersection": {}}

IDs1, IDs2, IDs12, \
Dict_LoA["outersection"]["DESI_mocks_Raul"], Dict_LoA["outersection"]["JPAS_x_DESI_Raul"], \
Dict_LoA["intersection"]["DESI_mocks_Raul"], Dict_LoA["intersection"]["JPAS_x_DESI_Raul"] = crossmatch_tools.crossmatch_IDs_two_datasets(
    DATA["DESI_mocks_Raul"]['all_pd']['TARGETID'], DATA["JPAS_x_DESI_Raul"]['all_pd']['TARGETID']
)

# Split the Lists of Arrays into training, validation, and testing sets
Dict_LoA_split = {"intersection":{}, "outersection":{}}

Dict_LoA_split["intersection"]["JPAS_x_DESI_Raul"] = process_dset_splits.split_LoA(
    Dict_LoA["intersection"]["JPAS_x_DESI_Raul"],
    train_ratio = dict_split_data_options["train_ratio_intersection"],
    val_ratio = dict_split_data_options["val_ratio_intersection"],
    test_ratio = dict_split_data_options["test_ratio_intersection"],
    seed = dict_split_data_options["random_seed_split_intersection"]
)
Dict_LoA_split["outersection"]["DESI_mocks_Raul"] = process_dset_splits.split_LoA(
    Dict_LoA["outersection"]["DESI_mocks_Raul"],
    train_ratio = dict_split_data_options["train_ratio_outersection"],
    val_ratio = dict_split_data_options["val_ratio_outersection"],
    test_ratio = dict_split_data_options["test_ratio_outersection"],
    seed = dict_split_data_options["random_seed_split_outersection"]
)

In [None]:
extract_dsets = [
    ("DESI_mocks_Raul", "outersection"),
    ("JPAS_x_DESI_Raul", "intersection")
]
xx = {}
yy = {}
for key_dset, key_xmatch in extract_dsets:
    xx[key_dset] = {}
    yy[key_dset] = {}
    for split in global_setup.splits:
        LoA_ = Dict_LoA_split[key_xmatch][key_dset].get(split, [])
        _, xx_, yy_ = process_dset_splits.extract_from_block_by_LoA(
            block=DATA[key_dset], LoA=LoA_, keys_xx=keys_xx, keys_yy=keys_yy
        )
        xx_batch = data_loaders.stack_features_from_dict_flattened(xx_, np.arange(len(np.concatenate(LoA_))))
        xx[key_dset][split] = torch.tensor(xx_batch, dtype=torch.float32, device=device)
        yy[key_dset][split] = yy_


key_dset = "Ignasi"
split = "all"
xx[key_dset] = {}
yy[key_dset] = {}
LoA_, xx_, yy_ = process_dset_splits.extract_from_block_by_LoA(
    block=DATA[key_dset],
    LoA=np.arange(DATA["Ignasi"]['all_observations'].shape[0])[:, None].tolist(),
    keys_xx=keys_xx,
    keys_yy=keys_yy
)
xx_batch = data_loaders.stack_features_from_dict_flattened(xx_, np.arange(len(np.concatenate(LoA_))))
xx[key_dset][split] = torch.tensor(xx_batch, dtype=torch.float32, device=device)
yy[key_dset][split] = yy_

# Impose magnitude limit for test results

In [None]:
Magnitude_limit = 22.5
flux_key = 'DESI_FLUX_R'

In [None]:
for ii, key_survey in enumerate(yy.keys()):
    for jj, key_dset in enumerate(yy[key_survey].keys()):
        tmp_mag = -2.5 * np.log10(yy[key_survey][key_dset][flux_key]) + 22.5
        tmp_mag_mask = tmp_mag <= Magnitude_limit
        xx[key_survey][key_dset] = xx[key_survey][key_dset][tmp_mag_mask]
        for key_loaded in yy[key_survey][key_dset].keys():
            yy[key_survey][key_dset][key_loaded] = yy[key_survey][key_dset][key_loaded][tmp_mag_mask]

# Load the models

In [None]:
path_load_encoder_direct = os.path.join(global_setup.path_models, paths_load_direct[0], "model_encoder.pt")
assert os.path.isfile(path_load_encoder_direct), f"❌ Encoder checkpoint not found: {path_load_encoder_direct}"
logging.info(f"📥 Loading encoder from checkpoint: {path_load_encoder_direct}")
_, model_encoder_direct = save_load_tools.load_model_from_checkpoint(
    path_load_encoder_direct, model_building_tools.create_mlp
)
model_encoder_direct.eval()
model_encoder_direct.to(device)

path_load_downstream_direct = os.path.join(global_setup.path_models, paths_load_direct[0], "model_downstream.pt")
assert os.path.isfile(path_load_downstream_direct), f"❌ Downstream checkpoint not found: {path_load_downstream_direct}"
logging.info(f"📥 Loading downstream model from checkpoint: {path_load_downstream_direct}")
_, model_downstream_direct = save_load_tools.load_model_from_checkpoint(
    path_load_downstream_direct, model_building_tools.create_mlp
)
model_downstream_direct.eval()
model_downstream_direct.to(device)




path_load_encoder_no_DA = os.path.join(global_setup.path_models, paths_load_no_DA[0], "model_encoder.pt")
assert os.path.isfile(path_load_encoder_no_DA), f"❌ Encoder checkpoint not found: {path_load_encoder_no_DA}"
logging.info(f"📥 Loading encoder from checkpoint: {path_load_encoder_no_DA}")
_, model_encoder_no_DA = save_load_tools.load_model_from_checkpoint(
    path_load_encoder_no_DA, model_building_tools.create_mlp
)
model_encoder_no_DA.eval()
model_encoder_no_DA.to(device)

path_load_downstream_no_DA = os.path.join(global_setup.path_models, paths_load_no_DA[0], "model_downstream.pt")
assert os.path.isfile(path_load_downstream_no_DA), f"❌ Downstream checkpoint not found: {path_load_downstream_no_DA}"
logging.info(f"📥 Loading downstream model from checkpoint: {path_load_downstream_no_DA}")
_, model_downstream_no_DA = save_load_tools.load_model_from_checkpoint(
    path_load_downstream_no_DA, model_building_tools.create_mlp
)
model_downstream_no_DA.eval()
model_downstream_no_DA.to(device)




path_load_encoder_DA = os.path.join(global_setup.path_models, paths_load_DA[0], "model_encoder.pt")
assert os.path.isfile(path_load_encoder_DA), f"❌ Encoder checkpoint not found: {path_load_encoder_DA}"
logging.info(f"📥 Loading encoder from checkpoint: {path_load_encoder_DA}")
_, model_encoder_DA = save_load_tools.load_model_from_checkpoint(
    path_load_encoder_DA, model_building_tools.create_mlp, use_batchnorm = False
)
model_encoder_DA.eval()
model_encoder_DA.to(device)

path_load_downstream_DA = os.path.join(global_setup.path_models, paths_load_DA[0], "model_downstream.pt")
assert os.path.isfile(path_load_downstream_DA), f"❌ Downstream checkpoint not found: {path_load_downstream_DA}"
logging.info(f"📥 Loading downstream model from checkpoint: {path_load_downstream_DA}")
_, model_downstream_DA = save_load_tools.load_model_from_checkpoint(
    path_load_downstream_DA, model_building_tools.create_mlp
)
model_downstream_DA.eval()
model_downstream_DA.to(device)

# Compute the results

In [None]:
extract_dsets = [
    ("DESI_mocks_Raul", "outersection"),
    ("JPAS_x_DESI_Raul", "intersection")
]
features = {"Supervised" : {}, "no-DA" : {}, "DA" : {}}
probs = {"Supervised" : {}, "no-DA" : {}, "DA" : {}}
labels = {"Supervised" : {}, "no-DA" : {}, "DA" : {}}
for key_dset, key_xmatch in extract_dsets:

    features["Supervised"][key_dset] = {}
    probs["Supervised"][key_dset] = {}
    labels["Supervised"][key_dset] = {}

    features["no-DA"][key_dset] = {}
    probs["no-DA"][key_dset] = {}
    labels["no-DA"][key_dset] = {}

    features["DA"][key_dset] = {}
    probs["DA"][key_dset] = {}
    labels["DA"][key_dset] = {}

    for split in global_setup.splits:

        xx_input = xx[key_dset][split]
        with torch.no_grad():
            features_ = model_encoder_direct(xx_input)
            logits_ = model_downstream_direct(features_)
            probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
        features["Supervised"][key_dset][split] = features_.cpu().numpy()
        probs["Supervised"][key_dset][split] = probs_
        labels["Supervised"][key_dset][split] = np.argmax(probs_, axis=1)

        xx_input = xx[key_dset][split]
        with torch.no_grad():
            features_ = model_encoder_no_DA(xx_input)
            logits_ = model_downstream_no_DA(features_)
            probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
        features["no-DA"][key_dset][split] = features_.cpu().numpy()
        probs["no-DA"][key_dset][split] = probs_
        labels["no-DA"][key_dset][split] = np.argmax(probs_, axis=1)

        xx_input = xx[key_dset][split]
        with torch.no_grad():
            features_ = model_encoder_DA(xx_input)
            logits_ = model_downstream_DA(features_)
            probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
        features["DA"][key_dset][split] = features_.cpu().numpy()
        probs["DA"][key_dset][split] = probs_
        labels["DA"][key_dset][split] = np.argmax(probs_, axis=1)



key_dset = "Ignasi"
split = "all"

features["Supervised"][key_dset] = {}
probs["Supervised"][key_dset] = {}
labels["Supervised"][key_dset] = {}

features["no-DA"][key_dset] = {}
probs["no-DA"][key_dset] = {}
labels["no-DA"][key_dset] = {}

features["DA"][key_dset] = {}
probs["DA"][key_dset] = {}
labels["DA"][key_dset] = {}


xx_input = xx[key_dset][split]
with torch.no_grad():
    features_ = model_encoder_direct(xx_input)
    logits_ = model_downstream_direct(features_)
    probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
features["Supervised"][key_dset][split] = features_.cpu().numpy()
probs["Supervised"][key_dset][split] = probs_
labels["Supervised"][key_dset][split] = np.argmax(probs_, axis=1)

xx_input = xx[key_dset][split]
with torch.no_grad():
    features_ = model_encoder_no_DA(xx_input)
    logits_ = model_downstream_no_DA(features_)
    probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
features["no-DA"][key_dset][split] = features_.cpu().numpy()
probs["no-DA"][key_dset][split] = probs_
labels["no-DA"][key_dset][split] = np.argmax(probs_, axis=1)

xx_input = xx[key_dset][split]
with torch.no_grad():
    features_ = model_encoder_DA(xx_input)
    logits_ = model_downstream_DA(features_)
    probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
features["DA"][key_dset][split] = features_.cpu().numpy()
probs["DA"][key_dset][split] = probs_
labels["DA"][key_dset][split] = np.argmax(probs_, axis=1)

# metrics

In [None]:
class_names = list(global_setup.config_dict_cleaning["encoding"]["shared_mappings"]["SPECTYPE"].keys())

# 1) Confusion matrices (grid)

In [None]:
dict_cm = {
    "JPAS Obs. (Supervised)": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["Supervised"]["JPAS_x_DESI_Raul"]["test"],
    },
    "Mocks (no-DA)": {
        "y_true": yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["DESI_mocks_Raul"]["test"],
    },
    "JPAS Obs. (no-DA)": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    },
    "JPAS Obs. (DA)": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["DA"]["JPAS_x_DESI_Raul"]["test"],
    },
}

fig, axes = evaluation_tools.plot_confusion_matrices_grid(
    dict_cases=dict_cm,
    class_names=class_names,
    cmap="RdYlGn",
    save_path=os.path.join(global_setup.path_saved_figures, "confusion_matrices.pdf"),
)
plt.show()

# 2) Global metrics (single metric per subplot)

In [None]:
dict_cmp = {
    "JPAS Obs. Supervised": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["Supervised"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {"color": "grey", "label": "JPAS Obs. Supervised"},
    },
    "Mocks no-DA": {
        "y_true": yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["DESI_mocks_Raul"]["test"],
        "plot_kwargs": {"color": "royalblue", "label": "Mocks no-DA"},
    },
    "JPAS Obs. no-DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {"color": "firebrick", "label": "JPAS Obs. no-DA"},
    },
    "JPAS Obs. DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {"color": "green", "label": "JPAS Obs. DA"},
    },
}
%autoreload
fig, axes = evaluation_tools.compare_models_performance(
    dict_cases=dict_cmp,
    class_names=class_names,
    title=None,
    figsize=(5, 24),
    palette=["grey", "royalblue", "firebrick", "darkorange", "green"],
    save_path=os.path.join(global_setup.path_saved_figures, "metrics_comparison.pdf"),
    include_metrics=("Accuracy", "Macro F1", "Macro TPR", "Macro Precision", "Macro AUROC", "ECE"),
    nrows=7,
    subplot_hspace=-1.166,
    subplot_wspace=0.25,
    ylabel_text="Score",
    # y_ranges={
    #     "Accuracy":        (0.92, 0.97),
    #     "Macro F1":        (0.7, 0.84),
    #     "Macro TPR":       (0.84, 0.9),
    #     "Macro Precision": (0.65, 0.8),
    #     "Macro AUROC":     (0.9, 1.0),
    #     "ECE":             (0.00, 0.09)
    # },
    y_margin_frac=0.07,
    bar_alpha=0.9,
    bar_edgecolor="black",
    bar_width=0.7,
    annotate_values=True,
    value_label_fontsize=12,
)
plt.show()


# 3) Per-class metrics (one subplot per metric; x-axis = classes)

In [None]:
fig, axes = evaluation_tools.compare_models_performance_per_class(
    dict_cases=dict_cmp,
    class_names=class_names,  # e.g., ["GALAXY","QSO","STAR"]
    title=None,
    figsize=(9, 25),
    palette=["grey", "royalblue", "firebrick", "darkorange", "green"],
    save_path=os.path.join(global_setup.path_saved_figures, "metrics_comparison_per_class.pdf"),
    include_metrics=("Accuracy", "F1", "TPR", "Precision", "AUROC", "ECE", "Brier"),
    nrows=7,
    subplot_hspace=0.5,
    subplot_wspace=0.25,
    y_ranges={
        "Accuracy":  (0.65, 1.01),
        "F1":        (0.3, 1.1),
        "TPR":       (0.65, 1.0),
        "Precision": (0.18, 1.1),
        "AUROC":     (0.9, 1.03),
        "ECE":       (0.0, 0.09),
        "Brier":     (0.0, 0.06),
    },
    y_margin_frac=0.07,
    bar_alpha=0.9,
    bar_edgecolor="black",
    group_width=0.9,
    annotate_values=True,
    value_label_fontsize=9,
    ylabel_text="Score",
    left_margin=0.10,
    ytick_step=0.05,
    ytick_format="{x:.2f}",
    two_line_class_xticklabels=False,
)
fig.tight_layout()
plt.show()


# 4) Radar plot (e.g., per-class F1)

In [None]:
dict_radar = {
    "JPAS Obs. Supervised": {
        "y_true": yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["Supervised"]["DESI_mocks_Raul"]["test"],
        "plot_kwargs": {
            "linestyle": ":", "linewidth": 2.0, "color": "grey",
            "marker": "X", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "JPAS Obs. Supervised",
        },
    },
    "Mocks no-DA": {
        "y_true": yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["DESI_mocks_Raul"]["test"],
        "plot_kwargs": {
            "linestyle": "--", "linewidth": 2.0, "color": "royalblue",
            "marker": "s", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "Mocks no-DA",
        },
    },
    "JPAS Obs. no-DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {
            "linestyle": "--", "linewidth": 2.0, "color": "firebrick",
            "marker": "v", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "JPAS Obs. no-DA",
        },
    },
    "JPAS Obs. (Train) DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["train"]["SPECTYPE_int"],
        "y_pred": probs["DA"]["JPAS_x_DESI_Raul"]["train"],
        "plot_kwargs": {
            "linestyle": "-", "linewidth": 2.0, "color": "darkorange",
            "marker": "^", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "JPAS Obs. (Train) DA",
        },
    },
    "JPAS Obs. DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {
            "linestyle": "-", "linewidth": 2.0, "color": "green",
            "marker": "o", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "JPAS Obs. DA",
        },
    },
}

fig, ax = evaluation_tools.radar_plot(
    dict_radar=dict_radar,
    class_names=class_names,
)
fig.savefig(os.path.join(global_setup.path_saved_figures, "F1_radar.pdf"), bbox_inches='tight')
plt.show()

# 5) ROC curves (single panel, macro one-vs-rest per class & case)

In [None]:
dict_roc = {
    "JPAS Obs. Supervised": {
        "y_true": yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["Supervised"]["DESI_mocks_Raul"]["test"],
        "plot_kwargs": {"linestyle": ":", "linewidth": 2.0, "marker": None, "markersize": 8,
                        "label": "JPAS Obs. Supervised"},
    },
    "JPAS Obs. no-DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {"linestyle": "--", "linewidth": 2.0, "marker": None, "markersize": 8,
                        "label": "JPAS Obs. no-DA"},
    },
    "JPAS Obs. DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {"linestyle": "-", "linewidth": 2.0, "marker": None, "markersize": 8,
                        "label": "JPAS Obs. DA"},
    },
}

fig, ax = evaluation_tools.plot_multiclass_rocs(
    dict_cases=dict_roc,
    class_names=class_names,
    title=None,
    x_lims=(-0.01, 0.3),
    y_lims=(0.7, 1.01),
)
fig.savefig(os.path.join(global_setup.path_saved_figures, "ROC.pdf"), format="pdf", bbox_inches="tight")
plt.show()

# Latents

In [None]:
feat_dict = {
    "latents_no_DA_Source": features["no-DA"]["DESI_mocks_Raul"]["test"],
    "latents_no_DA_Target": features["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    "latents_DA_Target": features["DA"]["JPAS_x_DESI_Raul"]["test"]
}

latents_tSNE = evaluation_tools.tsne_per_key(
    feat_dict,
    standardize=False,
    subsample=None,
    random_state=137,
    tsne_kwargs={"perplexity": 100},
    return_all_key=None,
)

In [None]:
xlim = (-150, 150)
ylim = (-150, 150)

evaluation_tools.plot_latents_scatter_val_test(
    X_val=latents_tSNE['latents_no_DA_Source_tSNE'], y_val=yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
    X_test=latents_tSNE['latents_no_DA_Target_tSNE'], y_test=yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    class_names=class_names,
    title="Latents no-DA: Source (Mocks) vs Target (JPAS x DESI obs.)",
    marker_val="o", marker_test="^",
    size_val=14, size_test=14, alpha_val=0.7, alpha_test=0.7,
    xlim=xlim, ylim=ylim,
    subsample=4000, seed=137,
    edgecolor=None, linewidths=0.0,
    legend_split_1="Source (Mocks) no-DA",
    legend_split_2="Target (JPAS x DESI obs.) no-DA"
)
evaluation_tools.plot_latents_scatter(
    latents_tSNE['latents_no_DA_Source_tSNE'], yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
    class_counts=None,
    class_names=class_names,
    title="Latents no-DA: Source (Mocks)",
    n_bins=128, sigma=2.0,
    scatter_size=0.003, scatter_alpha=0.3,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_no_DA_Source_tSNE'],
    title="Latents no-DA: Source (Mocks)",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latents_scatter(
    latents_tSNE['latents_no_DA_Target_tSNE'], yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    class_counts=None,
    class_names=class_names,
    title="Latents no-DA: Target (JPAS x DESI obs.)",
    n_bins=128, sigma=2.0,
    scatter_size=0.1, scatter_alpha=0.5,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_no_DA_Target_tSNE'],
    title="Latents no-DA: Target (JPAS x DESI obs.)",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)

evaluation_tools.plot_latents_scatter_val_test(
    X_val=latents_tSNE['latents_no_DA_Target_tSNE'], y_val=yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    X_test=latents_tSNE['latents_DA_Target_tSNE'], y_test=yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    class_names=class_names,
    title="Latents Target (JPAS x DESI obs.): no-DA vs DA",
    marker_val="o", marker_test="^",
    size_val=8, size_test=8, alpha_val=0.7, alpha_test=0.7,
    xlim=xlim, ylim=ylim,
    subsample=None, seed=137,
    edgecolor=None, linewidths=0.0,
    legend_split_1="Target (JPAS x DESI obs.) no-DA",
    legend_split_2="Target (JPAS x DESI obs.) DA"
)
evaluation_tools.plot_latents_scatter(
    latents_tSNE['latents_DA_Target_tSNE'], yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    class_counts=None,
    class_names=class_names,
    title="Latents DA: Target (JPAS x DESI obs.)",
    n_bins=128, sigma=2.0,
    scatter_size=0.1, scatter_alpha=0.5,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_DA_Target_tSNE'],
    title="Latents DA: Target (JPAS x DESI obs.)",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)

In [None]:
feat_dict = {
    "latents_DA_JPAS_x_DESI": features["DA"]["JPAS_x_DESI_Raul"]["test"],
    "latents_DA_all_JPAS": features["DA"]["Ignasi"]["all"],
}

subsample = features["DA"]["JPAS_x_DESI_Raul"]["test"].shape[0]

latents_tSNE = evaluation_tools.tsne_per_key(
    feat_dict,
    standardize=False,
    subsample={"latents_DA_JPAS_x_DESI": subsample, "latents_DA_all_JPAS": subsample},
    random_state=137,
    tsne_kwargs={"perplexity": 100},
    return_all_key=None,
)

In [None]:
xlim = (-100, 100)
ylim = (-100, 100)

evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_DA_JPAS_x_DESI_tSNE'],
    title="Latents JPAS x DESI",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_DA_all_JPAS_tSNE'],
    title="Latents all JPAS",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)

In [None]:
_, _, IDs, _, _, LoA_intersection_JPAS_x_DESI_Raul_test, LoA_intersection_Ignasi_all = crossmatch_tools.crossmatch_IDs_two_datasets(
    yy['JPAS_x_DESI_Raul']['test']['TARGETID'], DATA['Ignasi']['all_pd']['TARGETID']
)
_, _, _, _, _, LoA_intersection_JPAS_x_DESI_Raul_DATA_test, _ = crossmatch_tools.crossmatch_IDs_two_datasets(
    DATA['JPAS_x_DESI_Raul']['all_pd']['TARGETID'], IDs
)

In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# ───── Config ─────
survey_jpas   = "JPAS_x_DESI_Raul"
survey_ignasi = "Ignasi"
NN_plot = 1000
rng = np.random.default_rng(0)
USE_LOG_Y = True

# Filtering params
RATIO_DEV_THR = 50.     # keep if max |ratio - 1| > this (ratio = log10(Ignasi)/log10(JPAS))
MIN_VALID_PTS = 1        # require at least this many finite ratio points

# ---- helpers to normalize index containers ----
def iter_indices(x):
    if x is None:
        return
    if isinstance(x, (list, tuple, np.ndarray)):
        for v in x:
            if v is None: 
                continue
            yield int(v)
    else:
        yield int(x)

def first_index(x):
    for v in iter_indices(x):
        return v
    return None

def compute_ratio_for_id(idx_ID):
    """Return ratio array (or None) for first occurrence in each survey."""
    occ_jpas   = LoA_intersection_JPAS_x_DESI_Raul_DATA_test[idx_ID]
    occ_ignasi = LoA_intersection_Ignasi_all[idx_ID]
    i_j = first_index(occ_jpas)
    i_i = first_index(occ_ignasi)
    if i_j is None or i_i is None:
        return None

    obs_j = np.asarray(DATA[survey_jpas]['all_observations'][i_j], dtype=float)
    obs_i = np.asarray(DATA[survey_ignasi]['all_observations'][i_i], dtype=float)

    j_safe = np.log10(np.clip(obs_j, 1e-12, None))
    i_safe = np.log10(np.clip(obs_i, 1e-12, None))
    denom  = np.where(j_safe == 0, np.nan, j_safe)
    ratio  = i_safe / denom
    return ratio

def id_passes_threshold(idx_ID, thr=RATIO_DEV_THR):
    ratio = compute_ratio_for_id(idx_ID)
    if ratio is None:
        return False
    valid = np.isfinite(ratio)
    if valid.sum() < MIN_VALID_PTS:
        return False
    max_dev = np.nanmax(np.abs(ratio[valid] - 1.0))
    return bool(max_dev > thr)

# ───── sampling ─────
assert len(IDs) == len(LoA_intersection_JPAS_x_DESI_Raul_DATA_test) == len(LoA_intersection_Ignasi_all), \
    "Lengths of IDs and intersection lookups must match."

NN_eff = min(NN_plot, len(IDs))
sample_all = rng.choice(len(IDs), NN_eff, replace=False)

# Filter by threshold
filtered_idxs = [idx for idx in sample_all if id_passes_threshold(idx)]
if not filtered_idxs:
    print(f"ℹ️ No IDs exceeded threshold (|ratio-1| > {RATIO_DEV_THR}). Nothing to plot.")
    filtered_idxs = []  # keep empty; the rest will no-op

# colors/linestyles
colors = plt.cm.plasma(np.linspace(0.08, 0.92, max(1, len(filtered_idxs))))
linestyles = {survey_jpas: "--", survey_ignasi: "-"}

# figure (main + ratio)
fig, (ax, ax_ratio) = plt.subplots(
    2, 1, figsize=(9, 8), height_ratios=[3, 1], sharex=True, gridspec_kw={'hspace': 0.06}
)
ax.set_ylabel(r'Flux [arb. units]', fontsize=18)
ax_ratio.set_xlabel(r'$\mathrm{Filter~Index}$', fontsize=18)
ax_ratio.set_ylabel(r'$\log_{10}(\mathrm{Ignasi}) / \log_{10}(\mathrm{JPAS})$', fontsize=13)

# legends scaffolding
survey_handles = [
    mpl.lines.Line2D([0], [0], color="gray", linestyle=linestyles[survey_jpas], lw=2, label="JPAS×DESI"),
    mpl.lines.Line2D([0], [0], color="gray", linestyle=linestyles[survey_ignasi], lw=2, label="Ignasi"),
]
id_handles = []

# ───── plotting ─────
for j, idx_ID in enumerate(filtered_idxs):
    color = colors[j]
    tid_label = str(IDs[idx_ID])
    id_handles.append(mpl.lines.Line2D([0], [0], color=color, lw=3, label=f"ID {tid_label}"))

    # occurrences
    occ_jpas   = LoA_intersection_JPAS_x_DESI_Raul_DATA_test[idx_ID]
    occ_ignasi = LoA_intersection_Ignasi_all[idx_ID]

    # JPAS: plot every occurrence
    for ii in iter_indices(occ_jpas):
        obs = np.asarray(DATA[survey_jpas]['all_observations'][ii], dtype=float)
        x = np.arange(obs.size)
        ax.plot(x, obs, linestyle=linestyles[survey_jpas], lw=2.0, marker='o', ms=3.0,
                color=color, alpha=0.9)

    # Ignasi: plot every occurrence
    for ii in iter_indices(occ_ignasi):
        obs = np.asarray(DATA[survey_ignasi]['all_observations'][ii], dtype=float)
        x = np.arange(obs.size)
        ax.plot(x, obs, linestyle=linestyles[survey_ignasi], lw=2.0, marker='o', ms=3.0,
                color=color, alpha=0.9)

    # ratio: first occurrence each (already validated by filter)
    ratio = compute_ratio_for_id(idx_ID)
    if ratio is not None:
        ax_ratio.plot(np.arange(ratio.size), ratio, color=color, lw=1.8, alpha=0.95)

# ───── styling ─────
ax.tick_params(axis='both', labelsize=12)
ax_ratio.axhline(1.0, ls='--', lw=1.0, color='black', alpha=0.6)
ax_ratio.tick_params(axis='both', labelsize=11)

leg0 = ax.legend(handles=survey_handles, loc='upper left', fontsize=12,
                 fancybox=True, shadow=True, title="Survey", title_fontsize=13)
ax.add_artist(leg0)

if id_handles:
    leg1 = ax.legend(handles=id_handles, loc='upper right', fontsize=11,
                     fancybox=True, shadow=True, title=f"Sampled IDs (>|ratio-1|>{RATIO_DEV_THR})", title_fontsize=12)
    ax.add_artist(leg1)

if USE_LOG_Y:
    ax.set_yscale("log")

plt.tight_layout()
plt.show()


In [None]:
feat_dict = {
    "latents_DA_Raul_test_x_Ignasi": features["DA"]["JPAS_x_DESI_Raul"]["test"][np.concatenate(LoA_intersection_JPAS_x_DESI_Raul_test)],
    "latents_DA_Ignasi_x_Raul_test": features["DA"]["Ignasi"]["all"][np.concatenate(LoA_intersection_Ignasi_all)],
}
latents_tSNE = evaluation_tools.tsne_per_key(
    feat_dict,
    standardize=False,
    subsample=None,
    random_state=137,
    tsne_kwargs={"perplexity": 100},
    return_all_key=None,
)

In [None]:
A = latents_tSNE['latents_DA_Raul_test_x_Ignasi_tSNE']
B = latents_tSNE['latents_DA_Ignasi_x_Raul_test_tSNE']

plt.figure(figsize=(7.5, 6))

plt.scatter(
    A[:, 0], A[:, 1],
    s=14, marker='X', alpha=0.9, edgecolors='none', color='royalblue'
)

plt.scatter(
    B[:, 0], B[:, 1],
    s=5, marker='o', alpha=0.4, edgecolors='none', color='crimson'
)

plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.25)
plt.tight_layout()
plt.show()

# Magnitude cuts

In [None]:
magnitude_key="DESI_FLUX_R"
mag_bin_edges=(17, 19, 21, 22, 22.5)
magnitude_ranges = [(mag_bin_edges[i], mag_bin_edges[i+1]) for i in range(len(mag_bin_edges)-1)]
bin_colors = ['blue', 'green', 'orange', 'red']
colormaps = [plt.cm.Blues, plt.cm.Greens, plt.cm.YlOrBr, plt.cm.Reds]

In [None]:
magnitudes_plot = {
    ("Mocks", "Test"): -2.5 * np.log10(yy["DESI_mocks_Raul"]["test"]['DESI_FLUX_R']) + 22.5,
    ("JPAS x DESI", "Train"): -2.5 * np.log10(yy["JPAS_x_DESI_Raul"]["train"]['DESI_FLUX_R']) + 22.5,
    ("JPAS x DESI", "Test"): -2.5 * np.log10(yy["JPAS_x_DESI_Raul"]["test"]['DESI_FLUX_R']) + 22.5
}
labels_plot = {
    ("Mocks", "Test"): yy["DESI_mocks_Raul"]["test"]['SPECTYPE_int'],         
    ("JPAS x DESI", "Train"): yy["JPAS_x_DESI_Raul"]["train"]['SPECTYPE_int'], 
    ("JPAS x DESI", "Test"): yy["JPAS_x_DESI_Raul"]["test"]['SPECTYPE_int']
}
masks_magnitudes, stats_magnitudes = plotting_utils.plot_histogram_with_ranges_multiple(
    magnitudes_plot, ranges=magnitude_ranges, colors=bin_colors, bins=42,
    x_label="DESI Magnitude (R)", title="DESI R-band Magnitudes",
    labels_dict=labels_plot, class_names=class_names, pct_decimals=0,
    annotate_mode='text', legend_fontsize=12
)

stats_magnitudes[("Mocks", "Test")]["plot_kwargs"] = {
    "linestyle": "--", "marker": "+", "markersize": 8.0,
    "label": "Source-Test (Mocks) no-DA",
}
stats_magnitudes[("JPAS x DESI", "Test")]["plot_kwargs"] = {
    "linestyle": "-.", "marker": "x", "markersize": 8.0,
    "label": "Target-Test (JPAS x DESI) no-DA",
}
stats_magnitudes[("JPAS x DESI", "Train")]["plot_kwargs"] = {
    "linestyle": ":", "marker": "^", "markersize": 8.0,
    "label": "Target-Training (JPAS x DESI) DA",
}
fig, ax = plotting_utils.plot_per_class_counts_together(
    stats_magnitudes,
    yscale="log",
    figsize=(12, 6),
    title="Per-class counts vs. magnitude (all entries together)",
    title_fontsize=18,  # title size
    class_legend_kwargs={
        "loc": "upper left", "bbox_to_anchor": (0.0, 1.0),
        "fontsize": 10, "title": "Classes", "frameon": True, "fancybox": True, "shadow": True
    },
    entry_legend_kwargs={
        "loc": "upper left", "bbox_to_anchor": (0.65, 1.0),
        "fontsize": 9, "ncol": 1, "title": "Evaluation Cases", "frameon": True, "fancybox": True, "shadow": True
    },
    legend_outside=True,
)
plt.show()

In [None]:
# Transform from flux to (your) magnitude definition
def flux_to_mag(flux_array):
    return -2.5 * np.log10(np.asarray(flux_array)) + 22.5

# Precompute magnitude arrays you’ll need (aligned to the same indexing as y_true/y_pred)
MAG = {
    ("DESI_mocks_Raul", "test"):  flux_to_mag(yy["DESI_mocks_Raul"]["test"][magnitude_key]),
    ("JPAS_x_DESI_Raul", "test"): flux_to_mag(yy["JPAS_x_DESI_Raul"]["test"][magnitude_key]),
    ("JPAS_x_DESI_Raul", "train"):flux_to_mag(yy["JPAS_x_DESI_Raul"]["train"][magnitude_key]),
}

# --------------------------------------------
# Helpers
# --------------------------------------------
def _with_suffix(path_or_stem, suffix, ext_if_missing=".pdf"):
    """
    If path_or_stem has an extension, insert suffix before it.
    If not, add ext_if_missing.
    """
    base, ext = os.path.splitext(path_or_stem)
    if ext:
        return f"{base}{suffix}{ext}"
    else:
        return f"{path_or_stem}{suffix}{ext_if_missing}"

def filter_dict_cases_by_mag(dict_cases, case_to_mag, mag_range):
    """
    dict_cases: {case: {"y_true": (N,), "y_pred": (N,C), ...}}
    case_to_mag: {case: (N,) magnitudes aligned with above}
    mag_range: (low, high)
    Returns a NEW dict with filtered y_true/y_pred.
    Drops cases with zero samples after filtering.
    """
    low, high = mag_range
    out = {}
    for case, payload in dict_cases.items():
        y_true = np.asarray(payload["y_true"])
        y_pred = np.asarray(payload["y_pred"])
        mags   = np.asarray(case_to_mag[case])

        if mags.shape[0] != y_true.shape[0]:
            raise ValueError(f"[{case}] magnitude array length {mags.shape[0]} "
                             f"!= y_true length {y_true.shape[0]}")

        mask = (mags >= low) & (mags < high)
        n_kept = int(mask.sum())
        if n_kept == 0:
            # skip empty case
            continue

        new_payload = dict(payload)  # shallow copy
        new_payload["y_true"] = y_true[mask]
        new_payload["y_pred"] = y_pred[mask]
        out[case] = new_payload
    return out

def _recolor_all_cases(dict_cases, new_color):
    """Return a copy where every case's plot color is set to `new_color`."""
    out = {}
    for k, v in dict_cases.items():
        vv = {**v}
        pk = dict(vv.get("plot_kwargs", {}))
        pk["color"] = new_color
        vv["plot_kwargs"] = pk
        out[k] = vv
    return out

# CONFUSION MATRICES

In [None]:
dict_cm = {
    "JPAS Obs. (Supervised)": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["Supervised"]["JPAS_x_DESI_Raul"]["test"],
    },
    "Mocks (no-DA)": {
        "y_true": yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["DESI_mocks_Raul"]["test"],
    },
    "JPAS Obs. (no-DA)": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    },
    "JPAS Obs. (DA)": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["DA"]["JPAS_x_DESI_Raul"]["test"],
    },
}
case_to_mag_cm = {
    "JPAS Obs. (Supervised)": MAG[("JPAS_x_DESI_Raul", "test")],
    "Mocks (no-DA)":                MAG[("DESI_mocks_Raul", "test")],
    "JPAS Obs. (no-DA)":            MAG[("JPAS_x_DESI_Raul", "test")],
    "JPAS Obs. (DA)":               MAG[("JPAS_x_DESI_Raul", "test")],
}

for ii, (low, high) in enumerate(magnitude_ranges):
    suffix = f"_mag_{low:.1f}-{high:.1f}"
    dict_cm_bin = filter_dict_cases_by_mag(dict_cm, case_to_mag_cm, (low, high))

    if not dict_cm_bin:
        print(f"[confusion matrices] No samples in bin {suffix}; skipping.")
        continue

    save_path = _with_suffix(
        os.path.join(global_setup.path_saved_figures, "confusion_matrices.pdf"),
        suffix
    )
    fig, axes = evaluation_tools.plot_confusion_matrices_grid(
        dict_cases=dict_cm_bin, class_names=class_names, cmap=colormaps[ii],
        save_path=save_path,
    )
    plt.show()

# Global metrics

In [None]:
dict_cmp = {
    "JPAS Obs. Supervised": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["Supervised"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {"color": "grey", "label": "JPAS Obs. Supervised"}
    },
    "Mocks no-DA": {
        "y_true": yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["DESI_mocks_Raul"]["test"],
        "plot_kwargs": {"color": "royalblue", "label": "Mocks no-DA"}
    },
    "JPAS Obs. no-DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {"color": "firebrick", "label": "JPAS Obs. no-DA"}
    },
    "JPAS Obs. DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {"color": "green", "label": "JPAS Obs. DA"}
    },
}
case_to_mag_cmp = {
    "JPAS Obs. Supervised": MAG[("JPAS_x_DESI_Raul", "test")],
    "Mocks no-DA":                MAG[("DESI_mocks_Raul", "test")],
    "JPAS Obs. no-DA":            MAG[("JPAS_x_DESI_Raul", "test")],
    "JPAS Obs. DA":               MAG[("JPAS_x_DESI_Raul", "test")],
}

for ii, (low, high) in enumerate(magnitude_ranges):
    suffix = f"_mag_{low:.1f}-{high:.1f}"
    dict_cmp_bin = filter_dict_cases_by_mag(dict_cmp, case_to_mag_cmp, (low, high))
    if not dict_cmp_bin:
        print(f"[global metrics] No samples in bin {suffix}; skipping.")
        continue

    # recolor all cases for this bin
    dict_cmp_bin = _recolor_all_cases(dict_cmp_bin, bin_colors[ii])

    save_path = _with_suffix(
        os.path.join(global_setup.path_saved_figures, "metrics_comparison.pdf"),
        suffix
    )
    fig, axes = evaluation_tools.compare_models_performance(
        dict_cases=dict_cmp_bin,
        class_names=class_names,
        title=f"Magnitude = [{mag_bin_edges[ii]}, {mag_bin_edges[ii+1]}]",
        palette=["grey", "royalblue", "firebrick", "darkorange", "green"],  # kept; plot_kwargs color takes precedence
        save_path=save_path,
        include_metrics=("Accuracy", "Macro F1", "Macro TPR", "Macro Precision", "Macro AUROC", "ECE", "Brier Score"),
        nrows=7,
        y_ranges={
            "Accuracy":        (0.4, 1.2),
            "Macro F1":        (0.4, 1.2),
            "Macro TPR":       (0.4, 1.2),
            "Macro Precision": (0.4, 1.2),
            "Macro AUROC":     (0.4, 1.2),
            "ECE":             (0.0, 0.13),
            "Brier Score":     (0.0, 0.46),
        },
    )
    plt.show()

# Radar

In [None]:
dict_radar = {
    "JPAS Obs. Supervised": {
        "y_true": yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["Supervised"]["DESI_mocks_Raul"]["test"],
        "plot_kwargs": {"linestyle": ":", "linewidth": 2.0, "color": "grey",
                        "marker": "X", "markersize": 10.0, "fill_alpha": 0.05,
                        "label": "JPAS Obs. Supervised"}
    },
    "Mocks no-DA": {
        "y_true": yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["DESI_mocks_Raul"]["test"],
        "plot_kwargs": {"linestyle": "--", "linewidth": 2.0, "color": "royalblue",
                        "marker": "s", "markersize": 10.0, "fill_alpha": 0.05,
                        "label": "Mocks no-DA"}
    },
    "JPAS Obs. no-DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {"linestyle": "--", "linewidth": 2.0, "color": "firebrick",
                        "marker": "v", "markersize": 10.0, "fill_alpha": 0.05,
                        "label": "JPAS Obs. no-DA"}
    },
    "JPAS Obs. DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {"linestyle": "-", "linewidth": 2.0, "color": "green",
                        "marker": "o", "markersize": 10.0, "fill_alpha": 0.05,
                        "label": "JPAS Obs. DA"}
    },
}
case_to_mag_radar = {
    "JPAS Obs. Supervised": MAG[("DESI_mocks_Raul", "test")],
    "Mocks no-DA":                MAG[("DESI_mocks_Raul", "test")],
    "JPAS Obs. no-DA":            MAG[("JPAS_x_DESI_Raul", "test")],
    "JPAS Obs. DA":               MAG[("JPAS_x_DESI_Raul", "test")],
}

for ii, (low, high) in enumerate(magnitude_ranges):
    suffix = f"_mag_{low:.1f}-{high:.1f}"
    dict_radar_bin = filter_dict_cases_by_mag(dict_radar, case_to_mag_radar, (low, high))
    if not dict_radar_bin:
        print(f"[radar] No samples in bin {suffix}; skipping.")
        continue

    # recolor all cases for this bin
    dict_radar_bin = _recolor_all_cases(dict_radar_bin, bin_colors[ii])

    save_path = _with_suffix(
        os.path.join(global_setup.path_saved_figures, "F1_radar.pdf"),
        suffix
    )
    fig, ax = evaluation_tools.radar_plot(
        dict_radar=dict_radar_bin,
        class_names=class_names,
        title=f"Magnitude = [{mag_bin_edges[ii]}, {mag_bin_edges[ii+1]}]",
    )
    fig.savefig(save_path, bbox_inches='tight')
    plt.show()
