# Comparison of Fitting Methods

We consider here the differences in classification results when using fits performed independently by band, or collectively by band set (i.e., red bands and blue bads as two collective groups). We also discuss the impact of the inclusion of Milky Way extinction. We choose to use the DES lightcurves since they have consistent internal behavior (by nature of being a cosmological data release).

#### Table of Contents:
1. <a href='#mwext'>Effect of Milky Way Extinction</a>: The impact of MW extinction of classification parameters.
1. <a href='#collfit'>Effect Of Band-By-Band Vs. Collective Fitting</a>: The impact of choosing to fit bands independently vs. collectively.


In [None]:
import sys
from pathlib import Path

from copy import deepcopy
import numpy as np
import sncosmo
import sndata
from astropy.table import Table
from matplotlib import pyplot as plt

sys.path.insert(0, '..')
from phot_class import models, utils

models.register_sources(force=True)
sndata.des.sn3yr.register_filters(force=True)

# Output directory for figures
fig_dir = Path('./notebook_figs/ext_effects')
fig_dir.mkdir(exist_ok=True, parents=True)


## Effect of Milky Way Extinction <a id='mwext'></a>

We start by reading in the fit results and classification parameters for DES. For now, we only consider fits that are performed band-by-band. We expect the impact to be minimal since band-by-band fitting is independent of a light-curve's color.


In [None]:
def read_combined_data(path_pattern):
    """Read in fit / classification results 
    
    Args:
        path_pattern (str): A path that can be formatted using the words
            'with' and 'no' to point to data with/without MW extinction
            
    Returns:
        A Pandas DataFrame
    """
    
    path_pattern = str(path_pattern)
    no_ext = Table.read(path_pattern.format('no')).to_pandas()
    with_ext = Table.read(path_pattern.format('with')).to_pandas()
    
    if 'band' in with_ext:
        no_ext.set_index(['obj_id', 'source', 'band'], inplace=True)
        with_ext.set_index(['obj_id', 'source', 'band'], inplace=True)
        out_data = no_ext.join(with_ext, lsuffix='_noext', rsuffix='_ext')
        
    else:
        no_ext.set_index('obj_id', inplace=True)
        with_ext.set_index('obj_id', inplace=True)
        out_data = no_ext.join(with_ext, lsuffix='_noext', rsuffix='_ext')
        
        out_data['sep'] = np.sqrt(
            (out_data['x_noext'] - out_data['x_ext']) ** 2 + 
            (out_data['y_noext'] - out_data['y_ext']) ** 2 
        )

    return out_data


In [None]:
results_dir = Path('../results/')
band_coords_path = results_dir / 'band_fits/{}_ext/des_sn3yr_simple_fit_class.ecsv'
band_coords = read_combined_data(band_coords_path)

band_fits_path = results_dir / 'band_fits/{}_ext/des_sn3yr_simple_fit_fits.ecsv'
band_fits = read_combined_data(band_fits_path)

band_coords.head()

Next, we create overlaid plots of the classification coordinates determined with and without extinction. This allows us to inspect the shift in the coordinates visually.

