# Inspecting the Ported 91bg Model

#### Table of Contents:
1. <a href='#parameter_covarience'>Parameter Covarience</a>: Demonstrates the covarience between stretch and color for 91bgs
1. <a href='#spectral_evolution'>Spectral Evolution</a>: Plots spectra as a function of phase
1. <a href='#photometric_evolution'>Photometric Evolution</a>: Explore modeled flux and colors over time
1. <a href='#s2_parameters'>Comparison with Salt2 Parameters</a>: Comparison of simulated salt2 parameters against best fit 91bg parameters.
1. <a href='#hsiao'>Comparison with Hsiao</a>: Comparison of the light-curve morphologies for the `sn91bg` and `hsiao` models.
1. <a href='#fits_to_data'>Fits to Real Data</a>: Comparison of the models fir to CSP data.


In [None]:
import sys
from pathlib import Path

import numpy as np
import sncosmo
import sndata
from astropy.table import Table
from matplotlib import pyplot as plt
from matplotlib.ticker import MultipleLocator, NullFormatter
from sndata.csp import dr3
from sndata.sdss import sako18
from tqdm import tqdm

sys.path.insert(0, '../')
from phot_class import fitting, models, utils
from phot_class.simulation import sncosmo_sims


In [None]:
# Output directory for figures
fig_dir = Path('./notebook_figs/model_inspection')
fig_dir.mkdir(exist_ok=True, parents=True)

models.register_sources(force=True)
bg_source = sncosmo.get_source('sn91bg')

dr3.register_filters(True)
sako18.register_filters(True)

print('Source Summary:\n')
print(bg_source)


## Parameter Covarience <a id='parameter_covarience'></a>

We consider the covarience between stretch and color in 91bg like supernovae. Note that our 91bg model does not intrinsically consider this covarience.


In [None]:
def plot_hexbin(x, y, xlim, ylim, fig_size=(8, 8), bins=50):
    """Plot a hexbin with x and y histograms
    
    Args:
      x      (ndarray):
      x      (ndarray):
      xlim     (tuple): x limits for the plot
      ylim     (tuple): y limits for the plot
      fig_size (tuple): x and y dimensions for the figure
      bins       (int): Number of bins to use
      
    Returns:
        A matplotlib figure
        The axis of the main subplot
        The axis of the x histogram
        The axis of the y histogram
    """

    # Define axis coordinates
    left, width = 0.1, 0.65
    bottom, height = 0.1, 0.65
    bottom_h = left_h = left + width + 0.02

    rect_scatter = [left, bottom, width, height]
    rect_histx = [left, bottom_h, width, 0.2]
    rect_histy = [left_h, bottom, 0.2, height]

    # Instantiate figure
    fig = plt.figure(1, fig_size)

    ax_main = plt.axes(rect_scatter)
    ax_hist_x = plt.axes(rect_histx)
    ax_hist_y = plt.axes(rect_histy)

    # Remove labels from histogram axes
    nullfmt = NullFormatter() 
    ax_hist_x.xaxis.set_major_formatter(nullfmt)
    ax_hist_y.yaxis.set_major_formatter(nullfmt)

    # The hexbin
    ax_main.hexbin(x, y, gridsize=(bins, bins), cmap='Blues')
    ax_main.set_xlim(xlim)
    ax_main.set_ylim(ylim)

    # The histograms
    avg_x = np.average(x)
    std_x = np.std(x)
    ax_hist_x.hist(x, bins=bins) 
    ax_hist_x.axvline(avg_x, color='black', linestyle='--', label='Average')
    ax_hist_x.axvline(avg_x + std_x, color='black', linestyle=':')
    ax_hist_x.axvline(avg_x - std_x, color='black', linestyle=':', 
                      label='Standard Deviation')
    
    avg_y = np.average(y)
    std_y = np.std(y)
    ax_hist_y.hist(y, bins=bins, orientation='horizontal')
    ax_hist_y.axhline(avg_y, color='black', linestyle='--')
    ax_hist_y.axhline(avg_y + std_y, color='black', linestyle=':')
    ax_hist_y.axhline(avg_y - std_y, color='black', linestyle=':')

    ax_hist_x.set_xlim(ax_main.get_xlim())
    ax_hist_y.set_ylim(ax_main.get_ylim())

    return fig, ax_main, ax_hist_x, ax_hist_y
    

