In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup
from JPAS_DA import global_setup
from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools
from JPAS_DA.data import crossmatch_tools
from JPAS_DA.data import process_dset_splits
from JPAS_DA.data import data_loaders

from JPAS_DA.models import model_building_tools
from JPAS_DA.training import save_load_tools
from JPAS_DA.evaluation import evaluation_tools
from JPAS_DA.wrapper_wandb import wrapper_tools
from JPAS_DA.evaluation import evaluation_tools

import os
import torch
import numpy as np

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib widget

from JPAS_DA.utils import aux_tools
aux_tools.set_seed(42)

In [None]:
paths_load = [os.path.join(global_setup.path_models, "09_DA")]

# ─────────────────────────────────────────────────────────────
# Load and validate data config across all paths
# ─────────────────────────────────────────────────────────────
logging.info("🔍 Validating model configs...")
configs = []
for path in paths_load:
    _, config = wrapper_tools.load_and_massage_config_file(
        os.path.join(path, "config.yaml"), path
    )
    configs.append(config)

config_ref = configs[0]
for i, cfg in enumerate(configs[1:], 1):
    logging.debug(f"🔍 Comparing config 0 and config {i}")
    if not evaluation_tools.safe_compare(cfg['data'], config_ref['data']):
        raise ValueError(f"🚫 Data config mismatch between model 0 and model {i}")

config_data = config_ref["data"]

In [None]:
root_path = global_setup.DATA_path
load_JPAS_x_DESI_Raul   = global_setup.load_JPAS_x_DESI_Raul
load_DESI_mocks_Raul    = global_setup.load_DESI_mocks_Raul
load_Ignasi             = global_setup.load_Ignasi

random_seed_load = global_setup.default_seed

list_of_datasets_to_load = ["JPAS_x_DESI_Raul", "DESI_mocks_Raul", "Ignasi"]

config_dict_cleaning = config_data['cleaning_config']

dict_split_data_options = global_setup.dict_split_data_options

keys_xx = config_data['keys_xx']
keys_yy = ["SPECTYPE_int", "TARGETID", "DESI_FLUX_R"]

In [None]:
DATA = loading_tools.load_data_bundle(
    root_path=root_path,
    include=list_of_datasets_to_load,
    JPAS_x_DESI_Raul={"datasets": load_JPAS_x_DESI_Raul},
    DESI_mocks_Raul={"datasets": load_DESI_mocks_Raul},
    Ignasi={"datasets": load_Ignasi},
    random_seed=random_seed_load,
)
DATA = cleaning_tools.clean_data_pipeline(DATA, config=config_dict_cleaning, in_place=True)

Dict_LoA = {"intersection": {}, "outersection": {}}

IDs1, IDs2, IDs12, \
Dict_LoA["outersection"]["DESI_mocks_Raul"], Dict_LoA["outersection"]["JPAS_x_DESI_Raul"], \
Dict_LoA["intersection"]["DESI_mocks_Raul"], Dict_LoA["intersection"]["JPAS_x_DESI_Raul"] = crossmatch_tools.crossmatch_IDs_two_datasets(
    DATA["DESI_mocks_Raul"]['all_pd']['TARGETID'], DATA["JPAS_x_DESI_Raul"]['all_pd']['TARGETID']
)

# Split the Lists of Arrays into training, validation, and testing sets
Dict_LoA_split = {"intersection":{}, "outersection":{}}

Dict_LoA_split["intersection"]["JPAS_x_DESI_Raul"] = process_dset_splits.split_LoA(
    Dict_LoA["intersection"]["JPAS_x_DESI_Raul"],
    train_ratio = dict_split_data_options["train_ratio_intersection"],
    val_ratio = dict_split_data_options["val_ratio_intersection"],
    test_ratio = dict_split_data_options["test_ratio_intersection"],
    seed = dict_split_data_options["random_seed_split_intersection"]
)
Dict_LoA_split["outersection"]["DESI_mocks_Raul"] = process_dset_splits.split_LoA(
    Dict_LoA["outersection"]["DESI_mocks_Raul"],
    train_ratio = dict_split_data_options["train_ratio_outersection"],
    val_ratio = dict_split_data_options["val_ratio_outersection"],
    test_ratio = dict_split_data_options["test_ratio_outersection"],
    seed = dict_split_data_options["random_seed_split_outersection"]
)

