In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup
from JPAS_DA import global_setup
from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools
from JPAS_DA.data import crossmatch_tools
from JPAS_DA.data import process_dset_splits
from JPAS_DA.data import data_loaders

from JPAS_DA.models import model_building_tools
from JPAS_DA.training import save_load_tools
from JPAS_DA.evaluation import evaluation_tools
from JPAS_DA.wrapper_wandb import wrapper_tools
from JPAS_DA.evaluation import evaluation_tools
from JPAS_DA.evaluation import wandb_evaluation_tools

import os
import torch
import numpy as np
from sklearn.manifold import TSNE

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib widget

from JPAS_DA.utils import aux_tools
aux_tools.set_seed(42)

In [None]:
path_wandb_sweep_no_DA = os.path.join(global_setup.path_models, "sweeps_no_DA")
N_selected_sweeps_no_DA = 1
sorted_list_sweep_names_no_DA, sorted_losses_no_DA = wandb_evaluation_tools.load_and_plot_sorted_sweeps(
    path_wandb_sweep_no_DA, max_runs_to_plot=N_selected_sweeps_no_DA
)
path_wandb_sweep_DA = os.path.join(global_setup.path_models, "sweeps_DA")
N_selected_sweeps_DA = 2
sorted_list_sweep_names_DA, sorted_losses_DA = wandb_evaluation_tools.load_and_plot_sorted_sweeps(
    path_wandb_sweep_DA, max_runs_to_plot=N_selected_sweeps_DA
)

# Load the config files

In [None]:
paths_load_no_DA = [os.path.join(global_setup.path_models, "sweeps_no_DA", sorted_list_sweep_names_no_DA[0])]
logging.info("🔍 Validating model configs...")
configs_no_DA = []
for path in paths_load_no_DA:
    _, config = wrapper_tools.load_and_massage_config_file(
        os.path.join(path, "config.yaml"), path
    )
    configs_no_DA.append(config)
config_ref_no_DA = configs_no_DA[0]
for i, cfg in enumerate(configs_no_DA[1:], 1):
    logging.debug(f"🔍 Comparing config 0 and config {i}")
    if not evaluation_tools.safe_compare(cfg['data'], config_ref_no_DA['data']):
        raise ValueError(f"🚫 Data config mismatch between model 0 and model {i}")


paths_load_DA = [os.path.join(global_setup.path_models, "sweeps_DA", sorted_list_sweep_names_DA[0])]
logging.info("🔍 Validating model configs...")
configs_DA = []
for path in paths_load_DA:
    _, config = wrapper_tools.load_and_massage_config_file(
        os.path.join(path, "config.yaml"), path
    )
    configs_DA.append(config)
config_ref_DA = configs_DA[0]
for i, cfg in enumerate(configs_no_DA[1:], 1):
    logging.debug(f"🔍 Comparing config 0 and config {i}")
    if not evaluation_tools.safe_compare(cfg['data'], config_ref_DA['data']):
        raise ValueError(f"🚫 Data config mismatch between model 0 and model {i}")


config_data = config_ref_DA["data"]

# Load the data

In [None]:
root_path = global_setup.DATA_path
load_JPAS_x_DESI_Raul   = global_setup.load_JPAS_x_DESI_Raul
load_DESI_mocks_Raul    = global_setup.load_DESI_mocks_Raul
load_Ignasi             = global_setup.load_Ignasi

random_seed_load = global_setup.default_seed

list_of_datasets_to_load = ["JPAS_x_DESI_Raul", "DESI_mocks_Raul", "Ignasi"]

config_dict_cleaning = config_data['cleaning_config']

dict_split_data_options = global_setup.dict_split_data_options

keys_xx = config_data['keys_xx']
keys_yy = ["SPECTYPE_int", "TARGETID", "DESI_FLUX_R"]

device = 'cpu'

In [None]:
DATA = loading_tools.load_data_bundle(
    root_path=root_path,
    include=list_of_datasets_to_load,
    JPAS_x_DESI_Raul={"datasets": load_JPAS_x_DESI_Raul},
    DESI_mocks_Raul={"datasets": load_DESI_mocks_Raul},
    Ignasi={"datasets": load_Ignasi},
    random_seed=random_seed_load,
)
DATA = cleaning_tools.clean_data_pipeline(DATA, config=config_dict_cleaning, in_place=True)

