In [1]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools
from JPAS_DA.data import crossmatch_tools
from JPAS_DA.data import process_dset_splits

import numpy as np

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib widget

In [2]:
root_path = "/home/dlopez/Documents/Projects/JPAS_Domain_Adaptation/DATA/noise_jpas_v1/Train-Validate-Test"

load_JPAS_data = [{
    "name": "all",
    "npy": "JPAS_DATA_Aper_Cor_3_FLUX+NOISE.npy",
    "csv": "JPAS_DATA_PROPERTIES.csv",
    "sample_percentage": 1.0  # Optional, defaults to 1.0
}]

load_DESI_data = [
{
    "name": "train",
    "npy": "mock_3_train.npy",
    "csv": "props_training.csv",
    "sample_percentage": 0.3
},
{
    "name": "val",
    "npy": "mock_3_validate.npy",
    "csv": "props_validate.csv",
    "sample_percentage": 1.0
},
{
    "name": "test",
    "npy": "mock_3_test.npy",
    "csv": "props_test.csv",
    "sample_percentage": 1.0
}
]

random_seed_load = 42

In [3]:
DATA = loading_tools.load_dsets(root_path=root_path, datasets_jpas=load_JPAS_data, datasets_desi=load_DESI_data, random_seed=random_seed_load)

2025-05-16 12:49:23,693 - INFO - 📥 Starting full dataset loading with `load_dsets()`
2025-05-16 12:49:23,693 - INFO - ├ Loading JPAS datasets...
2025-05-16 12:49:23,694 - INFO - ├─── 📥 Starting JPAS dataset loading...
2025-05-16 12:49:23,694 - INFO - |    ├─── 🔹 Dataset: all (sample 100%)
2025-05-16 12:49:23,748 - INFO - |    |    ✔ CSV loaded: JPAS_DATA_PROPERTIES.csv (shape: (52020, 18))
2025-05-16 12:49:23,764 - INFO - |    |    ✔ NPY loaded: JPAS_DATA_Aper_Cor_3_FLUX+NOISE.npy (obs shape: (52020, 57))
2025-05-16 12:49:23,765 - INFO - ├─── ✅ Finished loading all JPAS datasets.
2025-05-16 12:49:23,766 - INFO - ├ Loading DESI datasets (splitted)...
2025-05-16 12:49:23,766 - INFO - ├─── 📥 Starting DESI dataset loading...
2025-05-16 12:49:23,766 - INFO - |    ├─── 🔹 Dataset: train
2025-05-16 12:49:24,716 - INFO - |    |    ✔ CSV loaded ((1087882, 18)), Size: 445.74 MB
2025-05-16 12:49:24,717 - INFO - |    |    ✔ NPY loaded ((1087882, 57, 3)), Size: 1488.22 MB
2025-05-16 12:49:24,724 - I

In [4]:
dict_clean_data_options = {
    "apply_masks"         : ["unreliable", "magic_numbers", "negative_errors", "nan_values", "apply_additional_filters"],
    "mask_indices"        : [0, -2],
    "magic_numbers"       : [-99, 99],
    "i_band_sn_threshold" : 0,
    "z_lim_QSO_cut"       : 2.2
}

In [5]:
DATA = cleaning_tools.clean_and_mask_data(
    DATA=DATA,
    apply_masks=dict_clean_data_options["apply_masks"],
    mask_indices=dict_clean_data_options["mask_indices"],
    magic_numbers=dict_clean_data_options["magic_numbers"],
    i_band_sn_threshold=dict_clean_data_options["i_band_sn_threshold"],
    z_lim_QSO_cut=dict_clean_data_options["z_lim_QSO_cut"]
)

