In [1]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools
from JPAS_DA.data import crossmatch_tools
from JPAS_DA.data import process_dset_splits
from JPAS_DA.data import data_loaders

import numpy as np

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib widget

In [2]:
root_path = "/home/dlopez/Documents/Projects/JPAS_Domain_Adaptation/DATA/noise_jpas_v1/Train-Validate-Test"

load_JPAS_data = [{
    "name": "all",
    "npy": "JPAS_DATA_Aper_Cor_3_FLUX+NOISE.npy",
    "csv": "JPAS_DATA_PROPERTIES.csv",
    "sample_percentage": 1.0  # Optional, defaults to 1.0
}]

load_DESI_data = [
{
    "name": "train",
    "npy": "mock_3_train.npy",
    "csv": "props_training.csv",
    "sample_percentage": 0.3
},
{
    "name": "val",
    "npy": "mock_3_validate.npy",
    "csv": "props_validate.csv",
    "sample_percentage": 1.0
},
{
    "name": "test",
    "npy": "mock_3_test.npy",
    "csv": "props_test.csv",
    "sample_percentage": 1.0
}
]

random_seed_load = 42

In [3]:
DATA = loading_tools.load_dsets(root_path=root_path, datasets_jpas=load_JPAS_data, datasets_desi=load_DESI_data, random_seed=random_seed_load)

2025-05-16 12:50:49,821 - INFO - 📥 Starting full dataset loading with `load_dsets()`
2025-05-16 12:50:49,821 - INFO - ├ Loading JPAS datasets...
2025-05-16 12:50:49,821 - INFO - ├─── 📥 Starting JPAS dataset loading...
2025-05-16 12:50:49,821 - INFO - |    ├─── 🔹 Dataset: all (sample 100%)
2025-05-16 12:50:49,863 - INFO - |    |    ✔ CSV loaded: JPAS_DATA_PROPERTIES.csv (shape: (52020, 18))
2025-05-16 12:50:49,874 - INFO - |    |    ✔ NPY loaded: JPAS_DATA_Aper_Cor_3_FLUX+NOISE.npy (obs shape: (52020, 57))
2025-05-16 12:50:49,874 - INFO - ├─── ✅ Finished loading all JPAS datasets.
2025-05-16 12:50:49,875 - INFO - ├ Loading DESI datasets (splitted)...
2025-05-16 12:50:49,875 - INFO - ├─── 📥 Starting DESI dataset loading...
2025-05-16 12:50:49,876 - INFO - |    ├─── 🔹 Dataset: train
2025-05-16 12:50:50,796 - INFO - |    |    ✔ CSV loaded ((1087882, 18)), Size: 445.74 MB
2025-05-16 12:50:50,797 - INFO - |    |    ✔ NPY loaded ((1087882, 57, 3)), Size: 1488.22 MB
2025-05-16 12:50:50,803 - I

In [4]:
dict_clean_data_options = {
    "apply_masks"         : ["unreliable", "magic_numbers", "negative_errors", "nan_values", "apply_additional_filters"],
    "mask_indices"        : [0, -2],
    "magic_numbers"       : [-99, 99],
    "i_band_sn_threshold" : 0,
    "z_lim_QSO_cut"       : 2.2
}

In [5]:
DATA = cleaning_tools.clean_and_mask_data(
    DATA=DATA,
    apply_masks=dict_clean_data_options["apply_masks"],
    mask_indices=dict_clean_data_options["mask_indices"],
    magic_numbers=dict_clean_data_options["magic_numbers"],
    i_band_sn_threshold=dict_clean_data_options["i_band_sn_threshold"],
    z_lim_QSO_cut=dict_clean_data_options["z_lim_QSO_cut"]
)

2025-05-16 12:50:51,816 - INFO - 🧽 Cleaning and masking data...
2025-05-16 12:50:51,816 - INFO - ├── remove_invalid_NaN_rows()
2025-05-16 12:50:52,002 - INFO - │   ├── # objects filled with NaNs in JPAS: 0(0.0%)
2025-05-16 12:50:52,003 - INFO - │   ├── # objects filled with NaNs in DESI: 505(0.06%)
2025-05-16 12:50:52,466 - INFO - ├── 🧹 Deleted cleaned DATA_clean dictionary to free memory.
2025-05-16 12:50:52,466 - INFO - ├── apply_additional_filters()
2025-05-16 12:50:52,475 - INFO - │   ├── JPAS: 52020 valid rows (S/N ≥ 0) (100.0%)
2025-05-16 12:50:52,476 - INFO - │   ├── DESI: 792095 valid rows (S/N ≥ 0) (100.0%)
2025-05-16 12:50:52,739 - INFO - │   ├── Additional filters applied successfully.
2025-05-16 12:50:52,741 - INFO - ├── Masking out indices [0, -2] (unreliable in DESI).
2025-05-16 12:50:53,219 - INFO - │   ├── Updated JPAS obs/err shape: (52020, 55)
2025-05-16 12:50:53,219 - INFO - │   ├── Updated DESI mean/err shape: (792095, 55)
2025-05-16 12:50:53,220 - INFO - ├── Checki

