# Inspecting the Ported 91bg Model

#### Table of Contents:
1. <a href='#model_evolution'>Model Evolution</a>: Explore modeled flux and colors over time
1. <a href='#parameter_covarience'>Parameter Covarience</a>: Demonstrates the covarience between stretch and color for 91bgs
1. <a href='#s2_parameters'>Comparison with Salt2 Parameters</a>: Comparison of simulated salt2 parameters against best fit 91bg parameters.


In [None]:
import sys
from pathlib import Path

import numpy as np
import sncosmo
from astropy.table import Table
from matplotlib import pyplot as plt
from matplotlib.ticker import NullFormatter
from sndata.csp import dr3
from tqdm import tqdm

sys.path.append('../')
from phot_class import models
from phot_class.simulation import sncosmo_sims


In [None]:
# Output directory for figures
fig_dir = Path('./notebook_figs')
fig_dir.mkdir(exist_ok=True, parents=True)

models.register_sources(force=True)
bg_source = sncosmo.get_source('sn91bg')

print('Source Summary:\n')
print(bg_source)


## Parameter Covarience <a id='parameter_covarience'></a>

We consider the covarience between stretch and color in 91bg like supernovae. Note that our 91bg model does not consider this covarience (yet).


In [None]:
def plot_hexbin(x, y, xlim, ylim, fig_size=(8, 8), bins=50):
    """Plot a hexbin with x and y histograms
    
    Args:
      x      (ndarray):
      x      (ndarray):
      xlim     (tuple): x limits for the plot
      ylim     (tuple): y limits for the plot
      fig_size (tuple): x and y dimensions for the figure
      bins       (int): Number of bins to use
      
    Returns:
        A matplotlib figure
        The axis of the main subplot
        The axis of the x histogram
        The axis of the y histogram
    """

    # Define axis coordinates
    left, width = 0.1, 0.65
    bottom, height = 0.1, 0.65
    bottom_h = left_h = left + width + 0.02

    rect_scatter = [left, bottom, width, height]
    rect_histx = [left, bottom_h, width, 0.2]
    rect_histy = [left_h, bottom, 0.2, height]

    # Instantiate figure
    fig = plt.figure(1, fig_size)

    ax_main = plt.axes(rect_scatter)
    ax_hist_x = plt.axes(rect_histx)
    ax_hist_y = plt.axes(rect_histy)

    # Remove labels from histogram axes
    nullfmt = NullFormatter() 
    ax_hist_x.xaxis.set_major_formatter(nullfmt)
    ax_hist_y.yaxis.set_major_formatter(nullfmt)

    # The hexbin
    ax_main.hexbin(x, y, gridsize=(bins, bins), cmap='Blues')
    ax_main.set_xlim(xlim)
    ax_main.set_ylim(ylim)

    # The histograms
    avg_x = np.average(x)
    std_x = np.std(x)
    ax_hist_x.hist(x, bins=bins) 
    ax_hist_x.axvline(avg_x, color='black', linestyle='--', label='Average')
    ax_hist_x.axvline(avg_x + std_x, color='black', linestyle=':')
    ax_hist_x.axvline(avg_x - std_x, color='black', linestyle=':', 
                      label='Standard Deviation')
    
    avg_y = np.average(y)
    std_y = np.std(y)
    ax_hist_y.hist(y, bins=bins, orientation='horizontal')
    ax_hist_y.axhline(avg_y, color='black', linestyle='--')
    ax_hist_y.axhline(avg_y + std_y, color='black', linestyle=':')
    ax_hist_y.axhline(avg_y - std_y, color='black', linestyle=':')

    ax_hist_x.set_xlim(ax_main.get_xlim())
    ax_hist_y.set_ylim(ax_main.get_ylim())

    return fig, ax_main, ax_hist_x, ax_hist_y
    

In [None]:
stretch, color = sncosmo_sims.bg_stretch_color(
    size=1e6, 
    min_stretch=-np.inf, 
    min_color=-np.inf, 
    max_stretch=np.inf, 
    max_color=np.inf)


In [None]:
xlim = (sncosmo_sims.AVG_STRETCH - .4, sncosmo_sims.AVG_STRETCH + .4)
ylim = (sncosmo_sims.AVG_COLOR - .6, sncosmo_sims.AVG_COLOR + .6)
fig, *axes = plot_hexbin(stretch, color, xlim, ylim, bins=30)

axes[0].axvline(min(bg_source._stretch), linestyle='-.')
axes[0].axvline(max(bg_source._stretch), linestyle='-.')
axes[0].axhline(min(bg_source._color), linestyle='-.')
axes[0].axhline(max(bg_source._color), linestyle='-.', label='Model Limits')

