# EXPRES Wavelength Solution Characterization

At present (Oct. 31, 2019), the EXPRES wavelength solution has been carried out by scanning each order, finding peaks, and fitting each of these peaks to a Guassian.  If these peaks and subsequent fits succeed in a series of checks, they are deemed LFC lines and loaded into a list along with the supposed wavelength of the LFC according to the LFC equation and an initial guess using the ThAr wavelength solution.

A 2D polynomial is then fit to pixel of line center, order, and wavelength (with wavelength being the dependent variable).  This polynomial is evaluated at all pixels to produce a wavelength solution.

This notebook sets up that frame work and then explores alternative ways of constructing a wavelength solution.

In [None]:
import os
from glob import glob
import numpy as np
from numpy.polynomial.polynomial import polyvander2d, polyval2d
import matplotlib.pyplot as plt
from astropy.io import fits
from astropy.constants import c
from scipy.optimize import curve_fit, least_squares
from scipy.signal import argrelmin
from scipy.interpolate import UnivariateSpline, interp1d
from sklearn.decomposition import TruncatedSVD

In [None]:
# Load example LFC and checkpoint information
hdus = fits.open('./LFCs/LFC_190923.1071.fits')
spec = hdus[1].data['spectrum'].copy()
head = hdus[0].header.copy()
hdus.close()

# Checkpoint files include information about the fits to each line,
# their wavelength, and a few quality values
info = np.load('./Checkpoints/LFC_190923.1071.npy',allow_pickle=True)[()]

In [None]:
# An example plot of how an LFC looks and the ofund line centers
plt.figure(figsize=(6.4*2,4.8))
plt.title('Example Spectrum')
plt.xlabel('Pixel')
plt.ylabel('Extracted Value')
plt.plot(spec[60])
for i in info['params'][60][:,1]:
    plt.axvline(i,color='.75')
plt.xlim(3000,3250)

In [None]:
def readParams(file_name):
    """
    Given the file name of a check_point file,
    load in all relevant data into 1D vectors
    
    Returns vectors for line center in pixel (x),
    order (y), error in line center fit in pixels (e),
    and wavelength of line (w)
    """
    info = np.load(file_name,allow_pickle=True)[()]
    # Assemble information into "fit-able" form
    lines = [p[:,1] for p in info['params'] if p is not None]
    errs = [np.sqrt(cov[:,1,1]) for cov in info['cov'] if cov is not None]
    ordrs = [o for o in np.arange(len(spec)) if info['params'][o] is not None]
    waves = [w for w in info['wvln'] if w is not None]
    # I believe, but am not sure, that the wavelengths are multiplied by order
    # to separate them from when orders overlap at the edges
    waves = [wvln for order, wvln in zip(ordrs,waves)]
    ordrs = [np.ones_like(x) * m for m,x in zip(ordrs, lines)]

    x = np.concatenate(lines)
    y = np.concatenate(ordrs)
    e = np.concatenate(errs)
    w = np.concatenate(waves)
    # Note: default of pipeline includes ThAr lines, which we're not including here
    
    return (x,y,e,w)

In [None]:
x,y,e,w = readParams('./Checkpoints/LFC_190923.1071.npy')

In [None]:
plt.figure()
plt.title('Fit Parameters')
plt.xlabel('Pixel')
plt.ylabel('Order')
plt.scatter(x,y,c=w/y)
plt.colorbar(label='Wavelength [nm]');

In [None]:
plt.figure(figsize=(6.4*3,4.8))
plt.title('Error in Line Centers')
plt.xlabel('Pixel')
plt.ylabel('Order')
plt.scatter(x,y,c=e,marker='|',cmap='Spectral_r')
plt.colorbar(label='Error in Pixel');

## Polynomial Fitting