Dict_LoA = {"intersection": {}, "outersection": {}}

IDs1, IDs2, IDs12, \
Dict_LoA["outersection"]["DESI_mocks_Raul"], Dict_LoA["outersection"]["JPAS_x_DESI_Raul"], \
Dict_LoA["intersection"]["DESI_mocks_Raul"], Dict_LoA["intersection"]["JPAS_x_DESI_Raul"] = crossmatch_tools.crossmatch_IDs_two_datasets(
    DATA["DESI_mocks_Raul"]['all_pd']['TARGETID'], DATA["JPAS_x_DESI_Raul"]['all_pd']['TARGETID']
)

# Split the Lists of Arrays into training, validation, and testing sets
Dict_LoA_split = {"intersection":{}, "outersection":{}}

Dict_LoA_split["intersection"]["JPAS_x_DESI_Raul"] = process_dset_splits.split_LoA(
    Dict_LoA["intersection"]["JPAS_x_DESI_Raul"],
    train_ratio = dict_split_data_options["train_ratio_intersection"],
    val_ratio = dict_split_data_options["val_ratio_intersection"],
    test_ratio = dict_split_data_options["test_ratio_intersection"],
    seed = dict_split_data_options["random_seed_split_intersection"]
)
Dict_LoA_split["outersection"]["DESI_mocks_Raul"] = process_dset_splits.split_LoA(
    Dict_LoA["outersection"]["DESI_mocks_Raul"],
    train_ratio = dict_split_data_options["train_ratio_outersection"],
    val_ratio = dict_split_data_options["val_ratio_outersection"],
    test_ratio = dict_split_data_options["test_ratio_outersection"],
    seed = dict_split_data_options["random_seed_split_outersection"]
)

In [None]:
extract_dsets = [
    ("DESI_mocks_Raul", "outersection"),
    ("JPAS_x_DESI_Raul", "intersection")
]
xx = {}
yy = {}
for key_dset, key_xmatch in extract_dsets:
    xx[key_dset] = {}
    yy[key_dset] = {}
    for split in global_setup.splits:
        LoA_ = Dict_LoA_split[key_xmatch][key_dset].get(split, [])
        _, xx_, yy_ = process_dset_splits.extract_from_block_by_LoA(
            block=DATA[key_dset], LoA=LoA_, keys_xx=keys_xx, keys_yy=keys_yy
        )
        xx_batch = data_loaders.stack_features_from_dict_flattened(xx_, np.arange(len(np.concatenate(LoA_))))
        xx[key_dset][split] = torch.tensor(xx_batch, dtype=torch.float32, device=device)
        yy[key_dset][split] = yy_


key_dset = "Ignasi"
split = "all"
xx[key_dset] = {}
yy[key_dset] = {}
LoA_, xx_, yy_ = process_dset_splits.extract_from_block_by_LoA(
    block=DATA[key_dset],
    LoA=np.arange(DATA["Ignasi"]['all_observations'].shape[0])[:, None].tolist(),
    keys_xx=keys_xx,
    keys_yy=keys_yy
)
xx_batch = data_loaders.stack_features_from_dict_flattened(xx_, np.arange(len(np.concatenate(LoA_))))
xx[key_dset][split] = torch.tensor(xx_batch, dtype=torch.float32, device=device)
yy[key_dset][split] = yy_

# Load the models

In [None]:
_, model_encoder_no_DA = save_load_tools.load_model_from_checkpoint(
    os.path.join(paths_load_no_DA[0], "model_encoder.pt"), model_building_tools.create_mlp)
_, model_downstream_no_DA = save_load_tools.load_model_from_checkpoint(
    os.path.join(paths_load_no_DA[0], "model_downstream.pt"), model_building_tools.create_mlp)
model_encoder_no_DA.to(device)
model_downstream_no_DA.to(device)


_, model_encoder_DA = save_load_tools.load_model_from_checkpoint(
    os.path.join(paths_load_DA[0], "model_encoder.pt"), model_building_tools.create_mlp)
_, model_downstream_DA = save_load_tools.load_model_from_checkpoint(
    os.path.join(paths_load_DA[0], "model_downstream.pt"), model_building_tools.create_mlp)
model_encoder_DA.to(device)
model_downstream_DA.to(device)

# Compute the results

