In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup

from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools
from JPAS_DA.data import crossmatch_tools
from JPAS_DA.data import process_dset_splits
from JPAS_DA.data import wrapper_data_loaders

from JPAS_DA.models import model_building_tools
from JPAS_DA.training import save_load_tools
from JPAS_DA.evaluation import evaluation_tools
from JPAS_DA.wrapper_wandb import wrapper_tools
from JPAS_DA.evaluation import evaluation_tools

from JPAS_DA.evaluation import wandb_evaluation_tools

import os
import torch
import numpy as np
import fitsio

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib widget

from JPAS_DA.utils import aux_tools
aux_tools.set_seed(42)

In [None]:
path_wandb_sweep_no_DA = os.path.join(global_setup.path_models, "wandb_no_DA")
N_selected_sweeps_no_DA = 1

In [None]:
path_wandb_sweep_DA = os.path.join(global_setup.path_models, "wandb_DA")
N_selected_sweeps_DA = 1

In [None]:
sorted_list_sweep_names_no_DA, sorted_losses_no_DA = wandb_evaluation_tools.load_and_plot_sorted_sweeps(path_wandb_sweep_no_DA, max_runs_to_plot=N_selected_sweeps_no_DA)

In [None]:
sorted_list_sweep_names_DA, sorted_losses_DA = wandb_evaluation_tools.load_and_plot_sorted_sweeps(path_wandb_sweep_DA, max_runs_to_plot=N_selected_sweeps_DA)

In [None]:
paths_load_no_DA = [os.path.join(path_wandb_sweep_no_DA, sweep_name) for sweep_name in sorted_list_sweep_names_no_DA[:N_selected_sweeps_no_DA]]

In [None]:
paths_load_DA = [os.path.join(path_wandb_sweep_DA, sweep_name) for sweep_name in sorted_list_sweep_names_DA[:N_selected_sweeps_DA]]

In [None]:
return_keys=['val_DESI_only', 'test_JPAS_matched']
define_dataset_loaders_keys=['DESI_only', 'JPAS_matched']
keys_yy=["SPECTYPE_int", "TARGETID", "DESI_FLUX_R"]

In [None]:
RESULTS_no_DA = evaluation_tools.evaluate_results_from_load_paths(
    paths_load=paths_load_no_DA, return_keys=return_keys, define_dataset_loaders_keys=define_dataset_loaders_keys, keys_yy=keys_yy
)

In [None]:
return_keys=['train_JPAS_matched', 'test_JPAS_matched']
define_dataset_loaders_keys=['JPAS_matched']
keys_yy=["SPECTYPE_int", "TARGETID", "DESI_FLUX_R"]

In [None]:
RESULTS_DA = evaluation_tools.evaluate_results_from_load_paths(
    paths_load=paths_load_DA, return_keys=return_keys, define_dataset_loaders_keys=define_dataset_loaders_keys, keys_yy=keys_yy
)

In [None]:
n_classes = len(np.unique(RESULTS_no_DA[0]['val_DESI_only']['label']))

if n_classes == 2:
    class_names = ['QSO_high', 'no_QSO_high']
    manually_select_one_SPECTYPE_vs_rest = 'QSO_high'
else:
    class_names = global_setup.class_names
    manually_select_one_SPECTYPE_vs_rest = None

In [None]:
ii_model = 0
IDs_only_1, IDs_only_2, IDs_both, idxs_only_1, idxs_only_2, idxs_both_1, idxs_both_2 = crossmatch_tools.crossmatch_IDs_two_datasets(
    RESULTS_no_DA[ii_model]["test_JPAS_matched"]['TARGETID'],
    RESULTS_DA[ii_model]["test_JPAS_matched"]['TARGETID']
)
# message asserting that all TARGETIDs are matched
assert len(IDs_only_1) == 0, f"IDs only in no DA: {IDs_only_1}"
assert len(IDs_only_2) == 0, f"IDs only in DA: {IDs_only_2}"
assert len(IDs_both) == len(RESULTS_no_DA[ii_model]["test_JPAS_matched"]['TARGETID']), f"Something wrong"