In [None]:
device = 'cpu'

In [None]:
xx = {}
yy = {}
for key_dset_split in ["train", "val", "test"]:

    xx[key_dset_split] = {}
    yy[key_dset_split] = {}

    key_dset = "DESI_mocks_Raul"
    key_xmatch = "outersection"
    LoA_, xx_, yy_ = process_dset_splits.extract_from_block_by_LoA(
        block=DATA[key_dset],
        LoA=Dict_LoA_split[key_xmatch][key_dset][key_dset_split],
        keys_xx=keys_xx,
        keys_yy=keys_yy
    )
    xx_batch = data_loaders.stack_features_from_dict_flattened(xx_, np.arange(len(np.concatenate(LoA_))))
    xx[key_dset_split][key_dset] = torch.tensor(xx_batch, dtype=torch.float32, device=device)
    yy[key_dset_split][key_dset] = yy_

    key_dset = "JPAS_x_DESI_Raul"
    key_xmatch = "intersection"
    LoA_, xx_, yy_ = process_dset_splits.extract_from_block_by_LoA(
        block=DATA[key_dset],
        LoA=Dict_LoA_split[key_xmatch][key_dset][key_dset_split],
        keys_xx=keys_xx,
        keys_yy=keys_yy
    )
    xx_batch = data_loaders.stack_features_from_dict_flattened(xx_, np.arange(len(np.concatenate(LoA_))))
    xx[key_dset_split][key_dset] = torch.tensor(xx_batch, dtype=torch.float32, device=device)
    yy[key_dset_split][key_dset] = yy_

In [None]:
key_dset = "Ignasi"

xx[key_dset] = {}
yy[key_dset] = {}

LoA_, xx_, yy_ = process_dset_splits.extract_from_block_by_LoA(
    block=DATA[key_dset],
    LoA=np.arange(DATA["Ignasi"]['all_observations'].shape[0])[:, None].tolist(),
    keys_xx=keys_xx,
    keys_yy=keys_yy
)
xx_batch = data_loaders.stack_features_from_dict_flattened(xx_, np.arange(len(np.concatenate(LoA_))))
xx[key_dset] = torch.tensor(xx_batch, dtype=torch.float32, device=device)
yy[key_dset] = yy_

In [None]:
_, model_encoder = save_load_tools.load_model_from_checkpoint(
    os.path.join(path, "model_encoder.pt"), model_building_tools.create_mlp)
_, model_downstream = save_load_tools.load_model_from_checkpoint(
    os.path.join(path, "model_downstream.pt"), model_building_tools.create_mlp)

model_encoder.to(device)
model_downstream.to(device)

In [None]:
features = {}
probs = {}
labels = {}
for key_dset_split in ["train", "val", "test"]:

    features[key_dset_split] = {}
    probs[key_dset_split] = {}
    labels[key_dset_split] = {}

    key_dset = "DESI_mocks_Raul"
    xx_input = xx[key_dset_split][key_dset]
    with torch.no_grad():
        features_ = model_encoder(xx_input)
        logits_ = model_downstream(features_)
        probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()

    features[key_dset_split][key_dset] = features_.cpu().numpy()
    probs[key_dset_split][key_dset] = probs_
    labels[key_dset_split][key_dset] = np.argmax(probs_, axis=1)

    key_dset = "JPAS_x_DESI_Raul"
    xx_input = xx[key_dset_split][key_dset]
    with torch.no_grad():
        features_ = model_encoder(xx_input)
        logits_ = model_downstream(features_)
        probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()

    features[key_dset_split][key_dset] = features_.cpu().numpy()
    probs[key_dset_split][key_dset] = probs_
    labels[key_dset_split][key_dset] = np.argmax(probs_, axis=1)