In [None]:
stretch, color = sncosmo_sims.bg_stretch_color(
    size=1e6, 
    min_stretch=-np.inf, 
    min_color=-np.inf, 
    max_stretch=np.inf, 
    max_color=np.inf)


In [None]:
xlim = (sncosmo_sims.AVG_STRETCH - .4, sncosmo_sims.AVG_STRETCH + .4)
ylim = (sncosmo_sims.AVG_COLOR - .6, sncosmo_sims.AVG_COLOR + .6)
fig, *axes = plot_hexbin(stretch, color, xlim, ylim, bins=30)

axes[0].axvline(min(bg_source._stretch), linestyle='-.')
axes[0].axvline(max(bg_source._stretch), linestyle='-.')
axes[0].axhline(min(bg_source._color), linestyle='-.')
axes[0].axhline(max(bg_source._color), linestyle='-.', label='Model Limits')

axes[0].set_xlabel('Stretch (x1)', fontsize=12)
axes[0].set_ylabel('Color (c)', fontsize=12)
fig.legend(fontsize=12, loc='upper right', bbox_to_anchor=(1.05, .9))

plt.savefig(fig_dir / '91bg_x1_c_covariance.pdf')
plt.show()


## Spectral Evolution <a id='spectral_evolution'></a>

We plot the spectra over some phase range in arbitrary flux units

In [None]:
def plot_bg_spectral_evolution(phase_range, source, offset):
    """Plot spectral evolution over time
    
    Args:
        phase_range (tuple): Phase range to plot the model over
        source     (Source): Sncosmo Source to use for flux modeling
        offset      (float): The vertical offset between each spectrum
    
    Returns:
        A matplotlib figure
        An array with the axes of the figure
    """
    
    phase_arr = np.arange(*phase_range)
    wave_arr = np.arange(source.minwave(), source.maxwave())
    flux_arr = source.flux(phase_arr, wave_arr)
    plot_data = zip(phase_arr, flux_arr)

    fig, axis_left = plt.subplots(1, 1, figsize=(6, 10))
    spectrum_y_coords = []
    for num_spectra, (phase, flux) in enumerate(plot_data):
        if phase_range[0] <= phase and phase < phase_range[1]:
            flux_with_offset = flux + num_spectra * offset
            spectrum_y_coords.append(flux_with_offset[-1])
            axis_left.plot(wave_arr, flux_with_offset, label=str(phase))

    # Si II and Fe II
    si1 = axis_left.axvline(4130, linestyle='--', color='grey', alpha=1)
    si2 = axis_left.axvline(5972, linestyle='--', color='grey', alpha=1)
    si3 = axis_left.axvline(6355, linestyle='--', color='grey', alpha=1)
    fe1 = axis_left.axvline(4950, linestyle=':', color='grey', alpha=1)

    # Format axes
    axis_left.set_xlim(min(wave_arr), max(wave_arr))
    axis_left.set_xlabel(r'Wavelength ($\AA$)', fontsize=14)
    axis_left.set_title(f'{min(phase_arr)} to {max(phase_arr)} days', fontsize=14)
    axis_left.set_ylabel(r'Flux + C', fontsize=14)
    
    axis_right = axis_left.twinx()
    axis_right.set_ylim(axis_left.get_ylim())
    axis_right.set_yticks(spectrum_y_coords)
    axis_right.set_yticklabels(phase_arr)
    axis_right.set_yticklabels(phase_arr)
    axis_right.set_ylabel('Phase (Days)', fontsize=14)
    
    for tick_axis in (axis_left.xaxis, axis_left.yaxis, axis_right.yaxis):
        for tick in tick_axis.get_major_ticks():
            tick.label.set_fontsize(14) 

    fig.legend(
        (si1, si2, si3, fe1), 
        (r'Si II ($4130 \AA$)',
         r'Si II ($5972 \AA$)', 
         r'Si II ($6355 \AA$)', 
         r'Fe II ($4950 \AA$)'), 
        fontsize=14, framealpha=1)
    

    plt.tight_layout()
    return fig, axis_left


In [None]:
plot_bg_spectral_evolution((-10, 11), bg_source, offset=0.06)
plt.savefig(fig_dir / '91bg_spectral_evolution.pdf')
plt.show()