2025-05-16 12:49:25,742 - INFO - 🧽 Cleaning and masking data...
2025-05-16 12:49:25,742 - INFO - ├── remove_invalid_NaN_rows()
2025-05-16 12:49:25,930 - INFO - │   ├── # objects filled with NaNs in JPAS: 0(0.0%)
2025-05-16 12:49:25,930 - INFO - │   ├── # objects filled with NaNs in DESI: 505(0.06%)
2025-05-16 12:49:26,341 - INFO - ├── 🧹 Deleted cleaned DATA_clean dictionary to free memory.
2025-05-16 12:49:26,342 - INFO - ├── apply_additional_filters()
2025-05-16 12:49:26,351 - INFO - │   ├── JPAS: 52020 valid rows (S/N ≥ 0) (100.0%)
2025-05-16 12:49:26,351 - INFO - │   ├── DESI: 792095 valid rows (S/N ≥ 0) (100.0%)
2025-05-16 12:49:26,602 - INFO - │   ├── Additional filters applied successfully.
2025-05-16 12:49:26,604 - INFO - ├── Masking out indices [0, -2] (unreliable in DESI).
2025-05-16 12:49:27,078 - INFO - │   ├── Updated JPAS obs/err shape: (52020, 55)
2025-05-16 12:49:27,078 - INFO - │   ├── Updated DESI mean/err shape: (792095, 55)
2025-05-16 12:49:27,079 - INFO - ├── Checki

In [6]:
Dict_LoA = {"both":{}, "only":{}} # Dictionary of Lists of Arrays (LoA) indicating, for each TARGETID, the associatted entries in the arrays, e.g. TARGETID[LoA[ii][0]] == TARGETID[LoA[ii][-1]]
(
    IDs_only_DESI, IDs_only_JPAS, IDs_both,
    Dict_LoA["only"]["DESI"], Dict_LoA["only"]["JPAS"],
    Dict_LoA["both"]["DESI"], Dict_LoA["both"]["JPAS"]
) = crossmatch_tools.crossmatch_IDs_two_datasets(
    DATA["DESI"]['TARGETID'], DATA["JPAS"]['TARGETID']
)

2025-05-16 12:49:28,777 - INFO - 🔍 crossmatch_IDs_two_datasets()...
2025-05-16 12:49:28,778 - INFO - ├── 🚀 Starting ID categorization process...
2025-05-16 12:49:28,787 - INFO - |    ├── 📌 Found 804570 unique IDs across 2 arrays.
2025-05-16 12:49:28,939 - INFO - |    ├── Presence matrix created with shape: (2, 804570)
2025-05-16 12:49:28,940 - INFO - |    ├── Category mask created with shape: (2, 804570)
2025-05-16 12:49:28,941 - INFO - ├── 🚀 Starting index retrieval process...
2025-05-16 12:49:28,941 - INFO - |    ├── 📌 Processing 804570 unique IDs across 2 arrays.
2025-05-16 12:49:29,230 - INFO - ├── 🚀 Starting post-processing of unique IDs across two arrays...
2025-05-16 12:49:29,245 - INFO - |    ├── Processing complete: 752550 IDs only in Array 1 (93.53%).
2025-05-16 12:49:29,245 - INFO - |    ├── Processing complete: 24788 IDs only in Array 2 (3.08%).
2025-05-16 12:49:29,246 - INFO - |    ├── Processing complete: 27232 IDs in both arrays (3.38%).
2025-05-16 12:49:29,246 - INFO - 

In [7]:
dict_split_data_options = {
    "train_ratio_both"             : 0.8,
    "val_ratio_both"               : 0.1,
    "test_ratio_both"              : 0.1,
    "random_seed_split_both"       : 42,
    "train_ratio_only_DESI"        : 0.8,
    "val_ratio_only_DESI"          : 0.1,
    "test_ratio_only_DESI"         : 0.1,
    "random_seed_split_only_DESI"  : 42
}

In [8]:
# Split the Lists of Arrays into training, validation, and testing sets
Dict_LoA_split = {"both":{}, "only":{}}
Dict_LoA_split["both"]["JPAS"] = process_dset_splits.split_LoA(
    Dict_LoA["both"]["JPAS"],
    train_ratio = dict_split_data_options["train_ratio_both"],
    val_ratio = dict_split_data_options["val_ratio_both"],
    test_ratio = dict_split_data_options["test_ratio_both"],
    seed = dict_split_data_options["random_seed_split_both"]
)
Dict_LoA_split["both"]["DESI"] = process_dset_splits.split_LoA(
    Dict_LoA["both"]["DESI"],
    train_ratio = dict_split_data_options["train_ratio_both"],
    val_ratio = dict_split_data_options["val_ratio_both"],
    test_ratio = dict_split_data_options["test_ratio_both"],
    seed = dict_split_data_options["random_seed_split_both"]
)
Dict_LoA_split["only"]["DESI"]  = process_dset_splits.split_LoA(
    Dict_LoA["only"]["DESI"],
    train_ratio = dict_split_data_options["train_ratio_only_DESI"],
    val_ratio = dict_split_data_options["val_ratio_only_DESI"],
    test_ratio = dict_split_data_options["test_ratio_only_DESI"],
    seed = dict_split_data_options["random_seed_split_only_DESI"]
)

