# Classifying SNe

This notebook identifies SN91bg-like SNe and compares the results of photometric and spectroscopic classifications.

#### Table of Contents:
1. <a href='#reading_in_data'>Reading in the Data</a>: Reading in data from both the analysis pipeline and external publications.
1. <a href='#spectroscopic_classification'>Spectroscopic Classification</a>: Subtyping of spectroscopically observed targets.
1. <a href='#photometric_classification'>Photometric Classification</a>: Subtyping of Photometrically observed targets.
1. <a href='#intrinsic_properties'>Intrinsic Properties</a>: Plots of fitted parameters from the classification proccess.
1. <a href='#host_properties'>Host Galaxy Properties</a>: Identification of trends with host galaxy mass and SSFR.


In [None]:
import math
import sys
import warnings
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
from astropy.cosmology import WMAP9 as wmap9
from astropy.table import Table
from matplotlib import pyplot as plt
from scipy import stats
from sklearn.utils import resample
from sndata.csp import dr1
from sndata.sdss import sako18, sako18spec

sys.path.insert(0, '../')
from phot_class import spectra as spec_class

dr1.download_module_data()
sako18spec.download_module_data()
sako18.download_module_data()

# Output directory for figures
fig_dir = Path('./notebook_figs/classification')
fig_dir.mkdir(exist_ok=True, parents=True)
results_dir = Path('../results/').resolve()


## Reading in the Data <a id='reading_in_data'></a>

To save time later on, we read in all of the necessary data in advance. We start with spectroscopic measurements and classifications from external publications. This will allow us to perform a sanity check on our results.


In [None]:
# Read in sdss classifications
sdss_master = sako18.load_table('master')
sako_classification = pd.DataFrame({
    'obj_id': sdss_master['CID'],
    'spec_class': sdss_master['Classification']
})
sako_classification.set_index('obj_id', inplace=True)
sako_classification.head()


In [None]:
folatelli_13 = dr1.load_table(6).to_pandas()
branch_06 = pd.DataFrame({
    'obj_id': ['1981B', '1984A', '1986G', '1989B', '1990N', '1991M', '1991T', '1991bg', '1992A', '1994D', '1194ae', '1996X', '1997cn', '1998ag', '1998bu', '1999aw', '1999by', '1999ee', '2000cx', '2001ay', '2001el', '2002bf', '2002bo', '2002cx'],
    'pw7': [127, 204, 126, 124, 88, 137, 29, 92, 107, 96, 89, 87, 101, 78, 94, 58, 95, 82, 51, 150, 95, 171, 146, 18 ],
    'pw6': [17, 23, 33, 20, 12, 19, 0, 49, 19, 19, 7, 17, 45, 12, 16, 1, 46, 5, 2, 8, 16, 10, 11, 0]
})

branch_06.set_index('obj_id', inplace=True)
branch_06.head()


Next we read in spectroscopic measurements from our own analysis pipeline.

In [None]:
@np.vectorize
def calc_julian_date(date):
    """
    Convert a datetime object into julian float.
    
    Args:
        date (str): The date to convert in %Y-%m-%d format

    Returns:
        The Julian date as a float
    """

    
    date = datetime.strptime(date, '%Y-%m-%d')
    julian_datetime = (
        367 * date.year - 
        int((7 * (date.year + int((date.month + 9) / 12.0))) / 4.0) + 
        int((275 * date.month) / 9.0) + date.day + 
        1721013.5 + 
        (date.hour + date.minute / 60.0 + date.second / math.pow(60, 2)) / 24.0 - 
        0.5 * math.copysign(1, 100 * date.year + date.month - 190002.5) + 0.5
    )

    return julian_datetime


def read_spec_results(path):
    """Read in spectroscopic measurements from the analysis pipeline
    
    Args:
        path (str): The path of the ecsv file to read
        
    Returns:
        A Pandas DataFrame
    """

    # Read in pipeline results
    spec_class = Table.read(path).to_pandas()
    spec_class.set_index(['obj_id', 'feat_name'], inplace=True)
    spec_class['jd'] = calc_julian_date(spec_class.date)
    
    # Get time of peak brightness
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        master = sako18spec.load_table('master').to_pandas()
        master = master.rename(columns={'CID': 'obj_id'}).set_index('obj_id')

    # Calculate days since maximum
    peak_jd = master.MJDatPeakrmag + 2400000.5
    spec_class['days'] = spec_class.jd - peak_jd
    
    return spec_class


In [None]:
spec_results = read_spec_results(results_dir / 'spec_class/sdss_sako18spec_rv3_1_bin5.0_methgauss_step5.ecsv')
spec_results.head()


Finally, we read in the photometric fit results and determine the resulting classification coordinates using the  `classify_targets` function. We included data using band-by-band and collective fitting methods. We also include fit results from `iminuit` and `emcee`. This may take a minute since the `classify_targets` function isn't very well optimized.

In [None]:
from phot_class.classification import classify_targets

mcmc_coll_fits = Table.read('../results/collective_fits/with_ext/sdss_sako18_mcmc_fit_fits.ecsv')
mcmc_coll_class = classify_targets(mcmc_coll_fits).to_pandas().set_index(['obj_id'])
mcmc_coll_fits = mcmc_coll_fits.to_pandas().set_index(['source', 'obj_id'])

mcmc_band_fits = Table.read('../results/band_fits/with_ext/sdss_sako18_mcmc_fit_fits.ecsv')
mcmc_band_class = classify_targets(mcmc_band_fits).to_pandas().set_index(['obj_id'])
mcmc_band_fits = mcmc_band_fits.to_pandas().set_index(['source', 'obj_id'])

iminuit_coll_fits = Table.read('../results/collective_fits/with_ext/sdss_sako18_simple_fit_fits.ecsv')
iminuit_coll_class = classify_targets(iminuit_coll_fits).to_pandas().set_index(['obj_id'])
iminuit_coll_fits = iminuit_coll_fits.to_pandas().set_index(['source', 'obj_id'])

iminuit_band_fits = Table.read('../results/band_fits/with_ext/sdss_sako18_simple_fit_fits.ecsv')
iminuit_band_class = classify_targets(iminuit_band_fits).to_pandas().set_index(['obj_id'])
iminuit_band_fits = iminuit_band_fits.to_pandas().set_index(['source', 'obj_id'])

iminuit_band_fits.head()


## Spectroscopic Classification <a id='spectroscopic_classification'></a>

This section assigns Branch et al. 2006 style subtypes to SDSS spectra. This classification scheme relies on the pseudo equivalent width (pEW) of Si ii at λ5972 vs λ6355 (here-in refered to as features pW6 and pW7)

#### Section Contents:
1. <a href='#spectral_binning'>Spectral Binning</a>: A reminder of some of the manipulations performed on each spectrum during the analysis.
1. <a href='#selecting_good_data'>Selecting the "Good" Data</a>: Selects only measurements near peak brightness and drops results from noisy spectra.
1. <a href='#branch'>Branch Classification Plot</a>: Plots Si ii λ5972 vs Si ii λ6355 and assigns subtypes to each SN.

### Spectral Binning <a id='spectral_binning'></a>

We pause for a moment and remind ourselves that the values we are working with in this notebook are calculated after each spectra is restframed, corrected for Milky Way extinction, and binned to a resolution of five (unless otherwise noted) angstroms. We choose to use a gaussian filter to reduce the resolution, although options are available for averaging and summing in each bin using the `method` argument.


In [None]:
data = sako18spec.get_data_for_id('5635')
wave = data['wavelength']
flux = data['flux']
z = data.meta['z']
ra = data.meta['ra']
dec = data.meta['dec']

# Bin the flux and calculate the average in each bin
bin_wave, bin_flux = spec_class.bin_spectrum(wave, flux, method='gauss')

# Correct for exctinction and shift to rest frame
rest_wave, rest_flux = spec_class.correct_extinction(bin_wave, bin_flux, ra, dec, z)

plt.figure(figsize=(16, 8))
plt.plot(wave, flux, linewidth=.5, label='Original')
plt.plot(bin_wave, bin_flux + 5e-17, linewidth=1, label='Binned Spectrum')
plt.plot(rest_wave, rest_flux + 1e-16, linewidth=.5, label='Rest Framed Spectrum')

plt.ylim(ymin=0)
plt.ylabel('Flux')
plt.xlabel(r'Wavelength ($\AA$)')
plt.legend()


### Selecting the "Good" Data <a id='selecting_good_data'></a>

We start by checking the reasons why some of our calculations failed. Note that range and index related errors are user generated errors during the analysis used to indicate cases where the spectrum was too noisy to identify the feature.



In [None]:
spec_results[spec_results.pew.isna()].msg.value_counts()


We also check the number of observations available for each target. The data release includes targets with multiple observations, however, depending on how the analysis pipeline was run, duplicate observations may have already been dropped.