In [None]:
salt2 = sncosmo.get_source('salt2')
plot_bg_spectral_evolution((-10, 11), salt2, offset=2e-13)
plt.savefig(fig_dir / 'salt2_spectral_evolution.pdf')
plt.show()


## Photometric Evolution <a id='photometric_evolution'></a>

We explore the predicted flux and colors of our ported 91bg model. We arbitrarily choose to use the CSP band passes.


In [None]:
# A set of fiducial band passes.
unique_bands = (
    'csp_dr3_u',
    'csp_dr3_g',
    'csp_dr3_r',
    'csp_dr3_i',
    'csp_dr3_B',
    'csp_dr3_V'
)


In [None]:
def plot_flux_for_param(source, param_name, param_value):
    """Plot a modeled light curve over the range of a given param
    
    Args:
        source    (Source): An sncosmo source class
        param_name   (str): Name of the param to vary
        param_value  (str): Values of the param to plot
    """

    model = sncosmo.Model(source)
    phase = np.arange(model.mintime(), model.maxtime())
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)
    for band, axis in zip(unique_bands, axes.flatten()):
        for p_val in param_value:
            model.update({param_name: p_val})
            flux = model.bandflux(band, phase, zp=25, zpsys='AB')
            # flux /= np.max(flux)
            axis.plot(phase, flux, label=f'{param_name} = {p_val}')
            axis.set_xlabel('Time')
            axis.set_ylabel('Flux')
            axis.set_title(band)
            axis.legend()


In [None]:
plot_flux_for_param(bg_source, 'c', (0.0, 0.25, 0.5, 0.75, 1))
plt.savefig(fig_dir / '91bg_flux_color_evolution.pdf')
plt.show()


In [None]:
plot_flux_for_param(bg_source, 'x1', (0.65, 0.85, 1.05, 1.25))
plt.savefig(fig_dir / '91bg_flux_stretch_evolution.pdf')
plt.show()


In [None]:
def plot_color_for_param(source, colors, param_name, param_value):
    """Plot a modeled light curve over the range of a given param
    
    Args:
        source    (Source): An sncosmo source class
        param_name   (str): Name of the param to vary
        param_value  (str): Values of the param to plot
    """

    model = sncosmo.Model(source)
    phase = np.arange(model.mintime(), model.maxtime())
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True)
    for (band1, band2), axis in zip(colors, axes.flatten()):
        for p_val in param_value:
            model.update({param_name: p_val})
            color = model.color(band1, band2, 'ab', phase)
            axis.plot(phase, color, label=f'{param_name} = {p_val}')
            axis.set_xlabel('Time')
            axis.set_ylabel(param_name)
            axis.set_title(f'{band1} - {band2}')
            axis.legend()
            
    plt.tight_layout()


In [None]:
colors = tuple((unique_bands[i], unique_bands[i + 1]) for i in range(len(unique_bands) - 1))
plot_color_for_param(bg_source, colors, 'c', (0.0, 0.25, 0.5, 0.75, 1))
plt.savefig(fig_dir / '91bg_color_evolution.pdf')
plt.show()


In [None]:
plot_color_for_param(bg_source, colors, 'x1', (0.65, 0.85, 1.05, 1.25))
plt.savefig(fig_dir / '91bg_stretch_evolution.pdf')
plt.show()


## Comparison with Salt2 Parameters <a id='s2_parameters'></a>


We consider the correlation between salt2 and 91bg fit parameters - specificaly between stretch and color. We do this by simulating light curves with salt2, fitting each light curve with our 91bg model, and then plotting the relationship between the input and output parameters. 

First we create a set table of simulated observations that defines the cadence of our simulated light curves.

