# Test Choices in K and Interpolation

In [None]:
import os
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from astropy.time import Time
from astropy.constants import c
from scipy import interpolate
from scipy.optimize import minimize, least_squares, curve_fit
from mpfit import mpfit

from tqdm import tqdm
import seaborn as sns

from waveCal import *

In [None]:
# Gather files

# LFC
lfc_files = glob('/mnt/home/lzhao/ceph/lfc5*/LFC_*.fits')
ckpt_files = glob('/mnt/home/lzhao/ceph/ckpt5*/LFC_19*.npy')
lfc_files, lfc_times = sortFiles(lfc_files, get_mjd=True)
ckpt_files = sortFiles(ckpt_files)
num_lfc_files = len(lfc_files)

hdus = fits.open(lfc_files[0])
t_spec = hdus[1].data['spectrum'].copy()
t_errs = hdus[1].data['uncertainty'].copy()
t_mask = hdus[1].data['pixel_mask'].copy()
hdus.close()
nord, npix = t_spec.shape

lfc_orders = range(45,76)

In [None]:
# ThAr
thar_files = glob('/mnt/home/lzhao/ceph/thar5*/ThAr_*.fits')
thid_files  = glob('/mnt/home/lzhao/ceph/thid5*/ThAr_*.thid')
thar_files, thar_times = sortFiles(thar_files, get_mjd=True)
thar_files = thar_files[1:] # First file is from before LFCs
thar_times = thar_times[1:]
thid_files = sortFiles(thid_files) [1:]
num_thar_files = len(thar_files)

In [None]:
patch_dict = patchAndDenoise(ckpt_files[:100], file_times=lfc_times[:100],
    K=2, running_window=9, num_iters=25, return_iters=False,
    line_cutoff=0.5, file_cutoff=0.5, fast_pca=False, plot=False, verbose=True)

In [None]:
# Taking calibration images and return a wavelength solution -> dict of interp functions

def getWaveSoln(times,orders,lambs,denoised_xs):
    assert times.shape==orders.shape
    assert times.shape==lambs.shape
    assert times.shape==denoised_xs.shape
    
    sol_dict = {}
    for m in np.unique(orders):
        I = orders==m
        print(np.sum(I))
        sol_dict[m] = interpolate.interp2d(times[I],denoised_xs[I],lambs[I],kind='cubic',
                                           bounds_error=False,fill_value=np.nan)
    return sol_dict

def getWave(times,orders,x_values,sol_dict):
    assert times.shape==orders.shape
    assert times.shape==x_values.shape
    
    lambs = np.zeros_like(x_values)
    for m in np.unique(orders):
        I = orders==m
        lambs[I] = sol_dict[m](times[i],x_values[i])
    
    return lambs

In [None]:
def makeBIGtable(patch_dict, times, max_num=None):
    if max_num is None:
        max_num = patch_dict['denoised_x_values'].shape[0]
    denoised_xs = patch_dict['denoised_x_values'][:max_num]
    times = np.zeros_like(denoised_xs) + times[:max_num,None]
    orders = np.zeros_like(denoised_xs) + patch_dict['orders'][None,:]
    lambs = np.zeros_like(denoised_xs) + patch_dict['waves'][None,:]
    
    return times.flatten(), orders.flatten(), lambs.flatten(), denoised_xs.flatten()

In [None]:
patch_dict.keys()

In [None]:
sol_dict = getWaveSoln(*makeBIGtable(patch_dict,lfc_times,max_num=100))

test_j = 30
x,m,w = readThid(ckpt_files[test_j])
t = np.zeros_like(x) + thar_times[test_j]
w_fit = getWave(t,m,x,sol_dict)

## Interpolate in Time

In [None]:
print(patch_dict.keys())

In [None]:
thid_wfits = []
thid_x, thid_m = [], []
thid_diffs = []
thid_shift = []
for nfile in tqdm(range(len(thid_files))):
    # Do not run if thar not between LFCs
    if thar_times[nfile] < lfc_times[:100].min():
        continue
    if thar_times[nfile] > lfc_times[:100].max():
        continue
        
    file_name = thid_files[nfile]
    
    try:
        x,m,w = readThid(file_name)
    except ValueError as err:
        continue
    
    w_fit = interp_coefs_and_predict(thar_times[nfile],patch_dict,
                                     t_interp_deg=3, x_interp_deg=3,
                                     new_x=x, new_m=m)
    
    thid_wfits.append(w_fit)
    thid_x.append(x)
    thid_m.append(m)
    thid_diffs.append(w-w_fit)
    thid_shift.append((w-w_fit)/w*c.value)

In [None]:
all_thid_shift = np.concatenate(thid_shift)
good_mask = np.isfinite(all_thid_shift)
plt.hist(all_thid_shift[good_mask],50);
plt.xlabel('m/s');

In [None]:
np.median(all_thid_shift[good_mask])

In [None]:
nfile=5
x,m,w,e = readParams(ckpt_files[nfile])
w_fit = interp_coefs_and_predict(lfc_times[nfile],patch_dict,
                                 t_interp_deg=3, x_interp_deg=3,
                                 new_x=x, new_m=m)

In [None]:
good_mask = np.isfinite(w_fit)
resid = w[good_mask]-w_fit[good_mask]
rv_shift = resid/w[good_mask]*c.value
plt.hist(rv_shift,50);

In [None]:
np.median(rv_shift)

In [None]:
plt.scatter(x[good_mask],m[good_mask],c=rv_shift,vmin=-9,vmax=9,cmap='RdBu_r')
plt.colorbar(label='RV [m/s]')

In [None]:
nfile=5
w_fit2 = interp_coefs_and_predict(lfc_times[nfile],patch_dict,
                                 t_interp_deg=3, x_interp_deg=3,
                                 new_x=patch_dict['denoised_x_values'][nfile],
                                 new_m=patch_dict['orders'])
w2 = patch_dict['waves']

In [None]:
good_mask2 = np.isfinite(w_fit2)
resid2 = w2[good_mask2]-w_fit2[good_mask2]
rv_shift2 = resid2/w2[good_mask2]*c.value
plt.hist(rv_shift2,50);

In [None]:
plt.hist(resid2,50);

In [None]:
np.median(rv_shift2)