In [None]:
import signal
import sys
from copy import deepcopy

import numpy as np
import sncosmo
from matplotlib import pyplot as plt
from sndata.csp import dr1, dr3
from tqdm import tqdm_notebook
from IPython.display import clear_output

sys.path.insert(0, '../')

from analysis_pipeline import models
from analysis_pipeline.utils import split_data
from analysis_pipeline.lc_fitting import calc_chisq


## Load and register data

In [None]:
# Download and register data
dr1.download_module_data()
dr3.download_module_data()
dr3.register_filters(force=True)
models.register_sources(force=True)


## Define utility and fitting functions

In [None]:
class timeout:
    """A timeout context manager"""

    def __init__(self, seconds=1, error_message='Timeout'):
        """A timeout context manager
        Args:
            seconds       (int): The number of seconds until timeout
            error_message (str): The TimeOutError message on timeout
        """

        self.seconds = seconds
        self.error_message = error_message

    def handle_timeout(self, signum, frame):
        raise TimeoutError(self.error_message)

    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)

    def __exit__(self, type_, value, traceback):
        signal.alarm(0)


## Define classification functions

In [None]:
def get_classification_coords(
    all_data, red_data, blue_data, fit_func=sncosmo.fit_lc, 
    kwargs_all=dict(), kwargs_s2=dict(), kwargs_bg=dict(), 
    show_plots=False):
    """Determine the classification coordinates for a given target
    
    Args:
        data             (Table): Table of light curve data for SNCosmo
        band_names   (list[str]): Band names of the parent survey
        lambda_eff (list[float]): Effective wavelength for each band
    
    Returns:
       The x and y classification coordinates
    """
    
    # Parse data
    z = all_data.meta['redshift']    
    vparmas = ['t0', 'x0', 'x1', 'c']

    # Load models
    salt2 = sncosmo.Model('salt2')
    sn91bg = sncosmo.Model(source=sncosmo.get_source('sn91bg', version='salt2_phase'))

    # Fit salt2 model with fixed redshift
    salt2.set(z=z)
    norm_result_all, norm_fit_all = fit_func(all_data, salt2, vparmas, **kwargs_all, **kwargs_s2)
    norm_result_all, norm_fit_red = fit_func(red_data, salt2, vparmas, **kwargs_all, **kwargs_s2)
    norm_result_all, norm_fit_blue = fit_func(blue_data, salt2, vparmas, **kwargs_all, **kwargs_s2)

    # Fit 91bg model using salt2 t0
    bounds = {'x1': [0.65, 1.25], 'c': [0, 1]}
    sn91bg.set(z=z, t0=norm_fit_all.parameters[1])
    bg_result_all, bg_fit_all = fit_func(all_data, sn91bg, vparmas[1:], bounds=bounds, **kwargs_all, **kwargs_bg)
    bg_result_all, bg_fit_red = fit_func(red_data, sn91bg, vparmas[1:], bounds=bounds, **kwargs_all, **kwargs_bg)
    bg_result_all, bg_fit_blue = fit_func(blue_data, sn91bg, vparmas[1:], bounds=bounds, **kwargs_all, **kwargs_bg)

    if show_plots:
        print('Salt2')
        sncosmo.plot_lc(all_data, norm_fit_all)
        sncosmo.plot_lc(red_data, norm_fit_red)
        sncosmo.plot_lc(blue_data, norm_fit_blue)
        plt.show()

        print('91bg')
        sncosmo.plot_lc(all_data, bg_fit_all)
        sncosmo.plot_lc(red_data, bg_fit_red)
        sncosmo.plot_lc(blue_data, bg_fit_blue)
        plt.show()

    # Calculate chisq
    norm_blue_chisq = np.divide(*calc_chisq(blue_data, norm_fit_blue))
    norm_red_chisq = np.divide(*calc_chisq(red_data, norm_fit_red))
    bg_blue_chisq = np.divide(*calc_chisq(blue_data, bg_fit_blue))
    bg_red_chisq = np.divide(*calc_chisq(red_data, bg_fit_red))
    
    return (norm_blue_chisq - bg_blue_chisq, norm_red_chisq - bg_red_chisq)



