In [116]:
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from glob import glob
from itertools import chain
from astropy.table import QTable, Table

In [97]:
def flatten_chain(matrix):
    return list(chain.from_iterable(matrix))

In [43]:
max_t = []
num_points = []
for file in glob('preprocessed_lc/test_lcs*.npz'):
    lightcurves = np.load(file, allow_pickle=True)['lcs']
    max_t.append([max(lightcurve.times) for lightcurve in lightcurves])
    num_points.append([len(lightcurve.times) for lightcurve in lightcurves])
max_t = np.array(flatten_chain(max_t))
num_points = np.array(flatten_chain(num_points))

In [61]:
lightcurves = np.load('preprocessed_lc/test_lcs_10.npz', allow_pickle=True)['lcs']
names = np.array([lightcurve.name.split('_')[0] for lightcurve in lightcurves])

In [62]:
test = []
for sn in sorted(set(names)):
    inds = np.where(names==sn)[0]
    times = [len(lightcurve.times) for lightcurve in lightcurves[inds]]
    argmax = np.argmax(times)
    test.append(lightcurves[inds][argmax])

In [63]:
for sn in test:
    print(sn.name,len(sn.times), min(sn.times), max(sn.times))

AGN_id98634196 64 0.0 141.33123556581796
CaRT_id1308595 49 -12.039746543779758 119.80843894009247
ILOT_id92676146 23 -36.332970908022546 87.61387011460519
KN_id22691576 4 0.0 1.1942848682264018
PISN_id102824316 31 -66.95679012345673 39.672356414384886
SLSNI_id22664506 87 -50.59170021678538 165.92006813254994
SNII_id320467 91 -0.5584722222233217 37.39522222222149
SNIa_id44217 78 -12.74774851316762 61.107476635511595
SNIa91bg_id342818 40 -3.047698113209577 61.06226415094428
SNIax_id129179 82 -12.838706896557063 90.40215517241074
SNIbc_id56987 77 -43.37026615969717 9.065475285170775
TDE_id120886180 30 -20.225922509223395 142.79677121771363


In [121]:
IIn_ids = Table.read('plasticc_modelpar_042_SNIIn.csv',format='ascii')['object_id'].value
lightcurves = np.load('preprocessed_lc/test_lcs_2.npz', allow_pickle=True)['lcs']
names = np.array([lightcurve.name.split('_')[0] for lightcurve in lightcurves])
ids = np.array([lightcurve.name.split('_')[-1][2:] for lightcurve in lightcurves])
names[np.nonzero(np.in1d(ids,IIn_ids))[0]]='SNIIn'
test = []
for sn in sorted(set(names)):
    inds = np.where(names==sn)[0]
    times = [len(lightcurve.times) for lightcurve in lightcurves[inds]]
    argmax = np.argmax(times)
    test.append(lightcurves[inds][argmax])

lengths = flatten_chain([np.arange(1,301,1) for i in range(len(test))])
ids = flatten_chain([[lightcurve.name for i in range(300)]for lightcurve in test])
sequence_len = np.max(lengths)
nfilts = 6
nfiltsp1 = nfilts+1
n_lcs = len(test)*300
sequence = np.zeros((n_lcs, sequence_len, nfilts*2+1))

lightcurves = flatten_chain([[lightcurve for i in range(300)]for lightcurve in test])
lms = []
for i, (lightcurve, time) in enumerate(zip(lightcurves,lengths)):
    sequence[i, 0:lengths[i], 0] = lightcurve.dense_times[:time]
    sequence[i, 0:lengths[i], 1:nfiltsp1] = lightcurve.dense_lc[:, :, 0][:time]
    sequence[i, 0:lengths[i], nfiltsp1:] = lightcurve.dense_lc[:, :, 1][:time] + 0.01
    sequence[i, lengths[i]:, 0] = 200+100
    sequence[i, lengths[i]:, 1:nfiltsp1] = lightcurve.abs_lim_mag
    sequence[i, lengths[i]:, nfiltsp1:] = 1
    lms.append(lightcurve.abs_lim_mag)