In [None]:
# Divide by two since there are two feature measurments per object.
(spec_results.sid.value_counts() / 2).hist()
plt.title('Number of spectra per object (Not including host)')
plt.ylabel('Number of Targets')
plt.xlabel('Number of Spectra')
plt.show()


Moving forward we drop results from spectra that:
1. Are not measured from the spectrum taken closest to peak brightness.
1. Do not have pW7 measurements due to their spectral range.
1. Fail visual inspection

We start with the first two conditions:

In [None]:
def get_tmax_pew(spec_data):
    """Keep only pew measurements performed nearest tmax
    
    Args:
        spec_data (DataFrame): Measurements from the analysis pipeline
    
    Returns:
        A pandas DataFrame
    """
    
    # Determine what features were measured
    features = spec_data.index.get_level_values('feat_name').unique()

    data_frames = []
    for feat_name in features:
        feat_data = spec_data.xs(feat_name, level='feat_name')
        feat_data['feat_name'] = feat_name
        feat_data.set_index('feat_name', append=True, inplace=True)
        
        feat_data['sort'] = feat_data.days.abs()         
        feat_data = feat_data.sort_values(by='sort')
        feat_data = feat_data[~feat_data.index.duplicated()]
        feat_data = feat_data.drop(axis=1, labels='sort')
        data_frames.append(feat_data)
    
    all_data = pd.concat(data_frames)
    all_data.dropna(subset=['pew'], inplace=True)
    
    return all_data.loc[all_data.xs('pW7', level='feat_name').index]


In [None]:
spec_results_peak = get_tmax_pew(spec_results)
spec_results_peak.head()


Next we perform a visual inspection to drop results from any spectra that are particularly noisy. The analysis pipeline already requires the visual inspection of each spectrum, and many noisy spectra may have already been flagged and skipped over. However, the goal of the initial inspection is to be overly ambitious in what we can measure while acknowledging a secondary cut is required later on. We perform that cut here. We also drop spectra with H$\alpha$ lines, since that indicates a potential problem with the host galaxy subtraction.

The following few cells will plot each spectra and ask for an input. Valid inputs are as follows:
1. `<Enter>` indicates a good spectrum. 
1. A number replots the spectrum with a new upper bound
1. `n` Means the spectrum is too noisy.
1. `h` Indicates a strong H-alpha line
1. `r` Indicates the spectrum does not cover enough of the necessary wavelength range.

The `start_from` variable can be used to start the iteration process from a given object Id. Results generated by the notebook author are hardcoded below.

In [None]:
ha = []  # Significant H-alpha
noisy = []  # Noisy spectra
bad_range = []  # Incomplete wavelength range

# Set the object ID to start from. '99999' is higher than any object Id,
# thus skipping to the end.
start_from = '99999'  


In [None]:
from IPython.display import clear_output

indices = spec_results_peak.index.get_level_values('obj_id').unique()
for i, idx in enumerate(sorted(indices)):
    if idx < start_from:
        continue
        
    inp = '5e-17'    
    while inp:
        if inp == 'h':
            ha.append(idx)
            break
            
        if inp == 'n':
            noisy.append(idx)
            break
            
        if inp == 'r':
            bad_range.append(idx)
            break
            
        clear_output()
        f, a = plot_outliers(spec_results_peak.loc[[idx]])
        a[0].set_ylim(0, float(inp))
        plt.show()
        inp = input(f'{i}/{len(indices)} ({i / len(indices) * 100:.2f}%)')
    
    last_index = idx
    clear_output()
 

Here are results tabulated by the notebook author.

In [None]:
ha = [
    '10096', '10805', '12844', '12856', '12874', '12927', '12950', '12977', 
    '13025', '13070', '13099', '13152', '13174', '13254', '13354', '13467', 
    '13830', '14261', '14279', '15129', '15136', '15213', '15234', '15467', 
    '16099', '16099', '16578', '16637', '16692', '16776', '16847', '17176', 
    '17220', '17389', '17568', '17605', '18375', '18485', '18697', '18903', 
    '19003', '19008', '19207', '19353', '19626', '19775', '19969', '20142', 
    '20245', '20528', '21062', '2330', '2635', '2789', '3901', '4524', '5717', 
    '5751', '6057', '6108', '6249', '6773', '7876', '8921'
]

noisy = [
    '11067', '11300', '11452', '11557', '1166', '13610', '13655', '13689', 
    '13736', '14421', '15203', '15229', '15287', '15356', '15383', '15456', 
    '16072', '16350', '16414', '16789', '17048', '18749', '18855', '18890', 
    '18927', '18959', '19230', '20106', '20144', '20184', '20227', '20345', 
    '20432', '20581', '20768', '21034', '21510', '2372', '2533', '2689', '3080', 
    '3199', '3452', '4577', '4679', '6127', '6137', '744', '7475', '7512', '762', 
    '774', '7947', '8598', '9457', '7143'
]

bad_range = ['16116', '17171', '17208', '17215', '19616', '19658']


In [None]:
bad_obj_ids = set(ha + noisy + bad_range)
good_spectra = spec_results_peak.drop(bad_obj_ids)
print(f'{len(bad_obj_ids)} bad spectra out of {len(spec_results_peak.index.unique())}')


In [None]:
def plot_all_spectra(
        spec_measurements, xlim=(5000, 7000), space_scale=1.5, num_columns=2):
    """Plot all spectra used in our data sample with PEW above a given SNR

    Args:
        spec_measurements: 
        xlim: 
        space_scale: 
        num_columns: 

    Returns:
        A matplotlib figure
        A matplotlib axis
    """

    fig, axes = plt.subplots(1, num_columns, figsize=(8.5, 11), sharex=True)

    # Divide object ids into sperate collections for each figure columns
    spec_to_plot = list(spec_measurements.index.get_level_values('obj_id').unique())
    spec_per_col = int(np.ceil(len(spec_to_plot) / num_columns))
    spectra_cols = [spec_to_plot[i * spec_per_col: (i + 1) * spec_per_col] for i in range(num_columns)]
    
    for obj_ids_in_column, axis in zip(spectra_cols, axes.flatten()):
        
        yticks = []
        for i, obj_id in enumerate(obj_ids_in_column):
            target_data = sako18spec.get_data_for_id(obj_id)
            target_measurements = spec_measurements.loc[obj_id]
            offset = i * space_scale

            # Keep only the SN spectra used in the classification
            date = target_measurements.date[0]
            target_data = target_data[target_data['date'] == date]
            target_data = target_data[target_data['type'] != 'Gal']

            # Correct for extinction, shift to rest frame, and bin spectrum
            wave = target_data['wavelength']
            flux = target_data['flux']
            z = target_data.meta['z']
            ra = target_data.meta['ra']
            dec = target_data.meta['dec']

            bin_wave, bin_flux = spec_class.bin_spectrum(wave, flux)
            rest_wave, rest_flux = spec_class.correct_extinction(
                bin_wave, bin_flux, ra, dec, z)

            # Scale and offset spectra to same order of magnitude
            _, pw7_end = spec_class.guess_feature_bounds(
                rest_wave, rest_flux, spec_class.line_locations['pW7'])
            scale = rest_flux[np.where(rest_wave == pw7_end)[0][0]]
            rest_flux /= scale
            rest_flux += offset

            axis.plot(rest_wave, rest_flux, lw=1, color='k')
            
            # Add the object Id as a tickmark label
            min_idx = np.argmin(abs(rest_wave - xlim[0]))
            yticks.append(rest_flux[min_idx])

        axis.set_ylim(0, (i + 1) * space_scale)
        axis.set_yticks(yticks)
        axis.set_yticklabels(obj_ids_in_column)
        axis.set_xlim(xlim)
        axis.set_xlabel('Wavelength')

    axes[0].set_ylabel('Candidate Id')
    plt.tight_layout()


In [None]:
obj_ids = sorted(good_spectra.index.get_level_values('obj_id').unique())

# Break object Ids into chuncks and create multiple figures
num_figs = 5
for i in range(num_figs - 1):
    n = len(obj_ids) // num_figs
    fig_ids = obj_ids[i * n: (i + 1) * n]
    plot_all_spectra(spec_results_peak.loc[fig_ids])
    plt.show()

fig_ids = obj_ids[(i + 1) * n:]
plot_all_spectra(spec_results_peak.loc[fig_ids])
plt.show()


### Branch Classification Plot <a id='branch'></a>

We plot the pEW of Si ii at λ5972 vs λ6355. 


In [None]:
def get_colors(pw6, pw7):
    """Get the color of each point bassed on its coordinates
    
    Args:
        pw6 (ndarray): Array of EW measurements for feature 6
        pw7 (ndarray): Array of EW measurements for feature 7
        
    Returns:
        A 2d array of RGB values
    """
    
    color = np.empty(len(pw6), dtype='U10')
    color[:] = 'black'  # Default to black
    color[pw6 > 30] = 'blue'  # Blue
    color[(pw6 < 30) & (pw7 > 105)] = 'red'  # Red
    color[pw7 < 70] = 'green'  # Green

    return color

