# Salt2 Fit Results

This notebook inspects the distribution and correlation of Salt2 parameters fit to the SDSS SN sample.

#### Table of Contents:
1. <a href='#loading_data'>Loading the Data</a>: Read in data and drop bad fits.
1. <a href='#fit_results'>Investigating Fit Results</a>: Plots of various fit parameters.
1. <a href='#with_class'>Trends with Classification</a>: Plots of various fit parameters color coded by classification.


In [None]:
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import sncosmo
from astropy.table import Table
from corner import corner
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import NullFormatter
from scipy import stats
from sndata.sdss import sako18

fig_dir = Path('./notebook_figs/salt2_fits')
fig_dir.mkdir(exist_ok=True, parents=True)


## Loading the Data <a id='loading_data'></a>

We load in both the original fit results from Sako et al. 2018 (Salt 2.0) and the fit results that we ran ourselves (Salt 2.4).

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    
    salt2_fits = Table.read('../results/sdss_salt2_fits.ecsv').to_pandas(index='obj_id')
    sako_fits = sako18.load_table('master').to_pandas(index='CID')
    sako_sneia = sako_fits[sako_fits.Classification.isin(['SNIa', 'SNIa?', 'pSNIa', 'zSNIa'])]

    combined = salt2_fits.join(sako_sneia)


In [None]:
salt2_fits.head()


Since it is not available as part of the fit results, we calculate the B - V color at $t_{max}$ for each of the fitted SNe.

In [None]:
salt2 = sncosmo.Model('salt2')
   
color = []
for idx, row in salt2_fits.iterrows():
    salt2.update({p: row[p] for p in salt2.param_names})
    t0 = salt2.parameters[salt2.param_names.index('t0')]
    if np.isnan(t0):
        color.append(np.nan)
        continue
        
    color.append(salt2.color('standard::B', 'standard::V', 'AB', t0))
    
salt2_fits['max_color'] = color


As a simple sanity check, we compare our fit parameters against the published values. We don't expect to get the same results since different Salt versions were used. However, we do expect to see obvious (unsurprising) trends.

In [None]:
def compare_salt_versions(combined_data):
    """Create a plot comparing our fit results with Sako 2018
    
    Args:
        combined_data (DataFrame): Combined data to plot
        
    Returns:
        A matplotlib figure
        An array of matplotlib axes
    """

    fig, axes = plt.subplots(1, 2, figsize=(18, 9))

    line = [-10, 10]
    line_args = dict(c='grey', linestyle='--')

    axes[0].scatter(combined_data.x1SALT2flat, combined_data.x1)
    axes[0].set_xlabel('Sako18 SALT 2.0')
    axes[0].set_ylabel('SALT 2.4')
    axes[0].set_title(r'Stretch (x$_1$)')
    axes[0].set_ylim(-6, 6)
    axes[0].set_xlim(-6, 6)
    axes[0].plot(line, line, **line_args)

    axes[1].scatter(combined_data.cSALT2flat, combined_data.c)
    axes[1].set_xlabel('Sako18 SALT 2.0')
    axes[1].set_title('Color (c)')
    axes[1].set_ylim(-1, 1)
    axes[1].set_xlim(-1, 1)
    axes[1].plot(line, line, **line_args)

    return fig, axes


In [None]:
_ = compare_salt_versions(combined)


We see that there are some fits that failed to converge and ran into the boundary. We drop these from our sample.

In [None]:
good_fits = salt2_fits[['z', 'c', 'x1', 'x0', 't0', 'chisq', 'ndof', 'b_max', 'delta_15', 'max_color']].dropna()
good_fits = good_fits[good_fits.c < .495]
good_fits = good_fits[good_fits.x1 < 4.95]
#good_fits = good_fits[salt2_fits.chisq / salt2_fits.ndof <= 5]

good_fits.head()


## Investigating Fit Results <a id='fit_results'></a>

We look at trends within the fitted model parameters.