In [None]:
def plot_coord_seperation(band_fit_df, x_cols, ycols):
    """Create overlaid plots of the classification coordinates
    
    Args:
        band_fit_df (DataFrame): The dataframe to pull data from
        x_cols      (list[str]): List of two colum names with x-coordinates
        y_cols      (list[str]): List of two colum names with y-coordinates
        
    Returns:
        A matplotlib figure
        An array of matplotlib axes
    """
    
    fig, axes = plt.subplots(1, 2, figsize=(20, 10))

    for axis in axes:
        axis.axvline(0, linestyle=':', color='grey')
        axis.axhline(0, linestyle=':', color='grey')
        axis.set_xlabel(r'$\chi^2_{blue}(Ia) - \chi^2_{blue}(91bg)$', fontsize=14)
        axis.set_ylabel(r'$\chi^2_{red}(Ia) - \chi^2_{red}(91bg)$', fontsize=14)

        axis.scatter(band_fit_df[x_cols[0]], band_fit_df[ycols[0]], s=25, c='C0', label='With Mw Ext')
        axis.scatter(band_fit_df[x_cols[1]], band_fit_df[ycols[1]], s=25, 
                     facecolors='none', edgecolors='black', alpha=.8, label='No MW Ext')

        distant_data = band_fit_df[band_fit_df.sep.ge(.5)]
        for index, row in distant_data.iterrows():
            x = [row[x_cols[0]], row[x_cols[1]]]
            y = [row[ycols[0]], row[ycols[1]]]
            axis.plot(x, y, color='grey', marker='')
    
    axis.legend()
    axes[1].set_xlim(-5, 1)
    axes[1].set_ylim(-5, 1)
    return fig, axes


In [None]:
fig, axes = plot_coord_seperation(band_coords, ['x_ext', 'x_noext'], ['y_ext', 'y_noext'])
plt.savefig(fig_dir / 'displacement_bandfit_mwextinction.pdf')
axes[0].set_xlim(-300, 5)


To understand why some points don't agree, we inspect the fits with/without extinction for the target having with the largest separation in either subplot.

In [None]:
def plot_band_fits(fits, data):
    """Plot fits with/without MW exctinction for tabulated hasiao/91bg params
    
    Args:
        fits (DataFrame): Tabulate fit parameters with colomns
            *params + '_ext' and *params + '_noext'
        data     (Table): Observed light curve data for the desired target
      
    Returns:
        A matplotlib figure
        An array of matplotlib axes
    """
    
    dust_kw = dict(effects=[sncosmo.F99Dust()], effect_names=['mw'], effect_frames=['obs'])
    
    bg_source = sncosmo.get_source('sn91bg', version='hsiao_phase')
    sn91bg = sncosmo.Model(bg_source, **dust_kw)
    hsiao = sncosmo.Model('hsiao_x1', **dust_kw)
    
    num_cols = 2
    num_rows = len(fits.loc[data.meta['obj_id']]) // 2 - 1
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(8, num_rows * 4))
    
    for model, axes_col in zip((hsiao, sn91bg), axes.T):
        fit_results = fits.loc[data.meta['obj_id'], model.source.name]
        fit_results = fit_results[fit_results.index != 'all']
    
        for (band_name, row), axis in zip(fit_results.iterrows(), axes_col):

            band_data = data[data['band'] == band_name]
            axis.scatter(band_data['time'], band_data['flux'], label=band_name)
            axis.errorbar(band_data['time'], band_data['flux'], 
                          yerr=band_data['fluxerr'], linestyle='', label='')

            for suffix in ('_ext', '_noext'):
                params = {p: row[p + suffix] for p in model.param_names}
                model.update(params)

                time = np.arange(params['t0'] - 18, params['t0'] + 90)
                model_flux = model.bandflux(band_name, time, zp=27.5, zpsys='AB')
                label = suffix.lstrip('_').replace('no', 'no ') + 'inction'
                axis.plot(time, model_flux, label=label)

            axis.legend()
            
    return fig, axes


In [None]:
demo_id_1 = band_coords.sep.idxmax()

print(f'Largest Overall Seperation: {demo_id_1}')
print(band_coords.loc[demo_id_1])

demo_data_1 = sndata.des.sn3yr.get_data_for_id(demo_id_1)
_ = plot_band_fits(band_fits, demo_data_1)


In [None]:
i1 = (band_coords.x_ext >= -5) & (band_coords.y_ext >= -5)
i2 = (band_coords.x_noext >= -5) & (band_coords.y_noext >= -5)
subset = band_coords[i1 | i2]
demo_id_2 = subset.sep.idxmax()

print('\nLargest Seperation with x,y > -5')
print(subset.loc[demo_id_2])