def subplot_published_classes(axis):
    """Plot Si ii pEW at λ5972 vs λ6355 from CSP and Branch 2006
    
    Args:
        axis (Axis): A matplotlib axis
    """

    axis.scatter(
        branch_06.pw7, 
        branch_06.pw6, 
        marker='D',
        facecolor='none', 
        edgecolor=get_colors(branch_06.pw6, branch_06.pw7),
        zorder=2,
        alpha=.5,
        label='Branch 2006'
    )

    axis.scatter(
        folatelli_13.pW7, 
        folatelli_13.pW6, 
        marker='v',
        facecolor='none', 
        edgecolor=get_colors(folatelli_13.pW6, folatelli_13.pW7),
        zorder=2,
        alpha=.5,
        label='Folatelli 2013'
    )
    
def get_pew_above_snr(df, snr):
    """Return a data frame with only measurements above a given SNR
    
    Args:
        df (DataFrame): Dataframe of spectroscopic pipeline results
        snr    (float): The signal to noise ratio
        
    Returns:
        A DataFrame
    """
    
    lt_7_days = df[df.days.abs() < 7]
    good_snr = lt_7_days[lt_7_days.pew / lt_7_days.pew_samperr  >= snr][['pew', 'pew_samperr']]
    pw6 = good_snr.xs('pW6', level='feat_name')
    pw7 = good_snr.xs('pW7', level='feat_name')
    pew = pd.merge(pw6, pw7, on='obj_id', suffixes=('_pw6', '_pw7'))
    pew = pew.join(df.date)
    return pew
    

def plot_si_ratio(spec_data, snr_ratios=(1, 2, 3), plot_external_data=True):
    """Plot the pW6 vs pW7 silicon pEw ratios for different SNR cutoffs
    
    Args:
        spec_data     (DataFrame): Measurements from the analysis pipeline
        snr_ratios        (tuple): SNR cutoffs to plot
        plot_external_data (bool): Whether to plot data from CSP and Branch 2006
        
    Returns:
        A matplotlib figure
        An array of matplotlib axes
    """
    
    # Keep only Type Ia spectra
    si_data = spec_data[spec_data.type.isin(['Ia', 'Ia-pec', 'Ia?'])].dropna(subset=['pew'])
    
    num_subplots = len(snr_ratios)
    fig, axes = plt.subplots(1, num_subplots, figsize=(num_subplots * 6, 6), sharex=True, sharey=True)  
    flat_ax = [axes] if num_subplots == 1 else axes.flatten()
    for snr, axis in zip(snr_ratios, flat_ax):
        
        # Keep only data with both pw6 and pw7 measurements greater than snr
        plot_data = get_pew_above_snr(si_data, snr)
        
        axis.errorbar(
            x=plot_data.pew_pw7, 
            y=plot_data.pew_pw6, 
            xerr=plot_data.pew_samperr_pw7, 
            yerr=plot_data.pew_samperr_pw6, 
            linestyle='', 
            ecolor='grey', 
            color='grey',
            alpha=.3, 
            zorder=0)
        
        if plot_external_data:
            subplot_published_classes(axis)
        
        # Isolate each subtype
        cl = plot_data[plot_data.pew_pw6 > 30]
        bl = plot_data[(plot_data.pew_pw6 < 30) & (plot_data.pew_pw7 > 105)]
        ss = plot_data[plot_data.pew_pw7 < 70]
        cn = plot_data[(plot_data.pew_pw6 <= 30) & (70 <= plot_data.pew_pw7) & (plot_data.pew_pw7 <= 105)]
        
        axis.scatter(cl.pew_pw7, cl.pew_pw6, color='C0', zorder=1, label='Cool')
        axis.scatter(bl.pew_pw7, bl.pew_pw6, color='C3', zorder=1, label='Broad Line')
        axis.scatter(ss.pew_pw7, ss.pew_pw6, color='C2', zorder=1, label='Shallow Silicon')
        axis.scatter(cn.pew_pw7, cn.pew_pw6, color='k', zorder=1, label='Core Normal')
        
        axis.set_xlabel(r'Si ii $\lambda$6355', fontsize=14)
        axis.set_title(rf'SNR $>$ {snr}')
        
    flat_ax[0].set_xlim(-10, 500)
    flat_ax[0].set_ylim(-10, 200)
    flat_ax[0].set_ylabel(r'Si ii $\lambda$5972', fontsize=14) 
    flat_ax[-1].legend()
    return fig, flat_ax


Plotting the subtypes, we note by using a Gaussian filter to reduce noise, the error bars get artificially small. Values are plotted against results from existing publications.

In [None]:
fig, axes = plot_si_ratio(good_spectra, (5, 10, 15, 20))
fig.suptitle(f'Branch Si II Subtypes')
plt.savefig(fig_dir / 'branch_subtypes.pdf')
plt.show()


For curiosity's sake, we investigate the furthest outliers in the first plot.

In [None]:
def get_outliers(spec_data, pw6_cutoff, pw7_cutoff, snr_cutoff=0):
    """Select meaurements with pw6 pew > pw6_cutoff and pw7 pew pw7_cutoff
    
    Args:
        spec_data (DataFrame): Data to select on
        pw6_cutoff    (float): The PEW cutoff for feature pw6
        pw7_cutoff    (float): The PEW cutoff for feature pw7
        snr_cutoff    (float): SNR cutoff to apply before selection
        
    Returns:
        A DataFrame with measurements matching the specified criteria
    """
    good_snr = spec_data[spec_data.pew / spec_data.pew_samperr > snr_cutoff]

    pw6 = good_snr.xs('pW6', level='feat_name')
    outliers = set(pw6[pw6.pew > pw6_cutoff].index)

    pw7 = good_snr.xs('pW7', level='feat_name')
    outliers = list(outliers.intersection(pw7[pw7.pew > pw7_cutoff].index))

    outlier_meas = good_snr.loc[outliers]
    pw6_indices = set(outlier_meas.xs('pW6', level='feat_name').index)
    pw7_indices = set(outlier_meas.xs('pW7', level='feat_name').index)
    outliers = sorted(pw6_indices.intersection(pw7_indices))

    return good_snr.loc[outliers]

def subplot_feature_pew(wave, flux, axis, feat_start, feat_end, **kwargs):
    """Shade in the PEW of spectral properties

    Args:
        wave     (ndarray): The spectrum's wavelengths
        flux     (ndarray): The flux for each wavelength
        axis        (Axis): The axius to plot on
        feat_start (float): Where the feature starts
        feat_end   (float): Where the feature ends
        Any other kwargs for ``axis.fill_between``
    """

    idx_start = np.where(wave == feat_start)[0][0]
    idx_end = np.where(wave == feat_end)[0][0]
    feat_wave = wave[idx_start: idx_end + 1]
    feat_flux = flux[idx_start: idx_end + 1]

    continuum, norm_flux, pew = spec_class.feature_pew(feat_wave, feat_flux)
    axis.fill_between(feat_wave, feat_flux, continuum, alpha=.75, zorder=0, **kwargs)
    

def plot_outliers(outlier_data):
    """Plot a collection of spectra

    Args:
        outlier_data (DataFrame): PEW Measurements for spectra to plot

    Returns:
        A matplotlib Figure
        An array of matplotlib axes
    """

    obj_ids = outlier_data.index.get_level_values('obj_id').unique()
    fig, axes = plt.subplots(len(obj_ids), 1, figsize=(15, 4 * len(obj_ids)), sharex=True)
    axes = np.atleast_1d(axes)

    for obj_id, axis in zip(obj_ids, axes):
        target_data = sako18spec.get_data_for_id(obj_id)
        target_data.sort('wavelength')

        # Keep only the SN spectra used in the classification
        date = outlier_data.loc[obj_id].date[0]
        pw6_pew = outlier_data.loc[obj_id, 'pW6'].pew
        pw7_pew = outlier_data.loc[obj_id, 'pW7'].pew

        target_data = target_data[target_data['date'] == date]
        target_data = target_data[target_data['type'] != 'Gal']

        # Correct for exctinction and shift to rest frame
        wave = target_data['wavelength']
        flux = target_data['flux']
        z = target_data.meta['z']
        ra = target_data.meta['ra']
        dec = target_data.meta['dec']
        rest_wave, rest_flux = spec_class.correct_extinction(wave, flux, ra, dec, z)
        bin_wave, bin_flux = spec_class.bin_spectrum(rest_wave, rest_flux, method='gauss')

        axis.plot(rest_wave, rest_flux, lw=1, label='Restframed Spectrum', color='grey', alpha=.75)
        for feat_name, row in outlier_data.loc[obj_id].iterrows():
            subplot_feature_pew(bin_wave, bin_flux, axis, row.feat_start, row.feat_end, label=feat_name)

        axis.plot(bin_wave, bin_flux, lw=2, label='Binned Spectrum', color='k')
        axis.set_xlabel('Wavelength')
        axis.set_ylabel('Flux')
        axis.set_ylim(0, 3e-17)
        axis.set_title(f'Object Id: {obj_id}')
        
    axis.legend()
    axis.set_xlim(3000, 7500)

    return fig, axes


