In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup
from JPAS_DA.data import loading_tools
from JPAS_DA.data import cleaning_tools
from JPAS_DA.data import crossmatch_tools
from JPAS_DA.data import process_dset_splits

import numpy as np

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib inline

In [None]:
root_path = global_setup.DATA_path
load_JPAS_data = global_setup.load_JPAS_data
load_DESI_data = global_setup.load_DESI_data

random_seed_load = 42

In [None]:
DATA = loading_tools.load_dsets(root_path=root_path, datasets_jpas=load_JPAS_data, datasets_desi=load_DESI_data, random_seed=random_seed_load)

In [None]:
dict_clean_data_options = global_setup.dict_clean_data_options

In [5]:
DATA = cleaning_tools.clean_and_mask_data(
    DATA=DATA,
    apply_masks=dict_clean_data_options["apply_masks"],
    mask_indices=dict_clean_data_options["mask_indices"],
    magic_numbers=dict_clean_data_options["magic_numbers"],
    i_band_sn_threshold=dict_clean_data_options["i_band_sn_threshold"],
    z_lim_QSO_cut=dict_clean_data_options["z_lim_QSO_cut"]
)

In [None]:
Dict_LoA = {"both":{}, "only":{}} # Dictionary of Lists of Arrays (LoA) indicating, for each TARGETID, the associatted entries in the arrays, e.g. TARGETID[LoA[ii][0]] == TARGETID[LoA[ii][-1]]
(
    IDs_only_DESI, IDs_only_JPAS, IDs_both,
    Dict_LoA["only"]["DESI"], Dict_LoA["only"]["JPAS"],
    Dict_LoA["both"]["DESI"], Dict_LoA["both"]["JPAS"]
) = crossmatch_tools.crossmatch_IDs_two_datasets(
    DATA["DESI"]['TARGETID'], DATA["JPAS"]['TARGETID']
)

In [None]:
print(len(Dict_LoA['only']['DESI']))
print(np.concatenate(Dict_LoA['only']['DESI']).shape)
print(Dict_LoA['only']['DESI'])

print(len(Dict_LoA['both']['DESI']))
print(np.concatenate(Dict_LoA['both']['DESI']).shape)
print(Dict_LoA['both']['DESI'])

print(len(Dict_LoA['both']['JPAS']))
print(np.concatenate(Dict_LoA['both']['JPAS']).shape)
print(Dict_LoA['both']['JPAS'])

In [None]:
dict_split_data_options = {
    "train_ratio_both"             : 0.8,
    "val_ratio_both"               : 0.1,
    "test_ratio_both"              : 0.1,
    "random_seed_split_both"       : 42,
    "train_ratio_only_DESI"        : 0.8,
    "val_ratio_only_DESI"          : 0.1,
    "test_ratio_only_DESI"         : 0.1,
    "random_seed_split_only_DESI"  : 42
}

In [None]:
# Split the Lists of Arrays into training, validation, and testing sets
Dict_LoA_split = {"both":{}, "only":{}}
Dict_LoA_split["both"]["JPAS"] = process_dset_splits.split_LoA(
    Dict_LoA["both"]["JPAS"],
    train_ratio = dict_split_data_options["train_ratio_both"],
    val_ratio = dict_split_data_options["val_ratio_both"],
    test_ratio = dict_split_data_options["test_ratio_both"],
    seed = dict_split_data_options["random_seed_split_both"]
)
Dict_LoA_split["both"]["DESI"] = process_dset_splits.split_LoA(
    Dict_LoA["both"]["DESI"],
    train_ratio = dict_split_data_options["train_ratio_both"],
    val_ratio = dict_split_data_options["val_ratio_both"],
    test_ratio = dict_split_data_options["test_ratio_both"],
    seed = dict_split_data_options["random_seed_split_both"]
)
Dict_LoA_split["only"]["DESI"]  = process_dset_splits.split_LoA(
    Dict_LoA["only"]["DESI"],
    train_ratio = dict_split_data_options["train_ratio_only_DESI"],
    val_ratio = dict_split_data_options["val_ratio_only_DESI"],
    test_ratio = dict_split_data_options["test_ratio_only_DESI"],
    seed = dict_split_data_options["random_seed_split_only_DESI"]
)

In [None]:
for ii, key_dset in enumerate(Dict_LoA_split["both"]["JPAS"].keys()):
    assert len(Dict_LoA_split["both"]["JPAS"][key_dset]) == len(Dict_LoA_split["both"]["DESI"][key_dset]), "Both datasets must have the same number unique TARGETIDs in each of training, validation, and testing sets."
    for jj in range(len(Dict_LoA_split["both"]["JPAS"][key_dset])):
        idx_ = Dict_LoA_split["both"]["JPAS"][key_dset][jj][0]
        tmp_TARGETID = DATA["JPAS"]["TARGETID"][idx_]
        for kk in range(len(Dict_LoA_split["both"]["DESI"][key_dset][jj])):
            idx_ = Dict_LoA_split["both"]["DESI"][key_dset][jj][kk]
            tmp_TARGETID_ = DATA["DESI"]["TARGETID"][idx_]
            assert tmp_TARGETID == tmp_TARGETID_, "Both datasets must have the same TARGETIDs in each of training, validation, and testing sets."
        if len(Dict_LoA_split["both"]["DESI"][key_dset][jj]) > 2:
            print("TARGETID JPAS:", tmp_TARGETID, "TARGETID DESI:", tmp_TARGETID_)