In [None]:
JPAS_Ignasi = fitsio.read("/home/dlopez/Documentos/0.profesional/Postdoc/USP/Projects/JPAS_Domain_Adaptation/DATA/jpas_idr_classification_xmatch_desi_dr1.fits.gz")

yy_true_Ignasi = np.array(JPAS_Ignasi['SPECTYPE'][JPAS_Ignasi["is_in_desi_dr1"]])
yy_true_Ignasi = list(np.array(yy_true_Ignasi).astype(np.str_))
REDSHIFT = np.array(JPAS_Ignasi['z'][JPAS_Ignasi["is_in_desi_dr1"]])
# Split between High and Low redshift quasars
z_lim_QSO_cut = 2.1
for ii in range(len(yy_true_Ignasi)):
    if yy_true_Ignasi[ii] == "QSO":
        if REDSHIFT[ii] < z_lim_QSO_cut:
            yy_true_Ignasi[ii] = "QSO_low"
        else:
            yy_true_Ignasi[ii] = "QSO_high"

# (optional): if manually_select_one_SPECTYPE_vs_rest is specified, restrict the surveys SPECTYPE to that class
if manually_select_one_SPECTYPE_vs_rest != None:
    logging.info("├── Restricting the surveys SPECTYPE to " + str(manually_select_one_SPECTYPE_vs_rest))
    for ii in range(len(yy_true_Ignasi)):
        if yy_true_Ignasi[ii] != manually_select_one_SPECTYPE_vs_rest:
            yy_true_Ignasi[ii] = "no_" + manually_select_one_SPECTYPE_vs_rest

yy_true_Ignasi, class_mapping = cleaning_tools.encode_strings_to_integers(yy_true_Ignasi)


classification_keys = {
    "TRANS": ['conf_gal_TRANS', 'conf_hqso_TRANS', 'conf_lqso_TRANS', 'conf_star_TRANS'],
    "CBM": ['conf_gal_CBM', 'conf_hqso_CBM', 'conf_lqso_CBM', 'conf_star_CBM']
}

yy_pred_P = {}
for model_key, feature_keys in classification_keys.items():
    # Extract all probabilities (shape: [N_samples, N_classes])
    model_probs = np.array([
        JPAS_Ignasi[key][JPAS_Ignasi["is_in_desi_dr1"]] for key in feature_keys
    ]).T

    if manually_select_one_SPECTYPE_vs_rest is None:
        # Standard multiclass case
        yy_pred_P[model_key] = model_probs
    else:
        # Map SPECTYPE to index in original class list
        name_map = {
            "GALAXY": "conf_gal_",
            "QSO_high": "conf_hqso_",
            "QSO_low": "conf_lqso_",
            "STAR": "conf_star_"
        }
        positive_key = name_map[manually_select_one_SPECTYPE_vs_rest] + model_key

        idx_positive = feature_keys.index(positive_key)
        idx_all = np.arange(len(feature_keys))
        idx_rest = np.delete(idx_all, idx_positive)

        positive_probs = model_probs[:, idx_positive]
        rest_probs = model_probs[:, idx_rest].sum(axis=1)
        yy_pred_P[model_key] = np.stack([positive_probs, rest_probs], axis=1)

yy_pred_P["TRANS"].shape, yy_pred_P["CBM"].shape


IDs_only_1, IDs_only_2, IDs_both, idxs_only_1, idxs_only_2, idxs_both_1, idxs_both_2 = crossmatch_tools.crossmatch_IDs_two_datasets(
    RESULTS_no_DA[ii_model]["test_JPAS_matched"]['TARGETID'],
    np.array(JPAS_Ignasi["TARGETID"][JPAS_Ignasi["is_in_desi_dr1"]])
)
idxs_both_me = np.concatenate(idxs_both_1)
idxs_both_Ignasi = np.concatenate(idxs_both_2)

