# Implementing the Patch PCA Method on ThArs

In [None]:
import os
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from astropy.time import Time
from astropy.constants import c
from scipy import interpolate
from scipy.optimize import minimize, least_squares, curve_fit
from mpfit import mpfit

from tqdm import tqdm
import seaborn as sns

from waveCal import *

In [None]:
# Gather files
thar_files = glob('/mnt/home/lzhao/ceph/thar5*/ThAr_*.fits')
num_files = len(thar_files)
print(f'Number of files: {num_files}')

hdus = fits.open(thar_files[0])
t_spec = hdus[1].data['spectrum'].copy()
t_errs = hdus[1].data['uncertainty'].copy()
t_mask = hdus[1].data['pixel_mask'].copy()
hdus.close()
nord, npix = t_spec.shape

In [None]:
# ThAr line fit files
thid_files  = glob('/mnt/home/lzhao/ceph/thid5*/ThAr_*.thid')

In [None]:
# Sort files by date:
file_times = np.empty_like(thid_files,dtype='float')
for i in range(len(thid_files)):
    file_times[i] = os.path.basename(thid_files[i]).split('_')[-1][:-5]
thid_files = np.array(thid_files)[np.argsort(file_times)]

file_times = np.empty_like(thar_files,dtype='float')
for i in range(len(thar_files)):
    file_times[i] = os.path.basename(thar_files[i]).split('_')[-1][:-5]
thar_files = np.array(thar_files)[np.argsort(file_times)]

In [None]:
# Load in all observed wavelengths into a big dictionary
order_list = range(40,70)
wavedict = {}
for file_name in thid_files:
    try:
        x,m,w = readThid(file_name)
    except ValueError:
        continue
    for nord in order_list:
        if nord not in wavedict.keys():
            wavedict[nord] = np.array([])
        wavedict[nord] = np.unique(np.concatenate([wavedict[nord],w[m==nord]]))

In [None]:
# Reformat mode dictionary into a flat vector
waves  = np.array([]).astype(float)
orders = np.array([]).astype(int)
for m in wavedict.keys():
    waves = np.concatenate((waves, wavedict[m]))
    orders = np.concatenate((orders, (np.zeros_like(wavedict[m])+m)))

In [None]:
# Load in x values to match order/mode lines
x_values = np.empty((len(thid_files),len(waves)))
x_values[:] = np.nan
for i in tqdm(range(len(thid_files))):
    file_name = thid_files[i]
    try:
        x,m,w = readThid(file_name)
    except ValueError:
        continue
    for line in range(len(waves)):
        I = m==orders[line]
        if waves[line] in w[I]:
            x_values[i,line] = x[I][w[I]==waves[line]] # hogg hates this line
        else:
            x_values[i,line] = np.nan

In [None]:
# Where are we missing lines?
for m in order_list:
    ord_mask = orders==m
    x_range = waves[ord_mask]
    e_range = np.arange(len(thid_files)).astype(float)
    x_grid, e_grid = np.meshgrid(x_range,e_range)

    plt.figure()
    plt.title(f'Order {m}')
    plt.scatter(x_grid,e_grid,c=x_values[:,ord_mask],s=1)
    plt.colorbar(label='Line Center [px]')
    nan_mask = np.isnan(x_values[:,ord_mask])
    plt.scatter(x_grid[nan_mask],e_grid[nan_mask],s=.5,c='r')
    plt.xlabel('Wavelength')
    plt.ylabel('Exposure Number-ish');

In [None]:
colors = sns.color_palette('RdYlBu',len(order_list))

plt.figure()
for i,nord in enumerate(order_list):
    plt.plot(np.sum(np.isnan(x_values[:,orders==nord]),axis=1),color=colors[i])
plt.xlabel('Exposure')

plt.figure(figsize=(6.4*2,4.8))
for i,nord in enumerate(order_list):
    plt.plot(waves[orders==nord],
             np.sum(np.isnan(x_values[:,orders==nord]),axis=0),color=colors[i])
plt.xlabel('Wavelength')

In [None]:
# Get rid of bad wavelengths
x_values[x_values < 1] = np.nan
good_lines = np.mean(np.isnan(x_values),axis=0) < 0.3

# Trim everything
orders = orders[good_lines]
waves  = waves[good_lines]
x_values = x_values[:,good_lines]

In [None]:
np.sum(np.isnan(x_values))

In [None]:
# Get rid of bad exposures
good_exps = np.mean(np.isnan(x_values),axis=1) < 0.5
print(thid_files[~good_exps])

# Trim everything
x_values = x_values[good_exps]
exp_list = thid_files[good_exps]

bad_mask = np.isnan(x_values)
print(waves.shape, exp_list.shape, x_values.shape, bad_mask.shape)

In [None]:
np.sum(np.isnan(x_values))

In [None]:
colors = sns.color_palette('RdYlBu',len(order_list))

plt.figure()
for i,nord in enumerate(order_list):
    plt.plot(np.sum(np.isnan(x_values[:,orders==nord]),axis=1),color=colors[i])