demo_data_2 = sndata.des.sn3yr.get_data_for_id(demo_id_2)
_ = plot_band_fits(band_fits, demo_data_2)
plt.savefig(fig_dir / 'bandfit_difference_mwextinction.pdf')
plt.show()


From the light-curve plots, we can see that the inclusion of MW extinction does not directly impact the individual band fits but instead impacts the t0 estimation, which is performed using all available bands. The modified Hsiao model we are using is surprisingly robust and is not affected.

## Effect Of Band-By-Band Vs Collective Fitting <a id='collfit'></a>

Next, we can consider the impact of choosing to fit bands independently vs. collectively. We expect the difference here to be much larger than before since collective fits may be sensitive to color variations within the SNe population. As before, we start by reading in the data and plotting the two sets of classification coordinates.


In [None]:
results_dir = Path('../results/')
coll_coords_path = results_dir / 'collective_fits/{}_ext/des_sn3yr_simple_fit_class.ecsv'
coll_coords = read_combined_data(coll_coords_path)

coll_fits_path = results_dir / 'collective_fits/{}_ext/des_sn3yr_simple_fit_fits.ecsv'
coll_fits = read_combined_data(coll_fits_path)

coll_coords.head()


In [None]:
all_coords = band_coords.join(coll_coords, lsuffix='_band', rsuffix='_coll')
all_coords['sep'] = np.sqrt(
    (all_coords['x_ext_band'] - all_coords['x_ext_coll']) ** 2 + 
    (all_coords['y_ext_band'] - all_coords['y_ext_coll']) ** 2 
)

all_coords.head()


Let's look at the impact of MW exctinction when performing collective fits.

In [None]:
fig, axes = plot_coord_seperation(all_coords, ['x_ext_band', 'x_ext_coll'], ['y_ext_band', 'y_ext_coll'])
plt.savefig(fig_dir / 'displacement_collectivefit_mwextinction.pdf')
axes[0].set_xlim(-300, 5)


Note that the change in coordinates is now much larger than in the previous case. We would like to know why this is, so we consider the differences in coordinates between the band and collective fits (with and without extinction).

In [None]:
def plot_comparison_of_coords(band_x, coll_x, band_y, coll_y):
    """

    Args:
        band_x (Series): x coordinate from band fitting
        coll_x (Series): x coordinate from collective fitting
        band_y (Series): y coordinate from band fitting
        coll_y (Series): y coordinate from collective fitting

    Returns:
        A matplotlib figure
        An array of matplotlib axes
    """

    fig, axes = plt.subplots(1, 2, figsize=(16, 8))

    # Reference lines
    for axis in axes:
        axis.axvline(0, linestyle=':', c='grey')
        axis.axhline(0, linestyle=':', c='grey')

    ## Axis 0
    #########

    axes[0].scatter(band_x, coll_x, label='x Coord (Blue)', s=10)
    axes[0].scatter(band_y, coll_y, label='y Coord (Red)', s=10)

    # Determine points for plotting lines of best fit
    xlim = axes[0].get_xlim()
    ylim = axes[0].get_ylim()
    line_min = min(xlim[0], ylim[0])
    line_max = max(xlim[1], ylim[1])
    line_points = [line_min, line_max]

    # Determine lines of best fit
    ix = np.isfinite(band_x + coll_x)
    iy = np.isfinite(band_y + coll_y)
    x_best_fit = np.poly1d(np.polyfit(band_x[ix], coll_x[ix], 1))
    y_best_fit = np.poly1d(np.polyfit(band_y[iy], coll_y[iy], 1))

    # Plot best fit
    axes[0].plot(line_points, line_points, linestyle='-.', c='grey')
    axes[0].plot(line_points, x_best_fit(line_points), linestyle='--', c='C0')
    axes[0].plot(line_points, y_best_fit(line_points), linestyle='--', c='C1')

    # Restore axis to bounds before plotting lines and add a legend
    axes[0].set_xlim(*line_points)
    axes[0].set_ylim(*line_points)
    axes[0].set_xlabel('Band Fitting Parameter')
    axes[0].set_ylabel('Collective Fitting Parameter')
    axes[0].legend(loc='upper left')

    ## Axis 1
    #########

    axes[1].scatter(band_x - coll_x, band_y - coll_y, label='Band fit - Collective fit')
    axes[1].legend(loc='upper left')
    axes[1].set_xlabel(r'$\Delta$ x (blue)')
    axes[1].set_ylabel(r'$\Delta$ y (red)')

    return fig, axes