In [None]:
keys_xx = ['OBS', 'ERR', 'MORPHTYPE_int']
keys_yy = ['SPECTYPE_int', 'TARGETID']

In [None]:
key_dset = "train"

LoA_combined, xx_combined, yy_combined = process_dset_splits.extract_and_combine_DESI_data(
    Dict_LoA_split["only"]["DESI"][key_dset], Dict_LoA_split["both"]["DESI"][key_dset], DATA["DESI"], keys_xx, keys_yy
)
LoA_DESI_only, xx_DESI_only, yy_DESI_only = process_dset_splits.extract_data_using_LoA(Dict_LoA_split["only"]["DESI"][key_dset], DATA["DESI"], keys_xx, keys_yy)
LoA_DESI, xx_DESI, yy_DESI = process_dset_splits.extract_data_using_LoA(Dict_LoA_split["both"]["DESI"][key_dset], DATA["DESI"], keys_xx, keys_yy)
LoA_JPAS, xx_JPAS, yy_JPAS = process_dset_splits.extract_data_using_LoA(Dict_LoA_split["both"]["JPAS"][key_dset], DATA["JPAS"], keys_xx, keys_yy)

In [None]:
print(len(LoA_combined))
print(np.concatenate(LoA_combined).shape)

print(xx_combined.keys())
print(len(xx_combined['OBS']))
print(xx_combined['OBS'].shape)
print(len(xx_combined['MORPHTYPE_int']))
print(xx_combined['MORPHTYPE_int'].shape)

print(yy_combined.keys())
print(len(yy_combined['SPECTYPE_int']))
print(yy_combined['SPECTYPE_int'].shape)
print(len(yy_combined['TARGETID']))
print(yy_combined['TARGETID'].shape)

In [None]:
print(len(LoA_DESI_only))
print(np.concatenate(LoA_DESI_only).shape)

print(xx_DESI_only.keys())
print(len(xx_DESI_only['OBS']))
print(xx_DESI_only['OBS'].shape)
print(len(xx_DESI_only['MORPHTYPE_int']))
print(xx_DESI_only['MORPHTYPE_int'].shape)

print(yy_DESI_only.keys())
print(len(yy_DESI_only['SPECTYPE_int']))
print(yy_DESI_only['SPECTYPE_int'].shape)
print(len(yy_DESI_only['TARGETID']))
print(yy_DESI_only['TARGETID'].shape)

In [None]:
print(len(LoA_JPAS))
print(np.concatenate(LoA_JPAS).shape)

print(xx_JPAS.keys())
print(len(xx_JPAS['OBS']))
print(xx_JPAS['OBS'].shape)
print(len(xx_JPAS['MORPHTYPE_int']))
print(xx_JPAS['MORPHTYPE_int'].shape)

print(yy_JPAS.keys())
print(len(yy_JPAS['SPECTYPE_int']))
print(yy_JPAS['SPECTYPE_int'].shape)
print(len(yy_JPAS['TARGETID']))
print(yy_JPAS['TARGETID'].shape)

In [None]:
print(len(LoA_DESI))
print(np.concatenate(LoA_DESI).shape)

print(xx_DESI.keys())
print(len(xx_DESI['OBS']))
print(xx_DESI['OBS'].shape)
print(len(xx_DESI['MORPHTYPE_int']))
print(xx_DESI['MORPHTYPE_int'].shape)

print(yy_DESI.keys())
print(len(yy_DESI['SPECTYPE_int']))
print(yy_DESI['SPECTYPE_int'].shape)
print(len(yy_DESI['TARGETID']))
print(yy_DESI['TARGETID'].shape)

In [None]:
assert len(LoA_DESI) == len(LoA_JPAS), "Both datasets must have the same number unique TARGETIDs"
assert np.unique(yy_JPAS['TARGETID']).shape == np.unique(yy_DESI['TARGETID']).shape, "Both datasets must have the same number unique TARGETIDs"

for ii in np.arange(45,50):
    for jj in range(len(LoA_JPAS[ii])):
        print("JPAS TARGETID", yy_JPAS['TARGETID'][LoA_JPAS[ii][jj]])
    for jj in range(len(LoA_DESI[ii])):
        print("DESI TARGETID", yy_DESI['TARGETID'][LoA_DESI[ii][jj]])
    print()