In [None]:
outliers = get_outliers(good_spectra, 75, 0)
outliers


In [None]:
_ = plot_outliers(outliers)


We also plot a few of the dropped spectra

In [None]:
fig, axes = plot_outliers(spec_results_peak.loc[['6304', '7947', '7876']])
axes[0].set_ylim(0, .03)
axes[0].set_xlabel('')
axes[1].set_ylim(0, 6e-17)
axes[1].set_xlabel('')
axes[2].set_ylim(0, 4e-16)
axes[2].set_xlabel(r'Wavlength ($\AA$)')

plt.savefig(fig_dir / 'Dropped_spectra.pdf')


## Photometric Classification <a id='photometric_classification'></a>

With the spectroscopic classifications in hand, we move on to the photometric data.

#### Section Contents:
1. <a href='#exploration_of_failed_fits'>Exploration of Failed Fits</a>: A preliminary exploration of the data.
1. <a href='#classification'>Classification</a>: Classifying photometrically observed targets.
1. <a href='#photometric_vs_spectroscopic'>Photometric V.S. Spectroscopic Classifications</a>: Compares the photometric and spectroscopic results


### Exploration of Failed Fits <a id='exploration_of_failed_fits'></a>

We perform a cursory investigation of any fits that have failed to converge. To start, we note the unique error messages raised in the band and collective fit results.

In [None]:
def get_failed_fits(fits_df):
    """Select failed fits from a dataframe"""
    
    failed_fits = fits_df.message.str.lower().str.contains('failed')
    return fits_df[failed_fits]
    

In [None]:
print('Iminuit band fit error messages:\n')
print(get_failed_fits(iminuit_band_fits).message.unique())

print('\nIminuit collective fit error messages:\n')
print(get_failed_fits(iminuit_coll_fits).message.unique())

print('\nMCMC band fit error messages:\n')
print(get_failed_fits(mcmc_band_fits).message.unique())

print('\nMCMC collective fit error messages:\n')
print(get_failed_fits(mcmc_coll_fits).message.unique())

The SNR error is not concerning so long as there are an equal number of occurences between the band and collective fits.

In [None]:
snr_err_msg = 'No data points with S/N > 5.0. Initial guessing failed.'
band_snr_indices = iminuit_band_fits.message == snr_err_msg
collective_snr_indices = iminuit_coll_fits.message == snr_err_msg
equal_errors = sum(band_snr_indices) == sum(collective_snr_indices)

print('Equal Number of SNR errors:', equal_errors)


We drop the SNR errors for now and look at the distribution of the remaining errors across bands and models.

In [None]:
iminuit_band_fits_goodsnr = iminuit_band_fits[iminuit_band_fits.message != snr_err_msg]
iminuit_collective_fits_goodsnr = iminuit_coll_fits[iminuit_coll_fits.message != snr_err_msg]


In [None]:
band_failed_nosnr = get_failed_fits(iminuit_band_fits_goodsnr)

print('Band by band error distribution\n')
print('By fitted band (set):')
print(band_failed_nosnr.band.str[-2].value_counts())

print('\nNumber of failed fits per source:')
print(band_failed_nosnr.droplevel(1).index.value_counts())

print('\nNumber of failed fits per object (num_failures number_targets):')
print(band_failed_nosnr.index.value_counts().value_counts())


### Classification <a id='classification'></a>

We apply the classification to the fitted light curves.

In [None]:
def calc_delta_chisq(fits_df):
    """Calculate the difference in reduces chisq for overall fits

    Args:
        fits_df (DataFrame): Pipeline fit results

    Returns:
        A pandas series
    """

    fits_hsiao = fits_df.loc['hsiao_x1']
    fits_hsiao = fits_hsiao[fits_hsiao.band == 'all']
    fits_sn91bg = fits_df.loc['sn91bg']
    fits_sn91bg = fits_sn91bg[fits_sn91bg.band == 'all']
    return (fits_hsiao.chisq / fits_hsiao.ndof) - (
            fits_sn91bg.chisq / fits_sn91bg.ndof)

def plot_chisq_scatter(fits_df):
    """Plot classification results

    Args:
        fits_df (DataFrame): DataFrame of fit results
    """
    
    delta_chi = calc_delta_chisq(fits_df)
    chi_lt0 = delta_chi[delta_chi < 0].index
    chi_gt0 = delta_chi[delta_chi > 0].index

    all_fits =  fits_df[fits_df.band == 'all']
    chisq_hs = all_fits.loc['hsiao_x1'].chisq / all_fits.loc['hsiao_x1'].ndof
    chisq_bg = all_fits.loc['sn91bg'].chisq / all_fits.loc['sn91bg'].ndof
    chisq = pd.DataFrame(dict(chisq_hs=chisq_hs, chisq_bg=chisq_bg))

    fig, axis = plt.subplots(1, 1, figsize=(7 / 2, 7 / 2))
    labels = (r'$\Delta\chi^2 < 0$', r'$\Delta\chi^2 > 0$')
    for index, label in zip((chi_lt0, chi_gt0), labels):
        plot_data = chisq.reindex(index)
        axis.scatter(plot_data.chisq_hs, plot_data.chisq_bg, 
                     s=5, alpha=.2, label=label)

    ylim = axis.get_ylim()
    axis.plot(ylim, ylim, linestyle='--', color='grey')
    axis.set_xscale('log')
    axis.set_yscale('log')
    axis.set_xlabel(r'Reduced Hsiao $\chi^2$')
    axis.set_ylabel(r'Reduced SN91bg $\chi^2$')
    axis.legend(framealpha=1)
    return fig, axis


In [None]:
chisq_scat_fig, chisq_scat_axis = plot_chisq_scatter(iminuit_coll_fits)
chisq_scat_axis.set_xlim(0, 1e3)
chisq_scat_axis.set_ylim(0, 1e3)

plt.savefig(fig_dir / 'chisq_scatter.pdf', bbox_inches='tight')
plt.show()


In [None]:
def get_sako_pec():
    """Get objects flagged as peculiad in Sako+ 2018"""
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sako_data = sako18.load_table('master').to_pandas(index='CID')

    flagged = sako_data.reindex(sako_data.Notes.dropna().index) 
    flagged = flagged[flagged.Notes > 1]
    return flagged


def create_border_hist(axis, padding=0, xpos='top', ypos='right'):
    """Create axes for plotting border histograms

    Args:
        axis     (Axis): The matplotlib axis to border
        padding (float): Spacing between the main and bordering axes
        xpos      (str): Put the x histogram on the 'top' or 'bottom'
        ypos      (str): Put the y histogram on the 'left' or 'right'

    Returns:
        Axis for the upper histogram
        Axis for the right side histogram
    """

    axis_pos = axis.get_position()
    axis_width = axis_pos.x1 - axis_pos.x0
    axis_height = axis_pos.y1 - axis_pos.y0

    if xpos == 'top':
        histx_pos = axis.get_position()
        histx_pos.y0 += axis_height + padding
        histx_pos.y1 = histx_pos.y1 + .35 * axis_height + padding
        histx = plt.axes(histx_pos)
        histx.set_xlim(axis.get_xlim())
        histx.tick_params(direction='in', labelbottom=False)
        
    elif xpos == 'bottom':
        histx_pos = axis.get_position()
        histx_pos.y1 -= axis_height + padding
        histx_pos.y0 = histx_pos.y1 - .35 * axis_height - padding
        histx = plt.axes(histx_pos)
        histx.set_xlim(axis.get_xlim())
        histx.tick_params(direction='in', labelbottom=True)

    if ypos == 'right':
        histy_pos = axis.get_position()
        histy_pos.x0 += axis_width + padding
        histy_pos.x1 = histy_pos.x1 + .35 * axis_width + padding
        histy = plt.axes(histy_pos)
        histy.set_ylim(axis.get_ylim())
        histy.tick_params(direction='in', labelleft=False)
        
    elif ypos == 'left':
        histy_pos = axis.get_position()
        histy_pos.x1 -= axis_width + padding
        histy_pos.x0 = histy_pos.x1 - .35 * axis_width - padding
        histy = plt.axes(histy_pos)
        histy.set_ylim(axis.get_ylim())
        histy.tick_params(direction='in', labelleft=True)

    return histx, histy