In [5]:
def prep_input(input_lc_file, new_t_max=100.0, filler_err=1.0,
               save=False, load=False, outdir=None, prep_file=None):
    """
    Prep input file for fitting

    Parameters
    ----------
    input_lc_file : str
        True flux values
    new_t_max : float
        Predicted flux values
    filler_err : float
        Predicted flux values
    save : bool
        Predicted flux values
    load : bool
        Predicted flux values
    outdir : str
        Predicted flux values
    prep_file : str
        Predicted flux values

    Returns
    -------
    sequence : numpy.ndarray
        Array LC flux times, values and errors
    outseq : numpy.ndarray
        An array of LC flux values and limiting magnitudes
    ids : numpy.ndarray
        Array of SN names
    sequence_len : float
        Maximum length of LC values
    nfilts : int
        Number of filters in LC files
    """
    lightcurves = np.load('preprocessed_lc/test_lcs_10.npz', allow_pickle=True)['lcs']
    names = np.array([lightcurve.name.split('_')[0] for lightcurve in lightcurves])
    test = []
    for sn in sorted(set(names)):
        inds = np.where(names==sn)[0]
        times = [len(lightcurve.times) for lightcurve in lightcurves[inds]]
        argmax = np.argmax(times)
        test.append(lightcurves[inds][argmax])

    lengths = flatten_chain([np.arange(1,301,1) for i in range(12)])
    ids = flatten_chain([[lightcurve.name for i in range(300)]for lightcurve in test])
    sequence_len = np.max(lengths)
    nfilts = 6
    nfiltsp1 = nfilts+1
    n_lcs = 3600
    sequence = np.zeros((n_lcs, sequence_len, nfilts*2+1))

    lightcurves = flatten_chain([[lightcurve for i in range(300)]for lightcurve in test])
    lms = []
    for i, (lightcurve, time) in enumerate(zip(lightcurves,lengths)):
        sequence[i, 0:lengths[i], 0] = lightcurve.dense_times[:time]
        sequence[i, 0:lengths[i], 1:nfiltsp1] = lightcurve.dense_lc[:, :, 0][:time]
        err_pred = []
        for j in range(6):
            err_pred.append(np.interp(lightcurve.times, lightcurve.dense_times, lightcurve.dense_lc[:, :, 1].T[j]))
        err_pred = np.array(err_pred).T
        sequence[i, 0:lengths[i], nfiltsp1:] = lightcurve.dense_lc[:, :, 1][:time] + 0.01
        sequence[i, lengths[i]:, 0] = 200+new_t_max
        sequence[i, lengths[i]:, 1:nfiltsp1] = lightcurve.abs_lim_mag
        sequence[i, lengths[i]:, nfiltsp1:] = filler_err
        lms.append(lightcurve.abs_lim_mag)

    # Flip because who needs negative magnitudes
    sequence[:, :, 1:nfiltsp1] = -1.0 * sequence[:, :, 1:nfiltsp1]

    if load:
        prep_data = np.load(prep_file)
        bandmin = prep_data['bandmin']
        bandmax = prep_data['bandmax']
    else:
        bandmin = np.min(sequence[:, :, 1:nfiltsp1])
        bandmax = np.max(sequence[:, :, 1:nfiltsp1])

    sequence[:, :, 1:nfiltsp1] = (sequence[:, :, 1:nfiltsp1] - bandmin)         / (bandmax - bandmin)

    new_lms = np.reshape(np.repeat(lms, sequence_len), (len(lms), -1))

    outseq = np.reshape(sequence[:, :, 0], (len(sequence), sequence_len, 1)) * 1.0
    outseq = np.dstack((outseq, new_lms))
    if save:
        model_prep_file = outdir+'prep_'+date+'.npz'
        np.savez(model_prep_file, bandmin=bandmin, bandmax=bandmax)
        model_prep_file = outdir+'prep.npz'
        np.savez(model_prep_file, bandmin=bandmin, bandmax=bandmax)
    return sequence, outseq, ids, sequence_len, nfilts

