In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup
from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools
from JPAS_DA.data import crossmatch_tools
from JPAS_DA.data import process_dset_splits
from JPAS_DA.data import data_loaders

from JPAS_DA.models import model_building_tools
from JPAS_DA.training import save_load_tools
from JPAS_DA.evaluation import evaluation_tools
from JPAS_DA.wrapper_wandb import wrapper_tools
from JPAS_DA.evaluation import wandb_evaluation_tools

import os
import torch
import numpy as np
from sklearn.manifold import TSNE

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib widget

from JPAS_DA.utils import aux_tools
aux_tools.set_seed(42)

In [None]:
sweeps_direct_path_name = "sweeps_direct"
sweeps_no_DA_path_name = "sweeps_no_DA"
sweeps_DA_path_name = "sweeps_DA"

In [None]:
path_wandb_sweep_direct = os.path.join(global_setup.path_models, sweeps_direct_path_name)
N_selected_sweeps_direct = 1
sorted_list_sweep_names_direct, sorted_losses_direct = wandb_evaluation_tools.load_and_plot_sorted_sweeps(
    path_wandb_sweep_direct, max_runs_to_plot=N_selected_sweeps_direct
)
path_wandb_sweep_no_DA = os.path.join(global_setup.path_models, sweeps_no_DA_path_name)
N_selected_sweeps_no_DA = 1
sorted_list_sweep_names_no_DA, sorted_losses_no_DA = wandb_evaluation_tools.load_and_plot_sorted_sweeps(
    path_wandb_sweep_no_DA, max_runs_to_plot=N_selected_sweeps_no_DA
)
path_wandb_sweep_DA = os.path.join(global_setup.path_models, sweeps_DA_path_name)
N_selected_sweeps_DA = 2
sorted_list_sweep_names_DA, sorted_losses_DA = wandb_evaluation_tools.load_and_plot_sorted_sweeps(
    path_wandb_sweep_DA, max_runs_to_plot=N_selected_sweeps_DA
)

# Load the config files

In [None]:
paths_load_direct = [os.path.join(global_setup.path_models, sweeps_direct_path_name, sorted_list_sweep_names_direct[0])]
logging.info("🔍 Validating model configs...")
configs_direct = []
for path in paths_load_direct:
    _, config = wrapper_tools.load_and_massage_config_file(
        os.path.join(path, "config.yaml"), path
    )
    configs_direct.append(config)
config_ref_no_DA = configs_direct[0]
for i, cfg in enumerate(configs_direct[1:], 1):
    logging.debug(f"🔍 Comparing config 0 and config {i}")
    if not evaluation_tools.safe_compare(cfg['data'], config_ref_no_DA['data']):
        raise ValueError(f"🚫 Data config mismatch between model 0 and model {i}")
    

paths_load_no_DA = [os.path.join(global_setup.path_models, sweeps_no_DA_path_name, sorted_list_sweep_names_no_DA[0])]
logging.info("🔍 Validating model configs...")
configs_no_DA = []
for path in paths_load_no_DA:
    _, config = wrapper_tools.load_and_massage_config_file(
        os.path.join(path, "config.yaml"), path
    )
    configs_no_DA.append(config)
config_ref_no_DA = configs_no_DA[0]
for i, cfg in enumerate(configs_no_DA[1:], 1):
    logging.debug(f"🔍 Comparing config 0 and config {i}")
    if not evaluation_tools.safe_compare(cfg['data'], config_ref_no_DA['data']):
        raise ValueError(f"🚫 Data config mismatch between model 0 and model {i}")


paths_load_DA = [os.path.join(global_setup.path_models, sweeps_DA_path_name, sorted_list_sweep_names_DA[0])]
logging.info("🔍 Validating model configs...")
configs_DA = []
for path in paths_load_DA:
    _, config = wrapper_tools.load_and_massage_config_file(
        os.path.join(path, "config.yaml"), path
    )
    configs_DA.append(config)
config_ref_DA = configs_DA[0]
for i, cfg in enumerate(configs_no_DA[1:], 1):
    logging.debug(f"🔍 Comparing config 0 and config {i}")
    if not evaluation_tools.safe_compare(cfg['data'], config_ref_DA['data']):
        raise ValueError(f"🚫 Data config mismatch between model 0 and model {i}")


config_data = config_ref_DA["data"]

# Load the data

In [None]:
root_path = global_setup.DATA_path
load_JPAS_x_DESI_Raul   = global_setup.load_JPAS_x_DESI_Raul
load_DESI_mocks_Raul    = global_setup.load_DESI_mocks_Raul
load_Ignasi             = global_setup.load_Ignasi

random_seed_load = global_setup.default_seed

list_of_datasets_to_load = ["JPAS_x_DESI_Raul", "DESI_mocks_Raul", "Ignasi"]

config_dict_cleaning = config_data['cleaning_config']

dict_split_data_options = global_setup.dict_split_data_options

keys_xx = config_data['keys_xx']
keys_yy = ["SPECTYPE_int", "TARGETID", "DESI_FLUX_R"]

device = 'cpu'

In [None]:
DATA = loading_tools.load_data_bundle(
    root_path=root_path,
    include=list_of_datasets_to_load,
    JPAS_x_DESI_Raul={"datasets": load_JPAS_x_DESI_Raul},
    DESI_mocks_Raul={"datasets": load_DESI_mocks_Raul},
    Ignasi={"datasets": load_Ignasi},
    random_seed=random_seed_load,
)
DATA = cleaning_tools.clean_data_pipeline(DATA, config=config_dict_cleaning, in_place=True)

Dict_LoA = {"intersection": {}, "outersection": {}}