In [None]:
key_dset = "Ignasi"

xx_input = xx[key_dset]
with torch.no_grad():
    features_ = model_encoder(xx_input)
    logits_ = model_downstream(features_)
    probs_ = torch.nn.functional.softmax(logits_, dim=1).cpu().numpy()

features[key_dset] = features_.cpu().numpy()
probs[key_dset] = probs_
labels[key_dset] = np.argmax(probs_, axis=1)

In [None]:
class_names = list(global_setup.config_dict_cleaning['encoding']['shared_mappings']['SPECTYPE'].keys())

In [None]:
key_dset_split = "test"
key_dset = "JPAS_x_DESI_Raul"

evaluation_tools.plot_confusion_matrix(
    yy[key_dset_split][key_dset]["SPECTYPE_int"],
    probs[key_dset_split][key_dset],
    class_names=class_names,
    cmap=plt.cm.RdYlGn
)
evaluation_tools.plot_tsne_single(
    features[key_dset_split][key_dset],
    yy[key_dset_split][key_dset]["SPECTYPE_int"],
    class_counts=None,
    class_names=class_names,
    title="DA: " + key_dset_split + " - " + key_dset + " - True",
    n_bins=128,
    sigma=2.0,
    scatter_size=0.5,
    scatter_alpha=0.1,
    xlim=(-8, 10),
    ylim=(-18, 6)
)
evaluation_tools.plot_tsne_single(
    features[key_dset_split][key_dset],
    labels[key_dset_split][key_dset],
    class_counts=None,
    class_names=class_names,
    title="DA: " + key_dset_split + " - " + key_dset + " - Pred",
    n_bins=128,
    sigma=2.0,
    scatter_size=0.5,
    scatter_alpha=0.1,
    xlim=(-8, 10),
    ylim=(-18, 6)
)

In [None]:
import matplotlib.patches as mpatches

In [None]:
y_labels = yy[key_dset_split][key_dset]["SPECTYPE_int"]
X_emb = features[key_dset_split][key_dset]

xlim=(-8, 10)
ylim=(-18, 6)
scatter_size=0.1
scatter_alpha=0.5
title = "DA: " + key_dset_split + " - " + key_dset

unique_classes = np.unique(y_labels)
cmap = plt.cm.get_cmap("tab10")
class_color_dict = {cls: cmap(i) for i, cls in enumerate(unique_classes)}
class_rgb = np.array([class_color_dict[cls][:3] for cls in unique_classes])

class_counts = np.array([np.sum(y_labels == cls) for cls in unique_classes])

inv_freq_weights = 1 / class_counts
inv_freq_weights /= np.sum(inv_freq_weights)

# Determine plot limits
x_min = np.min(X_emb[:, 0]) if xlim is None else xlim[0]
x_max = np.max(X_emb[:, 0]) if xlim is None else xlim[1]
y_min = np.min(X_emb[:, 1]) if ylim is None else ylim[0]
y_max = np.max(X_emb[:, 1]) if ylim is None else ylim[1]

fig, ax = plt.subplots(figsize=(7, 6))
ax.set_title(title, fontsize=14)
ax.set_xlabel("t-SNE 1", fontsize=12)
ax.set_ylabel("t-SNE 2", fontsize=12)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.tick_params(labelsize=10)

for i, cls in enumerate(unique_classes):
    idx = y_labels == cls
    if np.sum(idx) == 0:
        continue
    ax.scatter(X_emb[idx, 0], X_emb[idx, 1],
                color=class_color_dict[cls], s=scatter_size, alpha=scatter_alpha)

