# Fit Inspection

This notebook inspects fit results and allows the reader to manually specify new priors for any bad fits. It is important to note that manually specifying new priors will **NOT** update the fit results saved to file. New fit results can only be tabulated by re-running the analysis piepline.


In [None]:
import sys
from copy import deepcopy

import pandas as pd
import sncosmo
from astropy.table import Column, Table, unique
from matplotlib import pyplot as plt
from sndata.csp import dr3
from sndata.des import sn3yr
from sndata.sdss import sako18

sys.path.insert(0, '../')
from analysis_pipeline import get_priors, get_fit_results, save_priors
from analysis_pipeline.lc_fitting import split_data, fit_lc
from analysis_pipeline.models import register_sources

register_sources(force=True)


We first ensure that data is available on the local machine and the the filters for each survey are registered with SNComso.

In [None]:
# Download data and register SNCosmo filters
for data in (dr3, sn3yr, sako18):
    data.download_module_data()
    data.register_filters(force=True)

# Define models
salt2_4 = sncosmo.Model('salt2')
sn91bg_p = sncosmo.Model(source=sncosmo.get_source('sn91bg', version='salt2_phase'))
sn91bg_c = sncosmo.Model(source=sncosmo.get_source('sn91bg', version='color_interpolation'))


Next we define a function for viewing fit results.

In [None]:
def plot_lightcurves(obj_id, module, model, fit_results):
    """Plot a light curve
    
    Args:
        obj_id           (str): Survey specific object id
        module       (module): SNData submodule for a survey's data release
        model          (Model): The model to plot
        results_df (DataFrame): Pipeline fit results
    """

    data = [module.get_data_for_id(obj_id, format_sncosmo=True)]
    data += split_data(data[0], module.band_names, module.lambda_effective)
    for data_table, df in zip(data, fit_results):
        fit_results = df.loc[obj_id]

        model = deepcopy(model)
        params = {p: fit_results[p] for p in model.param_names}
        model.update(params)

        sncosmo.plot_lc(data_table, model)
        plt.show()


## Inspecting fit results

We specify the data module, light-curve model, and number of parameters we wish to inspect fits for. Using that information we build an iterator of object id values. To optionally skip ahead and start inspecting at a given object ID, specify the `start_at_id` variable. 

In [None]:
# Specify values in this code block
module_to_inspect = dr3
start_at_id = None
model_to_inspect = sn91bg_c
params_to_inspect = 4

# Read in data and build iterator of object ids
priors = get_priors(module_to_inspect, model_to_inspect)
fits = get_fit_results(module_to_inspect, model_to_inspect, params_to_inspect)
ids = module_to_inspect.get_available_ids()
if start_at_id:
    ids = ids[ids.index(start_at_id):]

ids = iter(ids)


The following cell retrieves the next ID value and plots the the current fit results. If the fits look good, rerun the cell and move on to the next object. Otherwise, continue on to the following cells below.

In [None]:
# Plot fits
current_id = next(ids)
print(f'Inspecting {current_id}')
plot_lightcurves(current_id, module_to_inspect, model_to_inspect, fits)


Take note of the priors used in the fit.

In [None]:
# Print priors
priors_this = priors.loc[current_id]
priors_dict = dict(priors_this)
print(priors_this)


Next we pick the portion of the light curve that was fit poorly, update the priors, and look at the new fit results.

In [None]:
data_to_fit = 'all'  # 'all', 'blue', or 'red'

# Modify any values as you see fit:
#priors_dict['t0_min'] = 1000
#priors_dict['t0_max'] = 1010
#priors_dict['c_max'] = 245


In [None]:
# This cell reruns fits with the new priors and plots the results.
# There is no need to edit this cell.

# Get the specified bands of the lightcurve
data = module_to_inspect.get_data_for_id(current_id, format_sncosmo=True)
blue, red = split_data(
    data,
    module_to_inspect.band_names, 
    module_to_inspect.lambda_effective)

data = {'all': data, 'blue': blue, 'red': red}[data_to_fit]

# Set initial values in model
model_this = deepcopy(model_to_inspect)
vparams = model_to_inspect.param_names[5 - params_to_inspect:]
model_this.update({p: priors_dict[p] for p in vparams})
if 'z' not in vparams:
    model_this.set(z=data.meta['redshift'])

# Specify fitting bounds
bounds = None
bounds = {p: (priors_dict[f'{p}_min'], priors_dict[f'{p}_max']) for p in model_to_inspect.param_names}

fit_results = fit_lc(data, model_this, vparams, bounds=bounds)
z =  fit_results[2]
t0 = fit_results[3]
x0 = fit_results[4]
x1 = fit_results[5]
c = fit_results[6]
model_this.set(z=z, t0=t0, x0=x0, x1=x1, c=c)

sncosmo.plot_lc(data, model_this)
plt.show()
print(priors_dict)


If the new fits look good (enough), save the priors to file.

In [None]:
save_priors(current_id, module_to_inspect, model_to_inspect, priors_dict, "Increased t0_min")