In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup
from JPAS_DA.data import wrapper_data_loaders
from JPAS_DA.models import model_building_tools
from JPAS_DA.training import save_load_tools
from JPAS_DA.evaluation import evaluation_tools
from JPAS_DA.wrapper_wandb import wrapper_tools
from JPAS_DA.evaluation import evaluation_tools

import os
import torch
import numpy as np

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib inline

from JPAS_DA.utils import aux_tools
aux_tools.set_seed(42)

In [None]:
class_names = global_setup.class_names
# class_names = ["QSO_high", "no_QSO_high"]

n_classes = len(class_names)

return_keys=['val_DESI_only', 'test_JPAS_matched']
define_dataset_loaders_keys=['DESI_only', 'JPAS_matched']
keys_yy=["SPECTYPE_int", "TARGETID", "DESI_FLUX_R"]

In [None]:
RESULTS_no_DA = evaluation_tools.evaluate_results_from_load_paths(
    paths_load=[
        os.path.join(global_setup.path_models, "09_no_DA")
    ],
    return_keys=return_keys,
    define_dataset_loaders_keys=define_dataset_loaders_keys,
    keys_yy=keys_yy
)

In [None]:
RESULTS_DA = evaluation_tools.evaluate_results_from_load_paths(
    paths_load=[
        os.path.join(global_setup.path_models, "09_DA")
    ],
    return_keys=return_keys,
    define_dataset_loaders_keys=define_dataset_loaders_keys,
    keys_yy=keys_yy
)

In [None]:
for model_idx, model_outputs in RESULTS_no_DA.items():
    for key, result in model_outputs.items():
        yy_true = result["true"]
        yy_pred = result["prob"]
        
        # Plot confusion matrix
        evaluation_tools.plot_confusion_matrix(
            yy_true,
            yy_pred,
            class_names=class_names,
            cmap=plt.cm.RdYlGn,
            title=f"{key.replace('_', ' ')} (no-DA Model {model_idx})"
        )

for model_idx, model_outputs in RESULTS_DA.items():
    for key, result in model_outputs.items():
        yy_true = result["true"]
        yy_pred = result["prob"]
        
        # Plot confusion matrix
        evaluation_tools.plot_confusion_matrix(
            yy_true,
            yy_pred,
            class_names=class_names,
            cmap=plt.cm.RdYlGn,
            title=f"{key.replace('_', ' ')} (DA Model {model_idx})"
        )

In [None]:
ii_model = 0

evaluation_tools.compare_TPR_confusion_matrices(
    RESULTS_no_DA[ii_model]['val_DESI_only']['true'],
    RESULTS_no_DA[ii_model]['val_DESI_only']['prob'],
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'],
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['prob'],
    class_names=class_names,
    figsize=(10, 7),
    cmap='seismic',
    title='no-DA model: JPAS test VS DESI-mocks test',
    name_1 = "DESI-mocks",
    name_2 = "JPAS-obs",
)

evaluation_tools.compare_TPR_confusion_matrices(
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'],
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['prob'],
    RESULTS_DA[ii_model]['test_JPAS_matched']['true'],
    RESULTS_DA[ii_model]['test_JPAS_matched']['prob'],
    class_names=class_names,
    figsize=(10, 7),
    cmap='seismic',
    title='JPAS test: DA VS no-DA',
    name_1 = "no DA",
    name_2 = "DA",
)

In [None]:
metrics = evaluation_tools.compare_sets_performance(
    RESULTS_no_DA[ii_model]['val_DESI_only']['true'], RESULTS_no_DA[ii_model]['val_DESI_only']['prob'],
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'], RESULTS_no_DA[ii_model]['test_JPAS_matched']['prob'],
    class_names=class_names,
    name_1="DESI-mocks",
    name_2="JPAS-obs"
)

metrics = evaluation_tools.compare_sets_performance(
    RESULTS_no_DA[ii_model]['test_JPAS_matched']['true'], RESULTS_no_DA[ii_model]['test_JPAS_matched']['prob'],
    RESULTS_DA[ii_model]['test_JPAS_matched']['true'], RESULTS_DA[ii_model]['test_JPAS_matched']['prob'],
    class_names=class_names,
    name_1="no DA",
    name_2="DA"
)

In [None]:
# from sklearn.manifold import TSNE

# # === Stack all feature representations together ===
# n_val_no_DA = features_val_no_DA.shape[0]
# n_val_DA = features_val_DA.shape[0]
# n_test_no_DA = features_test_no_DA.shape[0]
# n_test_DA = features_test_DA.shape[0]

# X_all = np.vstack([
#     features_val_no_DA,
#     features_val_DA,
#     features_test_no_DA,
#     features_test_DA
# ])