In [None]:
lightcurves = np.load('preprocessed_lc/train_lcs.npz', allow_pickle=True)['lcs']
lengths = []
ids = []
gind = []
for j,lightcurve in enumerate(lightcurves):
    if type(lightcurve) == float:
        continue
    gind.append(j)
    lengths.append(len(lightcurve.times))
    ids.append(lightcurve.name)
lightcurves = lightcurves[gind]
sequence_len = np.max(lengths)
nfilts = np.shape(lightcurves[0].dense_lc)[1]
nfiltsp1 = nfilts+1
n_lcs = len(lightcurves)
# convert from LC format to list of arrays
sequence = np.zeros((n_lcs, sequence_len, nfilts*2+1))

In [None]:
lms = []
for i, lightcurve in enumerate(lightcurves):
    sequence[i, 0:lengths[i], 0] = lightcurve.times
    pred = []
    for j in range(6):
        pred.append(np.interp(lightcurve.times, lightcurve.dense_times, lightcurve.dense_lc[:, :, 0].T[j]))
    pred = np.array(pred).T
    sequence[i, 0:lengths[i], 1:nfiltsp1] = pred
    err_pred = []
    for j in range(6):
        err_pred.append(np.interp(lightcurve.times, lightcurve.dense_times, lightcurve.dense_lc[:, :, 1].T[j]))
    err_pred = np.array(err_pred).T
    sequence[i, 0:lengths[i], nfiltsp1:] = err_pred + 0.01
    sequence[i, lengths[i]:, 0] = np.max(lightcurve.times)+new_t_max
    sequence[i, lengths[i]:, 1:nfiltsp1] = lightcurve.abs_lim_mag
    sequence[i, lengths[i]:, nfiltsp1:] = filler_err
    lms.append(lightcurve.abs_lim_mag)

In [None]:
lightcurve = lightcurves[3644]

In [None]:
sequence

In [None]:
i = 3644
new_z, new_mag = resample(lightcurve)
new_time = lightcurve.times*(1.+lightcurve.redshift)/(1.+new_z)
sequence[i, 0:lengths[i], 0] = new_time
sequence[i, 0:lengths[i], 1:nfiltsp1] = new_mag.T
err_pred = []
for j in range(6):
    err_pred.append(np.interp(lightcurve.times, lightcurve.dense_times, lightcurve.dense_lc[:, :, 1].T[j]))
err_pred = np.array(err_pred).T
sequence[i, 0:lengths[i], nfiltsp1:] = err_pred + 0.01
sequence[i, lengths[i]:, 0] = np.max(new_time)+new_t_max
sequence[i, lengths[i]:, 1:nfiltsp1] = lightcurve.abs_lim_mag
sequence[i, lengths[i]:, nfiltsp1:] = filler_err
lms.append(lightcurve.abs_lim_mag)

In [None]:
 = resample(lightcurve)
sequence[i, 0:lengths[i], 0] = lightcurve.times
pred = []
for j in range(6):
    pred.append(np.interp(lightcurve.times, lightcurve.dense_times, lightcurve.dense_lc[:, :, 0].T[j]))
pred = np.array(pred).T
sequence[i, 0:lengths[i], 1:nfiltsp1] = pred
err_pred = []
for j in range(6):
    err_pred.append(np.interp(lightcurve.times, lightcurve.dense_times, lightcurve.dense_lc[:, :, 1].T[j]))
err_pred = np.array(err_pred).T
sequence[i, 0:lengths[i], nfiltsp1:] = err_pred + 0.01
sequence[i, lengths[i]:, 0] = np.max(lightcurve.times)+new_t_max
sequence[i, lengths[i]:, 1:nfiltsp1] = lightcurve.abs_lim_mag
sequence[i, lengths[i]:, nfiltsp1:] = filler_err
lms.append(lightcurve.abs_lim_mag)

In [None]:
pred = []
for j in range(6):
    pred.append(np.interp(lightcurve.times, lightcurve.dense_times, lightcurve.dense_lc[:, :, 0].T[j]))