In [None]:
fig, axes = plot_comparison_of_coords(
    band_x=all_coords['x_ext_band'], 
    coll_x=all_coords['x_ext_coll'],
    band_y=all_coords['y_ext_band'], 
    coll_y=all_coords['y_ext_coll'],
)

fig.suptitle('Collective vs. Band Fit Parameters With MW Extinction')
plt.savefig(fig_dir / 'collective_v_band_parameters_with_extinction.pdf')
plt.show()


In [None]:
fig, axes = plot_comparison_of_coords(
    band_x=all_coords['x_noext_band'], 
    coll_x=all_coords['x_noext_coll'],
    band_y=all_coords['y_noext_band'], 
    coll_y=all_coords['y_noext_coll'],
)

fig.suptitle('Collective vs. Band Fit Parameters Without MW Extinction')
plt.show()



Since more positive coordinates indicate a more 91bg-like target, we see that not correcting for exctinction biases classifications towards 91bgs likley because of the redder colors. We can validate this by visually inspecting the fit results.

In [None]:
def plot_collective_fits(fits, data, band_names, lambda_eff, suffix='_ext'):
    """Plot collective fits for tabulated hasiao/91bg params
    
    Args:
        fits  (DataFrame): Tabulate fit parameters with colomns *params + suffix
        data      (Table): Observed light curve data for the desired target
        band_names (iter): List of all bands available in the survey
        lambda_eff (iter): The effective wavelength of each band in band_names
        suffix      (str): Suffix to add to parameter names when indexing columns
    """
    
    dust_kw = dict(effects=[sncosmo.F99Dust()], effect_names=['mw'], effect_frames=['obs'])
    
    bg_source = sncosmo.get_source('sn91bg', version='hsiao_phase')
    sn91bg = sncosmo.Model(bg_source, **dust_kw)
    hsiao = sncosmo.Model('hsiao_x1', **dust_kw)
    
    for model in (hsiao, sn91bg):
        z = fits.loc[data.meta['obj_id'], model.source.name, 'all']['z' + suffix]
        blue_data, red_data = utils.split_data(data, band_names, lambda_eff, z, cutoff=float('inf'))
        
        for band_set, data_table in zip(('all', 'blue', 'red'), (data, blue_data, red_data)):
            fit_results = fits.loc[data.meta['obj_id'], model.source.name, band_set]
            model.update({p: fit_results[p + suffix] for p in model.param_names})
            sncosmo.plot_lc(data_table, model)
            plt.show()


In [None]:
demo_id_3 = all_coords.sep[np.isfinite(all_coords.sep)].idxmax()
demo_data_3 = sndata.des.sn3yr.get_data_for_id(demo_id_3)

print('ID:', demo_id_3)
print(all_coords.loc[demo_id_3])

print('Impact of exctinction on band-by-band fitting')
plot_band_fits(band_fits, demo_data_3)
plt.show()


In [None]:
print('Collective fits with exctinction included')
plot_collective_fits(coll_fits, demo_data_3, sndata.des.sn3yr.band_names, sndata.des.sn3yr.lambda_effective)


In [None]:
print('Collective fits without exctinction included')
plot_collective_fits(coll_fits, demo_data_3, sndata.des.sn3yr.band_names, sndata.des.sn3yr.lambda_effective, suffix='_noext')
