In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup

from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools
from JPAS_DA.data import crossmatch_tools
from JPAS_DA.data import process_dset_splits
from JPAS_DA.data import wrapper_data_loaders

from JPAS_DA.models import model_building_tools
from JPAS_DA.training import save_load_tools
from JPAS_DA.evaluation import evaluation_tools
from JPAS_DA.wrapper_wandb import wrapper_tools
from JPAS_DA.evaluation import evaluation_tools

from JPAS_DA.evaluation import wandb_evaluation_tools

import os
import torch
import numpy as np

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib widget

from JPAS_DA.utils import aux_tools
aux_tools.set_seed(42)

In [None]:
path_wandb_sweep_DA = os.path.join(global_setup.path_models, "wandb_DA")
N_selected_sweeps = 2

In [None]:
sorted_list_sweep_names_DA, sorted_losses_DA = wandb_evaluation_tools.load_and_plot_sorted_sweeps(path_wandb_sweep_DA, max_runs_to_plot=N_selected_sweeps)

In [None]:
paths_load_DA = [os.path.join(path_wandb_sweep_DA, sweep_name) for sweep_name in sorted_list_sweep_names_DA[:N_selected_sweeps]]

In [None]:
class_names = global_setup.class_names
# class_names = ["QSO_high", "no_QSO_high"]

n_classes = len(class_names)

return_keys=['test_JPAS_matched']
define_dataset_loaders_keys=['JPAS_matched']
keys_yy=["SPECTYPE_int", "TARGETID", "DESI_FLUX_R"]

In [None]:
RESULTS_DA = evaluation_tools.evaluate_results_from_load_paths(
    paths_load=paths_load_DA, return_keys=return_keys, define_dataset_loaders_keys=define_dataset_loaders_keys, keys_yy=keys_yy
)

In [None]:
for model_idx, model_outputs in RESULTS_DA.items():
    for key, result in model_outputs.items():
        yy_true = result["true"]
        yy_pred = result["prob"]
        
        # Plot confusion matrix
        evaluation_tools.plot_confusion_matrix(
            yy_true,
            yy_pred,
            class_names=class_names,
            cmap=plt.cm.RdYlGn,
            title=f"{key.replace('_', ' ')} (DA Model {model_idx})"
        )

In [None]:
magnitude_key="DESI_FLUX_R"
mag_bin_edges=(17, 19, 21, 22, 22.5)
output_key="MAG_BIN_ID"

magnitude_ranges = [(mag_bin_edges[i], mag_bin_edges[i+1]) for i in range(len(mag_bin_edges)-1)]
colors = ['blue', 'green', 'orange', 'red']
colormaps = [plt.cm.Blues, plt.cm.Greens, plt.cm.YlOrBr, plt.cm.Reds]

In [None]:
RESULTS_DA = evaluation_tools.add_magnitude_bins_to_results(
    RESULTS_DA, magnitude_key=magnitude_key, mag_bin_edges=mag_bin_edges, output_key=output_key
)

In [None]:
magnitudes_val_DESI = -2.5 * np.log10(RESULTS_DA[ii_model]['val_DESI_only']['DESI_FLUX_R']) + 22.5

masks_all = plotting_utils.plot_histogram_with_ranges_multiple(
    magnitudes_val_DESI, ranges=magnitude_ranges, colors=colors, bins=42,
    x_label="DESI Magnitude (R)",
    title="DESI R-band Magnitudes by Dataset Split and Loader"
)

In [None]:
bin_labels = [f"{lo}–{hi}" for lo, hi in magnitude_ranges]
num_bins = len(magnitude_ranges)

# Sweep DA models
for model_idx, model_outputs in RESULTS_DA.items():
    for key, result in model_outputs.items():
        mag_bins = result["MAG_BIN_ID"]
        yy_true_all = result["true"]
        yy_pred_all = result["prob"]

        for bin_id in range(num_bins):
            mask = mag_bins == bin_id
            if np.sum(mask) == 0:
                continue  # Skip empty bins

            yy_true = yy_true_all[mask]
            yy_pred = yy_pred_all[mask]

            evaluation_tools.plot_confusion_matrix(
                yy_true,
                yy_pred,
                class_names=class_names,
                cmap=colormaps[bin_id],
                title=f"{key.replace('_', ' ')} (DA {model_idx}) | Mag {bin_labels[bin_id]}"
            )