def plot_classification(class_df, fits_df, border_bins, padding, xpos='top', ypos='right', fig=None, axis=None):
    """Plot classification results

    Args:
        class_df  (DataFrame): Classification coordinates
        class_df  (DataFrame): Classification results from collective fitting
        fits_df   (DataFrame): DataFrame of fit results
        border_bins (ndarray): Bins for histogram
        padding       (float): Spacing between the main and bordering axes
        xpos            (str): Put the x histogram on the 'top' or 'bottom'
        ypos            (str): Put the y histogram on the 'left' or 'right'
        fig          (Figure): Optionally use an existing figure
        axis           (Axis): Optionally use an existing axis
    """

    if fig is None or axis is None:
        fig, axis = plt.subplots(1, 1, figsize=(7 / 2, 7 / 2))
        
    x_label = r'$\chi^2_{blue}$ (Ia) - $\chi^2_{blue}$ (91bg)'
    y_label = r'$\chi^2_{red}$ (Ia) - $\chi^2_{red}$ (91bg)'
    markers = {2: 's', 3: '^', 4: 'o', 5: 'v'}
    labels = {2: '91bg', 3: '00cx', 4: '02ci', 5: '02cx'}

    delta_chi = calc_delta_chisq(fits_df)
    chi_lt0 = delta_chi[delta_chi < 0]
    chi_gt0 = delta_chi[delta_chi > 0]

    sako_pec = get_sako_pec()
    all_data_lt = class_df.reindex(chi_lt0.index).drop(sako_pec.index, errors='ignore')
    all_data_gt = class_df.reindex(chi_gt0.index).drop(sako_pec.index, errors='ignore')
    
    axis.scatter(all_data_lt.x, all_data_lt.y, s=10, alpha=.7)
    axis.scatter(all_data_gt.x, all_data_gt.y, s=10, color='C1', alpha=.7)
    
    axis.axvline(0, color='grey', linestyle='--')
    axis.axhline(0, color='grey', linestyle='--')
    axis.set_xlabel(x_label, fontsize=12, labelpad=10)
    axis.set_ylabel(y_label, fontsize=12)
    axis.set_xlim(min(border_bins), max(border_bins))
    axis.set_ylim(-50, 50)
    
    for flag_type, flag_data in sako_pec.groupby('Notes'):
        plt_data = class_df.reindex(flag_data.index)
        marker = markers[flag_type]
        label = labels[flag_type]
        axis.scatter(plt_data.x, plt_data.y, s=20, 
                     marker=marker, zorder=9,
                     color='k', label=label, facecolor='none')

    histx, histy = create_border_hist(axis, padding, xpos=xpos, ypos=ypos)
    histx.hist([all_data_lt.x, all_data_gt.x], bins=border_bins, stacked=True)
    histy.hist([all_data_lt.y, all_data_gt.y], bins=border_bins,
               stacked=True, orientation='horizontal')

    histx.set_xlim(axis.get_xlim())
    histy.set_ylim(axis.get_ylim())
    histx.set_yscale('log')
    histy.set_xscale('log')
    return fig, axis, [histx, histy]


def plot_classification_with_subtypes():
    fig, axes = plt.subplots(2, 2, figsize=(7.5, 7.5), sharex=True, sharey=True)
    coll_fig, coll_axes, hist_axes = plot_classification(
        class_df=iminuit_band_class, 
        fits_df=iminuit_band_fits, 
        border_bins=np.arange(-100, 100, 5),
        padding=.01, 
        fig=fig, 
        axis=axes[0, 1]
    )
    
    x_label = r'$\chi^2_{blue}$ (Ia) - $\chi^2_{blue}$ (91bg)'
    y_label = r'$\chi^2_{red}$ (Ia) - $\chi^2_{red}$ (91bg)'
    axes[0, 0].set_ylabel(y_label)
    axes[0, 1].set_xlabel('')
    axes[0, 1].set_ylabel('')
    axes[1, 0].set_ylabel(y_label)
    axes[1, 0].set_xlabel(x_label)
    axes[1, 1].set_xlabel(x_label)
    hist_axes[0].set_ylim(0, 1e4)
    hist_axes[1].set_ylim(0, 1e4)

    categories = [
        ['SNIa', 'pSNIa', 'zSNIa', 'SNIa?'], 
        ['SNII', 'pSNII', 'zSNII'],
        ['pSNIbc', 'zSNIbc', 'SNIc', 'SNIb', 'SNIbc']
    ]

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sdss_pd = sdss_master.to_pandas()
        sdss_pd['obj_id'] = sdss_pd['CID']
        sdss_pd.set_index('obj_id', inplace=True)
        class_with_sdss = iminuit_coll_class.join(sdss_pd)

    for axis, cat in zip(axes.flatten()[[0, 2, 3]], categories):
        category_data = class_with_sdss[np.isin(class_with_sdss.Classification, cat)]
        for st, data in category_data.groupby('Classification'):
            axis.scatter(data.x, data.y, label=st, alpha=.75, s=20)
            axis.axvline(0, color='grey', linestyle='--')
            axis.axhline(0, color='grey', linestyle='--')
            axis.legend()

    plt.xlim(-100, 100)
    plt.ylim(-50, 50)



In [None]:
def get_sako_pec():
    """Get objects flagged as peculiad in Sako+ 2018"""
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sako_data = sako18.load_table('master').to_pandas(index='CID')

    flagged = sako_data.reindex(sako_data.Notes.dropna().index) 
    flagged = flagged[flagged.Notes > 1]
    return flagged


def create_border_hist(axis, padding=0, xpos='top', ypos='right'):
    """Create axes for plotting border histograms

    Args:
        axis     (Axis): The matplotlib axis to border
        padding (float): Spacing between the main and bordering axes
        xpos      (str): Put the x histogram on the 'top' or 'bottom'
        ypos      (str): Put the y histogram on the 'left' or 'right'

    Returns:
        Axis for the upper histogram
        Axis for the right side histogram
    """

    axis_pos = axis.get_position()
    axis_width = axis_pos.x1 - axis_pos.x0
    axis_height = axis_pos.y1 - axis_pos.y0

    if xpos == 'top':
        histx_pos = axis.get_position()
        histx_pos.y0 += axis_height + padding
        histx_pos.y1 = histx_pos.y1 + .35 * axis_height + padding
        histx = plt.axes(histx_pos)
        histx.set_xlim(axis.get_xlim())
        histx.tick_params(direction='in', labelbottom=False)
        
    elif xpos == 'bottom':
        histx_pos = axis.get_position()
        histx_pos.y1 -= axis_height + padding
        histx_pos.y0 = histx_pos.y1 - .35 * axis_height - padding
        histx = plt.axes(histx_pos)
        histx.set_xlim(axis.get_xlim())
        histx.tick_params(direction='in', labelbottom=True)

    if ypos == 'right':
        histy_pos = axis.get_position()
        histy_pos.x0 += axis_width + padding
        histy_pos.x1 = histy_pos.x1 + .35 * axis_width + padding
        histy = plt.axes(histy_pos)
        histy.set_ylim(axis.get_ylim())
        histy.tick_params(direction='in', labelleft=False)
        
    elif ypos == 'left':
        histy_pos = axis.get_position()
        histy_pos.x1 -= axis_width + padding
        histy_pos.x0 = histy_pos.x1 - .35 * axis_width - padding
        histy = plt.axes(histy_pos)
        histy.set_ylim(axis.get_ylim())
        histy.tick_params(direction='in', labelleft=True)

    return histx, histy


