## Scattered-light subtraction for KCWI data

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import astropy.io.fits as fits
from astropy.modeling import models, fitting
import shutil

In [2]:
# Scattered-Light Subtraction with Moffat model or Ricker wavelet
# ----------------------------------------------------------------------------------------
# This function performs scattered-light subtraction for KCWI data and can be used 
# in stead of Stage 2 of the KCWI data reduction pipeline. It uses a 1D 2nd order
# polynomial and Moffat 1D / Ricker Wavelet 1D to fit the scattered-light.
# The returned plots can help recognize whether a good fit is obtained, and serve as
# a reference for twekaing the parameters.
# ----------------------------------------------------------------------------------------
# Call:
#   scat_sub(int_file,gap_file,model,*dy=257,*amp=1,*cent=1000,*width=150,*level=5)
# ----------------------------------------------------------------------------------------
# Inputs:
#   Required:
#      int_file:   (str) path to file, e.g. '/data/kb190101_00011_int.fits'
#      gap_file:   (str) path to gap file, e.g. '/data/kb190101_00011_gaps.dat'
#         model:   (str) model to fit, 'moffat' or 'ricker'
#   Optional:
#            dy:   (int) number of rows for each slice, the default value is 257 (8 slices)
#           amp: (float) amplitude for Moffat/Ricker model, default amp = 1
#          cent: (float) center of the model, x position of the maximum, default cent = 1000
#         width: (float) width of the model/wavelet, default width = 150
#         level: (float) C0 in 2nd order polynomial, default level = 5
# ----------------------------------------------------------------------------------------
# Outputs:
#          plots:  original data points and fitted model
#   reduced_data:  saved in the same directory of the input data as *intd.fits
# ----------------------------------------------------------------------------------------
# Notes:
#  1. This code does not use data points from the central gap.
#  2. The bounds for the gaps can be taken (preferrably) loosely.
#  3. Moffat model is preferred for data with short exposure (e.g. standard stars).
#     Ricker model works better for those with longer exposure.
#     However, you may still want to decide which to use after examing the data.
# ----------------------------------------------------------------------------------------
# Reference:
#  1. Rupke D., Scattered-light subtraction code
#      https://github.com/drupke/ifsred/blob/master/kcwi/ifsr_kcwiscatsub.pro
#  2. KCWI Data Reduction Pipeline, Stage 2
#      https://github.com/Keck-DataReductionPipelines/KcwiDRP
#  3. Bernstein B., Fernandez-Granda C., Deconvolution of Point Sources: A Sampling
#     Theorem and Robustness Guarantees, Comm. Pure Appl. Math., vol 72, May 2018,
#     doi:10.1002/cpa.21805
#      https://arxiv.org/pdf/1707.00808.pdf
#  3. Grundahl F., Sørensen A.N., Detection of scattered light in telescopes, 
#     Astron. Astrophys. Suppl. Ser. 116, 367-371 (1996), doi:10.1051/aas:1996119
#      https://aas.aanda.org/articles/aas/pdf/1996/05/ds1089.pdf
#  4. Morrissey P., Matuszewski M., Martin D.C., et al., THE KECK COSMIC WEB
#     IMAGER INTEGRAL FIELD SPECTROGRAPH, doi:10.3847/1538-4357/aad597
#      https://arxiv.org/pdf/1807.10356.pdf
# ----------------------------------------------------------------------------------------
# Lastest Update: Aug 18, 2020  Nancy(Wenmeng) Ning
# ----------------------------------------------------------------------------------------


def slice_fit(data,gap,model,yi,yf,amp,cent,width,level):
    x = []; y = []; yi = np.int(yi); yf = np.int(yf);
    for i in gap:
        if i[1]-i[0]==0:
            x.append(i[0])
            y.append(np.sum(data[yi:yf,int(i[0])]))
        else:
            x.append(int((i[0]+i[1])/2))
            y.append(np.min(np.sum(data[yi:yf,int(i[0]):int(i[1])],axis=0)))
    if model == 'moffat':
        m_init = models.Moffat1D(amplitude=amp, x_0=cent, gamma=width,alpha=2) + models.Polynomial1D(degree=2,c0=level)
    if model == 'ricker':
        m_init = models.RickerWavelet1D(amplitude=amp, x_0=cent, sigma=width) + models.Polynomial1D(degree=2,c0=level)
    fit_m = fitting.LevMarLSQFitter()
    m = fit_m(m_init, x, np.divide(y,yf-yi))
    return x,y,m

def fit_plot(nslice,dy,result,name):
    fig, axes = plt.subplots(nrows=int(np.ceil(nslice/2)),ncols=2,figsize=(16,5*int(np.ceil(nslice/2))))
    fig.suptitle(name,fontsize=14,y=0.9)
    axs = axes.flat
    for ax in axs[nslice:]:
        ax.remove()
    axes = axs[:nslice]
    x = np.arange(0,2048)
    for i in np.arange(nslice,dtype=int):
        axes[i].scatter(result.x[i],np.divide(result.y[i],dy))
        axes[i].plot(x,result.model[i](x))
    fig.show()

def scat_sub(int_file,gap_file,model,dy=257,amp=1,cent=1000,width=150,level=5,**kargs):
    file_r = int_file.replace("int","intd")
    name = int_file.split("_int")[0]+'--'+model
    shutil.copy2(int_file,file_r)
    
    #Read the gap file
    gaps = open(gap_file, "r")
    gaps_data = gaps.readlines()
    gap = np.zeros((np.shape(gaps_data)[0],2))
    for i in np.arange(np.shape(gaps_data)[0]):
        gap[i,0] = int(gaps_data[i].split()[0])
        gap[i,1] = int(gaps_data[i].split()[1])
    
    with fits.open(file_r, mode='update') as hdu:
        data = hdu[0].data
        
        # Generate Slices
        yn,xn = np.shape(data)
        nslice = int(np.ceil(yn/dy))
        row = np.empty((nslice,2),dtype=int)
        for i in np.arange(nslice-1):
            row[i] = [i*dy,(i+1)*dy]
        row[nslice-1] = [(nslice-1)*dy,yn]
        
        # Fit for each slice and Plot the result
        result = np.recarray((nslice,),dtype=[('x','O'),('y','O'),('model','O')])
        for i in np.arange(nslice):
            result.x[i], result.y[i], result.model[i] = slice_fit(data,gap,model,row[i,0],row[i,1],amp,cent,width,level)
        fit_plot(nslice,dy,result,name)
        
        # Subtract the scattered light
        for i in np.arange(nslice,dtype=int):
            scat = result.model[i](np.arange(xn,dtype=int))
            data[row[i,0]:row[i,1],:] -= scat
        
    hdu.close()