plt.xlabel('Exposure')

plt.figure(figsize=(6.4*2,4.8))
for i,nord in enumerate(order_list):
    plt.plot(waves[orders==nord],
             np.sum(np.isnan(x_values[:,orders==nord]),axis=0),color=colors[i])
plt.xlabel('Wavelength')

In [None]:
# Where are we missing lines?
for m in order_list:
    ord_mask = orders==m
    x_range = waves[ord_mask]
    e_range = np.arange(len(exp_list)).astype(float)
    x_grid, e_grid = np.meshgrid(x_range,e_range)

    plt.figure()
    plt.title(f'Order {m}')
    plt.scatter(x_grid,e_grid,c=x_values[:,ord_mask],s=1)
    plt.colorbar(label='Line Center [px]')
    nan_mask = np.isnan(x_values[:,ord_mask])
    plt.scatter(x_grid[nan_mask],e_grid[nan_mask],s=.5,c='r')
    plt.xlabel('Wavelength')
    plt.ylabel('Exposure Number-ish');

In [None]:
# Patch bad data with running mean
half_size = 4
for i in range(x_values.shape[0]):
    exp_range = [max((i-half_size,0)), min((i+half_size+1,x_values.shape[1]))]
    run_med = np.nanmean(x_values[exp_range[0]:exp_range[1],:],axis=0)
    x_values[i][bad_mask[i,:]] = run_med[bad_mask[i,:]]

In [None]:
# Iterative PCA
num_iters = 50

iter_x_values = np.zeros((num_iters,*x_values.shape))
iter_vvs = np.zeros((num_iters,*x_values.shape))

for i in tqdm(range(num_iters)):
    # Redefine mean
    mean_x_values = np.mean(x_values,axis=0)
    # Run PCA
    uu,ss,vv = np.linalg.svd(x_values-mean_x_values,full_matrices=False)
    iter_vvs[i] = vv.copy()

    # Repatch bad data with K=2 PCA reconstruction
    pca_patch = np.dot((uu*ss)[:,0:2],vv[0:2])
    x_values[bad_mask] = (pca_patch+mean_x_values)[bad_mask]
    iter_x_values[i] = x_values.copy()

In [None]:
# How do the eigenvectors compare with each iteration
plt.figure(figsize=(6.4*3,4.8*2))
ax1 = plt.subplot(231)
plt.title('Eigenvector 0')
plt.ylabel('Eigenvector Value')
plt.xlabel('Fraction of Order')
ax2 = plt.subplot(232)
plt.title('Eigenvector 1')
plt.xlabel('Fraction of Order')
ax3 = plt.subplot(233)
plt.title('Eigenvector 2')
plt.xlabel('Fraction of Order')
ax4 = plt.subplot(234)
plt.title('Eigenvector 3')
plt.ylabel('Eigenvector Value')
plt.xlabel('Fraction of Order')
ax5 = plt.subplot(235)
plt.title('Eigenvector 4')
plt.xlabel('Fraction of Order')
ax6 = plt.subplot(236)
plt.title('Eigenvector 5')
plt.xlabel('Fraction of Order')
colors = sns.color_palette("RdYlBu",len(order_list))
for i in tqdm(range(num_iters)):
    for j, nord in enumerate(order_list):
        ax1.plot(np.linspace(0,1,np.sum(orders==nord)),
                 iter_vvs[i][0][orders==nord],color=colors[j])
        ax2.plot(np.linspace(0,1,np.sum(orders==nord)),
                 iter_vvs[i][1][orders==nord],color=colors[j])
        ax3.plot(np.linspace(0,1,np.sum(orders==nord)),
                 iter_vvs[i][2][orders==nord],color=colors[j])
        ax4.plot(np.linspace(0,1,np.sum(orders==nord)),
                 iter_vvs[i][3][orders==nord],color=colors[j])
        ax5.plot(np.linspace(0,1,np.sum(orders==nord)),
                 iter_vvs[i][4][orders==nord],color=colors[j])
        ax6.plot(np.linspace(0,1,np.sum(orders==nord)),
                 iter_vvs[i][5][orders==nord],color=colors[j])
plt.savefig('./Figures/191121_ThAreigenVs.png')

In [None]:
# Are the bad pixel values converging?
plt.figure()
plt.title('Convergence in Bad Pixel Values')
plt.ylabel('Normalized Pixel Value')
plt.xlabel('Iteration')
plt.plot(iter_x_values[:,bad_mask]/iter_x_values[-1,bad_mask],'.-',alpha=0.3);
plt.savefig('./Figures/191121_ThArbadPixConvergence.png')

In [None]:
# State of the SVD eigenvectors
plt.figure()
plt.title('SVD Eigenvectors')
plt.xlabel('Element Number')
plt.ylabel('Log Value')
plt.step(np.arange(16),np.log(ss[:16]))
plt.savefig('./Figures/191121_ThArssStep.png')