In [None]:
import sys

import numpy as np
import sncosmo
from astropy.table import Table
from matplotlib import pyplot as plt

sys.path.insert(0, '../')
from analysis_pipeline.data_access import sdss


## Check filter transmission curves

In [None]:
fig, axes = plt.subplots(1, 6, figsize=(24, 4))
for i, axis in enumerate(axes):
    for band in 'ugriz':
        band = sncosmo.get_bandpass('91bg_proj_sdss_{}{}'.format(band, i + 1))
        axis.plot(band.wave, band.trans)
        axis.set_title('CCD Column {}'.format(i + 1))


## Get outlier data points

In [None]:
outliers = sdss._data_access_funcs._get_outliers()
list(outliers.items())[:10]


## Get SDSS published table of spectroscopically confirmed SNIa

In [None]:
spec_confirmed_sn = sdss.master_table[sdss.master_table['Classification'] == 'zSNIa']
spec_confirmed_sn[:10]


## Pick an arbitrary target and look at the light curve

In [None]:
test_id = 932

# Get fit published values
published_values = sdss.master_table[sdss.master_table['CID'] == test_id]
x0 = published_values['x0SALT2zspec'][0]
x1 = published_values['x1SALT2zspec'][0]
c = published_values['cSALT2zspec'][0]
chisq_norm = published_values['chi2SALT2zspec'][0] / published_values['ndofSALT2zspec'][0]
peak_mjd = published_values['MJDatPeakrmag'][0]

print('Published Values for CID {}:'.format(test_id))
print('x0: ', x0)
print('x1: ', x1)
print('c: ', c)
print('chisq: ', published_values['chi2SALT2zspec'][0])
print('ndof: ', published_values['ndofSALT2zspec'][0])
print('chisq_norm: ', chisq_norm)

# Get sdss photometry data
phot_data = sdss.get_data_for_id(test_id)

fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i in range(5):
    band_data = phot_data[phot_data['FILT'] == i]
    axes[i].scatter(band_data['MJD'], band_data['MAG'])
    axes[i].errorbar(band_data['MJD'], band_data['MAG'], yerr=band_data['MERR'], linestyle='')

plt.show()

phot_data[:10]


## Create an SNCosmo input table

In [None]:
@np.vectorize
def band_name(filt, idccd):
    """Return the sncosmo band name given filter and CCD id
    
    Args:
        filt  (str): Filter name <ugriz>
        idccd (int): Column number 1 through 6
    
    Args:
        The name of the filter registered with sncosmo
    """

    return '91bg_proj_sdss_{}{}'.format('ugriz'[filt], idccd)


In [None]:
input_table = Table()

input_table.meta = phot_data.meta
input_table['flag'] = phot_data['FLAG']
input_table['time'] = phot_data['MJD']
input_table['band'] = band_name(phot_data['FILT'], phot_data['IDCCD'])
input_table['zp'] = np.full(len(phot_data), 2.5 * np.log10(3631))
input_table['flux'] = phot_data['FLUX'] #* 1E-6
input_table['fluxerr'] = phot_data['FLUXERR'] #* 1E-6
input_table['zpsys'] = np.full(len(phot_data), 'ab')

# Apply cuts described in Sako et al.
input_table = input_table[input_table['flag'] < 1024]

if test_id in outliers:
    print('Wait you need to remove the outier data points!!!!!')

input_table.show_in_notebook(display_length=10)


## Run fit

In [None]:
print('\n\nFitting for all terms except z:')
source = sncosmo.get_source('salt2', version='2.0')
model = sncosmo.Model(source=source)
model.set(z=input_table.meta['redshift'])

result, fitted_model = sncosmo.fit_lc(
    input_table, 
    model, ['t0', 'x0', 'x1', 'c'], 
    bounds=None,
    modelcov=True,
    phase_range=[-15, 45],
    minsnr=5,
    warn=False
)


result['chi2_norm'] = (result.chisq / result.ndof)
sncosmo.plot_lc(input_table, model=fitted_model, errors=result.errors)
plt.show()
print(result)