2025-05-16 12:49:29,280 - INFO - ├── ✂️ Splitting list of arrays (LoA) into train/val/test subsets...
2025-05-16 12:49:29,284 - INFO - ├── Finished splitting.
2025-05-16 12:49:29,285 - INFO - ├── ✂️ Splitting list of arrays (LoA) into train/val/test subsets...
2025-05-16 12:49:29,289 - INFO - ├── Finished splitting.
2025-05-16 12:49:29,289 - INFO - ├── ✂️ Splitting list of arrays (LoA) into train/val/test subsets...
2025-05-16 12:49:29,408 - INFO - ├── Finished splitting.


In [9]:
for ii, key_dset in enumerate(Dict_LoA_split["both"]["JPAS"].keys()):
    assert len(Dict_LoA_split["both"]["JPAS"][key_dset]) == len(Dict_LoA_split["both"]["DESI"][key_dset]), "Both datasets must have the same number unique TARGETIDs in each of training, validation, and testing sets."
    for jj in range(len(Dict_LoA_split["both"]["JPAS"][key_dset])):
        idx_ = Dict_LoA_split["both"]["JPAS"][key_dset][jj][0]
        tmp_TARGETID = DATA["JPAS"]["TARGETID"][idx_]
        for kk in range(len(Dict_LoA_split["both"]["DESI"][key_dset][jj])):
            idx_ = Dict_LoA_split["both"]["DESI"][key_dset][jj][kk]
            tmp_TARGETID_ = DATA["DESI"]["TARGETID"][idx_]
            assert tmp_TARGETID == tmp_TARGETID_, "Both datasets must have the same TARGETIDs in each of training, validation, and testing sets."
        if len(Dict_LoA_split["both"]["DESI"][key_dset][jj]) > 2:
            print("TARGETID JPAS:", tmp_TARGETID, "TARGETID DESI:", tmp_TARGETID_)

TARGETID JPAS: 39633293729072001 TARGETID DESI: 39633293729072001
TARGETID JPAS: 39633275194443316 TARGETID DESI: 39633275194443316
TARGETID JPAS: 39633297344563253 TARGETID DESI: 39633297344563253
TARGETID JPAS: 39633290071638383 TARGETID DESI: 39633290071638383
TARGETID JPAS: 39633290067445314 TARGETID DESI: 39633290067445314
TARGETID JPAS: 39633282656112085 TARGETID DESI: 39633282656112085
TARGETID JPAS: 39633267661475062 TARGETID DESI: 39633267661475062
TARGETID JPAS: 39633293712296354 TARGETID DESI: 39633293712296354
TARGETID JPAS: 39633286355488472 TARGETID DESI: 39633286355488472
TARGETID JPAS: 39633290042279158 TARGETID DESI: 39633290042279158
TARGETID JPAS: 39633282689666769 TARGETID DESI: 39633282689666769
TARGETID JPAS: 39633290046475536 TARGETID DESI: 39633290046475536
TARGETID JPAS: 39633293691323183 TARGETID DESI: 39633293691323183
TARGETID JPAS: 39633286376457480 TARGETID DESI: 39633286376457480
TARGETID JPAS: 39633278965122779 TARGETID DESI: 39633278965122779


In [10]:
keys_xx = ['OBS', 'ERR', 'MORPHTYPE_int']
keys_yy = ['SPECTYPE_int', 'TARGETID']

In [11]:
key_dset = "train"

LoA_combined, xx_combined, yy_combined = process_dset_splits.extract_and_combine_DESI_data(
    Dict_LoA_split["only"]["DESI"][key_dset], Dict_LoA_split["both"]["DESI"][key_dset], DATA["DESI"], keys_xx, keys_yy
)
LoA_DESI, xx_DESI, yy_DESI = process_dset_splits.extract_data_matched(Dict_LoA_split["both"]["DESI"][key_dset], DATA["DESI"], keys_xx, keys_yy)
LoA_JPAS, xx_JPAS, yy_JPAS = process_dset_splits.extract_data_matched(Dict_LoA_split["both"]["JPAS"][key_dset], DATA["JPAS"], keys_xx, keys_yy)