yy_true_Ignasi_crossmatch = yy_true_Ignasi[idxs_both_Ignasi]
yy_pred_P_Ignasi_crossmatch_CBM = yy_pred_P["CBM"][idxs_both_Ignasi]
yy_pred_P_Ignasi_crossmatch_TRANS = yy_pred_P["TRANS"][idxs_both_Ignasi]

In [None]:
confusion_matrix = evaluation_tools.plot_confusion_matrix(
    yy_true_Ignasi_crossmatch, yy_pred_P_Ignasi_crossmatch_CBM,
    class_names=class_names, cmap=plt.cm.RdYlGn, title="CBM"
)

confusion_matrix = evaluation_tools.plot_confusion_matrix(
    yy_true_Ignasi_crossmatch, yy_pred_P_Ignasi_crossmatch_TRANS,
    class_names=class_names, cmap=plt.cm.RdYlGn, title="TRANS"
)

In [None]:
for model_idx, model_outputs in RESULTS_no_DA.items():
    for key, result in model_outputs.items():
        yy_true = result["true"]
        yy_pred = result["prob"]
        
        # Plot confusion matrix
        evaluation_tools.plot_confusion_matrix(
            yy_true,
            yy_pred,
            class_names=class_names,
            cmap=plt.cm.RdYlGn,
            title=f"{key.replace('_', ' ')} (no-DA Model {model_idx})"
        )

In [None]:
for model_idx in RESULTS_DA.keys():
    model_outputs = RESULTS_DA[model_idx]
    result = model_outputs["test_JPAS_matched"]
    yy_true = result["true"]
    yy_pred = result["prob"]

    evaluation_tools.plot_confusion_matrix(
        yy_true,
        yy_pred,
        class_names=class_names,
        cmap=plt.cm.RdYlGn,
        title=f"test JPAS matched (DA Model {model_idx})"
    )

In [None]:
from sklearn.metrics import f1_score
from matplotlib.lines import Line2D

# === CONFIG ===
class_count = len(class_names)
angles = np.linspace(0, 2 * np.pi, class_count, endpoint=False).tolist()
angles += angles[:1]
radius_box = 1.05
text_fontsize = 14
linewidth_model = 1
linewidth_mean = 3
tick_labelsize = 18
radial_labelsize = 14
legend_fontsize = 16
title_fontsize = 20
title_pad = 50
figsize = (10, 10)

# === Prepare plot ===
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)
ax.set_thetagrids(np.degrees(angles[:-1]), class_names, fontsize=tick_labelsize)
ax.set_ylim(0, 1)
ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_yticklabels([str(y) for y in [0.2, 0.4, 0.6, 0.8, 1.0]], fontsize=radial_labelsize)
# ax.set_title("F1 Radar Plot – no-DA vs DA vs CBM vs TRANS", fontsize=title_fontsize, pad=title_pad)

# === Style mapping ===
plot_config = [
    {
        'label': 'no-DA (DESI)',
        'results': RESULTS_no_DA,
        'models': list(RESULTS_no_DA.keys()),
        'dataset': 'val_DESI_only',
        'color': 'crimson',
        'linestyle': '--'
    },
    {
        'label': 'no-DA (JPAS)',
        'results': RESULTS_no_DA,
        'models': list(RESULTS_no_DA.keys()),
        'dataset': 'test_JPAS_matched',
        'color': 'crimson',
        'linestyle': '-'
    },
    {
        'label': 'DA (JPAS)',
        'results': RESULTS_DA,
        'models': list(RESULTS_DA.keys()),
        'dataset': 'test_JPAS_matched',
        'color': 'limegreen',
        'linestyle': '-'
    },
    {
        'label': 'DA (JPAS train)',
        'results': RESULTS_DA,
        'models': list(RESULTS_DA.keys()),
        'dataset': 'train_JPAS_matched',
        'color': 'limegreen',
        'linestyle': '--'
    }
]

legend_handles = []