def plot_classification(class_df, fits_df, border_bins, padding, xpos='top', ypos='right', fig=None, axis=None):
    """Plot classification results

    Args:
        class_df  (DataFrame): Classification coordinates
        class_df  (DataFrame): Classification results from collective fitting
        fits_df   (DataFrame): DataFrame of fit results
        border_bins (ndarray): Bins for histogram
        padding       (float): Spacing between the main and bordering axes
        xpos            (str): Put the x histogram on the 'top' or 'bottom'
        ypos            (str): Put the y histogram on the 'left' or 'right'
        fig          (Figure): Optionally use an existing figure
        axis           (Axis): Optionally use an existing axis
    """

    if fig is None or axis is None:
        fig, axis = plt.subplots(1, 1, figsize=(7 / 2, 7 / 2))
        
    x_label = r'$\chi^2_{blue}$ (Ia) - $\chi^2_{blue}$ (91bg)'
    y_label = r'$\chi^2_{red}$ (Ia) - $\chi^2_{red}$ (91bg)'
    markers = {2: 's', 3: '^', 4: 'o', 5: 'v'}
    labels = {2: '91bg', 3: '00cx', 4: '02ci', 5: '02cx'}

    delta_chi = calc_delta_chisq(fits_df)
    chi_lt0 = delta_chi[delta_chi < 0]
    chi_gt0 = delta_chi[delta_chi > 0]

    sako_pec = get_sako_pec()
    all_data_lt = class_df.reindex(chi_lt0.index).drop(sako_pec.index, errors='ignore')
    all_data_gt = class_df.reindex(chi_gt0.index).drop(sako_pec.index, errors='ignore')
    
    axis.scatter(all_data_lt.x, all_data_lt.y, s=10, alpha=.7)
    axis.scatter(all_data_gt.x, all_data_gt.y, s=10, color='C1', alpha=.7)
    
    axis.axvline(0, color='grey', linestyle='--')
    axis.axhline(0, color='grey', linestyle='--')
    axis.set_xlabel(x_label, fontsize=12, labelpad=10)
    axis.set_ylabel(y_label, fontsize=12)
    axis.set_xlim(min(border_bins), max(border_bins))
    axis.set_ylim(-50, 50)
    
    for flag_type, flag_data in sako_pec.groupby('Notes'):
        plt_data = class_df.reindex(flag_data.index)
        marker = markers[flag_type]
        label = labels[flag_type]
        axis.scatter(plt_data.x, plt_data.y, s=20, 
                     marker=marker, zorder=9,
                     color='k', label=label, facecolor='none')

    histx, histy = create_border_hist(axis, padding, xpos=xpos, ypos=ypos)
    histx.hist([all_data_lt.x, all_data_gt.x], bins=border_bins, stacked=True)
    histy.hist([all_data_lt.y, all_data_gt.y], bins=border_bins,
               stacked=True, orientation='horizontal')

    histx.set_xlim(axis.get_xlim())
    histy.set_ylim(axis.get_ylim())
    histx.set_yscale('log')
    histy.set_xscale('log')
    return fig, axis


In [None]:
coll_fig, coll_axis = plot_classification(
    class_df=iminuit_band_class, 
    fits_df=iminuit_band_fits, 
    border_bins=np.arange(-110, 111, 5),
    padding=.05
)
    
plt.savefig(fig_dir / 'collective_classification.pdf', bbox_inches='tight')
plt.show()


In [None]:
coll_fig, coll_axis = plot_classification(
    class_df=iminuit_coll_class, 
    fits_df=iminuit_coll_fits, 
    border_bins=np.arange(-110, 111, 5),
    padding=.05
)
    
plt.savefig(fig_dir / 'collective_classification.pdf', bbox_inches='tight')
plt.show()


In [None]:
coll_fig, coll_axis = plot_classification(
    class_df=mcmc_band_class, 
    fits_df=mcmc_band_fits, 
    border_bins=np.arange(-110, 111, 5),
    padding=.05
)
    
plt.savefig(fig_dir / 'collective_classification.pdf', bbox_inches='tight')
plt.show()


In [None]:
coll_fig, coll_axis = plot_classification(
    class_df=mcmc_coll_class, 
    fits_df=mcmc_coll_fits, 
    border_bins=np.arange(-110, 111, 5),
    padding=.05
)
    
plt.savefig(fig_dir / 'collective_classification.pdf', bbox_inches='tight')
plt.show()


In [None]:
def plot_classification_with_subtypes(classification_coords):
    fig, axes = plt.subplots(2, 2, figsize=(7.5, 7.5), sharex=True, sharey=True)
    
    x_label = r'$\chi^2_{blue}$ (Ia) - $\chi^2_{blue}$ (91bg)'
    y_label = r'$\chi^2_{red}$ (Ia) - $\chi^2_{red}$ (91bg)'
    axes[0, 0].set_ylabel(y_label)
    axes[0, 1].set_xlabel('')
    axes[0, 1].set_ylabel('')
    axes[1, 0].set_ylabel(y_label)
    axes[1, 0].set_xlabel(x_label)
    axes[1, 1].set_xlabel(x_label)

    categories = [
        ['SNIa', 'pSNIa', 'zSNIa', 'SNIa?'], 
        ['SNII', 'pSNII', 'zSNII'],
        ['pSNIbc', 'zSNIbc', 'SNIc', 'SNIb', 'SNIbc'], 
        ['Unknown']
    ]

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sdss_pd = sako18.load_table('master').to_pandas()
        sdss_pd['obj_id'] = sdss_pd['CID']
        sdss_pd.set_index('obj_id', inplace=True)
        class_with_sdss = classification_coords.join(sdss_pd)

    sako_pec = get_sako_pec()
    for axis, cat in zip(axes.flatten(), categories):
        category_data = class_with_sdss[np.isin(class_with_sdss.Classification, cat)]
        for st, data in category_data.groupby('Classification'):
            data.drop(sako_pec.index, errors='ignore', inplace=True)
            axis.axvline(0, color='grey', linestyle='--', zorder=0)
            axis.axhline(0, color='grey', linestyle='--', zorder=0)
            axis.scatter(data.x, data.y, label=f'{st} ({len(data)})', alpha=.7, s=10)
            axis.legend(loc='upper left', framealpha=1)
            
        markers = {2: 's', 3: '^', 4: 'o', 5: 'v'}
        labels = {2: '91bg', 3: '00cx', 4: '02ci', 5: '02cx'}
        for flag_type, flag_data in sako_pec.groupby('Notes'):
            plt_data = category_data.reindex(flag_data.index)
            marker = markers[flag_type]
            axis.scatter(plt_data.x, plt_data.y, s=20, 
                         marker=marker, zorder=9,
                         color='k', facecolor='none')

    plt.xlim(-110, 110)
    plt.ylim(-55, 55)
    
    ylabels = np.arange(-50, 51, 10)
    ylabels_str = np.array(ylabels, dtype='str')
    ylabels_str[::2] = ''
    plt.yticks(ylabels, ylabels_str)
    
    xlabels = np.arange(-100, 101, 20)
    xlabels_str = np.array(xlabels, dtype='str')
    xlabels_str[::2] = ''
    plt.xticks(xlabels, xlabels_str)
    
    plt.subplots_adjust(wspace=.1, hspace=.1)


In [None]:
for classification_data in (iminuit_band_class, iminuit_coll_class, mcmc_band_class, mcmc_coll_class):
    plot_classification_with_subtypes(classification_data)
    plt.show()


### Photometric V.S. Spectroscopic Classifications <a id='photometric_vs_spectroscopic'></a>

In [None]:
def get_sako_pec():
    """Get objects flagged as peculiad in Sako+ 2018"""
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sako_data = sako18.load_table('master').to_pandas(index='CID')

    flagged = sako_data.reindex(sako_data.Notes.dropna().index) 
    flagged = flagged[flagged.Notes > 1]
    return flagged

pec = get_sako_pec()
pec = pec[pec.Notes == 2.0]
pec


In [None]:
def plot_classification_with_spec(phot_class, spec_measurements):
    """Plot classification results

    Args:
        class_df  (DataFrame): Classification coordinates
        class_df  (DataFrame): Classification results from collective fitting
    """

    fig, axis = plt.subplots(1, 1, figsize=(7, 7))
    x_label = r'$\chi^2_{blue}$ (Ia) - $\chi^2_{blue}$ (91bg)'
    y_label = r'$\chi^2_{red}$ (Ia) - $\chi^2_{red}$ (91bg)'
    markers = dict(blue='s', red='^', black='o', green='v')
    labels=dict(blue='Cool', red='Broad Line', green='Shallow Silicon', black='Core Normal')

    
    spec_measurements['color'] = get_colors(spec_measurements.pew_pw6, spec_measurements.pew_pw7)
    plot_data = phot_class.join(spec_measurements)    
    
    axis.scatter(plot_data.x, plot_data.y, s=10, alpha=.5, color='grey')
    for color, data in plot_data.groupby('color'):
        axis.scatter(data.x, data.y, s=14, color=color, label=labels[color], marker=markers[color])

    sako_results = phot_class.reindex(pec.index)
    axis.scatter(sako_results.x, sako_results.y, facecolor='none', edgecolor='black', marker='s')

    
    axis.axvline(0, color='grey', linestyle='--')
    axis.axhline(0, color='grey', linestyle='--')
    axis.set_xlabel(x_label, fontsize=12, labelpad=10)
    axis.set_ylabel(y_label, fontsize=12)
    axis.set_xlim(-75, 75)
    axis.set_ylim(-50, 50)
    axis.legend()

    return fig, axis


In [None]:
coll_fig, coll_axis = plot_classification_with_spec(
    phot_class=iminuit_band_class, 
    spec_measurements=get_pew_above_snr(good_spectra, 0)
)

plt.title('Iminuit Band-by-Band Fitting')
plt.show()


In [None]:
coll_fig, coll_axis = plot_classification_with_spec(
    phot_class=iminuit_coll_class, 
    spec_measurements=get_pew_above_snr(good_spectra, 0)
)

plt.title('Iminuit Collective Fitting')
plt.show()


In [None]:
coll_fig, coll_axis = plot_classification_with_spec(
    phot_class=mcmc_band_class, 
    spec_measurements=get_pew_above_snr(good_spectra, 0)
)