In [None]:
# Polynomial fit in the pipeline
def poly_fit_2d(x, y, data, deg=9, w=None):
    """
    Calculate the 2D polynomial fit coefficients assuming that the
    1D solution in x is approximately the correct answer.
    Parameters
    ----------
    x : ndarray
        The x positions
    y : ndarray
        The y postiions
    data : ndarray
        The data at each (x, y)
    deg : int or tuple
        The polynomial degree to fit. If a tuple: (deg_x, deg_y)
    w : ndarray
        A weight for each data point
    """
    if len(x) < 1:
        return None

    if w is None:
        w = np.ones_like(data)

    w = np.where(np.isnan(data) | np.isnan(w), 0, w)

    if isinstance(deg, int):
        deg = (deg, deg)

    deg_x, deg_y = deg

    def resid(coeffs):
        """The residual cost function for least_squares"""
        # Reshape the coefficient array into a matrix usable by polyval2d
        coeff_arr = coeffs.reshape(deg_x+1, -1)
        return (data - polyval2d(x, y, coeff_arr)) * w

    # Intialize the coefficients with the 1D polynomial fit
    coeffs = np.polyfit(x, data, deg=deg_x, w=w)[::-1, np.newaxis]

    # Gradually add higher order y parameters until the full 2D polynomial is fit
    for width in range(2, deg_y+2):
        guess = np.zeros((deg_x+1, width))
        guess[:, :-1] = coeffs
        result = least_squares(resid, guess.flatten(), method='lm')
        coeffs = result.x.reshape(deg_x+1, -1)

    return coeffs

In [None]:
# Constructing a design matrix for polynomial fitting instead
def mkBlob(x, m, deg):
    """
    x: pixel
    m: order
    deg: degree of polynomial
    """
    # shift the data to center around the mean and have lower values
    xshift = np.mean(x)
    mshift = np.mean(m)
    xt = (x - xshift)
    mt = (m - mshift)
    scales = []
    for i in range(deg+1):
        for j in range(deg+1-i):
            vec = xt ** i * mt ** j
            # Scale the data so they cover about the same range of values
            scales.append(np.sqrt(vec.dot(vec)))
    # Values of shift and scale must be catalogue
    # in order to keep the fitted coefficients interpretable
    return (deg, xshift, mshift, scales)
            
def mkDesignMatrix(x, m, blob):
    """
    blob: output of mkBlob()
    BUG: DUPLICATED CODE WITH mkBlob()
    """
    deg, xshift, mshift, scales = blob
    xt = (x - xshift)
    mt = (m - mshift)
    matrix = []
    k = 0
    for i in range(deg+1):
        for j in range(deg+1-i):
            vec = xt ** i * mt ** j
            matrix.append(vec / scales[k])
            k += 1
    return np.array(matrix).T

In [None]:
def fit(data, M, weights):
    """
    return coefficients of the linear fit!
    """
    MTM = M.T.dot(weights[:,None] * M)
    print("fit(): condition number: {:.2e}".format(np.linalg.cond(MTM)))
    MTy = M.T.dot(weights * data)
    return np.linalg.solve(MTM, MTy)

def predict(newx, newm, blob, coeffs):
    """
    use coefficients to predict new wavelengths
    """
    Mnew = mkDesignMatrix(newx, newm, blob)
    return Mnew.dot(coeffs)

### Results with `poly_fit_2d`

In [None]:
# Fit
coeffs8 = poly_fit_2d(x,y,w,deg=8,w=1/e)

In [None]:
# Residual Plot
plt.figure()
plt.title('Residual Plot (8th Deg Fit)')
plt.xlabel('Pixel')
plt.ylabel('Order')
poly = polyval2d(x,y,coeffs8)/y
plt.scatter(x,y,c=((poly-w/y)/poly*c.value),vmin=-30,vmax=30)
plt.colorbar(label='Residual of Fit [m/s]')
plt.tight_layout()
plt.savefig('./Figures/191031_deg8.png')

### Testing the Design Matrix

In [None]:
# Fit
blob = mkBlob(x, y, 8)
M = mkDesignMatrix(x, y, blob)
coeffs = fit(w, M, 1. / e ** 2)
w_poly = predict(x, y, blob, coeffs)

In [None]:
# Residual Plot
resid = w - w_poly
chi = resid / e
plt.scatter(x,y,c=resid/w_poly*c.value,vmin=-30,vmax=30,cmap='RdBu_r')
plt.title("median residual: {:.2e} m/s".format(np.median(np.abs(resid)/w_poly*c.value)))
plt.colorbar()

## Polynomial Fitting vs. Interpolation
We test both fitting the line centers, orders, and wavelengths to a polynomial and using the line centers to interpolate a wavelength solution across the rest of the CCD.

We start with `[poly/interp]_train_and_predict` functions that in take some training data (x, m, data, ...) that will be used to construct a model.  We then use this to make predictions for new x and m values.  This allows us to compare how the prediction does compared to the actual data of the new x and m values.