legend_elements = [
    mpatches.Patch(color=class_color_dict[cls], label=class_names[i] if class_names else f"Class {cls}")
    for i, cls in enumerate(unique_classes)
]
ax.legend(handles=legend_elements, title="Class", fontsize=10, title_fontsize=11)

plt.tight_layout()
plt.show()

In [None]:
key_dset = "Ignasi"

evaluation_tools.plot_tsne_single(
    features[key_dset],
    labels[key_dset],
    class_counts=None,
    class_names=class_names,
    title="DA: " + key_dset + " - Pred",
    n_bins=128,
    sigma=2.0,
    scatter_size=0.5,
    scatter_alpha=0.1,
    xlim=(-8, 10),
    ylim=(-18, 6)
)

In [None]:
X_emb = features[key_dset]
y_labels = np.zeros_like(X_emb[:, 0])

xlim=(-8, 10)
ylim=(-18, 6)
scatter_size=0.1
scatter_alpha=0.1
title = "DA: " + key_dset

unique_classes = np.unique(y_labels)
cmap = plt.cm.get_cmap("tab10")
class_color_dict = {cls: cmap(i) for i, cls in enumerate(unique_classes)}
class_rgb = np.array([class_color_dict[cls][:3] for cls in unique_classes])

class_counts = np.array([np.sum(y_labels == cls) for cls in unique_classes])

inv_freq_weights = 1 / class_counts
inv_freq_weights /= np.sum(inv_freq_weights)

# Determine plot limits
x_min = np.min(X_emb[:, 0]) if xlim is None else xlim[0]
x_max = np.max(X_emb[:, 0]) if xlim is None else xlim[1]
y_min = np.min(X_emb[:, 1]) if ylim is None else ylim[0]
y_max = np.max(X_emb[:, 1]) if ylim is None else ylim[1]

fig, ax = plt.subplots(figsize=(7, 6))
ax.set_title(title, fontsize=14)
ax.set_xlabel("t-SNE 1", fontsize=12)
ax.set_ylabel("t-SNE 2", fontsize=12)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.tick_params(labelsize=10)

for i, cls in enumerate(unique_classes):
    idx = y_labels == cls
    if np.sum(idx) == 0:
        continue
    ax.scatter(X_emb[idx, 0], X_emb[idx, 1],
                color=class_color_dict[cls], s=scatter_size, alpha=scatter_alpha)

legend_elements = [
    mpatches.Patch(color=class_color_dict[cls], label=class_names[i] if class_names else f"Class {cls}")
    for i, cls in enumerate(unique_classes)
]
ax.legend(handles=legend_elements, title="Class", fontsize=10, title_fontsize=11)

plt.tight_layout()
plt.show()

In [None]:
return_keys=['val_DESI_only', 'test_JPAS_matched']
define_dataset_loaders_keys=['DESI_only', 'JPAS_matched']
keys_yy=["SPECTYPE_int", "TARGETID", "DESI_FLUX_R"]

In [None]:
RESULTS_no_DA = evaluation_tools.evaluate_results_from_load_paths(
    paths_load=[
        os.path.join(global_setup.path_models, "wandb_no_DA_latent_2", "my-sweep-0")
    ],
    return_keys=return_keys,
    define_dataset_loaders_keys=define_dataset_loaders_keys,
    keys_yy=keys_yy
)

In [None]:
RESULTS_DA = evaluation_tools.evaluate_results_from_load_paths(
    paths_load=[
        os.path.join(global_setup.path_models, "wandb_DA_latent_2", "my-sweep-0")
    ],
    return_keys=return_keys,
    define_dataset_loaders_keys=define_dataset_loaders_keys,
    keys_yy=keys_yy
)

In [None]:
n_classes = len(np.unique(RESULTS_no_DA[0]['val_DESI_only']['label']))

if n_classes == 2:
    class_names = ['QSO_high', 'no_QSO_high']
else:
    class_names = global_setup.class_names

