In [None]:
import tables_io
import numpy as np
import matplotlib.pyplot as plt
from rail.raruma import plotting_functions as raruma_plot
from rail.raruma import utility_functions as raruma_util

In [None]:
seds = [    
    'El_B2004a',
    'Sbc_B2004a',
    'Scd_B2004a',
    'Im_B2004a',
    'SB3_B2004a',
    'SB2_B2004a',
    'ssp_25Myr_z008',
    'ssp_5Myr_z008',
]
seds_2 = [    
    'El_B2004a',
    'Sbc_B2004a',
    'Scd_B2004a',
    'Im_B2004a',
    'SB3_B2004a',
    'SB2_B2004a',
    'ssp_25Myr_z008',
    'ssp_5Myr_z008',
]

dataset = 'DP1'
if dataset == 'Rubin':
    input_file = '/Users/echarles/pz/sandbox_data/roman_rubin_9925.hdf5'
    band_template = 'LSST_obs_{band}'
    bands = 'ugrizy'
    filters = ['DC2LSST_u', 'DC2LSST_g', 'DC2LSST_r', 'DC2LSST_i', 'DC2LSST_z', 'DC2LSST_y']
    labels = ['u-g', 'g-i', 'r-i', 'i-z', 'z-y']
    redshift_col = 'redshift'
    sample = 100
elif dataset == 'ComCam':
    input_file = '/Users/echarles/pz/data/test/com_cam_secured_matched_test.hdf5'
    band_template = '{band}_cModelMag'
    bands = 'ugrizy'
    filters = ['DC2LSST_u', 'DC2LSST_g', 'DC2LSST_r', 'DC2LSST_i', 'DC2LSST_z', 'DC2LSST_y']
    labels = ['u-g', 'g-i', 'r-i', 'i-z', 'z-y']
    redshift_col = 'redshift'
    sample = 1
elif dataset == 'DC2':
    input_file = '/Users/echarles/pz/data/test/dc2_run2p2i_dr6_test_dered.hdf5'
    band_template = 'mag_{band}_lsst'
    bands = 'grizy'
    filters = ['DC2LSST_g', 'DC2LSST_r', 'DC2LSST_i', 'DC2LSST_z', 'DC2LSST_y']
    labels = ['g-i', 'r-i', 'i-z', 'z-y']
    redshift_col = 'redshift_true'
    sample = 1    
elif dataset == 'HSC':
    input_file = '/Users/echarles/pz/data/test/hsc_pdr3_wide_test_curated.hdf5'
    band_template = 'HSC{band}_cmodel_dered'
    bands = 'grizy'
    filters = ['DC2LSST_g', 'DC2LSST_r', 'DC2LSST_i', 'DC2LSST_z', 'DC2LSST_y']
    labels = ['g-i', 'r-i', 'i-z', 'z-y']
    redshift_col = 'specz_redshift'
    sample = 1    
elif dataset == 'LSST':
    input_file = '/Users/echarles/pz/sandbox_data/cdfs_matched_hst_dereddened.hdf5'
    band_template = "{band}_sersicMag"
    bands = 'ugrizy'
    filters = ['DC2LSST_u', 'DC2LSST_g', 'DC2LSST_r', 'DC2LSST_i', 'DC2LSST_z', 'DC2LSST_y']
    labels = ['u-g', 'g-i', 'r-i', 'i-z', 'z-y']
    redshift_col = 'redshift'
    sample = 1  
elif dataset == 'DP1':
    input_file = '/Users/echarles/pz/data/train/dp1_matched_train.hdf5'
    band_template = "{band}_gaap1p0Mag"
    bands = 'ugrizy'
    #filters = ['comcam_u', 'comcam_g', 'comcam_r', 'comcam_i', 'comcam_z', 'comcam_y']
    filters = ['DC2LSST_u', 'DC2LSST_g', 'DC2LSST_r', 'DC2LSST_i', 'DC2LSST_z', 'DC2LSST_y']
    labels = ['u-g', 'g-i', 'r-i', 'i-z', 'z-y']
    redshift_col = 'redshift'
    sample = 1
    
    
mag_labels = [f'Mag {band}' for band in bands]

In [None]:
data = tables_io.read(input_file)
band_names = raruma_util.make_band_names(band_template, bands)
mags = raruma_util.extract_data_to_2d_array(data, band_names)
colors = raruma_util.adjacent_band_colors(mags)
redshifts = data[redshift_col]

In [None]:
detect = np.isfinite(mags)

In [None]:
(mags[:,3] < 26.0).sum()

In [None]:
detect_6_band = detect.sum(axis=1) == 6

In [None]:
template_dict = raruma_util.build_template_dict(seds, filters)

In [None]:
_ = raruma_plot.plot_colors_v_redshifts_with_templates(redshifts[::sample], colors[::sample], zmax=3., templates=template_dict, labels=labels)

In [None]:
_ = raruma_plot.plot_colors_v_colors_with_templates(redshifts, colors[detect_6_band][::100], templates=template_dict, labels=labels)

In [None]:
_ = raruma_plot.plot_features_target_scatter(mags, redshifts[::sample], labels=mag_labels)

In [None]:
_ = raruma_plot.plot_feature_histograms(mags, labels=mag_labels)