axes[0].set_xlabel('Stretch (x1)', fontsize=12)
axes[0].set_ylabel('Color (c)', fontsize=12)
fig.legend(fontsize=12, loc='upper right', bbox_to_anchor=(1.05, .9))

plt.savefig(fig_dir / '91bg_x1_c_covariance.pdf')
plt.show()


## Model Evolution <a id='model_evolution'></a>

We explore the predicted flux and colors of our ported 91bg model. We arbitrarily choose to use the CSP band passes.


In [None]:
dr3.register_filters(True)

# A set of fiducial band passes.
unique_bands = (
    'csp_dr3_u',
    'csp_dr3_g',
    'csp_dr3_r',
    'csp_dr3_i',
    'csp_dr3_B',
    'csp_dr3_V'
)


In [None]:
def plot_flux_for_param(source, param_name, param_value):
    """Plot a modeled light curve over the range of a given param
    
    Args:
        source    (Source): An sncosmo source class
        param_name   (str): Name of the param to vary
        param_value  (str): Values of the param to plot
    """

    model = sncosmo.Model(source)
    phase = np.arange(model.mintime(), model.maxtime())
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)
    for band, axis in zip(unique_bands, axes.flatten()):
        for p_val in param_value:
            model.update({param_name: p_val})
            flux = model.bandflux(band, phase, zp=25, zpsys='AB')
            axis.plot(phase, flux, label=f'{param_name} = {p_val}')
            axis.set_xlabel('Time')
            axis.set_ylabel('Flux')
            axis.set_title(band)
            axis.legend()

    plt.show()


In [None]:
plot_flux_for_param(bg_source, 'c', (0.0, 0.25, 0.5, 0.75, 1))
plt.savefig(fig_dir / '91bg_flux_color_evolution.pdf')
plt.show()


In [None]:
plot_flux_for_param(bg_source, 'x1', (0.65, 0.85, 1.05, 1.25))
plt.savefig(fig_dir / '91bg_flux_stretch_evolution.pdf')
plt.show()


In [None]:
def plot_color_for_param(source, colors, param_name, param_value):
    """Plot a modeled light curve over the range of a given param
    
    Args:
        source    (Source): An sncosmo source class
        param_name   (str): Name of the param to vary
        param_value  (str): Values of the param to plot
    """

    model = sncosmo.Model(source)
    phase = np.arange(model.mintime(), model.maxtime())
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True)
    for (band1, band2), axis in zip(colors, axes.flatten()):
        for p_val in param_value:
            model.update({param_name: p_val})
            color = model.color(band1, band2, 'ab', phase)
            axis.plot(phase, color, label=f'{param_name} = {p_val}')
            axis.set_xlabel('Time')
            axis.set_ylabel(param_name)
            axis.set_title(f'{band1} - {band2}')
            axis.legend()
            
    plt.tight_layout()
    plt.show()


In [None]:
colors = tuple((unique_bands[i], unique_bands[i + 1]) for i in range(len(unique_bands) - 1))
plot_color_for_param(bg_source, colors, 'c', (0.0, 0.25, 0.5, 0.75, 1))
plt.savefig(fig_dir / '91bg_color_evolution.pdf')
plt.show()


In [None]:
plot_color_for_param(bg_source, colors, 'x1', (0.65, 0.85, 1.05, 1.25))
plt.savefig(fig_dir / '91bg_stretch_evolution.pdf')
plt.show()


## Comparison with Salt2 Parameters <a id='s2_parameters'></a>


We consider the correlation between salt2 and 91bg fit parameters - specificaly between stretch and color. We do this by simulating light curves with salt2, fitting each light curve with our 91bg model, and then plotting the relationship between the input and output parameters. 

First we create a set table of simulated observations that defines the cadence of our simulated light curves.

