In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools
from JPAS_DA.data import crossmatch_tools
from JPAS_DA.data import process_dset_splits
from JPAS_DA.data import data_loaders

import numpy as np

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib widget

In [None]:
root_path = "/home/dlopez/Documentos/0.profesional/Postdoc/USP/Projects/JPAS_photozs/DATA/Train-Validate-Test"

load_JPAS_data = [{
    "name": "all",
    "npy": "JPAS_DATA_Aper_Cor_3_FLUX+NOISE.npy",
    "csv": "JPAS_DATA_PROPERTIES.csv",
    "sample_percentage": 1.0  # Optional, defaults to 1.0
}]

load_DESI_data = [
{
    "name": "train",
    "npy": "mock_3_train.npy",
    "csv": "props_training.csv",
    "sample_percentage": 0.3
},
{
    "name": "val",
    "npy": "mock_3_validate.npy",
    "csv": "props_validate.csv",
    "sample_percentage": 1.0
},
{
    "name": "test",
    "npy": "mock_3_test.npy",
    "csv": "props_test.csv",
    "sample_percentage": 1.0
}
]

random_seed_load = 42

In [None]:
DATA = loading_tools.load_dsets(root_path=root_path, datasets_jpas=load_JPAS_data, datasets_desi=load_DESI_data, random_seed=random_seed_load)

In [None]:
dict_clean_data_options = {
    "apply_masks"         : ["unreliable", "magic_numbers", "negative_errors", "nan_values", "apply_additional_filters"],
    "mask_indices"        : [0, -2],
    "magic_numbers"       : [-99, 99],
    "i_band_sn_threshold" : 0,
    "z_lim_QSO_cut"       : 2.2
}

In [None]:
DATA = cleaning_tools.clean_and_mask_data(
    DATA=DATA,
    apply_masks=dict_clean_data_options["apply_masks"],
    mask_indices=dict_clean_data_options["mask_indices"],
    magic_numbers=dict_clean_data_options["magic_numbers"],
    i_band_sn_threshold=dict_clean_data_options["i_band_sn_threshold"],
    z_lim_QSO_cut=dict_clean_data_options["z_lim_QSO_cut"]
)

In [None]:
Dict_LoA = {"both":{}, "only":{}} # Dictionary of Lists of Arrays (LoA) indicating, for each TARGETID, the associatted entries in the arrays, e.g. TARGETID[LoA[ii][0]] == TARGETID[LoA[ii][-1]]
(
    IDs_only_DESI, IDs_only_JPAS, IDs_both,
    Dict_LoA["only"]["DESI"], Dict_LoA["only"]["JPAS"],
    Dict_LoA["both"]["DESI"], Dict_LoA["both"]["JPAS"]
) = crossmatch_tools.crossmatch_IDs_two_datasets(
    DATA["DESI"]['TARGETID'], DATA["JPAS"]['TARGETID']
)

In [None]:
dict_split_data_options = {
    "train_ratio_both"             : 0.8,
    "val_ratio_both"               : 0.1,
    "test_ratio_both"              : 0.1,
    "random_seed_split_both"       : 42,
    "train_ratio_only_DESI"        : 0.8,
    "val_ratio_only_DESI"          : 0.1,
    "test_ratio_only_DESI"         : 0.1,
    "random_seed_split_only_DESI"  : 42
}

In [None]:
# Split the Lists of Arrays into training, validation, and testing sets
Dict_LoA_split = {"both":{}, "only":{}}
Dict_LoA_split["both"]["JPAS"] = process_dset_splits.split_LoA(
    Dict_LoA["both"]["JPAS"],
    train_ratio = dict_split_data_options["train_ratio_both"],
    val_ratio = dict_split_data_options["val_ratio_both"],
    test_ratio = dict_split_data_options["test_ratio_both"],
    seed = dict_split_data_options["random_seed_split_both"]
)
Dict_LoA_split["both"]["DESI"] = process_dset_splits.split_LoA(
    Dict_LoA["both"]["DESI"],
    train_ratio = dict_split_data_options["train_ratio_both"],
    val_ratio = dict_split_data_options["val_ratio_both"],
    test_ratio = dict_split_data_options["test_ratio_both"],
    seed = dict_split_data_options["random_seed_split_both"]
)
Dict_LoA_split["only"]["DESI"]  = process_dset_splits.split_LoA(
    Dict_LoA["only"]["DESI"],
    train_ratio = dict_split_data_options["train_ratio_only_DESI"],
    val_ratio = dict_split_data_options["val_ratio_only_DESI"],
    test_ratio = dict_split_data_options["test_ratio_only_DESI"],
    seed = dict_split_data_options["random_seed_split_only_DESI"]
)

In [None]:
keys_xx = ['OBS', 'ERR', 'MORPHTYPE_int']
keys_yy = ['SPECTYPE_int', 'TARGETID']
normalize = True

In [None]:
dset_loaders = {"DESI_combined":{}, "DESI_matched":{}, "JPAS_matched":{}}
provided_normalization = None
for key_dset in ["train", "val", "test"]:
    logging.info(f"⚙️ Preparing split: {key_dset}")

    # DESI combined (only + matched)
    logging.info("├── DESI_combined")
    LoA, xx, yy = process_dset_splits.extract_and_combine_DESI_data(
        Dict_LoA_split["only"]["DESI"][key_dset],
        Dict_LoA_split["both"]["DESI"][key_dset],
        DATA["DESI"], keys_xx, keys_yy
    )
    dset_loaders["DESI_combined"][key_dset] = data_loaders.DataLoader(
        xx, yy, normalize=normalize, provided_normalization=provided_normalization
    )
    if key_dset == "train":
        provided_normalization = (
            dset_loaders["DESI_combined"]["train"].means,
            dset_loaders["DESI_combined"]["train"].stds
        )

    # DESI matched
    logging.info("├── DESI_matched")
    LoA, xx, yy = process_dset_splits.extract_data_matched(
        Dict_LoA_split["both"]["DESI"][key_dset],
        DATA["DESI"], keys_xx, keys_yy
    )
    dset_loaders["DESI_matched"][key_dset] = data_loaders.DataLoader(
        xx, yy, normalize=normalize, provided_normalization=provided_normalization
    )

    # JPAS matched
    logging.info("├── JPAS_matched")
    LoA, xx, yy = process_dset_splits.extract_data_matched(
        Dict_LoA_split["both"]["JPAS"][key_dset],
        DATA["JPAS"], keys_xx, keys_yy
    )
    dset_loaders["JPAS_matched"][key_dset] = data_loaders.DataLoader(
        xx, yy, normalize=normalize, provided_normalization=provided_normalization
    )

In [None]:
for key_dset in ["train", "val", "test"]:
    print(key_dset)
    tmp_DESI_TARGETID = dset_loaders["DESI_matched"][key_dset].yy["TARGETID"]
    tmp_JPAS_TARGETID = dset_loaders["JPAS_matched"][key_dset].yy["TARGETID"]

    print(tmp_DESI_TARGETID.shape)
    print(tmp_JPAS_TARGETID.shape)

    print(np.unique(tmp_DESI_TARGETID).shape)
    print(np.unique(tmp_JPAS_TARGETID).shape)

    print(np.array_equal(
        np.sort(np.unique(tmp_DESI_TARGETID)),
        np.sort(np.unique(tmp_JPAS_TARGETID))
    ))
    print()