In [None]:
time = 6 * np.arange(-10, 30, 3).tolist()
bands = np.concatenate(
    [np.full(len(time) // 6, 'LSST' + band) for band in 'ugrizy']
)

table_length = len(time)
observations = Table(
    {'time': time,
     'band': bands,
     'gain': np.full(table_length, 1.0),
     'skynoise': np.zeros(table_length),
     'zp': np.full(table_length, 27.5),
     'zpsys': np.full(table_length, 'ab')})

observations


Next, we generate a dictionary of parameters for each light curve we want to simulate, and run the simulations.

In [None]:
z=0
bg_param_list = []
dx = .1

stretch = np.arange(0.65, 1.25 + dx, dx)
color = np.arange(0, 1 + dx, dx)
for x1 in stretch:
    for c in color:
        bg_param_list.append({'z': z, 'x1': np.round(x1, 2), 'c': np.round(c, 2)})

sn91bg = sncosmo.Model('sn91bg')
light_curves = sncosmo.realize_lcs(observations, sn91bg, bg_param_list)


Each of these light curves are then fit with the 91bg model. For conveniance, we cache results to file. We also typcast the salt2 and 91bg model parameters from lists of dictionaries to astropy tables.

In [None]:
salt2 = sncosmo.Model('salt2')
salt2.set(z=z)

s2_param_list = []
bounds = {'c': [-3, 3], 'x1': [-6, 5]}  # To prevent runaway
for lc in tqdm(light_curves):
    result, fm = sncosmo.fit_lc(lc, salt2, ['x0', 'x1', 'c'], bounds=bounds)
    result_dict = {p:v for p, v in zip(result.param_names, result.parameters)}
    s2_param_list.append(result_dict)

s2_parameters = Table(rows=s2_param_list)
bg_parameters = Table(rows=bg_param_list) 


Finally, we plot the relationship between the parameters.

In [None]:
def scatter_plot_parameters(simulated_params, fit_params):
    """Plot fitted parameters against simulated parameters as a color plot
    
    Args:
        simulated_params (Table): A table of simulated parameter values
        fit_params       (Table): A table of fitted parameter values
        
    Returns:
        A matplotlib figure
        An array with the axes of the figure
    """

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].scatter(simulated_params['c'], fit_params['c'], c=fit_params['x1'])      
    axes[0].set_xlabel('SN91bg Color', fontsize=14)
    axes[0].set_ylabel('Salt2 Color', fontsize=14)
    
    axes[1].scatter(simulated_params['x1'], fit_params['x1'], c=fit_params['c'])      
    axes[1].set_xlabel('SN91bg Stretch', fontsize=14)
    axes[1].set_ylabel('Salt2 Stretch', fontsize=14)
    
    plt.tight_layout()


In [None]:
scatter_plot_parameters(bg_parameters, s2_parameters)
plt.savefig(fig_dir / 'salt2_fit_of_91bg_scatter.pdf')
plt.show()


In [None]:
def subplot_imshow(x, y, z, axis, **kwargs):
    """Imshow values with contours
    
    Args:
        x     (array): An array of x values
        y     (array): An array of y values
        z     (array): A 2d array of z values
        axis   (Axis): matplotlib axis to plot on
        vmin  (float): Lower limit for imshow
        vmax  (float): Upper limit for imshow
        limit (float): A value indicating the LSST 5 sigma limit
        Any other arguments for ``axis.imshow``
        
    Returns:
        A matplotlib figure
    """

    axis.set_xlim(min(x), max(x))
    axis.set_ylim(min(y), max(y))
    extent = [min(x), max(x), min(y), max(y)]
    
    im = axis.imshow(
        z, 
        origin='lower', 
        interpolation='bilinear',
        extent=extent, 
        aspect='auto',
        cmap='bwr',
        **kwargs
    )
    
    return im

def imshow_parameters(simulated_params, fit_params):
    """Plot fitted parameters against simulated parameters as a color plot
    
    Args:
        simulated_params (Table): A table of simulated parameter values
        fit_params       (Table): A table of fitted parameter values
        
    Returns:
        A matplotlib figure
        An array with the axes of the figure
    """
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    x = simulated_params['x1']
    y = simulated_params['c']
    
    z1 = [fit_params['c'][simulated_params['x1'] == x1] for x1 in set(simulated_params['x1'])]
    im1 = subplot_imshow(x, y, z1, axes[0])
    axes[0].set_xlabel('91bg Stretch', fontsize=14)
    axes[0].set_ylabel('91bg Color', fontsize=14)
    axes[0].set_title('Salt2 Color', fontsize=14)

    cbar_ax_1 = fig.add_axes([0.41, 0.15, 0.05, 0.7])
    fig.colorbar(im1, cax=cbar_ax_1)

    z2 = [fit_params['x1'][simulated_params['c'] == c] for c in set(simulated_params['c'])]
    im2 = subplot_imshow(x, y, z2, axes[1])
    axes[1].set_xlabel('91bg Stretch', fontsize=14)
    axes[1].set_ylabel('91bg Color', fontsize=14)
    axes[1].set_title('Salt2 Stretch', fontsize=14)

    cbar_ax_2 = fig.add_axes([0.92, 0.15, 0.05, 0.7])
    fig.colorbar(im2, cax=cbar_ax_2)

    plt.subplots_adjust(wspace=0.9)
    
    return fig, axes


In [None]:
imshow_parameters(bg_parameters, s2_parameters)
plt.savefig(fig_dir / 'salt2_fit_of_91bg_color.pdf')
plt.show()


## Comparison with Hsiao <a id='hsiao'></a>

Comparing with Salt2 allows us to think of terms like "stretch" and "color" on a common footing. However, in practice we are more interested in the Hsiao model since it has a broader model range. We here consider a custom version of the sncosmo `hsiao` model where a stretch parameter `x1` has been added.


In [None]:
plot_flux_for_param(sncosmo.get_source('hsiao_x1'), 'x1', (-.5, -.25, 0, .25, .5))
plt.savefig(fig_dir / 'hsaio_flux_stretch_evolution.pdf')
plt.show()


In [None]:
def compare_model_to_hsiao(source):
    """Compare a model's normalized lightcurve against the hsaio model
    
    Args:
        source (Source): An sncosmo source
        
    Returns:
        A matplotlib figure
        An array with the axes of the figure
    """
    
    hsiao = sncosmo.get_source('hsiao_x1')

    lc = sncosmo.realize_lcs(
        observations, 
        sn91bg, 
        [dict(zip(sn91bg.param_names, sn91bg.parameters))]
    )[0]
    
    _, fitted_hsiao = sncosmo.fit_lc(
        lc, 
        sncosmo.Model(hsiao), 
        ['amplitude', 'x1'], 
        bounds={'x1': (-.5, .5)})
    
    bands = [
        'LSSTu',
        'LSSTg',
        'LSSTr',
        'LSSTi',
        'LSSTz',
        'LSSTy'
    ]

    phase = np.arange(-20, 50)
    fig, axes = plt.subplots(2, 3, figsize=(12, 8), sharex=True)
    for band, axis in zip(bands, axes.flatten()):
        hflux = hsiao.bandflux(band, phase)
        hflux_fitted = fitted_hsiao.bandflux(band, phase)
        sflux = source.bandflux(band, phase)
        
        # Normalize model light curve
        hflux_fitted /= np.max(hflux_fitted)
        hflux /= np.max(hflux)
        sflux /= np.max(sflux)

        # axis.plot(phase, hflux, label=f'Normal SNIa ({band[-1]}-band)')
        axis.plot(phase, hflux_fitted, label=f'Normal SNIa ({band[-1]}-band)')
        axis.plot(phase, sflux, label=f'SN91bg-like ({band[-1]}-band)')

        axis.set_ylim(0, axis.get_ylim()[-1])
        axis.legend(loc='upper right')
        
    plt.tight_layout()
    for axis in axes[:, 0]:
        axis.set_ylabel('Normalized Flux', fontsize=14)

    for axis in axes[-1]:
        axis.set_xlabel('Phase', fontsize=14)
        
    return fig, axes


In [None]:
fig, axes = compare_model_to_hsiao(sn91bg)
plt.savefig(fig_dir / 'hsaio_vs_91bg_morphology.pdf')
plt.show()


## Fits to Real Data <a id='fits_to_data'></a>

We apply both models to real data of a normal and 91bg-like SNe.

In [None]:
# CSP zero points
zero_point = {
    'csp_dr3_u': 12.986,
    'csp_dr3_g': 15.111,
    'csp_dr3_r': 14.902,
    'csp_dr3_i': 14.545,
    'csp_dr3_B': 14.328,
    'csp_dr3_V0': 14.437,
    'csp_dr3_V1': 14.393,
    'csp_dr3_V': 14.439,
    'csp_dr3_Y': 13.921,
    'csp_dr3_J': 13.836,
    'csp_dr3_Jrc2': 13.836,
    'csp_dr3_H': 13.51,
    'csp_dr3_Ydw': 13.77,
    'csp_dr3_Jdw': 13.866,
    'csp_dr3_Hdw': 13.502
}

# CSP AB offsets
instrument_offsets = {
    'csp_dr3_u': -0.06,
    'csp_dr3_g': -0.02,
    'csp_dr3_r': -0.01,
    'csp_dr3_i': 0,
    'csp_dr3_B': -0.13,
    'csp_dr3_V': -0.02,
    'csp_dr3_V0': -0.02,
    'csp_dr3_V1': -0.02,
    'csp_dr3_Y': 0.63,
    'csp_dr3_J': 0.91,
    'csp_dr3_Jrc2': 0.90,
    'csp_dr3_H': 1.34,
    'csp_dr3_Ydw': 0.64,
    'csp_dr3_Jdw': 0.90,
    'csp_dr3_Hdw': 1.34
}

#@lru_cache(None)
def fetch_2012fr_data():
    """Fetch photometric data for 2012fr from the OSC
    
    Returns:
        An astropy Table
    """
    
    # Fetch data from the OSC
    obj_id = 'SN2012fr'
    osc_data = sndata.query_osc_photometry(obj_id)

    # Keep only optical data from Contreras et al. (2018)
    osc_data = osc_data[osc_data['source'] == '1']
    osc_data = osc_data[np.isin(osc_data['band'], list('ugrizB'))]

    # Drop empty columns
    empty_cols = [c.name  for c in osc_data.itercols() if c.mask.all()]
    osc_data.remove_columns(empty_cols)

    # Format data for use with sncsomo
    data = Table(
        names=['time', 'band', 'flux', 'fluxerr', 'zp', 'zpsys'],
        dtype=[float, 'U10', float, float, float, 'U2']
    )

    for row in osc_data:
        band = 'csp_dr3_' + row['band']
        if band == 'csp_dr3_V':
            band += '1'

        zp = zero_point[band]
        ab_mag = float(row['magnitude']) + instrument_offsets[band]
        ab_mag_err = float(row['e_magnitude'])
        time = float(row['time'])

        flux = 10 ** ((ab_mag - zp) / -2.5)
        flux_err = np.log(10) * flux * ab_mag_err / 2.5

        data.add_row([time, band, flux, flux_err, zp, 'AB'])

    data['time'] -= 53000
    data.meta['redshift'] = float(osc_data.meta['redshift'][0]['value'])
    data.meta['obj_id'] = obj_id
    return data


In [None]:
def convert_to_zp(data, zp):
    """Convert the flux and fluxerr columns of a table to a new zp
    
    Returns a copy
    
    Args:
        data (Table): The table to modify
        zp (float): The new zp
    
    Returns:
        An astropy Table
    """
    
    data = data.copy()
    conv_factor = 10 ** ((data['zp'] - zp) / -2.5)
    data['flux'] *= conv_factor
    data['fluxerr'] *= conv_factor
    data['zp'] = zp
    return data


In [None]:
def compare_fits_to_data(data, is_bg=False):
    """Plot fits of the hsiao_x1 and 91bg model
    
    Args:
        data (Table): The data to fit
        is_bg (bool): Whether the target is a 91bg
        
    Returns:
        A matplotlib figure
        An array of matplotlib axes
    """

    # Fit data in the red and blue bands
    fit_results = fitting.run_collective_fits(
        obj_id=data.meta['obj_id'],
        data=data,
        fit_func=sncosmo.fit_lc,
        band_names=dr3.band_names,
        lambda_eff=dr3.lambda_effective,
        priors_hs={'z': data.meta['redshift']},
        priors_bg={'z': data.meta['redshift']},
        kwargs_hs={'bounds': {'x1': (-.5, .5)}},
        kwargs_bg={'bounds': {'x1': (0.65, 1.25), 'c': (0, 1)}}
    )
    
    # Seperate data and fit results into red and blue 
    fit_results = fit_results.to_pandas().set_index(['band', 'source'])
    collective_params = fit_results.loc['blue'], fit_results.loc['red']
    collective_data = utils.split_data(
        data, 
        dr3.band_names, 
        dr3.lambda_effective, 
        data.meta['redshift'])
    
    # Define models
    sn91bg_source = sncosmo.get_source('sn91bg', version='hsiao_phase')
    sn91bg = sncosmo.Model(source=sn91bg_source)
    hsiao = sncosmo.Model(source='hsiao_x1')
    
    # Some arrays used in plotting
    phase_range = min(data['time']) - 5, min(data['time']) + 75
    phase_arr = np.arange(*phase_range)
    
    fig, axes = plt.subplots(1, 2, figsize=(7, 7 / 2), sharey=True)
    plot_data = zip(axes, collective_params, collective_data)
    markers = '.^sv.^sv.^sv'
    for axis, bandset_params, bandset_data in plot_data:
        hsiao_params = bandset_params.loc['hsiao_x1']    
        hsiao.update({p: hsiao_params[p] for p in hsiao.param_names})
        
        sn91bg_params = bandset_params.loc['sn91bg']    
        sn91bg.update({p: sn91bg_params[p] for p in sn91bg.param_names})
    
        for i, band_data in enumerate(bandset_data.group_by('band').groups):
            band_name = band_data['band'][0]
            zp = band_data['zp'][0]
            zpsys = band_data['zpsys'][0]

            axis.scatter(
                x=band_data['time'], 
                y=band_data['flux'], 
                label=band_name.split('_')[-1],
                color=f'C{i}',  
                marker=markers[i],
                s=5
            )

            axis.errorbar(
                x=band_data['time'], 
                y=band_data['flux'], 
                yerr=band_data['fluxerr'], 
                linestyle='', 
                color=f'C{i}',
            )

            default_style = {'linestyle': '-', 'alpha': 1}
            alt_style = {'linestyle': '--', 'alpha': .7}
            
            hsiao_flux = hsiao.bandflux(band_name, phase_arr, zp, zpsys)
            hsiao_style = alt_style if is_bg else default_style
            axis.plot(phase_arr, hsiao_flux, color=f'C{i}', **hsiao_style)
            
            sn91bg_flux = sn91bg.bandflux(band_name, phase_arr, zp, zpsys)
            sn91bg_style = default_style if is_bg else alt_style
            axis.plot(phase_arr, sn91bg_flux, **sn91bg_style, color=f'C{i}')
                        
            i += 1
            
    for axis in axes:
        axis.set_xlim(*phase_range)
        axis.set_xlabel('Observed Date (MJD - 53000)', fontsize=10, labelpad=10)
        axis.tick_params(labelsize=8)
        axis.xaxis.set_minor_locator(MultipleLocator(5))
        axis.xaxis.set_major_locator(MultipleLocator(10))
        axis.legend()
        
        mag_axis = axis.twinx()
    
        # Major tick marks for magnitude axis
        max_mag = 17 if is_bg else 15
        major_mag_labels = np.arange(10, max_mag)
        major_flux_labels = 10 ** ((major_mag_labels - zp) / -2.5)
        mag_axis.set_yticks(major_flux_labels)
        mag_axis.set_yticklabels(major_mag_labels)
        
        # Minor tick marks for magnitude axis
        minor_mag_labels = np.arange(10, max_mag + 1, .25)
        minor_flux_labels = 10 ** ((minor_mag_labels - zp) / -2.5)
        mag_axis.set_yticks(minor_flux_labels, minor=True)
        mag_axis.set_ylim(axis.get_ylim())
    
        
    mag_axis.set_ylabel('Magnitude', rotation=270, fontsize=10, labelpad=20)
    axes[0].set_ylabel(f'Flux (ZP = {zp})', labelpad=10)
    axes[0].set_ylim(ymin=0)
    axes[1].set_yticklabels(axes[0].get_yticklabels())

    return fig, axes


In [None]:
norm_data = convert_to_zp(fetch_2012fr_data(), 15)
_ = compare_fits_to_data(norm_data)
plt.tight_layout()
plt.savefig(fig_dir / 'fit_of_normal_sn.pdf')
plt.show()


In [None]:
bg_data = dr3.get_data_for_id('2005ke')
bg_data = bg_data[bg_data['band'] != 'csp_dr3_Ydw']
bg_data = bg_data[bg_data['band'] != 'csp_dr3_Y']
bg_data = bg_data[bg_data['band'] != 'csp_dr3_H']
bg_data = bg_data[bg_data['band'] != 'csp_dr3_J']
bg_data = bg_data[bg_data['band'] != 'csp_dr3_V']
bg_data = bg_data[bg_data['band'] != 'csp_dr3_V0']
bg_data['time'] -= 2453000.5

bg_data.remove_columns(['mag', 'mag_err'])
bg_data = convert_to_zp(bg_data, 15)

fig, axes = compare_fits_to_data(bg_data, True)
plt.tight_layout()
plt.savefig(fig_dir / 'fit_of_91bg.pdf')
plt.show()