In [None]:
def get_classification_coords(
    all_data, red_data, blue_data, vparams = ['t0', 'x0', 'x1', 'c'], kwargs_s2=dict(), kwargs_bg=dict(), 
    show_plots=False):
    """Determine the classification coordinates for a given target
    
    Args:
        data             (Table): Table of light curve data for SNCosmo
        band_names   (list[str]): Band names of the parent survey
        lambda_eff (list[float]): Effective wavelength for each band
    
    Returns:
       The x and y classification coordinates
    """

    salt2 = sncosmo.Model('salt2')
    sn91bg = sncosmo.Model('sn91bg')
    
    # If we are not fitting for z, set it in the models
    if 'z' not in vparams:
        z = all_data.meta['redshift']    
        salt2.set(z=z)
        sn91bg.set(z=z)

    # Fit salt2 model to all data and determine t0
    norm_result_all, norm_fit_all = sncosmo.fit_lc(all_data, salt2, vparams, **kwargs_s2)
    t0 = norm_fit_all.parameters[1]
    
    # Set t0 for remaining fits
    salt2.set(t0=t0)
    sn91bg.set(t0=t0)
    
    norm_result_red, norm_fit_red = sncosmo.fit_lc(red_data, salt2, vparams, **kwargs_s2)
    norm_result_blue, norm_fit_blue = sncosmo.fit_lc(blue_data, salt2, vparams, **kwargs_s2)

    vparams = [p for p in vparams if p != 't0']
    bg_result_all, bg_fit_all = sncosmo.fit_lc(all_data, sn91bg, vparams, **kwargs_bg)
    bg_result_red, bg_fit_red = sncosmo.fit_lc(red_data, sn91bg, vparams, **kwargs_bg)
    bg_result_blue, bg_fit_blue = sncosmo.fit_lc(blue_data, sn91bg, vparams, **kwargs_bg)

    if show_plots:
        print('Salt2')
        sncosmo.plot_lc(all_data, norm_fit_all)
        sncosmo.plot_lc(red_data, norm_fit_red)
        sncosmo.plot_lc(blue_data, norm_fit_blue)
        plt.show()

        print('91bg')
        sncosmo.plot_lc(all_data, bg_fit_all)
        sncosmo.plot_lc(red_data, bg_fit_red)
        sncosmo.plot_lc(blue_data, bg_fit_blue)
        plt.show()

    # Calculate chisq
    norm_blue_chisq = np.divide(*calc_chisq(blue_data, norm_fit_blue))
    norm_red_chisq = np.divide(*calc_chisq(red_data, norm_fit_red))
    bg_blue_chisq = np.divide(*calc_chisq(blue_data, bg_fit_blue))
    bg_red_chisq = np.divide(*calc_chisq(red_data, bg_fit_red))
    
    return (norm_blue_chisq - bg_blue_chisq, norm_red_chisq - bg_red_chisq)


In [None]:
survey = dr3
bands = survey.band_names
leff = survey.lambda_effective
all_data_table = survey.get_data_for_id('2004dt', True)
blue_data_table, red_data_table = split_data(all_data_table, bands, leff)
get_classification_coords(all_data_table, blue_data_table, red_data_table, 
                          show_plots=True, kwargs_bg={'bounds': {'x1': [0.65, 1.25], 'c': [0, 1]}})


In [None]:
def classify_survey(survey, obj_ids, fitting_func):
    # Get classification coordinates for each target
    timeout_seconds = 30
    bands = survey.band_names
    leff = survey.lambda_effective

    x, y = [], []
    for obj_id in tqdm_notebook(obj_ids):
        print(obj_id)
        all_data_table = survey.get_data_for_id(obj_id, True)
        red_data_table, blue_data_table = split_data(all_data_table, bands, leff)

        try:
            with timeout(timeout_seconds):
                x_this, y_this = get_classification_coords(
                    all_data_table, red_data_table, blue_data_table, fitting_func)

        except:
            x.append(np.NAN)
            y.append(np.NAN)    

        else:
            x.append(x_this)
            y.append(y_this)
            
        break

    return x, y


## Run classifications

In [None]:
# Get list of targets that are spectroscopically classified
dr1_table_1 = dr1.load_table(1)
is_typed = ~dr1_table_1['Type'].mask
classifications = dr1_table_1['SN', 'Type'][is_typed]
classifications.show_in_notebook(display_length=10)


In [None]:
simple_x, simple_y = classify_survey(dr3, classifications['SN'], simple_fit)
classifications['simple_x'] = simple_x
classifications['simple_y'] = simple_y


In [None]:
simple_y

In [None]:
plt.figure(figsize=(10, 10))
for snclass in set(classifications['Type']):
    plot_data = classifications[classifications['Type'] == snclass]
    plt.scatter(plot_data['x'], plot_data['y'], label=snclass)
    
plt.axvline(0, linestyle='--', color='grey')
plt.axhline(0, linestyle='--', color='grey')
plt.xlabel(r'$\chi^2_{blue}(Ia) - \chi^2_{blue}(91bg)$', fontsize=14)
plt.ylabel(r'$\chi^2_{red}(Ia) - \chi^2_{red}(91bg)$', fontsize=14)
plt.legend()


In [None]:
data_table = dr3.get_data_for_id('2007N', True)
ignore_bands = [b for b in dr3.band_names if ('Y' in b or 'J' in b or 'H' in b)]
for b in ignore_bands:
    data_table = data_table[~(data_table['band'] == b)]
            
get_classification_coords(data_table, bands, leff, show_plots=True)
    