In [8]:
from desispec.io import read_spectra
from desitrip.preproc import rebin_flux, rescale_flux

from glob import glob

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from astropy.table import Table

import os
import platform

mpl.rc('font', size=14)

In [9]:
def condition_spectra(coadd_files, truth_files):
    """Read DESI spectra, rebin to a subsampled logarithmic wavelength grid, and rescale.
    
    Parameters
    ----------
    coadd_files : list or ndarray
        List of FITS files on disk with DESI spectra.
    truth_files : list or ndarray
        Truth files.
    
    Returns
    -------
    fluxes : ndarray
        Array of fluxes rebinned to a logarithmic wavelength grid.
    """
    fluxes = None
    
    for cf, tf in zip(coadd_files, truth_files):
        spectra = read_spectra(cf)
        wave = spectra.wave['brz']
        flux = spectra.flux['brz']
        ivar = spectra.ivar['brz']
        
        truth = Table.read(tf, 'TRUTH')
        truez = truth['TRUEZ']

#         # Pre-condition: remove spectra with NaNs and zero flux values.
#         mask = np.isnan(flux).any(axis=1) | (np.count_nonzero(flux, axis=1) == 0)
#         mask_idx = np.argwhere(mask)
#         flux = np.delete(flux, mask_idx, axis=0)
#         ivar = np.delete(ivar, mask_idx, axis=0)

        # Rebin and rescale fluxes so that each is normalized between 0 and 1.
        rewave, reflux, reivar = rebin_flux(wave, flux, ivar, truez, minwave=2500., maxwave=9500., nbins=150, log=True, clip=True)
        rsflux = rescale_flux(reflux)

        if fluxes is None:
            fluxes = rsflux
        else:
            fluxes = np.concatenate((fluxes, rsflux))
    
    return fluxes

In [11]:
host_truth = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/hosts/*truth.fits'))
host_coadd = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/hosts/*coadd.fits'))
host_flux  = condition_spectra(host_coadd, host_truth)

In [14]:
snia_truth = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/sn_ia/hsiao/*truth.fits'))
snia_files = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/sn_ia/hsiao/*coadd.fits'))
snia_flux  = condition_spectra(snia_files, snia_truth)

In [15]:
snib_truth = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/sn_ib/*/*truth.fits'))
snib_files = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/sn_ib/*/*coadd.fits'))
snib_flux  = condition_spectra(snib_files, snib_truth)

In [13]:
snic_truth = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/sn_ic/*/*truth.fits'))
snic_files = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/sn_ic/*/*coadd.fits'))
snic_flux  = condition_spectra(snic_files, snic_truth)

In [16]:
sniin_truth = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/sn_iin/*/*truth.fits'))
sniin_files = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/sn_iin/*/*coadd.fits'))
sniin_flux  = condition_spectra(sniin_files, sniin_truth)

In [17]:
sniilp_truth = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/sn_iilp/*/*truth.fits'))
sniilp_files = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/sn_iilp/*/*coadd.fits'))
sniilp_flux  = condition_spectra(sniilp_files, sniilp_truth)

In [18]:
sniip_truth = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/sn_iip/*/*truth.fits'))
sniip_files = sorted(glob('/global/cfs/projectdirs/desi/science/td/sim/bgs/150s/sn_iip/*/*coadd.fits'))
sniip_flux  = condition_spectra(sniip_files, sniip_truth)

In [34]:
sniip_merge_flux = []
sniip_merge_flux = np.concatenate((sniip_flux, sniilp_flux))

In [35]:
nhost, nbins  = host_flux.shape
nsnia, nbins  = snia_flux.shape
nsnib, nbins  = snib_flux.shape
nsnic, nbins  = snic_flux.shape
nsniin, nbins = sniin_flux.shape
nsniip_merge, nbins = sniip_merge_flux.shape
nhost, nsnia, nsnib, nsnic, nsniin, nsniip_merge, nbins

(9969, 9964, 9958, 8269, 9949, 19910, 150)

In [37]:
x = np.concatenate([host_flux, 
                    snia_flux,
                    snib_flux,
                    snic_flux,
                    sniin_flux,
                    sniip_merge_flux
                   ]).reshape(-1, nbins, 1)

labels = ['Host',
          'SN Ia',
          'SN Ib',
          'SN Ic',
          'SN IIn',
          'SN IIP Merged']
ntypes = len(labels)

# # Convert y-label array to appropriate categorical array
# from tensorflow.keras.utils import to_categorical

# y = to_categorical(
#         np.concatenate([np.full(nhost, 0), 
#                         np.full(nsnia, 1),
#                         np.full(nsnib, 2),
#                         np.full(nsnic, 3),
#                         np.full(nsniin, 4),
#                         np.full(nsniilp_merge, 5),
#                        ]))

In [38]:
import pickle

with open(r'/global/u2/a/awasserm/merged_data.data', 'wb') as merge_data:
    pickle.dump(x, merge_data) 

In [39]:
sniip_merge_flux.shape

(19910, 150)

In [40]:
(9948+9962)

19910