# Chi-Squared Calculation Comparison

This notebook compares the Chi-Squared values returned by SNCosmo against Chi-Squared values calculated manually.

In [None]:
%matplotlib inline

import sys
from warnings import warn

import numpy as np
import sncosmo
from SNData.csp import dr3
from SNData.des import sn3yr
from astropy.table import Table, join
from matplotlib import pyplot as plt

sys.path.insert(0, '../')
from analysis_pipeline import get_fit_results, get_priors, SN91bgSource

for module in (dr3, sn3yr):
    module.download_module_data()
    module.register_filters(force=True)


## Manually Calculating Chi-Squared

In [None]:
def calc_chisq(data, model):
    """Calculate the chi-squared for a given data table and model
    
    Args:
        data  (Table): An SNCosmo input table
        model (Model): An SNCosmo Model
        
    Returns:
        The un-normalized chi-squared
        The number of data points used in the calculation
    """

    # Model flux and keep only non-zero values
    data['model_flux'] = [model.bandflux(b, t) for b, t in zip(data['band'], data['time'])]
    data = data[data['model_flux'] > 0]
    
    chisq = np.sum(
        ((data['model_flux'] - data['flux']) / data['fluxerr']) ** 2
    )
    return chisq, len(data)


In [None]:
test_id = '2007S'
model = sncosmo.Model('salt2')
data = dr3.get_sncosmo_input(test_id)
priors = get_priors(dr3, model).loc[test_id]

model.set(z=data.meta['redshift'])
bounds = dict()
for p in model.param_names:
    bounds[p] = (priors[f'{p}_min'], priors[f'{p}_max'])
    model.update({p: priors[p]})

fit_result, fit_model = sncosmo.fit_lc(data, model, ['t0', 'x0', 'x1', 'c'], bounds=bounds)
sncosmo.plot_lc(data, model)

full_sncosmo_chisq = fit_result['chisq'] / fit_result['ndof']
print('Chi-Squared from fit_lc:', full_sncosmo_chisq)

for band in ('csp_dr3_Y', 'csp_dr3_J', 'csp_dr3_Ydw', 'csp_dr3_H'):
    data = data[data['band'] != band]
    
func_chisq = sncosmo.chisq(data, model) / fit_result['ndof']
print('Chi-Squared from `sncosmo.chisq`:', func_chisq)

man_chisq, num_points = calc_chisq(data, model)
man_dof = fit_result['ndof'] - len(data) + num_points
print('Manual Chi-squared using all data:', man_chisq / man_dof)


In [None]:
band_sncosmo_chisq = dict()
band_manual_chisq = dict()
for band in set(data['band']):
    band_data = data[data['band'] == band]
    try:
        fit_result, fitted_model = sncosmo.fit_lc(band_data, model, ['t0', 'x0', 'x1', 'c'])
        
    except RuntimeError:
        continue
    
    
    # Calculate chi-squared
    man_chisq, num_points = calc_chisq(band_data, fitted_model)
    man_dof = fit_result['ndof'] - len(band_data) + num_points
    band_manual_chisq[band] = man_chisq / man_dof
    band_sncosmo_chisq[band] = fit_result['chisq'] / fit_result['ndof']

print('SNCosmo Chi-Squared:')
print(band_sncosmo_chisq)

print('\nManual Chi-Squared:')
print(band_manual_chisq)

print('\nSNCosmo Chi-Squared summed over bands:')
print(sum(i for i in band_sncosmo_chisq.values()))

print('\nManual Chi-Squared summed over bands:')
print(sum(i for i in band_manual_chisq.values()))