In [6]:
Dict_LoA = {"both":{}, "only":{}} # Dictionary of Lists of Arrays (LoA) indicating, for each TARGETID, the associatted entries in the arrays, e.g. TARGETID[LoA[ii][0]] == TARGETID[LoA[ii][-1]]
(
    IDs_only_DESI, IDs_only_JPAS, IDs_both,
    Dict_LoA["only"]["DESI"], Dict_LoA["only"]["JPAS"],
    Dict_LoA["both"]["DESI"], Dict_LoA["both"]["JPAS"]
) = crossmatch_tools.crossmatch_IDs_two_datasets(
    DATA["DESI"]['TARGETID'], DATA["JPAS"]['TARGETID']
)

2025-05-16 12:50:54,966 - INFO - 🔍 crossmatch_IDs_two_datasets()...
2025-05-16 12:50:54,967 - INFO - ├── 🚀 Starting ID categorization process...
2025-05-16 12:50:54,976 - INFO - |    ├── 📌 Found 804570 unique IDs across 2 arrays.
2025-05-16 12:50:55,125 - INFO - |    ├── Presence matrix created with shape: (2, 804570)
2025-05-16 12:50:55,127 - INFO - |    ├── Category mask created with shape: (2, 804570)
2025-05-16 12:50:55,127 - INFO - ├── 🚀 Starting index retrieval process...
2025-05-16 12:50:55,127 - INFO - |    ├── 📌 Processing 804570 unique IDs across 2 arrays.
2025-05-16 12:50:55,406 - INFO - ├── 🚀 Starting post-processing of unique IDs across two arrays...
2025-05-16 12:50:55,419 - INFO - |    ├── Processing complete: 752550 IDs only in Array 1 (93.53%).
2025-05-16 12:50:55,420 - INFO - |    ├── Processing complete: 24788 IDs only in Array 2 (3.08%).
2025-05-16 12:50:55,420 - INFO - |    ├── Processing complete: 27232 IDs in both arrays (3.38%).
2025-05-16 12:50:55,420 - INFO - 

In [7]:
dict_split_data_options = {
    "train_ratio_both"             : 0.8,
    "val_ratio_both"               : 0.1,
    "test_ratio_both"              : 0.1,
    "random_seed_split_both"       : 42,
    "train_ratio_only_DESI"        : 0.8,
    "val_ratio_only_DESI"          : 0.1,
    "test_ratio_only_DESI"         : 0.1,
    "random_seed_split_only_DESI"  : 42
}

In [8]:
# Split the Lists of Arrays into training, validation, and testing sets
Dict_LoA_split = {"both":{}, "only":{}}
Dict_LoA_split["both"]["JPAS"] = process_dset_splits.split_LoA(
    Dict_LoA["both"]["JPAS"],
    train_ratio = dict_split_data_options["train_ratio_both"],
    val_ratio = dict_split_data_options["val_ratio_both"],
    test_ratio = dict_split_data_options["test_ratio_both"],
    seed = dict_split_data_options["random_seed_split_both"]
)
Dict_LoA_split["both"]["DESI"] = process_dset_splits.split_LoA(
    Dict_LoA["both"]["DESI"],
    train_ratio = dict_split_data_options["train_ratio_both"],
    val_ratio = dict_split_data_options["val_ratio_both"],
    test_ratio = dict_split_data_options["test_ratio_both"],
    seed = dict_split_data_options["random_seed_split_both"]
)
Dict_LoA_split["only"]["DESI"]  = process_dset_splits.split_LoA(
    Dict_LoA["only"]["DESI"],
    train_ratio = dict_split_data_options["train_ratio_only_DESI"],
    val_ratio = dict_split_data_options["val_ratio_only_DESI"],
    test_ratio = dict_split_data_options["test_ratio_only_DESI"],
    seed = dict_split_data_options["random_seed_split_only_DESI"]
)