# # === Perform shared t-SNE projection ===
# tsne = TSNE(n_components=2, perplexity=30, init='pca', random_state=42)
# X_all_tsne = tsne.fit_transform(X_all)

# # === Split back to original domains ===
# i0 = 0
# i1 = i0 + n_val_no_DA
# i2 = i1 + n_val_DA
# i3 = i2 + n_test_no_DA
# i4 = i3 + n_test_DA

# X_val_no_DA_tsne   = X_all_tsne[i0:i1]
# X_val_DA_tsne      = X_all_tsne[i1:i2]
# X_test_no_DA_tsne  = X_all_tsne[i2:i3]
# X_test_DA_tsne     = X_all_tsne[i3:i4]

In [None]:
# evaluation_tools.plot_tsne_comparison_single_pair(
#     X_val_no_DA_tsne, yy_true_val_no_DA,
#     X_test_no_DA_tsne, yy_true_test,
#     dset_test.class_counts,
#     class_names=global_setup.class_names,
#     title_set1="No DA - Validation",
#     title_set2="No DA - Test",
#     n_bins=128,
#     sigma=2.0,
#     scatter_size=1,
#     scatter_alpha=1.0
# )

# evaluation_tools.plot_tsne_comparison_single_pair(
#     X_val_DA_tsne, yy_true_val_DA,
#     X_test_DA_tsne, yy_true_test,
#     dset_test.class_counts,
#     class_names=global_setup.class_names,
#     title_set1="DA - Validation",
#     title_set2="DA - Test",
#     n_bins=128,
#     sigma=2.0,
#     scatter_size=1,
#     scatter_alpha=1.0
# )

In [None]:
magnitude_key="DESI_FLUX_R"
mag_bin_edges=(17, 19, 21, 22, 22.5)
output_key="MAG_BIN_ID"

magnitude_ranges = [(mag_bin_edges[i], mag_bin_edges[i+1]) for i in range(len(mag_bin_edges)-1)]
colors = ['blue', 'green', 'orange', 'red']
colormaps = [plt.cm.Blues, plt.cm.Greens, plt.cm.YlOrBr, plt.cm.Reds]

In [None]:
RESULTS_no_DA = evaluation_tools.add_magnitude_bins_to_results(
    RESULTS_no_DA, magnitude_key=magnitude_key, mag_bin_edges=mag_bin_edges, output_key=output_key
)

In [None]:
RESULTS_DA = evaluation_tools.add_magnitude_bins_to_results(
    RESULTS_DA, magnitude_key=magnitude_key, mag_bin_edges=mag_bin_edges, output_key=output_key
)

In [None]:
magnitudes_val_DESI = -2.5 * np.log10(RESULTS_no_DA[ii_model]['val_DESI_only']['DESI_FLUX_R']) + 22.5

masks_all = plotting_utils.plot_histogram_with_ranges_multiple(
    magnitudes_val_DESI, ranges=magnitude_ranges, colors=colors, bins=42,
    x_label="DESI Magnitude (R)",
    title="DESI R-band Magnitudes by Dataset Split and Loader"
)

In [None]:
bin_labels = [f"{lo}–{hi}" for lo, hi in magnitude_ranges]
num_bins = len(magnitude_ranges)

# Sweep no-DA models
for model_idx, model_outputs in RESULTS_no_DA.items():
    for key, result in model_outputs.items():
        mag_bins = result["MAG_BIN_ID"]
        yy_true_all = result["true"]
        yy_pred_all = result["prob"]

        for bin_id in range(num_bins):
            mask = mag_bins == bin_id
            if np.sum(mask) == 0:
                continue  # Skip empty bins

            yy_true = yy_true_all[mask]
            yy_pred = yy_pred_all[mask]

            evaluation_tools.plot_confusion_matrix(
                yy_true,
                yy_pred,
                class_names=class_names,
                cmap=colormaps[bin_id],
                title=f"{key.replace('_', ' ')} (no-DA {model_idx}) | Mag {bin_labels[bin_id]}"
            )

# Sweep DA models
for model_idx, model_outputs in RESULTS_DA.items():
    for key, result in model_outputs.items():
        mag_bins = result["MAG_BIN_ID"]
        yy_true_all = result["true"]
        yy_pred_all = result["prob"]

        for bin_id in range(num_bins):
            mask = mag_bins == bin_id
            if np.sum(mask) == 0:
                continue  # Skip empty bins

            yy_true = yy_true_all[mask]
            yy_pred = yy_pred_all[mask]

            evaluation_tools.plot_confusion_matrix(
                yy_true,
                yy_pred,
                class_names=class_names,
                cmap=colormaps[bin_id],
                title=f"{key.replace('_', ' ')} (DA {model_idx}) | Mag {bin_labels[bin_id]}"
            )