IDs1, IDs2, IDs12, \
Dict_LoA["outersection"]["DESI_mocks_Raul"], Dict_LoA["outersection"]["JPAS_x_DESI_Raul"], \
Dict_LoA["intersection"]["DESI_mocks_Raul"], Dict_LoA["intersection"]["JPAS_x_DESI_Raul"] = crossmatch_tools.crossmatch_IDs_two_datasets(
    DATA["DESI_mocks_Raul"]['all_pd']['TARGETID'], DATA["JPAS_x_DESI_Raul"]['all_pd']['TARGETID']
)

# Split the Lists of Arrays into training, validation, and testing sets
Dict_LoA_split = {"intersection":{}, "outersection":{}}

Dict_LoA_split["intersection"]["JPAS_x_DESI_Raul"] = process_dset_splits.split_LoA(
    Dict_LoA["intersection"]["JPAS_x_DESI_Raul"],
    train_ratio = dict_split_data_options["train_ratio_intersection"],
    val_ratio = dict_split_data_options["val_ratio_intersection"],
    test_ratio = dict_split_data_options["test_ratio_intersection"],
    seed = dict_split_data_options["random_seed_split_intersection"]
)
Dict_LoA_split["outersection"]["DESI_mocks_Raul"] = process_dset_splits.split_LoA(
    Dict_LoA["outersection"]["DESI_mocks_Raul"],
    train_ratio = dict_split_data_options["train_ratio_outersection"],
    val_ratio = dict_split_data_options["val_ratio_outersection"],
    test_ratio = dict_split_data_options["test_ratio_outersection"],
    seed = dict_split_data_options["random_seed_split_outersection"]
)

In [None]:
extract_dsets = [
    ("DESI_mocks_Raul", "outersection"),
    ("JPAS_x_DESI_Raul", "intersection")
]
xx = {}
yy = {}
for key_dset, key_xmatch in extract_dsets:
    xx[key_dset] = {}
    yy[key_dset] = {}
    for split in global_setup.splits:
        LoA_ = Dict_LoA_split[key_xmatch][key_dset].get(split, [])
        _, xx_, yy_ = process_dset_splits.extract_from_block_by_LoA(
            block=DATA[key_dset], LoA=LoA_, keys_xx=keys_xx, keys_yy=keys_yy
        )
        xx_batch = data_loaders.stack_features_from_dict_flattened(xx_, np.arange(len(np.concatenate(LoA_))))
        xx[key_dset][split] = torch.tensor(xx_batch, dtype=torch.float32, device=device)
        yy[key_dset][split] = yy_


key_dset = "Ignasi"
split = "all"
xx[key_dset] = {}
yy[key_dset] = {}
LoA_, xx_, yy_ = process_dset_splits.extract_from_block_by_LoA(
    block=DATA[key_dset],
    LoA=np.arange(DATA["Ignasi"]['all_observations'].shape[0])[:, None].tolist(),
    keys_xx=keys_xx,
    keys_yy=keys_yy
)
xx_batch = data_loaders.stack_features_from_dict_flattened(xx_, np.arange(len(np.concatenate(LoA_))))
xx[key_dset][split] = torch.tensor(xx_batch, dtype=torch.float32, device=device)
yy[key_dset][split] = yy_

# Load the models

In [None]:
path_load_encoder_direct = os.path.join(global_setup.path_models, paths_load_direct[0], "model_encoder.pt")
assert os.path.isfile(path_load_encoder_direct), f"❌ Encoder checkpoint not found: {path_load_encoder_direct}"
logging.info(f"📥 Loading encoder from checkpoint: {path_load_encoder_direct}")
_, model_encoder_direct = save_load_tools.load_model_from_checkpoint(
    path_load_encoder_direct, model_building_tools.create_mlp
)
model_encoder_direct.eval()
model_encoder_direct.to(device)

path_load_downstream_direct = os.path.join(global_setup.path_models, paths_load_direct[0], "model_downstream.pt")
assert os.path.isfile(path_load_downstream_direct), f"❌ Downstream checkpoint not found: {path_load_downstream_direct}"
logging.info(f"📥 Loading downstream model from checkpoint: {path_load_downstream_direct}")
_, model_downstream_direct = save_load_tools.load_model_from_checkpoint(
    path_load_downstream_direct, model_building_tools.create_mlp
)
model_downstream_direct.eval()
model_downstream_direct.to(device)




path_load_encoder_no_DA = os.path.join(global_setup.path_models, paths_load_no_DA[0], "model_encoder.pt")
assert os.path.isfile(path_load_encoder_no_DA), f"❌ Encoder checkpoint not found: {path_load_encoder_no_DA}"
logging.info(f"📥 Loading encoder from checkpoint: {path_load_encoder_no_DA}")
_, model_encoder_no_DA = save_load_tools.load_model_from_checkpoint(
    path_load_encoder_no_DA, model_building_tools.create_mlp
)
model_encoder_no_DA.eval()
model_encoder_no_DA.to(device)

path_load_downstream_no_DA = os.path.join(global_setup.path_models, paths_load_no_DA[0], "model_downstream.pt")
assert os.path.isfile(path_load_downstream_no_DA), f"❌ Downstream checkpoint not found: {path_load_downstream_no_DA}"
logging.info(f"📥 Loading downstream model from checkpoint: {path_load_downstream_no_DA}")
_, model_downstream_no_DA = save_load_tools.load_model_from_checkpoint(
    path_load_downstream_no_DA, model_building_tools.create_mlp
)
model_downstream_no_DA.eval()
model_downstream_no_DA.to(device)