2025-05-16 12:50:55,452 - INFO - ├── ✂️ Splitting list of arrays (LoA) into train/val/test subsets...
2025-05-16 12:50:55,456 - INFO - ├── Finished splitting.
2025-05-16 12:50:55,457 - INFO - ├── ✂️ Splitting list of arrays (LoA) into train/val/test subsets...
2025-05-16 12:50:55,460 - INFO - ├── Finished splitting.
2025-05-16 12:50:55,461 - INFO - ├── ✂️ Splitting list of arrays (LoA) into train/val/test subsets...
2025-05-16 12:50:55,577 - INFO - ├── Finished splitting.


In [9]:
keys_xx = ['OBS', 'ERR', 'MORPHTYPE_int']
keys_yy = ['SPECTYPE_int', 'TARGETID']
normalize = True

In [10]:
dset_loaders = {"DESI_combined":{}, "DESI_matched":{}, "JPAS_matched":{}}
provided_normalization = None
for key_dset in ["train", "val", "test"]:
    logging.info(f"⚙️ Preparing split: {key_dset}")

    # DESI combined (only + matched)
    logging.info("├── DESI_combined")
    LoA, xx, yy = process_dset_splits.extract_and_combine_DESI_data(
        Dict_LoA_split["only"]["DESI"][key_dset],
        Dict_LoA_split["both"]["DESI"][key_dset],
        DATA["DESI"], keys_xx, keys_yy
    )
    dset_loaders["DESI_combined"][key_dset] = data_loaders.DataLoader(
        xx, yy, normalize=normalize, provided_normalization=provided_normalization
    )
    if key_dset == "train":
        provided_normalization = (
            dset_loaders["DESI_combined"]["train"].means,
            dset_loaders["DESI_combined"]["train"].stds
        )

    # DESI matched
    logging.info("├── DESI_matched")
    LoA, xx, yy = process_dset_splits.extract_data_matched(
        Dict_LoA_split["both"]["DESI"][key_dset],
        DATA["DESI"], keys_xx, keys_yy
    )
    dset_loaders["DESI_matched"][key_dset] = data_loaders.DataLoader(
        xx, yy, normalize=normalize, provided_normalization=provided_normalization
    )

    # JPAS matched
    logging.info("├── JPAS_matched")
    LoA, xx, yy = process_dset_splits.extract_data_matched(
        Dict_LoA_split["both"]["JPAS"][key_dset],
        DATA["JPAS"], keys_xx, keys_yy
    )
    dset_loaders["JPAS_matched"][key_dset] = data_loaders.DataLoader(
        xx, yy, normalize=normalize, provided_normalization=provided_normalization
    )

2025-05-16 12:50:55,586 - INFO - ⚙️ Preparing split: train
2025-05-16 12:50:55,587 - INFO - ├── DESI_combined
2025-05-16 12:50:55,587 - INFO - |    ├── 🔧 extract_and_combine_DESI_data()
2025-05-16 12:50:55,587 - INFO - |    ├── Extracting features and labels from DESI-only subset...
2025-05-16 12:50:57,326 - INFO - |    ├── Extracting features and labels from DESI-matched subset...
2025-05-16 12:50:57,492 - INFO - |    ├── Applied index shift of 611345 to matched DESI group to ensure uniqueness
2025-05-16 12:50:57,562 - INFO - |    ├── Finished extract_and_combine_DESI_data()
2025-05-16 12:50:57,567 - INFO - ├── 💿 Initializing DataLoader object with 633698 samples...
2025-05-16 12:50:57,970 - INFO - ├── DESI_matched
2025-05-16 12:50:57,970 - INFO - |    ├── 🔧 extract_data_matched()
2025-05-16 12:50:57,970 - INFO - |    ├── Extracting features and labels from matched dataset...
2025-05-16 12:50:58,026 - INFO - |    ├── Finished extract_data_matched()
2025-05-16 12:50:58,046 - INFO - ├──

In [11]:
for key_dset in ["train", "val", "test"]:
    print(key_dset)
    tmp_DESI_TARGETID = dset_loaders["DESI_matched"][key_dset].yy["TARGETID"]
    tmp_JPAS_TARGETID = dset_loaders["JPAS_matched"][key_dset].yy["TARGETID"]

    print(tmp_DESI_TARGETID.shape)
    print(tmp_JPAS_TARGETID.shape)

    print(np.unique(tmp_DESI_TARGETID).shape)
    print(np.unique(tmp_JPAS_TARGETID).shape)

    print(np.array_equal(
        np.sort(np.unique(tmp_DESI_TARGETID)),
        np.sort(np.unique(tmp_JPAS_TARGETID))
    ))
    print()

train
(22353,)
(21785,)
(21785,)
(21785,)
True

val
(2792,)
(2723,)
(2723,)
(2723,)
True

test
(2793,)
(2724,)
(2724,)
(2724,)
True