In [None]:
extract_dsets = [
    ("DESI_mocks_Raul", "outersection"),
    ("JPAS_x_DESI_Raul", "intersection")
]
features = {"no-DA" : {}, "DA" : {}}
probs = {"no-DA" : {}, "DA" : {}}
labels = {"no-DA" : {}, "DA" : {}}
for key_dset, key_xmatch in extract_dsets:

    features["no-DA"][key_dset] = {}
    probs["no-DA"][key_dset] = {}
    labels["no-DA"][key_dset] = {}

    features["DA"][key_dset] = {}
    probs["DA"][key_dset] = {}
    labels["DA"][key_dset] = {}

    for split in global_setup.splits:

        xx_input = xx[key_dset][split]
        with torch.no_grad():
            features_ = model_encoder_no_DA(xx_input)
            logits_ = model_downstream_no_DA(features_)
            probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
        features["no-DA"][key_dset][split] = features_.cpu().numpy()
        probs["no-DA"][key_dset][split] = probs_
        labels["no-DA"][key_dset][split] = np.argmax(probs_, axis=1)

        xx_input = xx[key_dset][split]
        with torch.no_grad():
            features_ = model_encoder_DA(xx_input)
            logits_ = model_downstream_DA(features_)
            probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
        features["DA"][key_dset][split] = features_.cpu().numpy()
        probs["DA"][key_dset][split] = probs_
        labels["DA"][key_dset][split] = np.argmax(probs_, axis=1)



key_dset = "Ignasi"
split = "all"

features["no-DA"][key_dset] = {}
probs["no-DA"][key_dset] = {}
labels["no-DA"][key_dset] = {}

features["DA"][key_dset] = {}
probs["DA"][key_dset] = {}
labels["DA"][key_dset] = {}

xx_input = xx[key_dset][split]
with torch.no_grad():
    features_ = model_encoder_no_DA(xx_input)
    logits_ = model_downstream_no_DA(features_)
    probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
features["no-DA"][key_dset][split] = features_.cpu().numpy()
probs["no-DA"][key_dset][split] = probs_
labels["no-DA"][key_dset][split] = np.argmax(probs_, axis=1)

xx_input = xx[key_dset][split]
with torch.no_grad():
    features_ = model_encoder_DA(xx_input)
    logits_ = model_downstream_DA(features_)
    probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()
features["DA"][key_dset][split] = features_.cpu().numpy()
probs["DA"][key_dset][split] = probs_
labels["DA"][key_dset][split] = np.argmax(probs_, axis=1)

# metrics

In [None]:
class_names = list(global_setup.config_dict_cleaning["encoding"]["shared_mappings"]["SPECTYPE"].keys())

In [None]:
dict_radar = {
    "Source-Test (Mocks) no-DA": {
        "y_true": yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["DESI_mocks_Raul"]["test"],
        "plot_kwargs": {
            "linestyle": "--", "linewidth": 2.0, "color": "k",
            "marker": "+", "markersize": 8.0,
            "label": "Source-Test (Mocks) no-DA"
        }
    },
    "Target-Test (JPAS x DESI) no-DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {
            "linestyle": "-.", "linewidth": 2.0, "color": "k",
            "marker": "x", "markersize": 8.0,
            "label": "Target-Test (JPAS x DESI) no-DA"
        }
    },
    "Target-Training (JPAS x DESI) DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["train"]["SPECTYPE_int"],
        "y_pred": probs["DA"]["JPAS_x_DESI_Raul"]["train"],
        "plot_kwargs": {
            "linestyle": ":", "linewidth": 2.0, "color": "k",
            "marker": "^", "markersize": 8.0,
            "label": "Target-Training (JPAS x DESI) DA"
        }
    },
    "Target-Test (JPAS x DESI) DA": {
        "y_true": yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
        "y_pred": probs["DA"]["JPAS_x_DESI_Raul"]["test"],
        "plot_kwargs": {
            "linestyle": "-", "linewidth": 2.0, "color": "k",
            "marker": "o", "markersize": 8.0,
            "label": "Target-Test (JPAS x DESI) DA"
        }
    },
}

