# Training and Validation Tests

In [None]:
import os
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from astropy.time import Time
from astropy.constants import c
from scipy import interpolate
import pickle
from mpfit import mpfit

from tqdm import tqdm
import seaborn as sns

from waveCal import *

## Gather Files

In [None]:
# LFC
lfc_files = glob('/mnt/home/lzhao/ceph/lfc5*/LFC_*.fits')
ckpt_files = glob('/mnt/home/lzhao/ceph/ckpt5*/LFC_19*.npy')
print(len(lfc_files))
lfc_files, lfc_times = sortFiles(lfc_files, get_mjd=True)
ckpt_files = sortFiles(ckpt_files)
num_lfc_files = len(lfc_files)

hdus = fits.open(lfc_files[0])
t_spec = hdus[1].data['spectrum'].copy()
t_errs = hdus[1].data['uncertainty'].copy()
t_mask = hdus[1].data['pixel_mask'].copy()
hdus.close()
nord, npix = t_spec.shape

lfc_orders = range(45,76)

In [None]:
# ThAr
thar_files = glob('/mnt/home/lzhao/ceph/thar5*/ThAr_*.fits')
thid_files  = glob('/mnt/home/lzhao/ceph/thid5*/ThAr_*.thid')
print(len(thar_files))
thar_files, thar_times = sortFiles(thar_files, get_mjd=True)
thar_files = thar_files[1:] # First file is from before LFCs
thar_times = thar_times[1:]
thid_files = sortFiles(thid_files) [1:]
num_thar_files = len(thar_files)

## Separate Training and Validation Sets