path_load_encoder_DA = os.path.join(global_setup.path_models, paths_load_DA[0], "model_encoder.pt")
assert os.path.isfile(path_load_encoder_DA), f"❌ Encoder checkpoint not found: {path_load_encoder_DA}"
logging.info(f"📥 Loading encoder from checkpoint: {path_load_encoder_DA}")
_, model_encoder_DA = save_load_tools.load_model_from_checkpoint(
    path_load_encoder_DA, model_building_tools.create_mlp, use_batchnorm = False
)
model_encoder_DA.eval()
model_encoder_DA.to(device)

path_load_downstream_DA = os.path.join(global_setup.path_models, paths_load_DA[0], "model_downstream.pt")
assert os.path.isfile(path_load_downstream_DA), f"❌ Downstream checkpoint not found: {path_load_downstream_DA}"
logging.info(f"📥 Loading downstream model from checkpoint: {path_load_downstream_DA}")
_, model_downstream_DA = save_load_tools.load_model_from_checkpoint(
    path_load_downstream_DA, model_building_tools.create_mlp
)
model_downstream_DA.eval()
model_downstream_DA.to(device)

# Compute the results

In [None]:
extract_dsets = [
    ("DESI_mocks_Raul", "outersection"),
    ("JPAS_x_DESI_Raul", "intersection")
]
features = {"direct" : {}, "no-DA" : {}, "DA" : {}}
probs = {"direct" : {}, "no-DA" : {}, "DA" : {}}
labels = {"direct" : {}, "no-DA" : {}, "DA" : {}}
for key_dset, key_xmatch in extract_dsets:

    features["direct"][key_dset] = {}
    probs["direct"][key_dset] = {}
    labels["direct"][key_dset] = {}

    features["no-DA"][key_dset] = {}
    probs["no-DA"][key_dset] = {}
    labels["no-DA"][key_dset] = {}

    features["DA"][key_dset] = {}
    probs["DA"][key_dset] = {}
    labels["DA"][key_dset] = {}

    for split in global_setup.splits:

        xx_input = xx[key_dset][split]
        with torch.no_grad():
            features_ = model_encoder_direct(xx_input)
            logits_ = model_downstream_direct(features_)
            probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
        features["direct"][key_dset][split] = features_.cpu().numpy()
        probs["direct"][key_dset][split] = probs_
        labels["direct"][key_dset][split] = np.argmax(probs_, axis=1)

        xx_input = xx[key_dset][split]
        with torch.no_grad():
            features_ = model_encoder_no_DA(xx_input)
            logits_ = model_downstream_no_DA(features_)
            probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
        features["no-DA"][key_dset][split] = features_.cpu().numpy()
        probs["no-DA"][key_dset][split] = probs_
        labels["no-DA"][key_dset][split] = np.argmax(probs_, axis=1)

        xx_input = xx[key_dset][split]
        with torch.no_grad():
            features_ = model_encoder_DA(xx_input)
            logits_ = model_downstream_DA(features_)
            probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
        features["DA"][key_dset][split] = features_.cpu().numpy()
        probs["DA"][key_dset][split] = probs_
        labels["DA"][key_dset][split] = np.argmax(probs_, axis=1)



key_dset = "Ignasi"
split = "all"

features["direct"][key_dset] = {}
probs["direct"][key_dset] = {}
labels["direct"][key_dset] = {}

features["no-DA"][key_dset] = {}
probs["no-DA"][key_dset] = {}
labels["no-DA"][key_dset] = {}

features["DA"][key_dset] = {}
probs["DA"][key_dset] = {}
labels["DA"][key_dset] = {}


xx_input = xx[key_dset][split]
with torch.no_grad():
    features_ = model_encoder_direct(xx_input)
    logits_ = model_downstream_direct(features_)
    probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
features["direct"][key_dset][split] = features_.cpu().numpy()
probs["direct"][key_dset][split] = probs_
labels["direct"][key_dset][split] = np.argmax(probs_, axis=1)

xx_input = xx[key_dset][split]
with torch.no_grad():
    features_ = model_encoder_no_DA(xx_input)
    logits_ = model_downstream_no_DA(features_)
    probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
features["no-DA"][key_dset][split] = features_.cpu().numpy()
probs["no-DA"][key_dset][split] = probs_
labels["no-DA"][key_dset][split] = np.argmax(probs_, axis=1)

xx_input = xx[key_dset][split]
with torch.no_grad():
    features_ = model_encoder_DA(xx_input)
    logits_ = model_downstream_DA(features_)
    probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
features["DA"][key_dset][split] = features_.cpu().numpy()
probs["DA"][key_dset][split] = probs_
labels["DA"][key_dset][split] = np.argmax(probs_, axis=1)

# metrics

In [None]:
class_names = list(global_setup.config_dict_cleaning["encoding"]["shared_mappings"]["SPECTYPE"].keys())

In [None]:
dict_radar = {
    "JPAS Obs. Fully_Supervised": {
        "y_true": yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["direct"]["DESI_mocks_Raul"]["test"],
        "plot_kwargs": {
            "linestyle": "-", "linewidth": 1.0, "color": "k",
            "marker": "s", "markersize": 10.0, "fill_alpha": 0.01,
            "label": "JPAS Obs. Fully_Supervised"
        }
    },
    "Mocks no-DA": {
        "y_true": yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["DESI_mocks_Raul"]["test"],
        "plot_kwargs": {
            "linestyle": "--", "linewidth": 2.0, "color": "royalblue",
            "marker": "^", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "Mocks no-DA"
        }
    },
    "JPAS Obs. no-DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {
            "linestyle": "-.", "linewidth": 2.0, "color": "crimson",
            "marker": "X", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "JPAS Obs. no-DA"
        }
    },
    "JPAS Obs. (Train) DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["train"]["SPECTYPE_int"],
        "y_pred": probs["DA"]["JPAS_x_DESI_Raul"]["train"],
        "plot_kwargs": {
            "linestyle": ":", "linewidth": 2.0, "color": "orange",
            "marker": "+", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "JPAS Obs. (Train) DA"
        }
    },
    "JPAS Obs. DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {
            "linestyle": "-", "linewidth": 2.0, "color": "limegreen",
            "marker": "o", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "JPAS Obs. DA"
        }
    },
}