2025-05-16 12:49:29,447 - INFO - |    ├── 🔧 extract_and_combine_DESI_data()
2025-05-16 12:49:29,447 - INFO - |    ├── Extracting features and labels from DESI-only subset...
2025-05-16 12:49:31,201 - INFO - |    ├── Extracting features and labels from DESI-matched subset...
2025-05-16 12:49:31,365 - INFO - |    ├── Applied index shift of 611345 to matched DESI group to ensure uniqueness
2025-05-16 12:49:31,436 - INFO - |    ├── Finished extract_and_combine_DESI_data()
2025-05-16 12:49:31,440 - INFO - |    ├── 🔧 extract_data_matched()
2025-05-16 12:49:31,441 - INFO - |    ├── Extracting features and labels from matched dataset...
2025-05-16 12:49:31,496 - INFO - |    ├── Finished extract_data_matched()
2025-05-16 12:49:31,497 - INFO - |    ├── 🔧 extract_data_matched()
2025-05-16 12:49:31,497 - INFO - |    ├── Extracting features and labels from matched dataset...
2025-05-16 12:49:31,531 - INFO - |    ├── Finished extract_data_matched()


In [12]:
print(len(LoA_combined))
print(np.concatenate(LoA_combined).shape)

print(xx_combined.keys())
print(len(xx_combined['OBS']))
print(xx_combined['OBS'].shape)
print(len(xx_combined['MORPHTYPE_int']))
print(xx_combined['MORPHTYPE_int'].shape)

print(yy_combined.keys())
print(len(yy_combined['SPECTYPE_int']))
print(yy_combined['SPECTYPE_int'].shape)
print(len(yy_combined['TARGETID']))
print(yy_combined['TARGETID'].shape)

623825
(633698,)
dict_keys(['OBS', 'ERR', 'MORPHTYPE_int'])
633698
(633698, 55)
633698
(633698,)
dict_keys(['SPECTYPE_int', 'TARGETID'])
633698
(633698,)
633698
(633698,)


In [13]:
print(len(LoA_JPAS))
print(np.concatenate(LoA_JPAS).shape)

print(xx_JPAS.keys())
print(len(xx_JPAS['OBS']))
print(xx_JPAS['OBS'].shape)
print(len(xx_JPAS['MORPHTYPE_int']))
print(xx_JPAS['MORPHTYPE_int'].shape)

print(yy_JPAS.keys())
print(len(yy_JPAS['SPECTYPE_int']))
print(yy_JPAS['SPECTYPE_int'].shape)
print(len(yy_JPAS['TARGETID']))
print(yy_JPAS['TARGETID'].shape)

21785
(21785,)
dict_keys(['OBS', 'ERR', 'MORPHTYPE_int'])
21785
(21785, 55)
21785
(21785,)
dict_keys(['SPECTYPE_int', 'TARGETID'])
21785
(21785,)
21785
(21785,)


In [14]:
assert len(LoA_DESI) == len(LoA_JPAS), "Both datasets must have the same number unique TARGETIDs"
assert np.unique(yy_JPAS['TARGETID']).shape == np.unique(yy_DESI['TARGETID']).shape, "Both datasets must have the same number unique TARGETIDs"

for ii in np.arange(45,50):
    for jj in range(len(LoA_JPAS[ii])):
        print("JPAS TARGETID", yy_JPAS['TARGETID'][LoA_JPAS[ii][jj]])
    for jj in range(len(LoA_DESI[ii])):
        print("DESI TARGETID", yy_DESI['TARGETID'][LoA_DESI[ii][jj]])
    print()

JPAS TARGETID 39633132139316722
DESI TARGETID 39633132139316722

JPAS TARGETID 39633290067446186
DESI TARGETID 39633290067446186

JPAS TARGETID 39633136572695155
DESI TARGETID 39633136572695155

JPAS TARGETID 39633282664499214
DESI TARGETID 39633282664499214
DESI TARGETID 39633282664499214

JPAS TARGETID 39633132118346665
DESI TARGETID 39633132118346665