plt.title('MCMC Band-by-Band Fitting')
plt.show()


In [None]:
coll_fig, coll_axis = plot_classification_with_spec(
    phot_class=mcmc_coll_class, 
    spec_measurements=get_pew_above_snr(good_spectra, 0)
)

plt.title('MCMC Collective Fitting')
plt.show()


## Intrinsic Properties <a id='intrinsic_properties'></a>

We consider the distribution of fit parameters.

In [None]:
def plot_param_histogram(fits_df, source, fit_type, param):
    """Plot a histagram of fit parameters
    
    Args:
        fits_df (DataFrame): Fit results
        source        (str): Name of the model to display results for
        fit_type      (str): Use "band" or "collective" fit results
        param         (str): Name of the parameter to plot
    """
    
    # Select data to plot
    fits_using_source = fits_df.loc[source]
    hist_data = fits_using_source[fits_using_source['band'] == 'all'][param]
    
    fig, axis = plt.subplots(1, 1, figsize=(7, 7))
    axis.hist(hist_data, bins=20)
    
    latex_safe_source = source.split("_")[0]
    axis.set_title(param + f' Distribution ({latex_safe_source} - {fit_type} Fits)'.title())
    axis.set_xlabel(param)
    axis.set_ylabel('Combined number of targets')   


In [None]:
for source in ('hsiao_x1', 'sn91bg'):
    for param in ('x1', 'c'):
        if source == 'hsiao_x1' and param == 'c':
            continue
            
        for fit_type, fit_data in zip(('band', 'collective'), (iminuit_band_fits, iminuit_coll_fits)):
            plot_param_histogram(fit_data, source, fit_type, param)
            plt.savefig(fig_dir / f'{param}_{source}_{fit_type}_fits.pdf'.lower())
            plt.show()


In [None]:
for source in ('hsiao_x1', 'sn91bg'):
    for param in ('x1', 'c'):
        if source == 'hsiao_x1' and param == 'c':
            continue
            
        for fit_type, fit_data in zip(('band', 'collective'), (mcmc_band_fits, mcmc_coll_fits)):
            plot_param_histogram(fit_data, source, fit_type, param)
            plt.savefig(fig_dir / f'{param}_{source}_{fit_type}_fits.pdf'.lower())
            plt.show()


### Host Galaxy Properties <a id='host_properties'></a>

We start with some book keeping and create dataframes for various subsets of the host galaxy data.

In [None]:
host_photometry = pd.DataFrame({
    'obj_id': sdss_master['CID'],
    'host_id': sdss_master['objIDHost'],  # Host galaxy object ID in SDSS DR8 Database 
    'ra': sdss_master['RAhost'],  # Right ascension of galaxy host (degrees) 
    'dec': sdss_master['DEChost'],  # Declination of galaxy host (degrees) 
    'dist': sdss_master['separationhost'], # Distance from SN to host (arcsec) 
    'distnorm': sdss_master['DLRhost'], # Normalized distance from SN to host (dDLR) 
    'z_KF': sdss_master['zphothost'], # Host photometric redshift (KF algorithm) 
    'z_KF_err': sdss_master['zphoterrhost'], # zphothost uncertainty
    'z_RF': sdss_master['zphotRFhost'],  # Host photometric redshift (RF algorithm) 
    'z_RF_err': sdss_master['zphotRFerrhost'],  # zphotRFhost uncertainty 
    'u_mag': sdss_master['dereduhost'],  # Host galaxy u-band magnitude (dereddened) 
    'u_mag_err': sdss_master['erruhost'], # Host galaxy u-band magnitude uncertainty 
    'g_mag': sdss_master['deredghost'],  # Host galaxy g-band magnitude (dereddened)
    'g_mag_err': sdss_master['errghost'],  #  Host galaxy g-band magnitude uncertainty 
    'r_mag': sdss_master['deredrhost'],  # Host galaxy r-band magnitude (dereddened)
    'r_mag_err': sdss_master['errrhost'],  #  Host galaxy r-band magnitude uncertainty 
    'i_mag': sdss_master['deredihost'], # Host galaxy i-band magnitude (dereddened) 
    'i_mag_err': sdss_master['errihost'], # Host galaxy i-band magnitude uncertainty 
    'z_mag': sdss_master['deredzhost'], # Host galaxy z-band magnitude (dereddened) 
    'z_mag_err': sdss_master['errzhost'] # Host galaxy z-band magnitude (dereddened)
})
host_photometry.set_index('obj_id', inplace=True)
    
# Galaxy Parameters Calculated with FSPS
fsps_params = pd.DataFrame({
    'obj_id' : sdss_master['CID'],
    'logmass' : sdss_master['logMassFSPS'],  # FSPS log(M), M=Galaxy Mass (M in units of Me)
    'logmass_lo' : sdss_master['logMassloFSPS'],  # FSPS Lower limit of uncertainty in log(M)
    'logmass_hi' : sdss_master['logMasshiFSPS'],  # FSPS Upper limit of uncertainty in log(M)
    'logssfr' : sdss_master['logSSFRFSPS'],  # FSPS log(sSFR) sSFR=Galaxy Specific Star-forming Rate (SFR in Me yr−1)
    'logssfr_lo' : sdss_master['logSSFRloFSPS'],  # FSPS Lower limit of uncertainty in log(sSFR)
    'logssfr_hi' : sdss_master['logSSFRhiFSPS'],  # FSPS Upper limit of uncertainty in log(sSFR)
    'age' : sdss_master['ageFSPS'],  # FSPS galaxy age (Gyr)
    'age_lo' : sdss_master['ageloFSPS'],  # FSPS Lower limit of uncertainty in age
    'age_hi' : sdss_master['agehiFSPS'],  # FSPS Upper limit of uncertainty in age 
    'rchisq' : sdss_master['minredchi2FSPS'] # Reduced chi-squared of best FSPS template fit
})
fsps_params.set_index('obj_id', inplace=True)

# Galaxy Parameters Calculated with PÉGASE.2
pegase_params = pd.DataFrame({
    'obj_id' : sdss_master['CID'],
    'logmass' : sdss_master['logMassPEGASE'], # PÉGASE.2 log(M), M=Galaxy Mass (M in units of Me) 
    'logmass_lo' : sdss_master['logMassloPEGASE'], # PÉGASE.2 Lower limit of uncertainty in log(M)
    'logmass_hi' : sdss_master['logMasshiPEGASE'], # PÉGASE.2 Upper limit of uncertainty in log(SFR) 
    'logssfr' : sdss_master['logSFRPEGASE'], # PÉGASE.2 log(SFR) SFR=Galaxy star-forming rate (Me yr−1)
    'logssfr_lo' : sdss_master['logSFRloPEGASE'], # PÉGASE.2 Lower limit of uncertainty in log(SFR)
    'logssfr_hi' : sdss_master['logSFRhiPEGASE'], # PÉGASE.2 Upper limit of uncertainty in log(SFR)
    'age' : sdss_master['agePEGASE'], # PÉGASE.2 galaxy age (Gyr)
    'rchisq' : sdss_master['minchi2PEGASE']# Reduced chi-squared of best PÉGASE.2 fit
})
pegase_params.set_index('obj_id', inplace=True)


We also determine the distance of each SN to it's host in kiloparsecs.

In [None]:
# Get the redshift values as a pandas array
sdss_table_2 = sako18.load_table(2)
sdss_table_2['obj_id'] = sdss_table_2['CID']
redshift = sdss_table_2['obj_id', 'zspecHelio'].to_pandas('obj_id')

# Add distance in kpc
host_photometry['arcmin'] = host_photometry.dist / 60
host_photometry['kpc'] = wmap9.kpc_comoving_per_arcmin(host_photometry.arcmin)

In [None]:
def plot_host_property_distribution(col_name, bg_ids, normal_ids, *data_frames, **kwargs):
    """Plot histograms of host galaxy properties
    
    Args:
    col_name           (str): The name of the value to plot
    bg_ids          (Series): Object Ids of 91bg like SN
    normal_ids      (Series): Object Ids of normal SN
    *data_frames (DataFrame): Data frames with host galaxy data

    Returns:
        A matplotlib figure
        An array of matplotlib axes
    """
    
    fig, axes = plt.subplots(1, len(data_frames))  #, sharex=True, sharey=True)
    if len(data_frames) == 1:
        axes = np.array([axes])
    
    for df, axis in zip(data_frames, axes.flatten()):
        
        plot_data = df[col_name]
        bg_data = plot_data.reindex(bg_ids).dropna()
        normal_data = plot_data.reindex(normal_ids).dropna()
        
        aks = stats.anderson_ksamp([normal_data, bg_data])
        sig_level = aks.significance_level
        
        _, bins, _ = axis.hist(
            normal_data,
            density=True,
            label=f'Normal ({len(normal_data)})', 
            **kwargs
        )
        
        axis.hist(bg_data, 
                  fill=False, 
                  hatch='///', 
                  density=True, 
                  histtype='step', 
                  label=f'91bg ({len(bg_data)})',
                  bins=bins
                 )
        
        title = r'(p $\geq$ 0.25)' if sig_level == 0.25 else f'(p = {sig_level:.2})'
        axis.set_title(title)

    axes[-1].legend()
    plt.tight_layout()
    return fig, axes
        