In [None]:
for model_idx, model_outputs in RESULTS_no_DA.items():
    for key, result in model_outputs.items():
        yy_true = result["true"]
        yy_pred = result["prob"]
        
        # Plot confusion matrix
        evaluation_tools.plot_confusion_matrix(
            yy_true,
            yy_pred,
            class_names=class_names,
            cmap=plt.cm.RdYlGn,
            title=f"{key.replace('_', ' ')} (no-DA Model {model_idx})"
        )

for model_idx, model_outputs in RESULTS_DA.items():
    for key, result in model_outputs.items():
        yy_true = result["true"]
        yy_pred = result["prob"]
        
        # Plot confusion matrix
        evaluation_tools.plot_confusion_matrix(
            yy_true,
            yy_pred,
            class_names=class_names,
            cmap=plt.cm.RdYlGn,
            title=f"{key.replace('_', ' ')} (DA Model {model_idx})"
        )

In [None]:
ii_model = 0

evaluation_tools.compare_TPR_confusion_matrices(
    RESULTS_no_DA[ii_model]['val_DESI_only']['true'],
    RESULTS_no_DA[ii_model]['val_DESI_only']['prob'],
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'],
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['prob'],
    class_names=class_names,
    figsize=(10, 7),
    cmap='seismic',
    title='no-DA model: JPAS test VS DESI-mocks test',
    name_1 = "DESI-mocks",
    name_2 = "JPAS-obs",
)

evaluation_tools.compare_TPR_confusion_matrices(
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'],
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['prob'],
    RESULTS_DA[ii_model]['test_JPAS_matched']['true'],
    RESULTS_DA[ii_model]['test_JPAS_matched']['prob'],
    class_names=class_names,
    figsize=(10, 7),
    cmap='seismic',
    title='JPAS test: DA VS no-DA',
    name_1 = "no DA",
    name_2 = "DA",
)

In [None]:
metrics = evaluation_tools.compare_sets_performance(
    RESULTS_no_DA[ii_model]['val_DESI_only']['true'], RESULTS_no_DA[ii_model]['val_DESI_only']['prob'],
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'], RESULTS_no_DA[ii_model]['test_JPAS_matched']['prob'],
    class_names=class_names,
    name_1="DESI-mocks",
    name_2="JPAS-obs"
)

metrics = evaluation_tools.compare_sets_performance(
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'], RESULTS_no_DA[ii_model]['test_JPAS_matched']['prob'],
    RESULTS_DA[ii_model]['test_JPAS_matched']['true'], RESULTS_DA[ii_model]['test_JPAS_matched']['prob'],
    class_names=class_names,
    name_1="no DA",
    name_2="DA"
)

# Latent plots

In [None]:
latent_positions_no_DA_mocks = RESULTS_no_DA[0]['val_DESI_only']['features']
true_no_DA_mocks = RESULTS_no_DA[0]['val_DESI_only']['true']
labels_no_DA_mocks = RESULTS_no_DA[0]['val_DESI_only']['label']
prob_no_DA_mocks = RESULTS_no_DA[0]['val_DESI_only']['prob']
DESI_FLUX_R_no_DA_mocks = RESULTS_no_DA[0]['val_DESI_only']['DESI_FLUX_R']

latent_positions_no_DA_obs = RESULTS_no_DA[0]['test_JPAS_matched']['features']
true_no_DA_obs = RESULTS_no_DA[0]['test_JPAS_matched']['true']
labels_no_DA_obs = RESULTS_no_DA[0]['test_JPAS_matched']['label']
prob_no_DA_obs = RESULTS_no_DA[0]['test_JPAS_matched']['prob']
DESI_FLUX_R_no_DA_mocks = RESULTS_no_DA[0]['test_JPAS_matched']['DESI_FLUX_R']

