# Fitting CO model interactively

Load up a spectrum and play with the widgets to get a good initial starting point for the fit. The next cell is to setup some variables:  which file to load, what the columns are called, and the wavelength region to fit. Also, choose a `scale` that puts your spectrum near unity.

Aside from the usual modules, you'll need the `ipympl` and `george` packages, which can be installed with `conda`.

In [None]:
fitsfile = "SN2016adj_20160521_BAA_NIR_cc_dered_rest.ecsv"
wavecol = 'wave'
fluxcol = 'flux'
wmin = 22800
wmax = 24950
scale = 1e-15
# Pickle file with Peter's grid of models
pklfile = "opac_fine.pkl"

## Getting a first guess fit

Running this next cell will give you an interactive plot where you can use the sliders to both see how the different parameters change the model, but also use chi-by-eye to get a good initial starting point for optimizers or MCMC inference. Note that if you don't set the `scale` parameter appropriately, you might not see anything with the default model parameters.

In [None]:
from astropy.io import fits, ascii
import COmodel
from matplotlib import pyplot as plt
import numpy as np
import ipywidgets as widgets

tab = ascii.read(fitsfile)
wave,flux = tab[wavecol].value, tab[fluxcol].value
gids = np.greater(wave, wmin) & np.less(wave, wmax)
wave,flux = wave[gids],flux[gids]

model = COmodel.FluxModel(pkl=pklfile, scale=scale)
model.set_pars([3000, 2000, -3, 1, 0, 3])

%matplotlib notebook
fig,ax = plt.subplots()
ax.plot(wave,flux/scale,'-')
ax.set_xlabel('Wavelength (Angstroms)')
ax.set_ylabel('scaled Flux')

line, = ax.plot(wave, model(wave), '-')

def updateA(change):
    model.A = change.new
    line.set_ydata(model(wave))

def updateT(change):
    model.T = change.new
    line.set_ydata(model(wave))

def updateV(change):
    model.vel = change.new
    line.set_ydata(model(wave))

def updatelco(change):
    model.lco = change.new
    line.set_ydata(model(wave))

def updatez(change):
    model.z = change.new
    line.set_ydata(model(wave))
    
def updateb(change):
    model.b = change.new
    line.set_ydata(model(wave))
    
Aslider = widgets.FloatSlider(
    value=1, min=0, max=5.0, step=0.01, description='Aplitude:',
    disabled=False, continuous_update=True, orientation='horizontal',
    readout=True, readout_format='.2f')
Tslider = widgets.FloatSlider(
   value=3000, min=model.T.min, max=model.T.max, step=(model.T.max-model.T.min)/100,
   description='Temperature:', continuous_update=True, readout=True, readout_format='.1f')
vslider = widgets.FloatSlider(
   value=2000, min=model.vel.min, max=model.vel.max, step=(model.vel.max-model.vel.min)/100,
   description='velocity:', continuous_update=True, readout=True, readout_format='.1f')
lslider = widgets.FloatSlider(
   value=-3, min=model.lco.min, max=model.lco.max, step=(model.lco.max-model.lco.min)/100,
   description='log(CO+/CO)', continuous_update=True, readout=True, readout_format='.1f')
zslider = widgets.FloatSlider(
   value=0, min=0, max=5e-3, step=1e-5,
   description='z:', continuous_update=True, readout=True, readout_format='.1e')
bslider = widgets.FloatSlider(
   value=3, min=0, max=10, step=0.01,
   description='continuum:', continuous_update=True, readout=True, readout_format='.1e')


Aslider.observe(updateA, 'value')
Tslider.observe(updateT, 'value')
vslider.observe(updateV, 'value')
lslider.observe(updatelco, 'value')
zslider.observe(updatez, 'value')
bslider.observe(updateb, 'value')
widgets.VBox([Aslider,Tslider,vslider,lslider,zslider,bslider])

## Masking the Data

In this next cell, the residuals from the fit made above are plotted. Now is the time to remove bad data. You can do this two ways:  interactively set the threshold for sigma-clipping (defaulted to 10-sigma) or selecting ranges to mask. The latter is done by clicking a dragging a wavelength region.