evaluation_tools.plot_confusion_matrix(
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["direct"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names,
    cmap=plt.cm.RdYlGn, title="JPAS Obs. (Fully_Supervised)"
)
evaluation_tools.plot_confusion_matrix(
    yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["DESI_mocks_Raul"]["test"],
    class_names=class_names,
    cmap=plt.cm.RdYlGn, title="Mocks (no-DA)"
)
evaluation_tools.plot_confusion_matrix(
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names,
    cmap=plt.cm.RdYlGn, title="JPAS Obs. (no-DA)"
)
evaluation_tools.plot_confusion_matrix(
    yy["JPAS_x_DESI_Raul"]["train"]["SPECTYPE_int"], probs["DA"]["JPAS_x_DESI_Raul"]["train"],
    class_names=class_names,
    cmap=plt.cm.RdYlGn, title="JPAS Obs.-Train (DA)"
)
evaluation_tools.plot_confusion_matrix(
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["DA"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names,
    cmap=plt.cm.RdYlGn, title="JPAS Obs. (DA)"
)
fig, ax = evaluation_tools.radar_plot(
    dict_radar=dict_radar, class_names=class_names,
    title="F1 Radar Plot", figsize=(8, 8), theta_offset=np.pi / 2, # first axis at 12 o'clock
    r_ticks=(0.1, 0.3, 0.5, 0.7, 0.9), r_lim=(0.0, 1.0),
    tick_labelsize=16, radial_labelsize=12, show_legend=True,
    legend_kwargs={
        "loc": "upper left", "bbox_to_anchor": (0.73, 1.0), "fontsize": 9, "ncol": 1,
        "title": "Evaluation Cases", "frameon": True, "fancybox": True, "shadow": True, "borderaxespad": 0.0,
    },
)
fig.savefig(os.path.join(global_setup.path_saved_figures, "F1_radar.pdf"), bbox_inches='tight')
plt.show()
plt.show()

evaluation_tools.compare_TPR_confusion_matrices(
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["DA"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names, figsize=(10, 7),
    cmap='seismic', title='TPR Comparison: DA vs no-DA', name_1 = "no-DA", name_2 = "DA"
)
metrics = evaluation_tools.compare_sets_performance(
    yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["DESI_mocks_Raul"]["test"],
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names, name_1="Mocks", name_2="no-DA: JPAS Obs.",
    y_max_Delta_F1=0.3, y_min_Delta_F1=-0.3, title_fontsize=20, color='k'
)
metrics = evaluation_tools.compare_sets_performance(
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["DA"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names, name_1="no-DA JPAS Obs.", name_2="DA",
    y_max_Delta_F1=0.3, y_min_Delta_F1=-0.3, color='k'
)
metrics = evaluation_tools.compare_sets_performance(
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["direct"]["JPAS_x_DESI_Raul"]["test"],
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["DA"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names, name_1="Fully_Supervised JPAS Obs.", name_2="DA",
    y_max_Delta_F1=0.3, y_min_Delta_F1=-0.3, color='k'
)

comparisons = [
    (yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["DESI_mocks_Raul"]["test"], yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["JPAS_x_DESI_Raul"]["test"], "JPAS Obs. vs Mocks no-DA"),
    (yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["direct"]["JPAS_x_DESI_Raul"]["test"], yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["DA"]["JPAS_x_DESI_Raul"]["test"],   "DA vs Fully_Supervised JPAS Obs."),
    (yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["JPAS_x_DESI_Raul"]["test"], yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["DA"]["JPAS_x_DESI_Raul"]["test"],   "DA vs no-DA JPAS Obs."),
]
fig, ax, deltas = evaluation_tools.plot_overall_deltaF1_grouped(
    comparisons,
    class_names=class_names,
    colors=["crimson", "darkorange", "limegreen"],  # extend as needed
    title=None,
    figsize=(8, 6),
    legend_kwargs={"loc":"upper right", "frameon":True, "fontsize": 12},
    save_dir=global_setup.path_saved_figures, save_format="pdf", filename="delta_F1"
)

In [None]:
# ---- Font & style config (same as before; tweak if desired) ----
FS_TITLE = 24
FS_LABEL = 20
FS_TICKS = 18
FS_CELL  = 14
FS_CELL_DIAG = 14
FS_CBAR_LABEL = 20
FS_CBAR_TICKS = 16
TICK_ROT = 20

threshold_color = 0.5
cmap = plt.cm.RdYlGn

# ---- Cases: (title, y_true, y_pred_probs) ----
cases = [
    ("Mocks (no-DA)",
     yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
     probs["no-DA"]["DESI_mocks_Raul"]["test"]),

    ("JPAS Obs. (no-DA)",
     yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
     probs["no-DA"]["JPAS_x_DESI_Raul"]["test"]),

    ("JPAS Obs. (Fully_Supervised)",
     yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
     probs["direct"]["JPAS_x_DESI_Raul"]["test"]),

    ("JPAS Obs. (DA)",
     yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
     probs["DA"]["JPAS_x_DESI_Raul"]["test"]),
]

# Classes & color normalization
n_classes = len(class_names)
norm = mpl.colors.Normalize(vmin=0.0, vmax=1.0)

# ---- 2×2 subpanels + thin colorbar column (compact layout) ----
fig = plt.figure(figsize=(18, 16), constrained_layout=True, dpi=150)
gs = fig.add_gridspec(
    2, 3,
    width_ratios=[1, 1, 0.05],  # thin colorbar on the right
    height_ratios=[1, 1],
    wspace=0.02, hspace=0.04
)

ax00 = fig.add_subplot(gs[0, 0])
ax01 = fig.add_subplot(gs[0, 1], sharex=ax00, sharey=ax00)
ax10 = fig.add_subplot(gs[1, 0], sharex=ax00, sharey=ax00)
ax11 = fig.add_subplot(gs[1, 1], sharex=ax00, sharey=ax00)
axes = np.array([[ax00, ax01], [ax10, ax11]])
cax  = fig.add_subplot(gs[:, 2])

# (optional) nudge constrained_layout pads a bit tighter
fig.set_constrained_layout_pads(w_pad=0.01, h_pad=0.01, wspace=0.02, hspace=0.02)

# Shared limits/ticks (reverse y to keep row 0 at the top with origin='upper')
xlim = (-0.5, n_classes - 0.5)
ylim = (n_classes - 0.5, -0.5)
ticks = np.arange(n_classes)

# Shared colorbar mappable
mappable_for_cbar = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
mappable_for_cbar.set_array([])

# ---- Draw panels ----
for idx, (ax, (title, yy_true, yy_pred_P)) in enumerate(zip(axes.ravel(), cases)):
    r, c = divmod(idx, 2)  # row, col in the 2×2 grid

    yy_true = np.asarray(yy_true).astype(int)
    yy_pred = np.argmax(yy_pred_P, axis=1).astype(int)

    # Confusion matrix including all classes
    cm = np.zeros((n_classes, n_classes), dtype=int)
    valid = (yy_true >= 0) & (yy_true < n_classes)
    for t, p in zip(yy_true[valid], yy_pred[valid]):
        if 0 <= t < n_classes and 0 <= p < n_classes:
            cm[t, p] += 1

    # Row-normalized proportions
    row_sums = cm.sum(axis=1, keepdims=True)
    cm_percent = np.divide(cm, row_sums, where=row_sums != 0)

    # Heatmap
    ax.imshow(cm_percent, interpolation='nearest', cmap=cmap, norm=norm, origin='upper')

    # Shared axes styling
    ax.set_xlim(xlim); ax.set_ylim(ylim)
    ax.set_xticks(ticks); ax.set_yticks(ticks)
    ax.set_xticklabels(class_names, fontsize=FS_TICKS)
    ax.set_yticklabels(class_names, fontsize=FS_TICKS)
    ax.set_aspect('equal', adjustable='box')
    ax.set_title(title, fontsize=FS_TITLE, pad=20)

    # Per-class metrics for diagonal annotation
    precision = np.zeros(n_classes, dtype=float)
    recall    = np.zeros(n_classes, dtype=float)
    f1        = np.zeros(n_classes, dtype=float)
    for i in range(n_classes):
        tp = cm[i, i]
        fp = cm[:, i].sum() - tp
        fn = cm[i, :].sum() - tp
        precision[i] = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall[i]    = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1[i]        = (2 * precision[i] * recall[i] / (precision[i] + recall[i])
                        if (precision[i] + recall[i]) > 0 else 0.0)

    # Cell annotations (counts + row-%; diagonal shows TPR/PPV/F1)
    for i in range(n_classes):
        for j in range(n_classes):
            count   = cm[i, j]
            percent = cm_percent[i, j] * 100 if row_sums[i, 0] != 0 else 0.0
            text_color = "white" if cm_percent[i, j] > threshold_color else "black"

            if i == j:
                text = (f"{count}\n"
                        f"TPR:{recall[i]*100:.1f}% "
                        f"\nPPV:{precision[i]*100:.1f}% "
                        f"\nF1:{f1[i]:.2f}")
                ax.text(j, i, text, ha="center", va="center",
                        color=text_color, fontsize=FS_CELL_DIAG, fontweight='bold', linespacing=1.2)
            else:
                text = f"{count}\n{percent:.1f}%"
                ax.text(j, i, text, ha="center", va="center",
                        color=text_color, fontsize=FS_CELL, linespacing=1.2)

    # --- Outer-edge labels only (compact layout) ---
    # Y-labels only on the left column
    if c == 0:
        ax.set_ylabel('True Label', fontsize=FS_LABEL, labelpad=6)
    else:
        plt.setp(ax.get_yticklabels(), visible=False)
        ax.set_ylabel("")

    # X-labels only on the bottom row
    if r == 1:
        ax.set_xlabel('Predicted Label', fontsize=FS_LABEL, labelpad=6)
        plt.setp(ax.get_xticklabels(), rotation=TICK_ROT, ha="right", rotation_mode="anchor")
    else:
        plt.setp(ax.get_xticklabels(), visible=False)
        ax.set_xlabel("")

# ---- Shared colorbar on the RIGHT ----
cbar = fig.colorbar(mappable_for_cbar, cax=cax)
cbar.set_label("True-label (row) normalized ratio", fontsize=FS_CBAR_LABEL)
cbar.ax.tick_params(labelsize=FS_CBAR_TICKS)
cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])

# Save/show
out_path = os.path.join(global_setup.path_saved_figures, "confusion_matrices.pdf")
plt.savefig(out_path, bbox_inches="tight")
plt.show()

# Latents

In [None]:
feat_dict = {
    "latents_no_DA_Source": features["no-DA"]["DESI_mocks_Raul"]["test"],
    "latents_no_DA_Target": features["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    "latents_DA_Target": features["DA"]["JPAS_x_DESI_Raul"]["test"]
}

latents_tSNE = evaluation_tools.tsne_per_key(
    feat_dict,
    standardize=False,
    subsample=None,
    random_state=137,
    tsne_kwargs={"perplexity": 100},
    return_all_key=None,
)

In [None]:
xlim = (-150, 150)
ylim = (-150, 150)

evaluation_tools.plot_latents_scatter_val_test(
    X_val=latents_tSNE['latents_no_DA_Source_tSNE'], y_val=yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
    X_test=latents_tSNE['latents_no_DA_Target_tSNE'], y_test=yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    class_names=class_names,
    title="Latents no-DA: Source (Mocks) vs Target (JPAS x DESI obs.)",
    marker_val="o", marker_test="^",
    size_val=14, size_test=14, alpha_val=0.7, alpha_test=0.7,
    xlim=xlim, ylim=ylim,
    subsample=4000, seed=137,
    edgecolor=None, linewidths=0.0,
    legend_split_1="Source (Mocks) no-DA",
    legend_split_2="Target (JPAS x DESI obs.) no-DA"
)
evaluation_tools.plot_latents_scatter(
    latents_tSNE['latents_no_DA_Source_tSNE'], yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
    class_counts=None,
    class_names=class_names,
    title="Latents no-DA: Source (Mocks)",
    n_bins=128, sigma=2.0,
    scatter_size=0.003, scatter_alpha=0.3,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_no_DA_Source_tSNE'],
    title="Latents no-DA: Source (Mocks)",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latents_scatter(
    latents_tSNE['latents_no_DA_Target_tSNE'], yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    class_counts=None,
    class_names=class_names,
    title="Latents no-DA: Target (JPAS x DESI obs.)",
    n_bins=128, sigma=2.0,
    scatter_size=0.1, scatter_alpha=0.5,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_no_DA_Target_tSNE'],
    title="Latents no-DA: Target (JPAS x DESI obs.)",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)

evaluation_tools.plot_latents_scatter_val_test(
    X_val=latents_tSNE['latents_no_DA_Target_tSNE'], y_val=yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    X_test=latents_tSNE['latents_DA_Target_tSNE'], y_test=yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    class_names=class_names,
    title="Latents Target (JPAS x DESI obs.): no-DA vs DA",
    marker_val="o", marker_test="^",
    size_val=8, size_test=8, alpha_val=0.7, alpha_test=0.7,
    xlim=xlim, ylim=ylim,
    subsample=None, seed=137,
    edgecolor=None, linewidths=0.0,
    legend_split_1="Target (JPAS x DESI obs.) no-DA",
    legend_split_2="Target (JPAS x DESI obs.) DA"
)
evaluation_tools.plot_latents_scatter(
    latents_tSNE['latents_DA_Target_tSNE'], yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    class_counts=None,
    class_names=class_names,
    title="Latents DA: Target (JPAS x DESI obs.)",
    n_bins=128, sigma=2.0,
    scatter_size=0.1, scatter_alpha=0.5,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_DA_Target_tSNE'],
    title="Latents DA: Target (JPAS x DESI obs.)",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)

In [None]:
feat_dict = {
    "latents_DA_JPAS_x_DESI": features["DA"]["JPAS_x_DESI_Raul"]["test"],
    "latents_DA_all_JPAS": features["DA"]["Ignasi"]["all"],
}

subsample = features["DA"]["JPAS_x_DESI_Raul"]["test"].shape[0]

latents_tSNE = evaluation_tools.tsne_per_key(
    feat_dict,
    standardize=False,
    subsample={"latents_DA_JPAS_x_DESI": subsample, "latents_DA_all_JPAS": subsample},
    random_state=137,
    tsne_kwargs={"perplexity": 100},
    return_all_key=None,
)

In [None]:
xlim = (-100, 100)
ylim = (-100, 100)

evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_DA_JPAS_x_DESI_tSNE'],
    title="Latents JPAS x DESI",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_DA_all_JPAS_tSNE'],
    title="Latents all JPAS",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)

In [None]:
_, _, IDs, _, _, LoA_intersection_JPAS_x_DESI_Raul_test, LoA_intersection_Ignasi_all = crossmatch_tools.crossmatch_IDs_two_datasets(
    yy['JPAS_x_DESI_Raul']['test']['TARGETID'], DATA['Ignasi']['all_pd']['TARGETID']
)
_, _, _, _, _, LoA_intersection_JPAS_x_DESI_Raul_DATA_test, _ = crossmatch_tools.crossmatch_IDs_two_datasets(
    DATA['JPAS_x_DESI_Raul']['all_pd']['TARGETID'], IDs
)

In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# ───── Config ─────
survey_jpas   = "JPAS_x_DESI_Raul"
survey_ignasi = "Ignasi"
NN_plot = 1000
rng = np.random.default_rng(0)
USE_LOG_Y = True

# Filtering params
RATIO_DEV_THR = 50.     # keep if max |ratio - 1| > this (ratio = log10(Ignasi)/log10(JPAS))
MIN_VALID_PTS = 1        # require at least this many finite ratio points

# ---- helpers to normalize index containers ----
def iter_indices(x):
    if x is None:
        return
    if isinstance(x, (list, tuple, np.ndarray)):
        for v in x:
            if v is None: 
                continue
            yield int(v)
    else:
        yield int(x)

def first_index(x):
    for v in iter_indices(x):
        return v
    return None

def compute_ratio_for_id(idx_ID):
    """Return ratio array (or None) for first occurrence in each survey."""
    occ_jpas   = LoA_intersection_JPAS_x_DESI_Raul_DATA_test[idx_ID]
    occ_ignasi = LoA_intersection_Ignasi_all[idx_ID]
    i_j = first_index(occ_jpas)
    i_i = first_index(occ_ignasi)
    if i_j is None or i_i is None:
        return None

    obs_j = np.asarray(DATA[survey_jpas]['all_observations'][i_j], dtype=float)
    obs_i = np.asarray(DATA[survey_ignasi]['all_observations'][i_i], dtype=float)

    j_safe = np.log10(np.clip(obs_j, 1e-12, None))
    i_safe = np.log10(np.clip(obs_i, 1e-12, None))
    denom  = np.where(j_safe == 0, np.nan, j_safe)
    ratio  = i_safe / denom
    return ratio

def id_passes_threshold(idx_ID, thr=RATIO_DEV_THR):
    ratio = compute_ratio_for_id(idx_ID)
    if ratio is None:
        return False
    valid = np.isfinite(ratio)
    if valid.sum() < MIN_VALID_PTS:
        return False
    max_dev = np.nanmax(np.abs(ratio[valid] - 1.0))
    return bool(max_dev > thr)

# ───── sampling ─────
assert len(IDs) == len(LoA_intersection_JPAS_x_DESI_Raul_DATA_test) == len(LoA_intersection_Ignasi_all), \
    "Lengths of IDs and intersection lookups must match."

NN_eff = min(NN_plot, len(IDs))
sample_all = rng.choice(len(IDs), NN_eff, replace=False)

# Filter by threshold
filtered_idxs = [idx for idx in sample_all if id_passes_threshold(idx)]
if not filtered_idxs:
    print(f"ℹ️ No IDs exceeded threshold (|ratio-1| > {RATIO_DEV_THR}). Nothing to plot.")
    filtered_idxs = []  # keep empty; the rest will no-op

# colors/linestyles
colors = plt.cm.plasma(np.linspace(0.08, 0.92, max(1, len(filtered_idxs))))
linestyles = {survey_jpas: "--", survey_ignasi: "-"}

# figure (main + ratio)
fig, (ax, ax_ratio) = plt.subplots(
    2, 1, figsize=(9, 8), height_ratios=[3, 1], sharex=True, gridspec_kw={'hspace': 0.06}
)
ax.set_ylabel(r'Flux [arb. units]', fontsize=18)
ax_ratio.set_xlabel(r'$\mathrm{Filter~Index}$', fontsize=18)
ax_ratio.set_ylabel(r'$\log_{10}(\mathrm{Ignasi}) / \log_{10}(\mathrm{JPAS})$', fontsize=13)

# legends scaffolding
survey_handles = [
    mpl.lines.Line2D([0], [0], color="gray", linestyle=linestyles[survey_jpas], lw=2, label="JPAS×DESI"),
    mpl.lines.Line2D([0], [0], color="gray", linestyle=linestyles[survey_ignasi], lw=2, label="Ignasi"),
]
id_handles = []

# ───── plotting ─────
for j, idx_ID in enumerate(filtered_idxs):
    color = colors[j]
    tid_label = str(IDs[idx_ID])
    id_handles.append(mpl.lines.Line2D([0], [0], color=color, lw=3, label=f"ID {tid_label}"))

    # occurrences
    occ_jpas   = LoA_intersection_JPAS_x_DESI_Raul_DATA_test[idx_ID]
    occ_ignasi = LoA_intersection_Ignasi_all[idx_ID]

    # JPAS: plot every occurrence
    for ii in iter_indices(occ_jpas):
        obs = np.asarray(DATA[survey_jpas]['all_observations'][ii], dtype=float)
        x = np.arange(obs.size)
        ax.plot(x, obs, linestyle=linestyles[survey_jpas], lw=2.0, marker='o', ms=3.0,
                color=color, alpha=0.9)

    # Ignasi: plot every occurrence
    for ii in iter_indices(occ_ignasi):
        obs = np.asarray(DATA[survey_ignasi]['all_observations'][ii], dtype=float)
        x = np.arange(obs.size)
        ax.plot(x, obs, linestyle=linestyles[survey_ignasi], lw=2.0, marker='o', ms=3.0,
                color=color, alpha=0.9)

    # ratio: first occurrence each (already validated by filter)
    ratio = compute_ratio_for_id(idx_ID)
    if ratio is not None:
        ax_ratio.plot(np.arange(ratio.size), ratio, color=color, lw=1.8, alpha=0.95)

# ───── styling ─────
ax.tick_params(axis='both', labelsize=12)
ax_ratio.axhline(1.0, ls='--', lw=1.0, color='black', alpha=0.6)
ax_ratio.tick_params(axis='both', labelsize=11)

leg0 = ax.legend(handles=survey_handles, loc='upper left', fontsize=12,
                 fancybox=True, shadow=True, title="Survey", title_fontsize=13)
ax.add_artist(leg0)

if id_handles:
    leg1 = ax.legend(handles=id_handles, loc='upper right', fontsize=11,
                     fancybox=True, shadow=True, title=f"Sampled IDs (>|ratio-1|>{RATIO_DEV_THR})", title_fontsize=12)
    ax.add_artist(leg1)

if USE_LOG_Y:
    ax.set_yscale("log")

plt.tight_layout()
plt.show()


In [None]:
feat_dict = {
    "latents_DA_Raul_test_x_Ignasi": features["DA"]["JPAS_x_DESI_Raul"]["test"][np.concatenate(LoA_intersection_JPAS_x_DESI_Raul_test)],
    "latents_DA_Ignasi_x_Raul_test": features["DA"]["Ignasi"]["all"][np.concatenate(LoA_intersection_Ignasi_all)],
}
latents_tSNE = evaluation_tools.tsne_per_key(
    feat_dict,
    standardize=False,
    subsample=None,
    random_state=137,
    tsne_kwargs={"perplexity": 100},
    return_all_key=None,
)

In [None]:
A = latents_tSNE['latents_DA_Raul_test_x_Ignasi_tSNE']
B = latents_tSNE['latents_DA_Ignasi_x_Raul_test_tSNE']

plt.figure(figsize=(7.5, 6))

plt.scatter(
    A[:, 0], A[:, 1],
    s=14, marker='X', alpha=0.9, edgecolors='none', color='royalblue'
)

plt.scatter(
    B[:, 0], B[:, 1],
    s=5, marker='o', alpha=0.4, edgecolors='none', color='crimson'
)

plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.25)
plt.tight_layout()
plt.show()

# Magnitude cuts

In [None]:
%autoreload

In [None]:
magnitude_key="DESI_FLUX_R"
mag_bin_edges=(17, 19, 21, 22, 22.5)
output_key="MAG_BIN_ID"

magnitude_ranges = [(mag_bin_edges[i], mag_bin_edges[i+1]) for i in range(len(mag_bin_edges)-1)]
colors = ['blue', 'green', 'orange', 'red']
colormaps = [plt.cm.Blues, plt.cm.Greens, plt.cm.YlOrBr, plt.cm.Reds]

In [None]:
magnitudes_plot = {
    ("Mocks", "Test"): -2.5 * np.log10(yy["DESI_mocks_Raul"]["test"]['DESI_FLUX_R']) + 22.5,
    ("JPAS x DESI", "Train"): -2.5 * np.log10(yy["JPAS_x_DESI_Raul"]["train"]['DESI_FLUX_R']) + 22.5,
    ("JPAS x DESI", "Test"): -2.5 * np.log10(yy["JPAS_x_DESI_Raul"]["test"]['DESI_FLUX_R']) + 22.5
}
labels_plot = {
    ("Mocks", "Test"): yy["DESI_mocks_Raul"]["test"]['SPECTYPE_int'],         
    ("JPAS x DESI", "Train"): yy["JPAS_x_DESI_Raul"]["train"]['SPECTYPE_int'], 
    ("JPAS x DESI", "Test"): yy["JPAS_x_DESI_Raul"]["test"]['SPECTYPE_int']
}
masks_magnitudes, stats_magnitudes = plotting_utils.plot_histogram_with_ranges_multiple(
    magnitudes_plot, ranges=magnitude_ranges, colors=colors, bins=42,
    x_label="DESI Magnitude (R)", title="DESI R-band Magnitudes",
    labels_dict=labels_plot, class_names=class_names, pct_decimals=0,
    annotate_mode='text', legend_fontsize=12
)

stats_magnitudes[("Mocks", "Test")]["plot_kwargs"] = {
    "linestyle": "--", "marker": "+", "markersize": 8.0,
    "label": "Source-Test (Mocks) no-DA",
}
stats_magnitudes[("JPAS x DESI", "Test")]["plot_kwargs"] = {
    "linestyle": "-.", "marker": "x", "markersize": 8.0,
    "label": "Target-Test (JPAS x DESI) no-DA",
}
stats_magnitudes[("JPAS x DESI", "Train")]["plot_kwargs"] = {
    "linestyle": ":", "marker": "^", "markersize": 8.0,
    "label": "Target-Training (JPAS x DESI) DA",
}
fig, ax = plotting_utils.plot_per_class_counts_together(
    stats_magnitudes,
    yscale="log",
    figsize=(12, 6),
    title="Per-class counts vs. magnitude (all entries together)",
    title_fontsize=18,  # title size
    class_legend_kwargs={
        "loc": "upper left", "bbox_to_anchor": (0.0, 1.0),
        "fontsize": 10, "title": "Classes", "frameon": True, "fancybox": True, "shadow": True
    },
    entry_legend_kwargs={
        "loc": "upper left", "bbox_to_anchor": (0.65, 1.0),
        "fontsize": 9, "ncol": 1, "title": "Evaluation Cases", "frameon": True, "fancybox": True, "shadow": True
    },
    legend_outside=True,
)
plt.show()

In [None]:
%autoreload

In [None]:
evaluation_tools.evaluate_all_plots_by_mag_bins(
    masks_magnitudes=masks_magnitudes,
    yy=yy,
    probs=probs,
    class_names=class_names,
    dict_radar_styles=dict_radar,   # only used to copy the linestyles/markers
    radar_title_base="F1 Radar Plot",
    radar_kwargs={
        "figsize": (8, 8),
        "theta_offset": np.pi / 2,
        "r_ticks": (0.1, 0.3, 0.5, 0.7, 0.9),
        "r_lim": (0.0, 1.0),
        "tick_labelsize": 16,
        "radial_labelsize": 12,
        "show_legend": True,
        "legend_kwargs": {
            "loc": "upper left", "bbox_to_anchor": (0.73, 1.0), "fontsize": 9, "ncol": 1,
            "title": "Evaluation Cases", "frameon": True, "fancybox": True, "shadow": True, "borderaxespad": 0.0,
        },
        "title_pad": 20,
    },
    colors=colors,
    colormaps=colormaps,
    show=False,                 # don't pop up windows
    save_dir=global_setup.path_saved_figures,  # <<-- save here
    save_format="pdf",
    save_dpi=250,
    close_after_save=True,      # free memory in large loops
)