latent_positions_DA_obs = RESULTS_DA[0]['test_JPAS_matched']['features']
true_DA_obs = RESULTS_DA[0]['test_JPAS_matched']['true']
labels_DA_obs = RESULTS_DA[0]['test_JPAS_matched']['label']
prob_DA_obs = RESULTS_DA[0]['test_JPAS_matched']['prob']
DESI_FLUX_R_DA_mocks = RESULTS_DA[0]['test_JPAS_matched']['DESI_FLUX_R']

from sklearn.manifold import TSNE
# === Stack all feature representations together ===
n_val_no_DA = latent_positions_no_DA_mocks.shape[0]
n_test_no_DA = latent_positions_no_DA_obs.shape[0]
n_test_DA = latent_positions_DA_obs.shape[0]
X_all = np.vstack([
    latent_positions_no_DA_mocks,
    latent_positions_no_DA_obs,
    latent_positions_DA_obs
])
# === Perform shared t-SNE projection ===
tsne = TSNE(n_components=2, perplexity=30, init='pca', random_state=42)
X_all_tsne = tsne.fit_transform(X_all)
# === Split back to original domains ===
i0 = 0
i1 = i0 + n_val_no_DA
i2 = i1 + n_test_no_DA
i3 = i2 + n_test_DA
latent_positions_no_DA_mocks_tsne   = X_all_tsne[i0:i1]
latent_positions_no_DA_obs_tsne  = X_all_tsne[i1:i2]
latent_positions_DA_obs_tsne     = X_all_tsne[i2:i3]

In [None]:
evaluation_tools.plot_tsne_single(
    latent_positions_no_DA_mocks_tsne, true_no_DA_mocks,
    class_counts=None,
    class_names=None,
    title="No DA - Mocks",
    n_bins=128,
    sigma=2.0,
    scatter_size=0.1,
    scatter_alpha=0.1,
    xlim=None,
    ylim=None
)

evaluation_tools.plot_tsne_single(
    latent_positions_no_DA_obs_tsne, true_no_DA_obs,
    class_counts=None,
    class_names=None,
    title="No DA - Data",
    n_bins=128,
    sigma=2.0,
    scatter_size=0.1,
    scatter_alpha=0.1,
    xlim=None,
    ylim=None
)

evaluation_tools.plot_tsne_single(
    latent_positions_DA_obs_tsne, true_DA_obs,
    class_counts=None,
    class_names=None,
    title="DA - Data",
    n_bins=128,
    sigma=2.0,
    scatter_size=0.1,
    scatter_alpha=0.1,
    xlim=None,
    ylim=None
)

evaluation_tools.plot_tsne_single(
    latent_positions_no_DA_mocks_tsne, true_no_DA_mocks,
    class_counts=None,
    class_names=None,
    title="No DA - Mocks",
    n_bins=128,
    sigma=2.0,
    scatter_size=0.1,
    scatter_alpha=0.1,
    xlim=None,
    ylim=None
)

evaluation_tools.plot_tsne_single(
    latent_positions_no_DA_obs_tsne, true_no_DA_obs,
    class_counts=None,
    class_names=None,
    title="No DA - Data",
    n_bins=128,
    sigma=2.0,
    scatter_size=0.1,
    scatter_alpha=0.1,
    xlim=None,
    ylim=None
)

evaluation_tools.plot_tsne_single(
    latent_positions_DA_obs_tsne, true_DA_obs,
    class_counts=None,
    class_names=None,
    title="DA - Data",
    n_bins=128,
    sigma=2.0,
    scatter_size=0.1,
    scatter_alpha=0.1,
    xlim=None,
    ylim=None
)

In [None]:
evaluation_tools.plot_tsne_single(
    X_val_no_DA_tsne, RESULTS_no_DA[ii_model]['val_DESI_only']['true'],
    class_counts=None,
    class_names=None,
    title="No DA - Mocks",
    n_bins=128,
    sigma=2.0,
    scatter_size=0.1,
    scatter_alpha=0.1,
    xlim=None,
    ylim=None
)