# === Plot grouped models (RESULTS_no_DA, RESULTS_DA) ===
text_box_angles = np.linspace(0, 2 * np.pi, len(plot_config), endpoint=False) + np.pi / 6
for i, cfg in enumerate(plot_config):
    f1_scores_all = []
    for model_idx in cfg['models']:
        data = cfg['results'][model_idx][cfg['dataset']]
        f1 = f1_score(data['true'], data['label'], average=None, zero_division=0)
        f1_scores_all.append(f1)
        f1_plot = f1.tolist() + [f1[0]]
        ax.plot(angles, f1_plot, color=cfg['color'], linestyle=cfg['linestyle'], linewidth=linewidth_model)

    f1_mean = np.mean(np.stack(f1_scores_all), axis=0)
    f1_mean_plot = f1_mean.tolist() + [f1_mean[0]]
    ax.plot(angles, f1_mean_plot, color=cfg['color'], linestyle=cfg['linestyle'], linewidth=linewidth_mean)

    macro_f1 = np.mean(f1_mean)
    text_obj = ax.text(
        text_box_angles[i], radius_box,
        f"{cfg['label']}\nF1={macro_f1:.2f}",
        color=cfg['color'],
        fontsize=text_fontsize,
        ha="center", va="center",
        bbox=dict(facecolor='white', edgecolor=cfg['color'], boxstyle='round,pad=0.4', lw=1.5)
    )
    text_obj.get_bbox_patch().set_linestyle(cfg['linestyle'])
    
    legend_handles.append(Line2D([0], [0], color=cfg['color'], linestyle=cfg['linestyle'], lw=2, label=cfg['label']))

# === Plot CBM ===
f1_cbm = f1_score(yy_true_Ignasi_crossmatch, np.argmax(yy_pred_P_Ignasi_crossmatch_CBM, axis=1), average=None, zero_division=0)
f1_cbm_plot = f1_cbm.tolist() + [f1_cbm[0]]
ax.plot(angles, f1_cbm_plot, color='purple', linestyle='-', linewidth=linewidth_mean)
macro_cbm = np.mean(f1_cbm)
ax.text(
    np.pi / 2 - 0.3, radius_box,
    f"CBM\nF1={macro_cbm:.2f}",
    color='purple',
    fontsize=text_fontsize,
    ha="center", va="center",
    bbox=dict(facecolor='white', edgecolor='purple', boxstyle='round,pad=0.4', lw=1.5)
)
legend_handles.append(Line2D([0], [0], color='purple', linestyle='-', lw=2, label='CBM'))

# === Plot TRANS ===
f1_trans = f1_score(yy_true_Ignasi_crossmatch, np.argmax(yy_pred_P_Ignasi_crossmatch_TRANS, axis=1), average=None, zero_division=0)
f1_trans_plot = f1_trans.tolist() + [f1_trans[0]]
ax.plot(angles, f1_trans_plot, color='gold', linestyle='-', linewidth=linewidth_mean)
macro_trans = np.mean(f1_trans)
ax.text(
    np.pi / 2 + np.pi/3, radius_box,
    f"TRANS\nF1={macro_trans:.2f}",
    color='gold',
    fontsize=text_fontsize,
    ha="center", va="center",
    bbox=dict(facecolor='white', edgecolor='gold', boxstyle='round,pad=0.4', lw=1.5)
)
legend_handles.append(Line2D([0], [0], color='gold', linestyle='-', lw=2, label='TRANS'))

# === Add legend ===
ax.legend(handles=legend_handles, bbox_to_anchor=(0.85, 1.15),
          ncol=3, fontsize=legend_fontsize - 2, title_fontsize=legend_fontsize, fancybox=True, shadow=True)

plt.tight_layout()
plt.show()

