# Grid Interpolation Validation

As part of porting the CMFGEN models to SNCosmo, the models are interpolated onto a uniform grid. As a simple sanity check, this notebook calculates the residuals between the original, un-interpolated model and the gridded model.

In [None]:
import sys
from pathlib import Path

import numpy as np
import sncosmo
from bokeh.plotting import save, ColumnDataSource, figure
from bokeh.layouts import column
from bokeh.models.callbacks import CustomJS
from bokeh.models.widgets import Dropdown

sys.path.insert(0, '../')
from analysis import models


In [None]:
# Load models for different masses
models.register_sources(force=True)
m102 = sncosmo.get_source('CMFGEN', version=1.02)
m104 = sncosmo.get_source('CMFGEN', version=1.04)
m14 = sncosmo.get_source('CMFGEN', version=1.4)
m17 = sncosmo.get_source('CMFGEN', version=1.7)
sources = (m102, m104, m14, m17)


## Model Resolutions

Although the wavelength range for each model is constant with phase, the resolution is not. Here we summarize the range and resolution of each model (i.e., we validate the previous sentence). We note that the wavelength range of the interpolated model is sometimes less than the original model, but this difference is small (less than an Angstrom).

In [None]:
for source in sources:
    _, interp_wave, _ = source.interpolated_model()
    original_phase, original_wave, _ = source.original_model()

    # Determine wavelength range and resolution for each phase
    original_range = []
    original_resolution = []
    for w in original_wave:
        original_range.append([min(w), max(w)])
        original_resolution.append(len(w))
    
    # We transpose so index 0 is all minimum values and index 1 is maxima
    original_range = np.transpose(original_range)
    
    # Find range of minimum and maximum wavelengths over all phases
    min_original_range = min(original_range[0]), max(original_range[0])
    max_original_range = min(original_range[1]), max(original_range[1])
    
    print(f'{source.name} M = {source.version}:\n'
          f'    Phase range: {min(original_phase)} - {max(original_phase)}\n'
          f'    Phase resolution: {len(original_phase)}\n'
          f'    Original range: {min_original_range} - {max_original_range}\n'
          f'    Original resoultion: {min(original_resolution)} - {max(original_resolution)}\n'
          f'    Interpolated range: {min(interp_wave)} - {max(interp_wave)}\n'
          f'    Interpolated resoultion: {len(interp_wave)}\n')



## Calculating Residuals

We calculate residuals for each model as a function of wavelength using the wavelength values of the original model. Since sncosmo is not able to model flux outside the interpolated wavelength range, we specify a `max_wavelength` value when calculating residuals.

In [None]:
def calc_source_residuals(source, max_wavelength):
    """Calculate residuals between the original and interpolated models
    
    Args:
        source        (Source): An sncosmo source
        max_wavelength (float): Upper wavelength bound for residuals

    Returns:
        A 1d array with phase values
        A 2d array with wavelength values for each phase
        A 2d array with flux residuals for each wavelength
    """

    # Determine residuals for each wavelength
    phase, wavelength, residuals  = [], [], []
    for orig_phase, orig_wave, orig_flux in zip(*source.original_model()):

        # Only consider data below the specified wavelength range.
        indices = orig_wave < max_wavelength
        orig_flux = orig_flux[indices]
        orig_wave = orig_wave[indices]

        # Calculate residuals for this wavelength
        grid_flux = source.flux(orig_phase, orig_wave)
        phase.append(orig_phase)
        wavelength.append(orig_wave)
        residuals.append((orig_flux - grid_flux) / orig_flux)
    
    return phase, wavelength, residuals


def calc_residuals(source_list, max_wavelength=67e4):
    """Calculate residuals for a list of sources
    
    Args:
        source_list (list[Source]): A list of sncosmo sources
        max_wavelength     (float): Upper wavelength bound for
            residuals calculation (Default: 67e4) 
            
    Returns:
        A 1d array with phase values for each source
        A 2d array with wavelength values for each source and phase
        A 2d array with flux residuals for each source and wavelength
    """
    
    return {f'{s.name}_{s.version}': calc_source_residuals(s, max_wavelength) for s in source_list}


In [None]:
residuals = calc_residuals(sources)


## Examining Residuals

The models span a large number of phases and wavelengths, making visualization difficult for all the available data at one - even for a single model. We instead save interactive figures of the residuals for each source where the user can select a specific phase value.

In [None]:
def save_residual_plot(residual, out_path):
    """Save interactive figure of residuals for a given model to file
    
    Output is saved in .html format.
    
    Args:
        residual (array): Return of ``calc_source_residuals``
        out_path  (Path): Output file path
    """
    
    # Format data in a bokeh friendly way
    data = {str(phase): {'wave': wave, 'resid': resid} for phase, wave, resid in zip(*residual)}
    cd_source = ColumnDataSource(data=data['0.75'])
    
    # Add dropdown button for selecting phase
    phase_menu = [(str(p), str(p)) for p in data.keys()]
    dropdown = Dropdown(label='Model Phase', menu=phase_menu)
    callback = CustomJS(args=dict(source=cd_source, data=data), code="""
        var phase = cb_obj.value
        source.data['wave'] = data[phase]['wave']
        source.data['resid'] = data[phase]['resid']
        source.change.emit();
    """)
    
    dropdown.js_on_change('value', callback)
    
    # Plot and save data
    s = figure(width=450, plot_height=300, title='CMFGEN M = 1.02')
    s.circle('wave', 'resid', source=cd_source, size=2, alpha=0.5)
    layout = column(dropdown, s)
    save(layout, filename=str(out_path), title=out_path.stem)


In [None]:
# Protect against "Run All Cells" option
if False:
    out_dir = Path('./figures')
    out_dir.mkdir(parents=True, exist_ok=True)
    for source_name, resid in residuals.items():
        out_path = out_dir / (source_name + '.html')
        print(f'Creating {out_path} ...')
        save_residual_plot(resid, out_path)

    print('Done')