In [None]:
def poly_train_and_predict(newx, newm, x, m, data, weights, deg):
    blob = mkBlob(x, m, deg)
    M = mkDesignMatrix(x, m, blob)
    coeffs = fit(data, M, weights)
    return predict(newx, newm, blob, coeffs)

def interp_train_and_predict(newx, newm, x, m, data, orders=range(86)):
    prediction = np.zeros_like(newx)
    for r in orders:
        Inew = newm == r
        if np.sum(Inew):
            I = m == r
            prediction[Inew] = np.interp(newx[Inew], x[I], data[I],
                                         left=np.nan,right=np.nan)
    return prediction

In [None]:
# Interp even lines from odd lines and vice versa
even = np.arange(len(x)) % 2
IA = even.astype(bool)
IB = (1 - even).astype(bool)

w_interp = np.zeros_like(w)
w_interp[IA] = interp_train_and_predict(x[IA], y[IA], x[IB], y[IB], w[IB])
w_interp[IB] = interp_train_and_predict(x[IB], y[IB], x[IA], y[IA], w[IA])

In [None]:
resid = w - w_stupid
chi = resid / e
plt.figure()
plt.title("median residual: {:.2e} m/s".format(np.nanmedian(np.abs(resid)/w*c.value)))
plt.xlabel('Pixel')
plt.ylabel('Order')
plt.scatter(x,y,c=resid/w*c.value,vmin=-30,vmax=30,cmap='RdBu_r')
plt.colorbar(label='Residual [m/s]')

The structure that was so concerning in the polynomial residual plots has disappeared!  We will proceed using the interpolation method for finding new wavelength solutions.

## LFC Changes with Exposure
We want to characterize how much the LFC changes from exposure to exposure with the ultimate goal of being able to predict how one LFC will look by using another (though more realistically we mean some "fiducial" LFC).  To do this, we first characterize how well one LFC exposure can straight up predict another one.  The hope is this will lead to a low-dimensional variation that can be fit using PCA.

For our first experiment, we try and predict exposures that are separated by:
1. a night
1. a month
1. a significant shift in the instrument

In [None]:
# Here, we select four exposures separated by Different times
x1,m1,e1,w1 = readParams('./Checkpoints/LFC_190923.1071.npy')
x2,m2,e2,w2 = readParams('./Checkpoints/LFC_190923.1151.npy')
x3,m3,e3,w3 = readParams('./Checkpoints/LFC_190905.1062.npy')
x4,m4,e4,w4 = readParams('./Checkpoints/LFC_191031.1062.npy')

In [None]:
# Prediction over a night
w_interp2 = stupid_train_and_predict(x2,m2,x1,m1,w1)

resid2 = w2 - w_interp2
chi = resid2 / e2
plt.figure()
plt.scatter(x2,m2,c=resid2/w2*c.value,vmin=-50,vmax=50,cmap='RdBu_r')
plt.title('BON -> EON')
plt.xlabel('Pixel')
plt.ylabel('Order')
plt.colorbar(label='Residuals [m/s]')
plt.tight_layout()
plt.savefig('./Figures/1911101_bon_eon.png')

w_poly = poly_train_and_predict(x2,m2,x1,m1,w1,1/e1**2,8)

resid2p = w2 - w_poly
chi = resid2p / e2
plt.figure()
plt.scatter(x2,m2,c=resid2p/w2*c.value,vmin=-50,vmax=50,cmap='RdBu_r')
plt.title('BON -> EON: Poly')
plt.xlabel('Pixel')
plt.ylabel('Order')
plt.colorbar(label='Residuals [m/s]')
plt.tight_layout()
plt.savefig('./Figures/1911101_bon_eon_poly.png')

In [None]:
# Prediction over a month
w_interp3 = stupid_train_and_predict(x3,m3,x1,m1,w1)

resid3 = w3 - w_interp3
chi = resid3 / e3
plt.figure()
plt.scatter(x3,m3,c=resid3/w3*c.value/500,vmin=0,vmax=6,cmap='Reds')
plt.title('Sept. 23 -> Sept. 05')
plt.xlabel('Pixel')
plt.ylabel('Order')
plt.colorbar(label='Residuals [pixels]')
plt.tight_layout()
plt.savefig('./Figures/1911101_190923_190905.png')

w_poly3 = poly_train_and_predict(x3,m3,x1,m1,w1,1/e1**3,8)