In [None]:
for ii_model, model in enumerate(RESULTS_no_DA):
    evaluation_tools.compare_TPR_confusion_matrices(
        RESULTS_no_DA[ii_model]['val_DESI_only']['true'],
        RESULTS_no_DA[ii_model]['val_DESI_only']['prob'],
        RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'],
        RESULTS_no_DA[ii_model]['test_JPAS_matched']['prob'],
        class_names=class_names,
        figsize=(10, 7),
        cmap='seismic',
        title='no-DA model ' + str(ii_model) + ': JPAS test VS DESI-mocks test',
        name_1 = "DESI-mocks",
        name_2 = "JPAS-obs",
    )

    metrics = evaluation_tools.compare_sets_performance(
        RESULTS_no_DA[ii_model]['val_DESI_only']['true'], RESULTS_no_DA[ii_model]['val_DESI_only']['prob'],
        RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'], RESULTS_no_DA[ii_model]['test_JPAS_matched']['prob'],
        class_names=class_names,
        name_1="DESI-mocks",
        name_2="JPAS-obs"
    )

In [None]:
ii_model = 0

evaluation_tools.compare_TPR_confusion_matrices(
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'],
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['prob'],
    RESULTS_DA[ii_model]['test_JPAS_matched']['true'],
    RESULTS_DA[ii_model]['test_JPAS_matched']['prob'],
    class_names=class_names,
    figsize=(10, 7),
    cmap='seismic',
    title='JPAS test: no-DA VS DA',
    name_1 = "no-DA",
    name_2 = "DA",
)

metrics = evaluation_tools.compare_sets_performance(
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'], RESULTS_no_DA[ii_model]['test_JPAS_matched']['prob'],
    RESULTS_DA[ii_model]['test_JPAS_matched']['true'], RESULTS_DA[ii_model]['test_JPAS_matched']['prob'],
    class_names=class_names,
    name_1="no-DA",
    name_2="DA"
)

In [None]:
magnitude_key="DESI_FLUX_R"
mag_bin_edges=(17, 19, 21, 22, 22.5)
output_key="MAG_BIN_ID"

magnitude_ranges = [(mag_bin_edges[i], mag_bin_edges[i+1]) for i in range(len(mag_bin_edges)-1)]
colors = ['blue', 'green', 'orange', 'red']
colormaps = [plt.cm.Blues, plt.cm.Greens, plt.cm.YlOrBr, plt.cm.Reds]

In [None]:
RESULTS_no_DA = evaluation_tools.add_magnitude_bins_to_results(
    RESULTS_no_DA, magnitude_key=magnitude_key, mag_bin_edges=mag_bin_edges, output_key=output_key
)

In [None]:
RESULTS_DA = evaluation_tools.add_magnitude_bins_to_results(
    RESULTS_DA, magnitude_key=magnitude_key, mag_bin_edges=mag_bin_edges, output_key=output_key
)

In [None]:
%autoreload

In [None]:
magnitudes_plot = {
    ("DESI", "val_DESI_only"): -2.5 * np.log10(RESULTS_no_DA[ii_model]['val_DESI_only']['DESI_FLUX_R']) + 22.5,
    ("JPAS", "train_JPAS_matched"): -2.5 * np.log10(RESULTS_DA[ii_model]['train_JPAS_matched']['DESI_FLUX_R']) + 22.5,
    ("JPAS", "test_JPAS_matched"): -2.5 * np.log10(RESULTS_DA[ii_model]['test_JPAS_matched']['DESI_FLUX_R']) + 22.5
}

labels_plot = {
    ("DESI", "val_DESI_only"): RESULTS_no_DA[ii_model]['val_DESI_only']['true'],         
    ("JPAS", "train_JPAS_matched"): RESULTS_DA[ii_model]['train_JPAS_matched']['true'], 
    ("JPAS", "test_JPAS_matched"): RESULTS_DA[ii_model]['test_JPAS_matched']['true'],
}

masks_all, stats_all = plotting_utils.plot_histogram_with_ranges_multiple(
    magnitudes_plot,
    ranges=magnitude_ranges,
    colors=colors,
    bins=42,
    x_label="DESI Magnitude (R)",
    title="DESI R-band Magnitudes",
    labels_dict=labels_plot,
    class_names=class_names,
    pct_decimals=0,
    annotate_mode='text',
    legend_fontsize=12
)

In [None]:
bin_labels = [f"{lo}–{hi}" for lo, hi in magnitude_ranges]
num_bins = len(magnitude_ranges)