evaluation_tools.plot_tsne_single(
    X_test_no_DA_tsne, RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'],
    class_counts=None,
    class_names=None,
    title="No DA - Data",
    n_bins=128,
    sigma=2.0,
    scatter_size=0.1,
    scatter_alpha=0.1,
    xlim=None,
    ylim=None
)

evaluation_tools.plot_tsne_single(
    X_test_DA_tsne, RESULTS_DA[ii_model]['test_JPAS_matched']['true'],
    class_counts=None,
    class_names=None,
    title="DA - Data",
    n_bins=128,
    sigma=2.0,
    scatter_size=0.1,
    scatter_alpha=0.1,
    xlim=None,
    ylim=None
)

In [None]:
# ------- Colors and class names -------
class_colors = {0: 'crimson', 1: 'limegreen', 2: 'royalblue', 3: 'yellow'}
class_names = {0: 'Galaxy', 1: 'QSO high-z', 2: 'QSO low-z', 3: 'Star'}
alpha = 0.5

# ------- Datasets (title, features, true, pred, prob) -------
datasets = [
    ("No-DA Mocks", latent_positions_no_DA_mocks, true_no_DA_mocks, labels_no_DA_mocks, prob_no_DA_mocks),
    ("No-DA Data", latent_positions_no_DA_obs, true_no_DA_obs, labels_no_DA_obs, prob_no_DA_obs),
    ("DA Data", latent_positions_DA_obs, true_DA_obs, labels_DA_obs, prob_DA_obs),
]

# ------- Utility function to scale point size -------
def compute_point_size(n_points, min_size=1, max_size=100, ref_points=30000):
    """Scale point size inversely with dataset size."""
    default_size = ref_points / n_points
    return max(min_size, min(max_size, default_size))

# ------- Utility plotting functions -------
def scatter_by_class(ax, features, y, colors, point_size):
    for cls in np.unique(y):
        mask = y == cls
        ax.scatter(features[mask, 0], features[mask, 1],
                   s=point_size, c=colors[cls], alpha=alpha, linewidths=0)

def scatter_by_prob(ax, features, prob, cls_idx, point_size):
    ax.scatter(features[:, 0], features[:, 1],
               s=point_size, c=prob[:, cls_idx],
               cmap='bwr', vmin=0.0, vmax=1.0, alpha=alpha, linewidths=0)

# ----- Global axis limits -----
x_min, x_max = -3., 3.
y_min, y_max = -4., 2.

# ----- Figure: 6 rows x N columns -----
n_cols = len(datasets)
fig, axes = plt.subplots(6, n_cols, figsize=(6 * n_cols, 5 * 6),
                         sharex=True, sharey=True)
axes = np.array(axes).reshape(6, n_cols)

for col, (title, feats, y_true, y_pred, prob) in enumerate(datasets):
    # Compute point size per dataset
    ps = compute_point_size(len(feats))

    # Column title (big, spanning the column)
    axes[0, col].set_title(title, fontsize=16, pad=20)

    # Row 0: TRUE labels
    ax_top = axes[0, col]
    scatter_by_class(ax_top, feats, y_true, class_colors, ps)
    ax_top.set_xlim(x_min, x_max)
    ax_top.set_ylim(y_min, y_max)

    # Row 1: PREDICTED labels
    ax_pred = axes[1, col]
    scatter_by_class(ax_pred, feats, y_pred, class_colors, ps)

    # Rows 2-5: Probabilities
    for cls_idx in range(4):
        ax_prob = axes[2 + cls_idx, col]
        scatter_by_prob(ax_prob, feats, prob, cls_idx, ps)

# Hide tick labels for internal subplots
for i in range(6):
    for j in range(n_cols):
        ax = axes[i, j]
        if i < 5:  # Hide x tick labels for rows above bottom row
            ax.set_xticklabels([])
        if j > 0:  # Hide y tick labels for columns beyond first
            ax.set_yticklabels([])