resid3p = w3 - w_poly3
chi = resid3p / e3
plt.figure()
plt.scatter(x3,m3,c=resid3p/w3*c.value/500,vmin=0,vmax=6,cmap='Reds')
plt.title('Sept. 23 -> Sept. 05')
plt.xlabel('Pixel')
plt.ylabel('Order')
plt.colorbar(label='Residuals [pixels]')
plt.tight_layout()
plt.savefig('./Figures/1911101_190923_190905_poly.png')

In [None]:
# prediction over a significant shift in the instrument
w_interp4 = stupid_train_and_predict(x4,m4,x1,m1,w1)

resid4 = w4 - w_interp4
chi = resid4 / e4
plt.figure()
plt.scatter(x4,m4,c=resid4/w4*c.value/500,vmin=-6,vmax=0,cmap='Blues_r')
plt.title('Sept. 23 -> Oct. 31')
plt.xlabel('Pixel')
plt.ylabel('Order')
plt.colorbar(label='Residuals [pixels]')
plt.tight_layout()
plt.savefig('./Figures/1911101_190923_191031.png')

w_poly4 = stupid_train_and_predict(x4,m4,x1,m1,w1)

resid4p = w4 - w_poly4
chi = resid4p / e4
plt.figure()
plt.scatter(x4,m4,c=resid4p/w4*c.value/500,vmin=-6,vmax=0,cmap='Blues_r')
plt.title('Sept. 23 -> Oct. 31')
plt.xlabel('Pixel')
plt.ylabel('Order')
plt.colorbar(label='Residuals [pixels]')
plt.tight_layout()
plt.savefig('./Figures/1911101_190923_191031_poly.png')

## Implement PCA on Background Variations

In [None]:
# Define a window with reasonable lines
# (So as to not overwhelm the PCA with noise)
plt.plot(x4,m4,'.')
ymin,ymax = 40, 75
xmin,xmax = 500,7000
plt.plot([xmin,xmax,xmax,xmin,xmin],[ymin,ymin,ymax,ymax,ymin],'r-')

In [None]:
# Set up area over which t o recover the interpolated wavelength solution
x_range=np.arange(xmin,xmax).astype(float)
y_range=np.arange(ymin,ymax).astype(float)
x_grid, y_grid = np.meshgrid(x_range,y_range)
x_grid = x_grid.flatten()
y_grid = y_grid.flatten()

In [None]:
# Get wavelength solution for all LFC files
cpt_files = glob('./Checkpoints/LFC_190923*.npy')
w_fit_array = np.empty((len(cpt_files),len(x_grid)))
for i, file_name in enumerate(cpt_files):
    x,m,e,w = readParams(file_name)
    w_fit_array[i] = interp_train_and_predict(x_grid,y_grid,x,m,w)

In [None]:
w_fit_array.shape

In [None]:
# Check for rows with too many bad points
good = np.isfinite(w_fit_array)
bad = np.logical_not(good)
okay = np.sum(good, axis=0) > 3
w_fit_array = w_fit_array[:,okay]
x_grid = x_grid[okay]
y_grid = y_grid[okay]
good = np.isfinite(w_fit_array)
bad = np.logical_not(good)
print(f"We're Not Okay: {np.sum(~okay)}")

In [None]:
print(good.shape, bad.shape, x_grid.shape, w_fit_array.shape)

In [None]:
# Get mean wavelength
mean_w_fit = np.nanmean(w_fit_array,axis=0)
# Replace bad points with value from mean wavelength
# THIS IS TERRIBLE
# (BUT RIGHT NOW IT'S OKAY BECAUSE WE'RE GETTING RID OF ALL BAD POINTS)
foo = np.zeros_like(w_fit_array) + mean_w_fit[None, :]
w_fit_array[bad] = foo[bad] 

In [None]:
svd = TruncatedSVD(n_components=5,n_iter=7,random_state=42)
svd.fit(w_fit_array - mean_w_fit[None, :])
vv = svd.components_

In [None]:
plt.scatter(x_grid, y_grid, c=mean_w_fit)

In [None]:
# Plot principle components
k = 0
plt.figure()
plt.scatter(x_grid, y_grid, c=vv[k])
plt.title("eigenvector {:d}".format(k))
plt.colorbar()

In [None]:
k = 1
plt.figure()
plt.scatter(x_grid, y_grid, c=vv[k])
plt.title("eigenvector {:d}".format(k))
plt.colorbar()

k = 2
plt.figure()
plt.scatter(x_grid, y_grid, c=vv[k])
plt.title("eigenvector {:d}".format(k))
plt.colorbar()