# Sweep no-DA models
for model_idx, model_outputs in RESULTS_no_DA.items():
    for key, result in model_outputs.items():
        mag_bins = result["MAG_BIN_ID"]
        yy_true_all = result["true"]
        yy_pred_all = result["prob"]

        for bin_id in range(num_bins):
            mask = mag_bins == bin_id
            if np.sum(mask) == 0:
                continue  # Skip empty bins

            yy_true = yy_true_all[mask]
            yy_pred = yy_pred_all[mask]

            evaluation_tools.plot_confusion_matrix(
                yy_true,
                yy_pred,
                class_names=class_names,
                cmap=colormaps[bin_id],
                title=f"{key.replace('_', ' ')} (no-DA {model_idx}) | Mag {bin_labels[bin_id]}"
            )

In [None]:
for model_idx in RESULTS_DA.keys():
    model_outputs = RESULTS_DA[model_idx]
    result = model_outputs["test_JPAS_matched"]
    mag_bins = result["MAG_BIN_ID"]
    yy_true_all = result["true"]
    yy_pred_all = result["prob"]
    TARGETIDs = result["TARGETID"]

    for bin_id in range(num_bins):
        mask = mag_bins == bin_id
        if np.sum(mask) == 0:
            continue  # Skip empty bins

        yy_true = yy_true_all[mask]
        yy_pred = yy_pred_all[mask]

        evaluation_tools.plot_confusion_matrix(
            yy_true,
            yy_pred,
            class_names=class_names,
            cmap=colormaps[bin_id],
            title=f"test JPAS matched (DA {model_idx}) | Mag {bin_labels[bin_id]}"
        )

        IDs_only_1, IDs_only_2, IDs_both, idxs_only_1, idxs_only_2, idxs_both_1, idxs_both_2 = crossmatch_tools.crossmatch_IDs_two_datasets(
            TARGETIDs[mask],
            np.array(JPAS_Ignasi["TARGETID"][JPAS_Ignasi["is_in_desi_dr1"]])
        )
        idxs_both_Ignasi = np.concatenate(idxs_both_2)
        yy_true_Ignasi_crossmatch = yy_true_Ignasi[idxs_both_Ignasi]
        yy_pred_P_Ignasi_crossmatch_CBM = yy_pred_P["CBM"][idxs_both_Ignasi]
        yy_pred_P_Ignasi_crossmatch_TRANS = yy_pred_P["TRANS"][idxs_both_Ignasi]

        evaluation_tools.plot_confusion_matrix(
            yy_true_Ignasi_crossmatch,
            yy_pred_P_Ignasi_crossmatch_CBM,
            class_names=class_names,
            cmap=colormaps[bin_id],
            title=f"CBM | Mag {bin_labels[bin_id]}"
        )

        evaluation_tools.plot_confusion_matrix(
            yy_true_Ignasi_crossmatch,
            yy_pred_P_Ignasi_crossmatch_TRANS,
            class_names=class_names,
            cmap=colormaps[bin_id],
            title=f"TRANS | Mag {bin_labels[bin_id]}"
        )

In [None]:
# General config (reused from existing context)
class_count = len(class_names)
angles = np.linspace(0, 2 * np.pi, class_count, endpoint=False).tolist()
angles += angles[:1]
radius_box = 1.05
text_fontsize = 14
linewidth_model = 1
linewidth_mean = 3
tick_labelsize = 14
radial_labelsize = 12
legend_fontsize = 12
title_fontsize = 14
title_pad = 30
num_bins = len(bin_labels)
n_cols = 1
n_rows = int(np.ceil(num_bins / n_cols))
figsize = (n_cols * 5, n_rows * 5)

fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize, subplot_kw=dict(polar=True))
axs = axs.flatten()

