In [None]:
import tables_io
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib import colors, cm
from rail.raruma import plotting_functions as raruma_plot
from rail.raruma import utility_functions as raruma_util

In [None]:
seds = [    
    'El_B2004a',
    'Sbc_B2004a',
    'Scd_B2004a',
    'Im_B2004a',
    'SB3_B2004a',
    'SB2_B2004a',
    'ssp_25Myr_z008',
    'ssp_5Myr_z008',
]
seds_2 = [    
    'El_B2004a',
    'Sbc_B2004a',
    'Scd_B2004a',
    'Im_B2004a',
    'SB3_B2004a',
    'SB2_B2004a',
    'ssp_25Myr_z008',
    'ssp_5Myr_z008',
]

dataset = 'DP1'
if dataset == 'DP1':
    input_file = '/Users/echarles/pz/data/train/dp1_matched_train.hdf5'
    band_template = "{band}_gaap1p0Mag"
    bands = 'ugrizy'
    #filters = ['comcam_u', 'comcam_g', 'comcam_r', 'comcam_i', 'comcam_z', 'comcam_y']
    filters = ['DC2LSST_u', 'DC2LSST_g', 'DC2LSST_r', 'DC2LSST_i', 'DC2LSST_z', 'DC2LSST_y']
    labels = ['u-g', 'g-i', 'r-i', 'i-z', 'z-y']
    redshift_col = 'redshift'
    sample = 1
    
    
mag_labels = [f'Mag {band}' for band in bands]

In [None]:
data = tables_io.read(input_file)
band_names = raruma_util.make_band_names(band_template, bands)
mags = raruma_util.extract_data_to_2d_array(data, band_names)
colors = raruma_util.adjacent_band_colors(mags)
redshifts = data[redshift_col]

In [None]:
detect = np.isfinite(mags)
detect_6_band = detect.sum(axis=1) == 6

In [None]:
template_dict = raruma_util.build_template_dict(seds, filters)

In [None]:
def plot_colors_v_redshifts_with_templates(
    redshifts: np.ndarray,
    colors: np.ndarray,
    zmax: float=4.0,
    templates: dict|None=None,
    labels: list[str]|None=None,    
) -> Figure:
    
    fig = plt.figure(figsize=(12, 16))
    n_colors = colors.shape[-1]
    nrow, ncol = (5, 1)
    axs = fig.subplots(nrow, ncol)

    for icolor in range(n_colors):
        icol = int(icolor / ncol)
        irow = icolor % ncol
        #axs[icol].scatter(redshifts, colors[:,icolor], color='black', s=1)
        axs[icol].hist2d(redshifts, colors[:,icolor], bins=(np.linspace(0., zmax, 201), np.linspace(-3., 3., 61)), cmap="binary")
        axs[icol].set_xlim(0, zmax)
        axs[icol].set_ylim(-3., 3.)
        if templates is not None:
            for key, val in templates.items():
                mask = val[0] < zmax
                _ = axs[icol].plot(val[0][mask], val[2][icolor][mask], label=key, c=cm.rainbow(1.-val[3]/len(templates)), alpha=0.2)
        # axs[icol][irow].legend()
        axs[icol].set_xlabel("redshift")
        if labels is not None:
            axs[icol].set_ylabel(labels[icolor])
            
    return fig


In [None]:
the_fig = plot_colors_v_redshifts_with_templates(redshifts, colors, templates=template_dict, labels=labels)

In [None]:
the_fig.savefig('color_v_redshift.pdf')

In [None]:
def plot_colors_v_colors_with_templates(
    redshifts: np.ndarray,
    colors: np.ndarray,
    zmax: float=4.0,
    templates: dict|None=None,
    labels: list[str]|None=None,    
) -> Figure:

    fig = plt.figure(figsize=(24, 24))
    n_colors = colors.shape[-1]
    nrow, ncol = n_colors-1, n_colors-1
    axs = fig.subplots(nrow, ncol)

    for icol in range(n_colors-1):        
        for irow in range(n_colors-1):
            if irow < icol:
                axs[icol, irow].set_visible(False)
                continue            
            axs[icol][irow].set_xlim(-1.5, 1.5)
            axs[icol][irow].set_ylim(-1.5, 1.5)
            if labels is not None:
                axs[icol][irow].set_ylabel(labels[icol])
                axs[icol][irow].set_xlabel(labels[irow+1])
            #axs[icol][irow].scatter(colors[:,icol], colors[:,irow+1], color='black', s=1)
            axs[icol][irow].hist2d(colors[:,icol], colors[:,irow+1], bins=(np.linspace(-1.5, 1.5, 61), np.linspace(-1.5, 1.5, 61)), cmap="binary")
            if templates is not None:
                for key, val in templates.items():
                    mask = val[0] < zmax
                    _ = axs[icol][irow].plot(val[2][icol][mask], val[2][irow+1][mask], label=key, c=cm.rainbow(1.-val[3]/len(templates)), alpha=0.5)
            # axs[icol][irow].legend()
    return fig


In [None]:
the_fig = plot_colors_v_colors_with_templates(redshifts, colors, templates=template_dict, labels=labels)

In [None]:
the_fig.savefig('color_v_color.pdf')

In [None]:
def plot_mag_i_v_redshift(
    redshifts: np.ndarray,
    mag_i: np.ndarray,
    zmax: float=4.0,
) -> Figure:

    fig = plt.figure(figsize=(8, 8))
    axs = fig.subplots(1, 1)

    axs.set_xlim(0., zmax)
    axs.set_ylim(16, 26)
    
    axs.set_xlabel('redshift')
    axs.set_ylabel('i [Mag]')
    
    axs.hist2d(redshifts, mag_i, bins=(np.linspace(0, zmax, 201), np.linspace(16, 26, 101)), cmap="binary")
    return fig


In [None]:
the_fig = plot_mag_i_v_redshift(redshifts, mags[:,3])

In [None]:
the_fig.savefig('mag_i_v_redshift.pdf')

In [None]:
def plot_mags(
    mags: np.ndarray,
) -> Figure:

    fig = plt.figure(figsize=(8, 8))
    axs = fig.subplots(1, 1)
    n_mags = mags.shape[-1]

    axs.set_xlim(16, 28)
    
    axs.set_xlabel('magnitude')
    axs.set_ylabel('Objects / [0.1 mag]')

    bands = 'ugrizy'
    for i in range(n_mags):
        axs.hist(mags[:,i], bins=np.linspace(16, 28, 121), color=cm.rainbow(i/n_mags), label=bands[i], alpha=0.2)
    axs.legend()
    return fig


In [None]:
the_fig = plot_mags(mags)

In [None]:
the_fig.savefig('mags.pdf')

In [None]:
mags