# Inspecting the Ported 91bg Model

#### Table of Contents:
1. <a href='#model_evolution'>Model Evolution</a>: Explore modeled flux and colors over time
1. <a href='#parameter_covarience'>Parameter Covarience</a>: Demonstrates the covarience between stretch and color for 91bgs


In [None]:
import sys
from copy import deepcopy

import numpy as np
import sncosmo
from matplotlib import pyplot as plt
from matplotlib.ticker import NullFormatter
from sndata.csp import dr3

sys.path.append('../')
from phot_class import models
from phot_class.simulation import sncosmo_sims


In [None]:
models.register_sources(force=True)
bg_source = sncosmo.get_source('sn91bg')

print('Source Summary:\n')
print(bg_source)


## Parameter Covarience <a id='parameter_covarience'></a>


In [None]:
def plot_hexbin(x, y, xlim, ylim, fig_size=(8, 8), bins=50):
    """Plot a hexbin with x and y histograms
    
    Args:
      x      (ndarray):
      x      (ndarray):
      xlim     (tuple): x limits for the plot
      ylim     (tuple): y limits for the plot
      fig_size (tuple): x and y dimensions for the figure
      bins       (int): Number of bins to use
      
    Returns:
        A matplotlib figure
        The axis of the main subplot
        The axis of the x histogram
        The axis of the y histogram
    """

    # Define axis coordinates
    left, width = 0.1, 0.65
    bottom, height = 0.1, 0.65
    bottom_h = left_h = left + width + 0.02

    rect_scatter = [left, bottom, width, height]
    rect_histx = [left, bottom_h, width, 0.2]
    rect_histy = [left_h, bottom, 0.2, height]

    # Instantiate figure
    fig = plt.figure(1, fig_size)

    ax_main = plt.axes(rect_scatter)
    ax_hist_x = plt.axes(rect_histx)
    ax_hist_y = plt.axes(rect_histy)

    # Remove labels from histogram axes
    nullfmt = NullFormatter() 
    ax_hist_x.xaxis.set_major_formatter(nullfmt)
    ax_hist_y.yaxis.set_major_formatter(nullfmt)

    # The hexbin
    ax_main.hexbin(x, y, gridsize=(bins, bins), cmap='Blues')
    ax_main.set_xlim(xlim)
    ax_main.set_ylim(ylim)

    # The histograms
    avg_x = np.average(x)
    std_x = np.std(x)
    ax_hist_x.hist(x, bins=bins) 
    ax_hist_x.axvline(avg_x, color='black', linestyle='--', label='Average')
    ax_hist_x.axvline(avg_x + std_x, color='black', linestyle=':')
    ax_hist_x.axvline(avg_x - std_x, color='black', linestyle=':', 
                      label='Standard Deviation')
    
    avg_y = np.average(y)
    std_y = np.std(y)
    ax_hist_y.hist(y, bins=bins, orientation='horizontal')
    ax_hist_y.axhline(avg_y, color='black', linestyle='--')
    ax_hist_y.axhline(avg_y + std_y, color='black', linestyle=':')
    ax_hist_y.axhline(avg_y - std_y, color='black', linestyle=':')

    ax_hist_x.set_xlim(ax_main.get_xlim())
    ax_hist_y.set_ylim(ax_main.get_ylim())

    return fig, ax_main, ax_hist_x, ax_hist_y
    

In [None]:
stretch, color = sncosmo_sims.bg_stretch_color(
    size=1e6, 
    min_stretch=-np.inf, 
    min_color=-np.inf, 
    max_stretch=np.inf, 
    max_color=np.inf)


In [None]:
xlim = (sncosmo_sims.AVG_STRETCH - .4, sncosmo_sims.AVG_STRETCH + .4)
ylim = (sncosmo_sims.AVG_COLOR - .6, sncosmo_sims.AVG_COLOR + .6)
fig, *axes = plot_hexbin(stretch, color, xlim, ylim, bins=30)

axes[0].axvline(min(bg_source._stretch), linestyle='-.')
axes[0].axvline(max(bg_source._stretch), linestyle='-.')
axes[0].axhline(min(bg_source._color), linestyle='-.')
axes[0].axhline(max(bg_source._color), linestyle='-.', label='Model Limits')

axes[0].set_xlabel('Stretch (x1)', fontsize=12)
axes[0].set_ylabel('Color (c)', fontsize=12)
fig.legend(fontsize=12, loc='upper right', bbox_to_anchor=(1.05, .9))

plt.show()


## Model Evolution <a id='model_evolution'></a>

We explore the predicted flux and colors of our ported 91bg model. We arbitrarily choose to use the CSP band passes.


In [None]:
dr3.register_filters(True)

# A set of fiducial band passes.
unique_bands = (
    'csp_dr3_u',
    'csp_dr3_g',
    'csp_dr3_r',
    'csp_dr3_i',
    'csp_dr3_B',
    'csp_dr3_V'
)


In [None]:
def plot_flux_for_param(source, param_name, param_value):
    """Plot a modeled light curve over the range of a given param
    
    Args:
        source    (Source): An sncosmo source class
        param_name   (str): Name of the param to vary
        param_value  (str): Values of the param to plot
    """

    model = sncosmo.Model(source)
    phase = np.arange(model.mintime(), model.maxtime())
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)
    for band, axis in zip(unique_bands, axes.flatten()):
        for p_val in param_value:
            model.update({param_name: p_val})
            flux = model.bandflux(band, phase, zp=25, zpsys='AB')
            axis.plot(phase, flux, label=f'{param_name} = {p_val}')
            axis.set_xlabel('Time')
            axis.set_ylabel('Flux')
            axis.set_title(band)
            axis.legend()

    plt.show()


In [None]:
plot_flux_for_param(bg_source, 'c', (0.0, 0.25, 0.5, 0.75, 1))


In [None]:
plot_flux_for_param(bg_source, 'x1', (0.65, 0.85, 1.05, 1.25))


In [None]:
def plot_color_for_param(source, colors, param_name, param_value):
    """Plot a modeled light curve over the range of a given param
    
    Args:
        source    (Source): An sncosmo source class
        param_name   (str): Name of the param to vary
        param_value  (str): Values of the param to plot
    """

    model = sncosmo.Model(source)
    phase = np.arange(model.mintime(), model.maxtime())
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True)
    for (band1, band2), axis in zip(colors, axes.flatten()):
        for p_val in param_value:
            model.update({param_name: p_val})
            color = model.color(band1, band2, 'ab', phase)
            axis.plot(phase, color, label=f'{param_name} = {p_val}')
            axis.set_xlabel('Time')
            axis.set_ylabel('Color')
            axis.set_title(f'{band1} - {band2}')
            axis.legend()
            
    plt.tight_layout()
    plt.show()


In [None]:
colors = tuple((unique_bands[i], unique_bands[i + 1]) for i in range(len(unique_bands) - 1))
plot_color_for_param(bg_source, colors, 'c', (0.0, 0.25, 0.5, 0.75, 1))


In [None]:
plot_color_for_param(bg_source, colors, 'x1', (0.65, 0.85, 1.05, 1.25))