In [None]:
def scatter_plot(x, y, c=None, contour=False, class_by=None, x_cutoff=0, y_cutoff=0):
    """Create a scatter plot with bordering histograms
    
    Args:
        x    (ndarray): x values to plot
        y    (ndarray): y values to plot
        c    (ndarray): Optional values for a colorbar
        contour (bool): Whether to plot contours
        class_by   (str): Classify by "collective" or "band" fits
        x_cutoff (float): The x cutoff for bg classifications
        y_cutoff (float): The y cutoff for bg classifications
        
    Returns:
        A seaborn figure
        A matplotlib axis for the color bar, if c is provided
    """

    joint_plot = sns.jointplot(x, y, height=8)
    scatter_ax = joint_plot.ax_joint
    
    if c is not None:
        scatter_ax.cla()
        basic_cols = ['blue', 'lightgrey', 'red']
        cmap = LinearSegmentedColormap.from_list('mycmap', basic_cols)
        s = scatter_ax.scatter(x, y, c=c, cmap=cmap, s=12)

        ax_pos = scatter_ax.get_position()
        ax_pos.y0 -= .2
        ax_pos.y1 = ax_pos.y0 + .05
        cbar_ax = plt.gcf().add_axes(ax_pos)
        plt.colorbar(s, cax=cbar_ax, orientation='horizontal')
        return joint_plot, cbar_ax 
    
    elif class_by is not None:
        x_coord = class_coords['x_' + class_by]
        y_coord = class_coords['y_' + class_by]
        sn91bg = class_coords[(x_coord > x_cutoff) & (y_coord > y_cutoff)].index
        normal = class_coords[(x_coord < x_cutoff) & (y_coord < y_cutoff)].index
        pec = class_coords[(x_coord < x_cutoff) & (y_coord > y_cutoff)].index

        scatter_ax.cla()
        scatter_ax.scatter(x.reindex(normal), y.reindex(normal), 
                          color=sns.color_palette('Paired')[0],
                          #color='C0',
                          s=12, label='Q3')
        
        scatter_ax.scatter(x.reindex(sn91bg), y.reindex(sn91bg), color='C3', marker='^', label='Q1')
        scatter_ax.scatter(x.reindex(pec), y.reindex(pec), color='C1', marker='s', label='Q2') 
        scatter_ax.legend()
        
    if contour:
        sns.kdeplot(x, y, ax=scatter_ax, color='black', alpha=.8)
        
    return joint_plot


In [None]:
stretch_color = scatter_plot(good_fits.c, good_fits.x1, contour=True)
stretch_color.ax_joint.set_xlabel('Color Excess (c)')
stretch_color.ax_joint.set_ylabel(r'Stretch (x$_1$)')
plt.savefig(fig_dir / 'stretch_color.pdf')


In [None]:
stretch_bmax, cbar = scatter_plot(good_fits.b_max, good_fits.x1, good_fits.c)
stretch_bmax.ax_joint.set_xlabel('Peak B band mag')
stretch_bmax.ax_joint.set_ylabel(r'Stretch (x$_1$)')
cbar.set_xlabel('Color Excess (c)')
plt.savefig(fig_dir / 'stretch_bmax.pdf')


In [None]:
color_bmax, cbar = scatter_plot(good_fits.b_max, good_fits.c, good_fits.x1)
color_bmax.ax_joint.set_xlabel('Peak B band mag')
color_bmax.ax_joint.set_ylabel(r'Color Excess (c)')
cbar.set_xlabel(r'Stretch (x$_1$)')
plt.savefig(fig_dir / 'color_bmax.pdf')


In [None]:
delta15_bmax, cbar = scatter_plot(good_fits.delta_15, good_fits.b_max, good_fits.c)
delta15_bmax.ax_joint.set_xlabel(r'$\Delta$ m(B)')
delta15_bmax.ax_joint.set_ylabel('Peak B band mag')
delta15_bmax.ax_joint.set_ylim(-21, -14.5)
delta15_bmax.ax_joint.invert_yaxis()
cbar.set_xlabel('Color Excess (c)')
plt.savefig(fig_dir / 'delta15_bmax.pdf')


## Trends with Classification <a id='with_class'></a>

We repeate similar plots to the previous section, but include classification data.

In [None]:
band_path = '../results/band_fits/with_ext/sdss_sako18_simple_fit_class.ecsv'
band_class = Table.read(band_path).to_pandas(index='obj_id')

collective_path = '../results/collective_fits/with_ext/sdss_sako18_simple_fit_class.ecsv'
collective_class = Table.read(collective_path).to_pandas(index='obj_id')

class_coords = collective_class.join(band_class, lsuffix='_collective', rsuffix='_band')
class_coords.head()


In [None]:
x_cut = .5
y_cut = 1

stretch_color_collective = scatter_plot(good_fits.c, good_fits.x1, class_by='collective',
    x_cutoff=x_cut, 
    y_cutoff=y_cut,
    #contour=True
)
stretch_color_collective.ax_joint.set_xlabel(r'Stretch (x$_1$)')
stretch_color_collective.ax_joint.set_ylabel('Color Excess (c)')
plt.savefig(fig_dir / 'stretch_color_collective.pdf')
plt.show()


In [None]:
delta15_bmax_collective = scatter_plot(
    x=good_fits.delta_15, 
    y=good_fits.b_max, 
    class_by='collective',
    x_cutoff=x_cut, 
    y_cutoff=y_cut,
    #contour=True
)