In [None]:
time = 4 * np.arange(-10, 30, 3).tolist()
bands = np.concatenate(
    [np.full(len(time) // 4, 'sdss' + band) for band in 'ugri']
)

table_length = len(time)
observations = Table(
    {'time': time,
     'band': bands,
     'gain': np.full(table_length, 1.0),
     'skynoise': np.zeros(table_length),
     'zp': np.full(table_length, 27.5),
     'zpsys': np.full(table_length, 'ab')})

observations


Next, we generate a dictionary of parameters for each light curve we want to simulate, and run the simulations.

In [None]:
z=0
bg_param_list = []
dx = .1

stretch = np.arange(0.65, 1.25 + dx, dx)
color = np.arange(0, 1 + dx, dx)
for x1 in stretch:
    for c in color:
        bg_param_list.append({'z': z, 'x1': np.round(x1, 2), 'c': np.round(c, 2)})

sn91bg = sncosmo.Model('sn91bg')
light_curves = sncosmo.realize_lcs(observations, sn91bg, bg_param_list)


Each of these light curves are then fit with the 91bg model. For conveniance, we cache results to file. We also typcast the salt2 and 91bg model parameters from lists of dictionaries to astropy tables.

In [None]:
salt2 = sncosmo.Model('salt2')
salt2.set(z=z)

s2_param_list = []
bounds = {'c': [-5, 5], 'x1': [-5, 5]}  # To prevent runaway
for lc in tqdm(light_curves):
    result, fm = sncosmo.fit_lc(lc, salt2, ['x0', 't0', 'x1', 'c'], bounds=bounds)
    result_dict = {p:v for p, v in zip(result.param_names, result.parameters)}
    s2_param_list.append(result_dict)

s2_parameters = Table(rows=s2_param_list)
bg_parameters = Table(rows=bg_param_list) 


Finally, we plot the relationship between the parameters.

In [None]:
def scatter_plot_parameters(simulated_params, fit_params):
    """Plot fitted parameters against simulated parameters as a color plot
    
    Args:
        simulated_params (Table): A table of simulated parameter values
        fit_params       (Table): A table of fitted parameter values
        
    Returns:
        A matplotlib figure
        An array with the axes of the figure
    """

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].scatter(simulated_params['c'], fit_params['c'])      
    axes[0].set_title('Color')

    axes[1].scatter(simulated_params['x1'], fit_params['x1'])      
    axes[1].set_title('Stretch')

    for axis in axes:
        axis.set_xlabel('Simulated')
        axis.set_ylabel('Fitted')


In [None]:
scatter_plot_parameters(bg_parameters, s2_parameters)
plt.savefig(fig_dir / 'salt2_fit_of_91bg_scatter.pdf')
plt.show()


In [None]:
def subplot_cotours(x, y, z, axis, **kwargs):
    """Imshow values with contours
    
    Args:
        x     (array): An array of x values
        y     (array): An array of y values
        z     (array): A 2d array of z values
        axis   (Axis): matplotlib axis to plot on
        vmin  (float): Lower limit for imshow
        vmax  (float): Upper limit for imshow
        limit (float): A value indicating the LSST 5 sigma limit
        Any other arguments for ``axis.imshow``
        
    Returns:
        A matplotlib figure
    """

    axis.set_xlim(min(x), max(x))
    axis.set_ylim(min(y), max(y))
    extent = [min(x), max(x), min(y), max(y)]
    
    im = axis.imshow(
        z, 
        origin='lower', 
        interpolation='bilinear',
        extent=extent, 
        aspect='auto',
        cmap='bwr',
        **kwargs
    )
    
    return im

def imshow_parameters(simulated_params, fit_params):
    """Plot fitted parameters against simulated parameters as a color plot
    
    Args:
        simulated_params (Table): A table of simulated parameter values
        fit_params       (Table): A table of fitted parameter values
        
    Returns:
        A matplotlib figure
        An array with the axes of the figure
    """
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    x = simulated_params['x1']
    y = simulated_params['c']
    
    z1 = [fit_params['c'][simulated_params['x1'] == x1] for x1 in set(simulated_params['x1'])]
    im1 = subplot_cotours(x, y, z1, axes[0])
    axes[0].set_xlabel('Simulated Stretch')
    axes[0].set_ylabel('Simulated Color')
    axes[0].set_title('Fitted Color')

    cbar_ax_1 = fig.add_axes([0.41, 0.15, 0.05, 0.7])
    fig.colorbar(im1, cax=cbar_ax_1)

    z2 = [fit_params['x1'][simulated_params['c'] == c] for c in set(simulated_params['c'])]
    im2 = subplot_cotours(x, y, z2, axes[1])
    axes[1].set_xlabel('Simulated Stretch')
    axes[1].set_ylabel('Simulated Color')
    axes[1].set_title('Fitted Stretch')

    cbar_ax_2 = fig.add_axes([0.92, 0.15, 0.05, 0.7])
    fig.colorbar(im2, cax=cbar_ax_2)

    plt.subplots_adjust(wspace=0.9)
    
    return fig, axes


In [None]:
imshow_parameters(bg_parameters, s2_parameters)
plt.savefig(fig_dir / 'salt2_fit_of_91bg_color.pdf')
plt.show()