evaluation_tools.plot_confusion_matrix(
    yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["DESI_mocks_Raul"]["test"],
    class_names=class_names,
    cmap=plt.cm.RdYlGn, title="no-DA Source-Test (Mocks)"
)
evaluation_tools.plot_confusion_matrix(
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names,
    cmap=plt.cm.RdYlGn, title="no-DA Target-Test (JPAS x DESI - observations)"
)
evaluation_tools.plot_confusion_matrix(
    yy["JPAS_x_DESI_Raul"]["train"]["SPECTYPE_int"], probs["DA"]["JPAS_x_DESI_Raul"]["train"],
    class_names=class_names,
    cmap=plt.cm.RdYlGn, title="DA Target-Train (JPAS x DESI - observations)"
)
evaluation_tools.plot_confusion_matrix(
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["DA"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names,
    cmap=plt.cm.RdYlGn, title="DA Target-Test (JPAS x DESI - observations)"
)
fig, ax = evaluation_tools.radar_plot(
    dict_radar=dict_radar, class_names=class_names,
    title="F1 Radar Plot", figsize=(8, 8), theta_offset=np.pi / 2, # first axis at 12 o'clock
    r_ticks=(0.1, 0.3, 0.5, 0.7, 0.9), r_lim=(0.0, 1.0),
    tick_labelsize=16, radial_labelsize=12, show_legend=True,
    legend_kwargs={
        "loc": "upper left", "bbox_to_anchor": (0.73, 1.0), "fontsize": 9, "ncol": 1,
        "title": "Evaluation Cases", "frameon": True, "fancybox": True, "shadow": True, "borderaxespad": 0.0,
    },
)
plt.show()

evaluation_tools.compare_TPR_confusion_matrices(
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["DA"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names, figsize=(10, 7),
    cmap='seismic', title='TPR Comparison: DA vs no-DA', name_1 = "no-DA", name_2 = "DA"
)
metrics = evaluation_tools.compare_sets_performance(
    yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["DESI_mocks_Raul"]["test"],
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names, name_1="Source-Test (Mocks)", name_2="no-DA: Target-Test (JPAS x DESI obs.)",
    y_max_Delta_F1=0.3, y_min_Delta_F1=-0.3, title_fontsize=20, color='k'
)
metrics = evaluation_tools.compare_sets_performance(
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"], probs["DA"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names, name_1="no-DA (Target-Test: JPAS x DESI obs.)", name_2="DA",
    y_max_Delta_F1=0.3, y_min_Delta_F1=-0.3, color='k'
)
fig, ax = evaluation_tools.plot_overall_deltaF1_two_comparisons(
    y_true_src   = yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
    y_pred_src_noDA = probs["no-DA"]["DESI_mocks_Raul"]["test"],
    y_true_tgt   = yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    y_pred_tgt_noDA = probs["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    y_pred_tgt_DA   = probs["DA"]["JPAS_x_DESI_Raul"]["test"],
    class_names=class_names,
    title=None,
    colors=("crimson", "limegreen"),
    labels=("JPAS x DESI no-DA VS Mocks no-DA", "JPAS x DESI DA VS no-DA"),
    ylim=(-0.3, 0.3),
    legend_kwargs={"loc":"upper right", "frameon":True, "fancybox":True, "shadow":True, "fontsize":18},
    save_dir=global_setup.path_saved_figures, save_format="png", save_dpi=250, filename="deltaF1"
)


# Latents

In [None]:
feat_dict = {
    "latents_no_DA_Source": features["no-DA"]["DESI_mocks_Raul"]["test"],
    "latents_no_DA_Target": features["no-DA"]["JPAS_x_DESI_Raul"]["test"],
    "latents_DA_Target": features["DA"]["JPAS_x_DESI_Raul"]["test"]
}

latents_tSNE = evaluation_tools.tsne_per_key(
    feat_dict,
    standardize=False,
    subsample=None,
    random_state=137,
    tsne_kwargs={"perplexity": 100},
    return_all_key=None,
)

In [None]:
xlim = (-150, 150)
ylim = (-150, 150)

evaluation_tools.plot_latents_scatter_val_test(
    X_val=latents_tSNE['latents_no_DA_Source_tSNE'], y_val=yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
    X_test=latents_tSNE['latents_no_DA_Target_tSNE'], y_test=yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    class_names=class_names,
    title="Latents no-DA: Source (Mocks) vs Target (JPAS x DESI obs.)",
    marker_val="o", marker_test="^",
    size_val=14, size_test=14, alpha_val=0.7, alpha_test=0.7,
    xlim=xlim, ylim=ylim,
    subsample=4000, seed=137,
    edgecolor=None, linewidths=0.0,
    legend_split_1="Source (Mocks) no-DA",
    legend_split_2="Target (JPAS x DESI obs.) no-DA"
)
evaluation_tools.plot_latents_scatter(
    latents_tSNE['latents_no_DA_Source_tSNE'], yy["DESI_mocks_Raul"]["test"]["SPECTYPE_int"],
    class_counts=None,
    class_names=class_names,
    title="Latents no-DA: Source (Mocks)",
    n_bins=128, sigma=2.0,
    scatter_size=0.003, scatter_alpha=0.3,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_no_DA_Source_tSNE'],
    title="Latents no-DA: Source (Mocks)",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latents_scatter(
    latents_tSNE['latents_no_DA_Target_tSNE'], yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    class_counts=None,
    class_names=class_names,
    title="Latents no-DA: Target (JPAS x DESI obs.)",
    n_bins=128, sigma=2.0,
    scatter_size=0.1, scatter_alpha=0.5,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_no_DA_Target_tSNE'],
    title="Latents no-DA: Target (JPAS x DESI obs.)",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)

evaluation_tools.plot_latents_scatter_val_test(
    X_val=latents_tSNE['latents_no_DA_Target_tSNE'], y_val=yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    X_test=latents_tSNE['latents_DA_Target_tSNE'], y_test=yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    class_names=class_names,
    title="Latents Target (JPAS x DESI obs.): no-DA vs DA",
    marker_val="o", marker_test="^",
    size_val=8, size_test=8, alpha_val=0.7, alpha_test=0.7,
    xlim=xlim, ylim=ylim,
    subsample=None, seed=137,
    edgecolor=None, linewidths=0.0,
    legend_split_1="Target (JPAS x DESI obs.) no-DA",
    legend_split_2="Target (JPAS x DESI obs.) DA"
)
evaluation_tools.plot_latents_scatter(
    latents_tSNE['latents_DA_Target_tSNE'], yy["JPAS_x_DESI_Raul"]["test"]["SPECTYPE_int"],
    class_counts=None,
    class_names=class_names,
    title="Latents DA: Target (JPAS x DESI obs.)",
    n_bins=128, sigma=2.0,
    scatter_size=0.1, scatter_alpha=0.5,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_DA_Target_tSNE'],
    title="Latents DA: Target (JPAS x DESI obs.)",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)

In [None]:
print(features["DA"]["JPAS_x_DESI_Raul"]["test"].shape)
print(features["DA"]["Ignasi"]["all"].shape)

In [None]:
feat_dict = {
    "latents_DA_JPAS_x_DESI": features["DA"]["JPAS_x_DESI_Raul"]["test"],
    "latents_DA_all_JPAS": features["DA"]["Ignasi"]["all"],
}

subsample = features["DA"]["JPAS_x_DESI_Raul"]["test"].shape[0]

latents_tSNE = evaluation_tools.tsne_per_key(
    feat_dict,
    standardize=False,
    subsample={"latents_DA_JPAS_x_DESI": subsample, "latents_DA_all_JPAS": subsample},
    random_state=137,
    tsne_kwargs={"perplexity": 100},
    return_all_key=None,
)

In [None]:
xlim = (-100, 100)
ylim = (-100, 100)

evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_DA_JPAS_x_DESI_tSNE'],
    title="Latents JPAS x DESI",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_DA_all_JPAS_tSNE'],
    title="Latents all JPAS",
    density_method="hist", # or "kde"
    bins=256, sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k", contour_linewidths=0.4, contour_label_fontsize=7, contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    xlim=xlim, ylim=ylim
)

# CHECK WHAT IS GOING ON BETWEEN JPAS IGNASI AND RAUL (CHECK THE CROSSMATH SPECTRUM)

# Magnitude cuts

In [None]:
magnitude_key="DESI_FLUX_R"
mag_bin_edges=(17, 19, 21, 22, 22.5)
output_key="MAG_BIN_ID"

magnitude_ranges = [(mag_bin_edges[i], mag_bin_edges[i+1]) for i in range(len(mag_bin_edges)-1)]
colors = ['blue', 'green', 'orange', 'red']
colormaps = [plt.cm.Blues, plt.cm.Greens, plt.cm.YlOrBr, plt.cm.Reds]

In [None]:
magnitudes_plot = {
    ("Mocks", "Test"): -2.5 * np.log10(yy["DESI_mocks_Raul"]["test"]['DESI_FLUX_R']) + 22.5,
    ("JPAS x DESI", "Train"): -2.5 * np.log10(yy["JPAS_x_DESI_Raul"]["train"]['DESI_FLUX_R']) + 22.5,
    ("JPAS x DESI", "Test"): -2.5 * np.log10(yy["JPAS_x_DESI_Raul"]["test"]['DESI_FLUX_R']) + 22.5
}
labels_plot = {
    ("Mocks", "Test"): yy["DESI_mocks_Raul"]["test"]['SPECTYPE_int'],         
    ("JPAS x DESI", "Train"): yy["JPAS_x_DESI_Raul"]["train"]['SPECTYPE_int'], 
    ("JPAS x DESI", "Test"): yy["JPAS_x_DESI_Raul"]["test"]['SPECTYPE_int']
}
masks_magnitudes, stats_magnitudes = plotting_utils.plot_histogram_with_ranges_multiple(
    magnitudes_plot, ranges=magnitude_ranges, colors=colors, bins=42,
    x_label="DESI Magnitude (R)", title="DESI R-band Magnitudes",
    labels_dict=labels_plot, class_names=class_names, pct_decimals=0,
    annotate_mode='text', legend_fontsize=12
)

stats_magnitudes[("Mocks", "Test")]["plot_kwargs"] = {
    "linestyle": "--", "marker": "+", "markersize": 8.0,
    "label": "Source-Test (Mocks) no-DA",
}
stats_magnitudes[("JPAS x DESI", "Test")]["plot_kwargs"] = {
    "linestyle": "-.", "marker": "x", "markersize": 8.0,
    "label": "Target-Test (JPAS x DESI) no-DA",
}
stats_magnitudes[("JPAS x DESI", "Train")]["plot_kwargs"] = {
    "linestyle": ":", "marker": "^", "markersize": 8.0,
    "label": "Target-Training (JPAS x DESI) DA",
}
fig, ax = plotting_utils.plot_per_class_counts_together(
    stats_magnitudes,
    yscale="log",
    figsize=(12, 6),
    title="Per-class counts vs. magnitude (all entries together)",
    title_fontsize=18,  # title size
    class_legend_kwargs={
        "loc": "upper left", "bbox_to_anchor": (0.0, 1.0),
        "fontsize": 10, "title": "Classes", "frameon": True, "fancybox": True, "shadow": True
    },
    entry_legend_kwargs={
        "loc": "upper left", "bbox_to_anchor": (0.65, 1.0),
        "fontsize": 9, "ncol": 1, "title": "Evaluation Cases", "frameon": True, "fancybox": True, "shadow": True
    },
    legend_outside=True,
)
plt.show()

In [None]:
evaluation_tools.evaluate_all_plots_by_mag_bins(
    masks_magnitudes=masks_magnitudes,
    yy=yy,
    probs=probs,
    class_names=class_names,
    dict_radar_styles=dict_radar,   # only used to copy the linestyles/markers
    radar_title_base="F1 Radar Plot",
    radar_kwargs={
        "figsize": (8, 8),
        "theta_offset": np.pi / 2,
        "r_ticks": (0.1, 0.3, 0.5, 0.7, 0.9),
        "r_lim": (0.0, 1.0),
        "tick_labelsize": 16,
        "radial_labelsize": 12,
        "show_legend": True,
        "legend_kwargs": {
            "loc": "upper left", "bbox_to_anchor": (0.73, 1.0), "fontsize": 9, "ncol": 1,
            "title": "Evaluation Cases", "frameon": True, "fancybox": True, "shadow": True, "borderaxespad": 0.0,
        },
        "title_pad": 20,
    },
    colors=colors,
    colormaps=colormaps,
    show=False,                 # don't pop up windows
    save_dir=global_setup.path_saved_figures,  # <<-- save here
    save_format="png",
    save_dpi=250,
    close_after_save=True,      # free memory in large loops
)

In [None]:
6108.36 - (3423.27 + 2380.32)

In [None]:
8217.21 - 5185.78