delta15_bmax_collective.ax_joint.set_xlabel(r'$\Delta$ m(B)')
delta15_bmax_collective.ax_joint.set_ylabel('Peak B band mag')
delta15_bmax_collective.ax_joint.set_ylim(-21, -14.5)
delta15_bmax_collective.ax_joint.invert_yaxis()
plt.savefig(fig_dir / 'delta15_bmax_collective.pdf')


In [None]:
def density_estimation(m1, m2):
    xmin = ymin = np.min([m1, m2])
    xmax = ymax = np.max([m1, m2])
    
    X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]                                                     
    positions = np.vstack([X.ravel(), Y.ravel()])                                                       
    values = np.vstack([m1, m2])                                                                        
    kernel = stats.gaussian_kde(values)                                                                 
    Z = np.reshape(kernel(positions).T, X.shape)
    return X, Y, Z

def plot_corner(data, columns, labels, ranges, class_by='collective', x_cutoff=0, y_cutoff=0):
    """Create a corner plot of fit parameters
    
    Args:
        data (Dataframe): The data to use when plotting
        columns   (list): The names of the columns to plot
        ranges    (list): Plottin limits for each column
        class_by   (str): Classify by "collective" or "band" fits
        x_cutoff (float): The x cutoff for bg classifications
        y_cutoff (float): The y cutoff for bg classifications
        
    Returns:
        A matplotlib figure
    """
    
    x_coord = class_coords['x_' + class_by]
    y_coord = class_coords['y_' + class_by]
    normal = class_coords[(x_coord < x_cutoff) & (y_coord < y_cutoff)].index
    sn91bg = class_coords[(x_coord > x_cutoff) & (y_coord > y_cutoff)].index
    pec = class_coords[(x_coord < x_cutoff) & (y_coord > y_cutoff)].index
    
    # Create initial plot
    plot_data = np.array(data[columns])
    fig = corner(
        plot_data, 
        labels=labels, 
        figsize=(7, 7), 
        range=ranges,
        hist_kwargs={'histtype': 'bar', 'color': 'C0'}, 
        plot_contours=False, 
        show_titles=True,
        plot_density=False,
    )

    # Isolate plots in lower triangle of figure
    num_cols = len(columns)
    axes = np.reshape(fig.axes, (num_cols, num_cols))
    skip_countors = [(4, 0)]
    
    # Replace scatter plots withcustom plots
    for i, j in zip(*np.tril_indices(num_cols, -1)):
        axis = axes[i, j]        
        x = data[columns[j]]
        y = data[columns[i]]
        
        label_prefix = '' if (i==1 and j==0) else '_'
        axis.scatter(
            x.reindex(normal), 
            y.reindex(normal), 
            s=1, 
            label=label_prefix + f'Quadrant 3 (x < {x_cutoff}, y < {y_cutoff})')
        
        if (i, j) not in skip_countors:
            X, Y, Z = density_estimation(x, y)
            axis.contour(X, Y, Z, colors='k', alpha=.5)
        
        axis.scatter(
            x.reindex(pec), 
            y.reindex(pec),
            s=15, 
            marker='^', 
            color='C3',
            label=label_prefix + f'Quadrant 2 (x < {x_cutoff}, y > {y_cutoff})',
            zorder=9
        ) 
        
        axis.scatter(
            x.reindex(sn91bg),
            y.reindex(sn91bg),
            s=15, 
            marker='s',
            color='C1',
            label=label_prefix + f'Quadrant 1 (x > {x_cutoff}, y > {y_cutoff})',
            zorder=10
        )

    # Peak density lines on scatter plots
    diagonal_plots = axes[np.diag_indices(num_cols)]
    for col_name, axis in zip(columns, diagonal_plots):
        col_data = good_fits[col_name]
        mean = col_data.mean()
        std = col_data.std()

        axis.axvline(mean, color='k', alpha=.8, linestyle='--')
        axis.axvline(mean + std, color='k', alpha=.8, linestyle=':')
        axis.axvline(mean - std, color='k', alpha=.8, linestyle=':')

    plt.subplots_adjust(wspace=.1, hspace=.1)
    return fig


In [None]:
fig = plot_corner(
    data=good_fits,
    columns = ['x1', 'c', 'max_color', 'b_max', 'delta_15'],
    labels=[r'x_1', 'c', r'B - V ($t_{max}$)', r'B$_{max}$', r'$\Delta$ m(B)'],
    ranges=((-6, 6), (-.6, .6), (-.5, 2), (-16, -21), (0, 3)),
    x_cutoff=.5,
    y_cutoff=1
)

fig.legend(loc=(.32, .85))
plt.savefig(fig_dir / 'params_corner_plot.pdf')