err_pred = []
for j in range(6):
    err_pred.append(np.interp(lightcurve.times, lightcurve.dense_times, lightcurve.dense_lc[:, :, 1].T[j]))

new_mag = []
for mag, err in zip(np.ravel(pred),np.ravel(err_pred)):
    new_mag.append(stats.norm(loc=mag, scale=err).rvs())
new_mag = np.reshape(new_mag, (6,len(lightcurve.times)))

In [None]:
fig = plt.figure(figsize=(6,5))

colors = ['darkblue','green','red','orange','purple','black']
times = lightcurve.times
dense_fluxes = lightcurve.dense_lc[:,:,0]
dense_errs = lightcurve.dense_lc[:,:,1]
for i in range(6):
    plt.plot(lightcurve.dense_times,dense_fluxes[:,i],color=colors[i],lw=2)
    plt.fill_between(lightcurve.dense_times, dense_fluxes[:,i]+dense_errs[:,i],
                     dense_fluxes[:,i]-dense_errs[:,i],alpha=0.1,color=colors[i])
    plt.scatter(times, new_mag[i], color=colors[i], marker='s')
# for t, mag, band in zip(times, lightcurve.abs_mags, lightcurve.filters):
#     plt.plot(t, mag,'o',color=colors[band])
plt.xlim(-20,30)
plt.gca().invert_yaxis()
plt.tick_params(labelsize=14)
plt.xlabel('Phase (days)', fontsize=18)
plt.ylabel('Scale Absolute Magnitude', fontsize=18)
# plt.title(list(transient_class.keys())[k], fontsize=20)
# plt.savefig('/Users/brianhsu/Downloads/'+list(transient_class.keys())[k]+'.png',bbox_inches='tight')shape(new_mag, (len(lightcurve.times),6)).shape

In [None]:
def resample(lightcurve):
    def z_resample():
        lo, high = (0, 10)
        loc, scale = (lightcurve.redshift, lightcurve.redshift_err)
        a, b = ((lo - loc)/scale, (high - loc)/scale)
        return stats.truncnorm(loc=loc, scale=scale, a=a, b=b).rvs()
    
    new_z = z_resample()
    
    pred = []
    for j in range(6):
        pred.append(np.interp(lightcurve.times, lightcurve.dense_times, lightcurve.dense_lc[:, :, 0].T[j]))

    err_pred = []
    for j in range(6):
        err_pred.append(np.interp(lightcurve.times, lightcurve.dense_times, lightcurve.dense_lc[:, :, 1].T[j]))
    
    new_mag = []
    for mag, err in zip(np.ravel(pred),np.ravel(err_pred)):
        new_mag.append(stats.norm(loc=mag, scale=err).rvs())
    return new_z, np.reshape(new_mag, (6,len(lightcurve.times)))