In [None]:
np.random.seed(0)
# Make sure validation exposures are not first or last exposure
valid_idx = np.random.choice(num_lfc_files-2, num_lfc_files//10, replace=False)+1

lfc_train = np.delete(ckpt_files,valid_idx)
lfc_times_train = np.delete(lfc_times,valid_idx)
time_sort = np.argsort(lfc_times_train)
lfc_train = lfc_train[time_sort]
lfc_times_train = lfc_times_train[time_sort]

lfc_valid = ckpt_files[valid_idx]
lfc_times_valid = lfc_times[valid_idx]
time_sort = np.argsort(lfc_times_valid)
lfc_valid = lfc_valid[time_sort]
lfc_times_valid = lfc_times_valid[time_sort]

In [None]:
np.random.seed(0)
valid_idx = np.random.choice(num_thar_files-2, num_thar_files//10, replace=False)+1

thar_train = np.delete(thid_files,valid_idx)
thar_times_train = np.delete(thar_times,valid_idx)
time_sort = np.argsort(thar_times_train)
thar_train = thar_train[time_sort]
thar_times_train = thar_times_train[time_sort]

thar_valid = thid_files[valid_idx]
thar_times_valid = thar_times[valid_idx]
time_sort = np.argsort(thar_times_valid)
thar_valid = thar_train[time_sort]
thar_times_valid = thar_times_valid[time_sort]

## Get and Save Patch Dictionaries

In [None]:
ckpt_patch_train = pickle.load(open('./191205_ckptPatch9_train.pkl','rb'))
thid_patch_train = pickle.load(open('./191205_thidPatch15_train.pkl','rb'))

## LFC Validation Test

In [None]:
denoised_xs = evalWaveSol(lfc_times_valid, ckpt_patch_train, t_intp_deg=3)
m = ckpt_patch_train['orders'].copy()
w = ckpt_patch_train['waves'].copy()

In [None]:
lfc_fits = []
lfc_shifts = np.array([],dtype=float)
ckpt_x = []
ckpt_m = []
ckpt_w = []
for file_num in tqdm(range(len(lfc_valid))):
    file_name = lfc_valid[file_num]
    try:
        newx,newm,neww,newe = readParams(file_name)
    except ValueError as err:
        continue
    
    w_fit = interp_train_and_predict(newx, newm,
                                     denoised_xs[file_num], m, w,
                                     e=newe, interp_deg=3)
    
    ckpt_x.append(newx)
    ckpt_m.append(newm)
    ckpt_w.append(neww)
    lfc_fits.append(w_fit)
    good_mask = np.isfinite(w_fit)
    lfc_shifts = np.concatenate([lfc_shifts,
                                 (w_fit[good_mask]-neww[good_mask])/neww[good_mask]*c.value])

In [None]:
rv_shift = lfc_shifts.flatten()
rv_shift = rv_shift[abs(rv_shift)<25]

plt.figure()
plt.title(f'LFC Training and Validation: All {len(lfc_times_valid)} Validation Exposures')
plt.xlabel('Predicted - Fit [m/s]')
plt.ylabel('Frequency')
plt.hist(rv_shift,50);
plt.axvline(np.mean(rv_shift),color='r',label='Mean: {:.3} m/s'.format(np.mean(rv_shift)))
plt.axvline(np.median(rv_shift),color='g',label='Median: {:.3} m/s'.format(np.median(rv_shift)))
plt.legend()
plt.tight_layout()
plt.savefig('./Figures/191205_lfcTnV.png')
print(np.std(rv_shift))

In [None]:
plt.figure(figsize=(6.4*2,4.8))
plt.xlabel('Time [mjd]')
plt.ylabel('PCA Coefficient')
for i in lfc_times_valid:
    plt.axvline(i,color='.75')
plt.axvline(i,color='.75',label='Validation Times')
plt.axvline(lfc_times_valid[38],color=sns.color_palette()[4],label='Problem Child')

plt.plot(ckpt_patch_train['times'],ckpt_patch_train['ec'][:,0],'.-',color=sns.color_palette()[0],label='EC 0')
f = interpolate.interp1d(ckpt_patch_train['times'],ckpt_patch_train['ec'][:,0],kind='cubic',
                     bounds_error=False,fill_value=np.nan)
x = np.linspace(lfc_times_valid[0],lfc_times_valid[-1],1000)
plt.plot(x,f(x),color=sns.color_palette()[2],label='EC 0 Interp')
plt.plot(ckpt_patch_train['times'],ckpt_patch_train['ec'][:,1],'.-',color=sns.color_palette()[1],label='EC 1')
f = interpolate.interp1d(ckpt_patch_train['times'],ckpt_patch_train['ec'][:,1],kind='cubic',
                     bounds_error=False,fill_value=np.nan)
plt.plot(x,f(x),color=sns.color_palette()[3],label='EC 1 Interp')

plt.legend(loc=2)
plt.tight_layout()
plt.savefig('./Figures/191205_intpBad.png')

In [None]:
plt.figure()
plt.xlabel('Predicted - Fit [m/s]')
plt.ylabel('Frequency')
colors = sns.color_palette('plasma',len(lfc_times_valid))
for i, t in enumerate(lfc_times_valid):
    resid = lfc_fits[i] - ckpt_w[i]
    rv_shift = resid/ckpt_w[i]*c.value
    plt.hist(rv_shift,np.arange(-25,26,2.5),histtype='step',color=colors[i])
    
resid = lfc_fits[38] - ckpt_w[38]
rv_shift = resid/ckpt_w[38]*c.value
plt.hist(rv_shift,np.arange(-25,26,2.5),histtype='step',color='r')
plt.xlim(-25,25)
plt.tight_layout()

In [None]:
plt.figure()
plt.title('LFC Training and Validation: Exp {}'.format(Time(t,format='mjd').isot))
plt.xlabel('Predicted - Fit [m/s]')
plt.ylabel('Frequency')
resid = lfc_fits[1] - ckpt_w[1]
rv_shift = resid/ckpt_w[1]*c.value
plt.hist(rv_shift,50);
plt.axvline(np.nanmean(rv_shift),color='r',label='Mean: {:.3} m/s'.format(np.nanmean(rv_shift)))
plt.axvline(np.nanmedian(rv_shift),color='g',label='Median: {:.3} m/s'.format(np.nanmedian(rv_shift)))
plt.legend()
plt.xlim(-25,25)
plt.tight_layout()
plt.savefig(f'./Figures/191205_lfcTnV6.png')
print(np.nanstd(rv_shift))

## ThAr Validation Test

In [None]:
denoised_xs = evalWaveSol(thar_times_valid, thid_patch_train, t_intp_deg=3)
m = thid_patch_train['orders'].copy()
w = thid_patch_train['waves'].copy()

In [None]:
thar_fits = []
thar_shifts = np.array([],dtype=float)
thid_x = []
thid_m = []
thid_w = []
for file_num in tqdm(range(len(thar_valid))):
    file_name = thar_valid[file_num]
    try:
        newx,newm,neww = readThid(file_name)
    except ValueError as err:
        continue
    
    try:
        w_fit = interp_train_and_predict(newx, newm,
                                         denoised_xs[file_num], m, w,
                                         e=None, interp_deg=3)
    except:
        continue
    
    thid_x.append(newx)
    thid_m.append(newm)
    thid_w.append(neww)
    thar_fits.append(w_fit)
    good_mask = np.isfinite(w_fit)
    thar_shifts = np.concatenate([thar_shifts,
                                 (w_fit[good_mask]-neww[good_mask])/neww[good_mask]*c.value])

In [None]:
rv_shift = thar_shifts.flatten()
innie_mask = abs(rv_shift) < 2000

plt.figure()
plt.title('ThAr Training and Validation')
plt.xlabel('Predicted - Fit [m/s]')
plt.ylabel('Frequency')
plt.hist(rv_shift[innie_mask],50);
plt.tight_layout()
plt.savefig('./Figures/191205_tharTnV.png')
print(np.std(rv_shift[innie_mask]))

In [None]:
plt.figure()
plt.title('ThAr Training and Validation: Exp {}'.format(Time(t,format='mjd').isot))
plt.xlabel('Predicted - Fit [m/s]')
plt.ylabel('Frequency')
resid = thar_fits[2] - thid_w[2]
rv_shift = resid/thid_w[2]*c.value
plt.hist(rv_shift,50);
plt.axvline(np.nanmean(rv_shift),color='r',label='Mean: {:.3} m/s'.format(np.nanmean(rv_shift)))
plt.axvline(np.nanmedian(rv_shift),color='g',label='Median: {:.3} m/s'.format(np.nanmedian(rv_shift)))
plt.legend()
plt.xlim(-25,25)
plt.tight_layout()
plt.savefig(f'./Figures/191205_tharTnV11.png')
print(np.nanstd(rv_shift))