In [None]:
x_cut = .5
y_cut = .5

ia_categories = ['SNIa', 'pSNIa', 'zSNIa', 'SNIa?', 'Unknown']
ia_indices = sdss_master['CID'][np.isin(sdss_master['Classification'], ia_categories)]
ia_subtyped = iminuit_coll_class.reindex(ia_indices)

bg_like = ia_subtyped[(ia_subtyped.x > x_cut) & (ia_subtyped.y > y_cut)].index
normal = ia_subtyped[(ia_subtyped.x < x_cut) & (ia_subtyped.y < y_cut)].index
print(len(bg_like))


In [None]:
fig, axes = plot_host_property_distribution('logmass', bg_like, normal, fsps_params) 

axes[0].set_ylabel('Number of Targets', fontsize=16)
axes[0].set_title('FSPS ' + axes[0].get_title())
# axes[1].set_title('PEGASE ' + axes[1].get_title())
for axis in axes:
    axis.set_xlabel(r'$\log($M$_\odot)$', fontsize=16)
    
plt.savefig(fig_dir / 'collective_fits_mass.pdf', bbox_inches='tight')
plt.show()


In [None]:
ssfr_data = pegase_params[pegase_params.logssfr > 0]
fig, axes = plot_host_property_distribution('logssfr', bg_like, normal, ssfr_data) 

axes[0].set_xlabel('SSFR', fontsize=16)
axes[0].set_ylabel('Number of Targets', fontsize=16)

plt.savefig(fig_dir / 'collective_fits_ssfr.pdf', bbox_inches='tight')
plt.show()


In [None]:
fig, axes = plot_host_property_distribution('kpc', bg_like, normal, host_photometry) 

axes[0].set_xlabel('Distance to Host Center (kpc)', fontsize=16)
axes[0].set_ylabel('Number of Targets', fontsize=16)

plt.savefig(fig_dir / 'collective_fits_distance.pdf', bbox_inches='tight')
plt.show()


In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

def plot_pvalue(class_df, data, max_x=3.5, max_y=3.5, x_cutoff=0, y_cutoff=0, size=30):
    """Plot the p-value for host galaxy data as a function of cutoff value
    
    Args:
        class_df (DataFrame): Classification coordinates
        data        (Series): Host galaxy
        max_x        (float): Maximum for x-axis range
        max_y        (float): Maximum for y-axis range
        x_cutoff     (float): x classification boundary
        y_cutoff     (float): y classification boundary
        size         (float): Number of values to sample in x and y direction
        
    Returns:
        A matplotlib figure
        An array of matplotlib axes
    """
    
    sig_arr = np.zeros((size, size))
    count_arr = np.zeros((size, size))
    x_arr = np.linspace(0, max_x, sig_arr.shape[0])
    y_arr = np.linspace(0, max_y, sig_arr.shape[1])
    
    dx = (x_arr[1] - x_arr[0]) / 2
    dy = (y_arr[1] - y_arr[0]) / 2
    extent = [x_arr[0] - dx, x_arr[-1] + dx, y_arr[0] - dy, y_arr[-1] + dy]

    for i, xcut in enumerate(x_arr):
        for j, ycut in enumerate(y_arr):
            test_bg = class_df[(class_df.x > xcut) & (class_df.y > ycut)].index
            test_normal = class_df[(class_df.x < xcut) & (class_df.y < ycut)].index
            test_bg_data = data.reindex(test_bg).dropna()
            test_normal_data = data.reindex(test_normal).dropna()
            sig = stats.anderson_ksamp([test_normal_data, test_bg_data]).significance_level
            sig_arr[i, j] = sig
    
    fig, axis = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(7 / 2, 7))
    divider = make_axes_locatable(axis)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    im = axis.imshow(sig_arr, extent=extent, origin='lower', cmap='Blues')
    fig.colorbar(im, cax=cax, orientation='vertical')
    axis.axvline(x_cutoff, linestyle='--', alpha=.8, color='red')
    axis.axhline(y_cutoff, linestyle='--', alpha=.8, color='red')
    
    axis.set_xlabel('x Cutoff', fontsize=10)
    axis.set_ylabel('y Cutoff', fontsize=10)
    # axis.contour(x_arr, y_arr, sig_arr, levels=[0.05], colors='red')
    
    return fig, axis


In [None]:
_, fsps_axis = plot_pvalue(iminuit_coll_class, fsps_params.logmass, x_cutoff=x_cut, y_cutoff=y_cut, size=50) 
fsps_axis.set_title('Host Mass (FSPS)')
plt.savefig(fig_dir / 'fsps_mass_pvalue.pdf', bbox_inches='tight')
plt.show()

# plot_pvalue(iminuit_coll_class, pegase_params.logmass, x_cutoff=x_cut, y_cutoff=y_cut) 
# plt.savefig(fig_dir / 'pegase_mass_pvalue.pdf', bbox_inches='tight')
# plt.show()

_, ssfr_axis = plot_pvalue(iminuit_coll_class, ssfr_data.logssfr, x_cutoff=x_cut, y_cutoff=y_cut) 
ssfr_axis.set_title('SSFR')
plt.savefig(fig_dir / 'ssfr_pvalue.pdf', bbox_inches='tight')
plt.show()

_, dist_axis = plot_pvalue(iminuit_coll_class, host_photometry.dist, x_cutoff=x_cut, y_cutoff=y_cut) 
dist_axis.set_title('Host Distance')
plt.savefig(fig_dir / 'dist_pvalue.pdf', bbox_inches='tight')
plt.show()


In [None]:
def plot_num_points(class_df, data, max_x=3.5, max_y=3.5, x_cutoff=0, y_cutoff=0, size=30):
    """Plot the number of 91bgs as a function of the x and y cutoffs
    
    Args:
        class_df (DataFrame): Classification coordinates
        data        (Series): Host galaxy
        max_x        (float): Maximum for x-axis range
        max_y        (float): Maximum for y-axis range
        x_cutoff     (float): x classification boundary
        y_cutoff     (float): y classification boundary
        size         (float): Number of values to sample in x and y direction
        
    Returns:
        A matplotlib figure
        An array of matplotlib axes
    """
    
    count_arr = np.zeros((size, size))
    x_arr = np.linspace(0, max_x, count_arr.shape[0])
    y_arr = np.linspace(0, max_y, count_arr.shape[1])
    
    dx = (x_arr[1] - x_arr[0]) / 2
    dy = (y_arr[1] - y_arr[0]) / 2
    extent = [x_arr[0] - dx, x_arr[-1] + dx, y_arr[0] - dy, y_arr[-1] + dy]

    for i, xcut in enumerate(x_arr):
        for j, ycut in enumerate(y_arr):
            test_bg = class_df[(class_df.x > xcut) & (class_df.y > ycut)].index
            test_bg_data = data.reindex(test_bg).dropna()
            count_arr[i, j] = len(test_bg_data)
    
    # fig, axes = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(7 / 2, 7))
    # for axis, plot_data in zip(axes, (count_arr, sig_arr)):
    fig, axis = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(7 / 2, 7))
    
    divider = make_axes_locatable(axis)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    
    im = axis.imshow(count_arr, extent=extent, origin='lower', cmap='Blues', vmin=0, vmax=50)
    fig.colorbar(im, cax=cax, orientation='vertical')
    axis.axvline(x_cutoff, linestyle='--', alpha=.8, color='red')
    axis.axhline(y_cutoff, linestyle='--', alpha=.8, color='red')

    axis.set_ylabel('y Cutoff', fontsize=10)
    axis.set_xlabel('x Cutoff', fontsize=10)
    axis.set_title('Number of 91bg points', fontsize=10)
    axis.contour(x_arr, y_arr, count_arr, levels=np.arange(0, 51, 10), colors='k')
    axis.contour(x_arr, y_arr, count_arr, levels=np.arange(5, 26, 5), colors='k', linestyles=':')
    cax_labels = cax.get_yticklabels
    
    cax.set_yticks(np.arange(0, 51, 10))
    cax.set_yticklabels(['0', '10', '20', '30', '40', '50+'])
    
    return fig, axis


In [None]:
plot_num_points(iminuit_coll_class, fsps_params.logmass, x_cutoff=x_cut, y_cutoff=y_cut, size=50) 
plt.savefig(fig_dir / 'num_points.pdf', bbox_inches='tight')