# Set axis labels only on outer plots
for ax in axes[-1, :]:  # Bottom row x-labels
    ax.set_xlabel("Feature 1")
for ax in axes[:, 0]:   # Leftmost column y-labels
    ax.set_ylabel("Feature 2")

# Legend for classes
class_lines = [mpl.lines.Line2D([0], [0], color=class_colors[k], marker='o',
                                linestyle='', markersize=8, label=class_names[k])
               for k in sorted(class_names.keys())]
fig.legend(class_lines, [class_names[k] for k in sorted(class_names.keys())],
           loc='upper right', title="Classes", fontsize=12)

plt.tight_layout()
plt.show()

In [None]:
magnitude_key="DESI_FLUX_R"
mag_bin_edges=(17, 19, 21, 22, 22.5)
output_key="MAG_BIN_ID"

magnitude_ranges = [(mag_bin_edges[i], mag_bin_edges[i+1]) for i in range(len(mag_bin_edges)-1)]
colors = ['blue', 'green', 'orange', 'red']
colormaps = [plt.cm.Blues, plt.cm.Greens, plt.cm.YlOrBr, plt.cm.Reds]

In [None]:
RESULTS_no_DA = evaluation_tools.add_magnitude_bins_to_results(
    RESULTS_no_DA, magnitude_key=magnitude_key, mag_bin_edges=mag_bin_edges, output_key=output_key
)

In [None]:
RESULTS_DA = evaluation_tools.add_magnitude_bins_to_results(
    RESULTS_DA, magnitude_key=magnitude_key, mag_bin_edges=mag_bin_edges, output_key=output_key
)

In [None]:
magnitudes_val_DESI = -2.5 * np.log10(RESULTS_no_DA[ii_model]['val_DESI_only']['DESI_FLUX_R']) + 22.5

masks_all = plotting_utils.plot_histogram_with_ranges_multiple(
    magnitudes_val_DESI, ranges=magnitude_ranges, colors=colors, bins=42,
    x_label="DESI Magnitude (R)",
    title="DESI R-band Magnitudes by Dataset Split and Loader"
)

In [None]:
bin_labels = [f"{lo}–{hi}" for lo, hi in magnitude_ranges]
num_bins = len(magnitude_ranges)

# Sweep no-DA models
for model_idx, model_outputs in RESULTS_no_DA.items():
    for key, result in model_outputs.items():
        mag_bins = result["MAG_BIN_ID"]
        yy_true_all = result["true"]
        yy_pred_all = result["prob"]

        for bin_id in range(num_bins):
            mask = mag_bins == bin_id
            if np.sum(mask) == 0:
                continue  # Skip empty bins

            yy_true = yy_true_all[mask]
            yy_pred = yy_pred_all[mask]

            evaluation_tools.plot_confusion_matrix(
                yy_true,
                yy_pred,
                class_names=class_names,
                cmap=colormaps[bin_id],
                title=f"{key.replace('_', ' ')} (no-DA {model_idx}) | Mag {bin_labels[bin_id]}"
            )

# Sweep DA models
for model_idx, model_outputs in RESULTS_DA.items():
    for key, result in model_outputs.items():
        mag_bins = result["MAG_BIN_ID"]
        yy_true_all = result["true"]
        yy_pred_all = result["prob"]

        for bin_id in range(num_bins):
            mask = mag_bins == bin_id
            if np.sum(mask) == 0:
                continue  # Skip empty bins

            yy_true = yy_true_all[mask]
            yy_pred = yy_pred_all[mask]

            evaluation_tools.plot_confusion_matrix(
                yy_true,
                yy_pred,
                class_names=class_names,
                cmap=colormaps[bin_id],
                title=f"{key.replace('_', ' ')} (DA {model_idx}) | Mag {bin_labels[bin_id]}"
            )