for bin_id in range(num_bins):
    ax = axs[bin_id]
    ax.set_theta_offset(np.pi / 2)
    ax.set_theta_direction(-1)
    ax.set_thetagrids(np.degrees(angles[:-1]), class_names, fontsize=tick_labelsize)
    ax.set_ylim(0, 1)
    ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_yticklabels([str(y) for y in [0.2, 0.4, 0.6, 0.8, 1.0]], fontsize=radial_labelsize)
    ax.set_title(f"Mag {bin_labels[bin_id]}", fontsize=title_fontsize, pad=title_pad)

    # Plot configs
    plot_config = [
        {
            'label': 'no-DA (DESI)',
            'results': RESULTS_no_DA,
            'dataset': 'val_DESI_only',
            'color': 'crimson',
            'linestyle': '--'
        },
        {
            'label': 'no-DA (JPAS)',
            'results': RESULTS_no_DA,
            'dataset': 'test_JPAS_matched',
            'color': 'crimson',
            'linestyle': '-'
        },
        {
            'label': 'DA (JPAS)',
            'results': RESULTS_DA,
            'dataset': 'test_JPAS_matched',
            'color': 'limegreen',
            'linestyle': '-'
        },
        {
            'label': 'DA (JPAS train)',
            'results': RESULTS_DA,
            'dataset': 'train_JPAS_matched',
            'color': 'limegreen',
            'linestyle': '--'
        }
    ]

    # Loop over no-DA and DA configs
    for cfg in plot_config:
        f1_scores_all = []
        for model_idx in cfg['results'].keys():
            data = cfg['results'][model_idx][cfg['dataset']]
            mask = data['MAG_BIN_ID'] == bin_id
            if np.sum(mask) == 0:
                continue
            yy_true = data['true'][mask]
            yy_pred = data['label'][mask]
            f1 = f1_score(yy_true, yy_pred, average=None, zero_division=0)
            f1_scores_all.append(f1)
            f1_plot = f1.tolist() + [f1[0]]
            ax.plot(angles, f1_plot, color=cfg['color'], linestyle=cfg['linestyle'], linewidth=linewidth_model)

        if f1_scores_all:
            f1_mean = np.mean(np.stack(f1_scores_all), axis=0)
            f1_mean_plot = f1_mean.tolist() + [f1_mean[0]]
            ax.plot(angles, f1_mean_plot, color=cfg['color'], linestyle=cfg['linestyle'], linewidth=linewidth_mean)

    # CBM + TRANS crossmatch in this bin
    test_data = RESULTS_DA[0]["test_JPAS_matched"]
    mask_test = test_data["MAG_BIN_ID"] == bin_id
    TARGETIDs = test_data["TARGETID"][mask_test]
    if np.sum(mask_test) > 0:
        yy_true_bin = test_data["true"][mask_test]

        IDs_only_1, IDs_only_2, IDs_both, idxs_only_1, idxs_only_2, idxs_both_1, idxs_both_2 = crossmatch_tools.crossmatch_IDs_two_datasets(
            TARGETIDs, np.array(JPAS_Ignasi["TARGETID"][JPAS_Ignasi["is_in_desi_dr1"]])
        )
        idxs_both_Ignasi = np.concatenate(idxs_both_2)
        if len(idxs_both_Ignasi) > 0:
            yy_true_Ignasi_bin = yy_true_Ignasi[idxs_both_Ignasi]
            f1_cbm = f1_score(yy_true_Ignasi_bin, np.argmax(yy_pred_P["CBM"][idxs_both_Ignasi], axis=1), average=None, zero_division=0)
            f1_trans = f1_score(yy_true_Ignasi_bin, np.argmax(yy_pred_P["TRANS"][idxs_both_Ignasi], axis=1), average=None, zero_division=0)
            f1_cbm_plot = f1_cbm.tolist() + [f1_cbm[0]]
            f1_trans_plot = f1_trans.tolist() + [f1_trans[0]]
            ax.plot(angles, f1_cbm_plot, color='purple', linestyle='-', linewidth=linewidth_mean)
            ax.plot(angles, f1_trans_plot, color='gold', linestyle='-', linewidth=linewidth_mean)

# Remove empty subplots
for i in range(num_bins, len(axs)):
    fig.delaxes(axs[i])

fig.tight_layout()
plt.show()