In [6]:
class LightCurve(object):
    """Light Curve class
    """
    def __init__(self, name, times, fluxes, flux_errs, filters,
                 zpt=0, mwebv=0, redshift=None, redshift_err=None,
                 lim_mag=None, obj_type=None):

        self.name = name
        self.times = times
        self.fluxes = fluxes
        self.flux_errs = flux_errs
        self.filters = filters
        self.zpt = zpt
        self.mwebv = mwebv
        self.redshift = redshift
        self.redshift_err = redshift_err
        self.lim_mag = lim_mag
        self.obj_type = obj_type

        self.abs_mags = None
        self.abs_mags_err = None
        self.abs_lim_mag = None

    def sort_lc(self):
        gind = np.argsort(self.times)
        self.times = self.times[gind]
        self.fluxes = self.fluxes[gind]
        self.flux_errs = self.flux_errs[gind]
        self.filters = self.filters[gind]
        if self.abs_mags is not None:
            self.abs_mags = self.abs_mags[gind]
            self.abs_mags_err = self.abs_mags_err[gind]

    def find_peak(self, tpeak_guess):
        gind = np.where((np.abs(self.times-tpeak_guess) < 1000.0) &
                        (self.fluxes/self.flux_errs > 3.0))
        if len(gind[0]) == 0:
            gind = np.where((np.abs(self.times - tpeak_guess) < 1000.0))
        if len(gind[0]) == 0:
            tpeak = tpeak_guess
            return tpeak
        if self.abs_mags is not None:
            tpeak = self.times[gind][np.argmin(self.abs_mags[gind])]
        return tpeak

    def cut_lc(self, limit_before=100, limit_after=200):
        gind = np.where((self.times > -limit_before) &
                        (self.times < limit_after))
        self.times = self.times[gind]
        self.fluxes = self.fluxes[gind]
        self.flux_errs = self.flux_errs[gind]
        self.filters = self.filters[gind]
        if self.abs_mags is not None:
            self.abs_mags = self.abs_mags[gind]
            self.abs_mags_err = self.abs_mags_err[gind]

    def shift_lc(self, t0=0):
        self.times = self.times - t0

    def correct_time_dilation(self):
        self.times = self.times / (1.+self.redshift)

    def add_LC_info(self, zpt=27.5, mwebv=0.0, redshift=0.0,redshift_err=0.0,
                    lim_mag=25.0, obj_type='-'):
        self.zpt = zpt
        self.mwebv = mwebv
        self.redshift = redshift
        self.redshift_err = redshift_err
        self.lim_mag = lim_mag
        self.obj_type = obj_type

    def get_abs_mags(self, replace_nondetections=True, mag_err_fill=1.0):
        """
        Convert flux into absolute magnitude

        Parameters
        ----------
        replace_nondetections : bool
            Replace nondetections with limiting mag.

        Returns
        -------
        self.abs_mags : list
            Absolute magnitudes

        Examples
        --------
        """
        lsst_filters = {'0':3740., '1':4870., '2':6250., '3':7700., '4':8900., '5':10845.}
        ext = G23(Rv=3.1)
        reddening = -2.5 * np.log10(ext.extinguish([lsst_filters[str(filt)] for filt 
                                                    in self.filters.astype(int)] * u.AA, 
                                                    Ebv=self.mwebv))
        k_correction = 2.5 * np.log10(1.+self.redshift)
        dist = cosmo.luminosity_distance([self.redshift]).value[0]  # returns dist in Mpc

        self.abs_mags = -2.5 * np.log10(self.fluxes) + self.zpt - 5. * \
            np.log10(dist*1e6/10.0) + k_correction - reddening
        self.abs_mags_err = np.abs((2.5/np.log(10))*(self.flux_errs/self.fluxes))

        if replace_nondetections:
            abs_lim_mag = self.lim_mag - 5.0 * np.log10(dist * 1e6 / 10.0) + \
                            k_correction
            gind = np.where((np.isnan(self.abs_mags)) |
                            np.isinf(self.abs_mags) |
                            np.isnan(self.abs_mags_err) |
                            np.isinf(self.abs_mags_err) |
                            (self.abs_mags > self.lim_mag))

            self.abs_mags[gind] = abs_lim_mag
            self.abs_mags_err[gind] = mag_err_fill
        self.abs_lim_mag = abs_lim_mag

        return self.abs_mags, self.abs_mags_err

    def make_dense_LC(self, nfilts=6):
        gp_mags = self.abs_mags - self.abs_lim_mag
        dense_fluxes = np.zeros((len(self.times), nfilts))
        dense_errs = np.zeros((len(self.times), nfilts))
        stacked_data = np.vstack([self.times, self.filters]).T
        x_pred = np.zeros((len(self.times)*nfilts, 2))
        print(self.name)

        pred, pred_var, gp, times = run_gp(self.times, self.filters, gp_mags, self.abs_mags_err)
        pred = pred.T
        pred_var = pred_var.T
        self.gp = [1,2,3]

        dense_fluxes = pred + self.abs_lim_mag
        dense_errs = np.sqrt(pred_var)

        self.dense_lc = np.dstack((dense_fluxes, dense_errs))
        self.dense_times = times

        self.gp_mags = gp_mags
        return gp, gp_mags