In [None]:
from astropy.modeling import fitting
from matplotlib.widgets import SpanSelector
import warnings

# Threshold for keeping data in units of MAD
thresh = 10

# Get an optimized fit
fitter = fitting.LevMarLSQFitter()
opt = fitter(model, wave, flux/scale)

fig,ax = plt.subplots()
ax.plot(wave,flux/scale-model(wave),'-', color='k', alpha=0.1)
ax.set_xlabel('Wavelength (Angstroms)')
ax.set_ylabel('scaled Flux Residuals')

ax.plot(wave, 0*wave, '-', color='C1')

mask = np.isnan(wave)    # initial mask
sigma = 1.5*np.absolute(np.median(flux/scale-model(wave)))    # MAD to avoid large outliers
siglin1 = ax.axhline(sigma, ls='--', color='C1')
siglin2 = ax.axhline(-sigma, ls='--', color='C1')

mask = np.greater(np.absolute(flux/scale-model(wave)), thresh*sigma)
datline, = ax.plot(wave[~mask], flux[~mask]/scale-model(wave)[~mask], '-', color='C0')

# Refit using mask and errors
opt = fitter(opt, wave[~mask], flux[~mask]/scale, weights=np.power(flux[~mask]*0 + sigma,-1))
sigma = 1.5*np.absolute(np.median(flux[~mask]/scale-model(wave[~mask])))

Xstart = None      # state variable
def on_Select(xmin,xmax):
    global datline
    gids = np.greater(wave, xmin) & np.less(wave, xmax)
    mask[gids] = True
    datline.remove()
    datline, = ax.plot(wave[~mask], flux[~mask]/scale-model(wave)[~mask], '-', color='C0')
    fig.canvas.draw_idle()

def updateThresh(change):
    global datline
    thresh = change.new
    mask = np.greater(np.absolute(flux/scale-model(wave)), thresh*sigma)
    datline.remove()
    datline, = ax.plot(wave[~mask], flux[~mask]/scale-model(wave)[~mask], '-', color='C0')
    
ThreshSlider = widgets.FloatSlider(
   value=10, min=1, max=20, step=0.1,
   description='continuum:', continuous_update=False, readout=True)
ThreshSlider.observe(updateThresh, 'value')
span = SpanSelector(ax, on_Select, "horizontal", props=dict(alpha=0.1, facecolor='red'), interactive=False,
            drag_from_anywhere=True)

plt.show()
ThreshSlider



## Modeling the Noise

The initial modeling was done assuming some constant white noise over the entire spectrum. Now we try to be a bit more careful. In this next cell, we model the noise as a Gaussian Process with the `Matern32` kernel. This allows for correlated noise at a level `alpha` and scale length `kscale`. A white noise component of `sigma` is retained. We just use an optimizer in this case, since we just want to know what `alpha` and `kscale` best describe the data.

In [None]:
import george
from george import kernels
from scipy.optimize import minimize

# Start with a scale that's likely to be close:
kscale = np.mean(wave[1:]-wave[:-1])*10
resids = flux[~mask]/scale - model(wave[~mask])
alpha = np.var(resids)
print("Initial kernel params:  alpha={}, scale={}".format(alpha, kscale))
gp = george.GP(alpha*kernels.Matern32Kernel(kscale))
gp.compute(wave[~mask], flux[~mask]*0+sigma)

# intial starting point
init = gp.get_parameter_vector()

def mlnprob(p):
    '''A function to return the negative log-probability. Used with minimize()'''
    gp.set_parameter_vector(p)
    return -gp.log_likelihood(resids, quiet=True) - gp.log_prior()
gp_p = minimize(mlnprob, init)

print(gp_p)
print("Final kernel params: alpha={}, scale={}".format(np.exp(gp_p.x[0]),np.exp(gp_p.x[1])))

## Visualizing the noise

This next cell, when run, will plot the best-fit CO model with the white noise `sigma` and the correlated noise side-by-side, for comparison. This can take a little while to run the first time, so you can skip it without affecting the rest of the notebook's execution.

In [None]:

fig,axes = plt.subplots(1,2, figsize=(10,4))
axes[0].plot(wave, model(wave)+np.random.normal(0, sigma, size=len(wave)), '-')
axes[0].plot(wave, model(wave),'-', zorder=20)
axes[0].set_xlabel('Wavelength (Angstroms)')
axes[0].set_ylabel('Simulated Flux + (white noise)')

gp.set_parameter_vector(gp_p.x)
gp.compute(wave, wave*0+sigma)
axes[1].plot(wave, model(wave)+gp.sample_conditional(wave*0, wave), '-')
axes[1].plot(wave, model(wave),'-', zorder=20)
axes[1].set_xlabel('Wavelength (Angstroms)')
axes[1].set_ylabel('Simulated Flux + (correlated noise)')


## Fitting with Correlated noise

Now for the main event:  use `emcee` to both optimize and infer the uncertainties in the paramters. Note that because we are now multiplying the residuals vector by the inverse of the covariance matrix, this takes significantly longer than the usual white-noise chi-square. In the call to `emcee.EnsembleSampler` below, you can replace the `lnprobCorr` with `lnprobWhite` to just use white noise and speed up the computation.

Also, the level and scale of the correlated noise (`alpha` and `kscale`) are kept *fixed*. One could also make them  variables and use `emcee` to infer their values along with the parameters of interest. But that means *inverting* an NxN matrix at each iteration, which can make the problem unreasonably slow. 

In [None]:
import multiprocessing as mp
import emcee
from scipy.stats import norm
Pool = mp.get_context('fork').Pool

def lnprior(p):
    pars = [par for par in model.param_names if not model.fixed[par]]
    for i,par in enumerate(pars):
        pmin,pmax = model.bounds[par]
        if pmin is not None and p[i] < pmin:  return -np.inf
        if pmax is not None and p[i] > pmax:  return -np.inf
    return 0

gp.set_parameter_vector(gp_p.x)
gp.compute(wave[~mask], flux[~mask]*0+sigma)
def lnprobCorr(p):
    lp = lnprior(p)
    if not np.isfinite(lp):
        return -np.inf
    model.set_pars(p)
    diff = (flux[~mask]/scale - model(wave[~mask]))
    return lp+gp.log_likelihood(diff, quiet=True)

def lnprobWhite(p):
    lp = lnprior(p)
    if not np.isfinite(lp):
        return -np.inf
    model.set_pars(p)
    diff = (flux[~mask]/scale - model(wave[~mask]))
    lp += np.sum(norm.logpdf(diff, loc=0, scale=wave[~mask]*0+sigma))
    return lp

p0 = opt.parameters + 1e-8*np.random.randn(200, len(opt.parameters))
with Pool() as pool:
    # Replace lnprobCorr with lnprobWhite if you want to just do the white noise errors
    sampler2 = emcee.EnsembleSampler(200, len(opt.parameters), lnprobCorr)
    sampler2.run_mcmc(p0, 1000, progress=True)


## Visualizing the results.

The next cell plots the traces of the MCMC runs. This will allow you to choose the "burn-in" time. Choose an interation where you think the chains have settled to their final distributions.

In [None]:
samples = sampler2.get_chain()
COmodel.plotTraces(samples, model)

Now, set the correct burn-in time (default is 500 iterations). Then run the rest of the cells to get the results and diagnostic plots. Note that for the best-fit values of the CO amplitude (`A`) and cotinuum level (`b`), the units will be `scale` $erg \cdot s^{-1} \cdot cm^{-2} \cdot AA^{-1}$.

In [None]:
Nburn = 500

fmts = {"T":"{:.0f}", "vel":"{:.0f}", "lco":"{:.2f}", "A":"{:.2f}", "z":"{:.5f}","b":"{:.2f}"}
meds = np.median(samples[Nburn:,:,:], axis=(0,1))
stds = np.std(samples[Nburn:,:,:], axis=(0,1))
for i,par in enumerate(model.param_names):
    fmt = "{}:  " + fmts[par] + " +/- " + fmts[par]
    print(fmt.format(par, meds[i], stds[i]))

In [None]:
fig = COmodel.plotCorner(samples[Nburn:,:,:], meds, model)

In [None]:
fig = COmodel.plotFit(wave, flux/scale, flux*0+sigma, meds, model)