### Add the rectified images together and write out as a fits file

In [8]:
import matplotlib.pyplot as plt
import numpy as np
import glob
from astropy.io import ascii, fits
from astropy.visualization import MinMaxInterval, AsinhStretch, HistEqStretch, ImageNormalize
from scipy import interpolate
from scipy.optimize import curve_fit
import seaborn as sns
icefire = sns.color_palette("icefire", as_cmap=True)

%matplotlib widget

In [9]:
# path to the directory containing the combined rectified and cal files
path = '/Users/jpw/Analysis/NIRSPEC/iSHELL/240107/'
path = '/Volumes/JPW_2TB/iSHELL/230630/'

# we only need to access the rectified files in this notebok so hardwire this in
path = path + 'rectified/'

In [10]:
def check_alignment(source, path, fitsfiles, order=106):

    j1 = 30           # starting point of first order -- figured out by hand
    dj_AB = 121       # width of each order -- this should match the size of the number of rows in the order extension
    dj_blank = 30     # gap between orders -- this is figured out by eye and assumed to be the same for all orders
    j0 = j1 + (dj_AB + dj_blank) * (order - 99)

    fig, ax = plt.subplots(figsize=(8, 4))
    ax.set_xlabel(r"Row (pixels)", fontsize=12)
    ax.set_ylabel(r"Flux (Jy)", fontsize=12)
    plt.suptitle(source)
    for n, f in enumerate(fitsfiles):
        hdu = fits.open(path+f)
        flux = hdu[0].data
        var = hdu[1].data
        hdu.close()

        order_flux = flux[j0:j0+dj_AB, :]
        im_median = np.nanmedian(order_flux, axis=1)
        ax.plot(np.arange(im_median.size), im_median, label=f.strip('rectified').strip('.fits'))

    ax.legend()
    plt.tight_layout()
    plt.savefig(path+source+'_alignment.png')

In [11]:
def weighted_average(path, fitsfiles):

    for n, f in enumerate(fitsfiles):
        hdu = fits.open(path+f)
        flux = hdu[0].data
        var = hdu[1].data
        hdu.close()

        w = 1 / var
        if n == 0:
            flux_sum = w * flux
            weights_sum = w
        else:
            flux_sum += w * flux
            weights_sum += w

    return flux_sum/weights_sum, 1/weights_sum

In [12]:
def plot_combined_image(source, flux, var):
    fig = plt.figure(figsize=(13.5, 7))

    ax1 = plt.subplot(121)
    flux_norm = ImageNormalize(flux, vmin=-15, vmax=15)
    im1 = ax1.imshow(flux, origin='lower', norm=flux_norm, cmap=icefire)
    cbar1 = fig.colorbar(im1,  ax=ax1, pad=0.01, aspect=50)
    cbar1.set_label('DN/s', fontsize = 13, labelpad=1)
    ax1.set_title('Combined Flux')

    ax2 = plt.subplot(122)
    var_norm = ImageNormalize(var, vmin=0, vmax=2)
    im2 = ax2.imshow(var, origin='lower', norm=var_norm, cmap='magma')
    cbar2 = fig.colorbar(im2,  ax=ax2, pad=0.01, aspect=50)
    cbar2.set_label('(DN/s)$^2$', fontsize = 13, labelpad=1)
    ax2.set_title('Combined Variance')
    
    ax1.set_xlabel('Pixel', fontsize = 13, labelpad=10)
    ax1.set_ylabel('Pixel', fontsize = 13, labelpad=10)
    ax2.set_xlabel('Pixel', fontsize = 13, labelpad=10)

    plt.suptitle(source)
    plt.tight_layout()
    plt.savefig(path+source+'_combined.pdf', bbox_inches='tight', dpi=400)
    return

In [13]:
def write_combined_fits(source, path, fitsfiles, flux, var):
    # write out the combined data as a fits file

    # just use the same header as one of the rectified fits files
    # TBD: include more metadata about what is being combined
    hdu = fits.open(path+fitsfiles[0])
    hd0 = hdu[0].header
    hdu.close()

    hdu_flux = fits.PrimaryHDU(flux, header=hd0)
    hdu_var = fits.ImageHDU(var)
    hdu_list = fits.HDUList([hdu_flux, hdu_var])
    hdu_list.writeto(path+source+'_rectified.fits', overwrite=True)
    return

In [14]:
%%capture

# read in the set of rectified files to be combined source by source
# you have to manually create sourcefile.txt in the same directory as the rectified fits files

with open(path+'sourcelist.txt') as f:
    all_lines = f.read()
lines = all_lines.split('\n')

# find the breakpoints between sources
nbreak = []
for nline, line1 in enumerate(lines):
    if line1[0:4] == '----':
        nbreak.append(nline)

# read in the source and the associated list of rectified files to combine
for n1 in range(len(nbreak)-1):
    source = lines[nbreak[n1]+1]
    wavecal = lines[nbreak[n1]+2]
    fitsfiles = []
    for n2 in range(nbreak[n1]+3, nbreak[n1+1]):
        fitsfiles.append(lines[n2])

    # probably don't need to check alignment since each rectified file was produced from the same wavecal
    check_alignment(source, path, fitsfiles, order=106)

    flux, var = weighted_average(path, fitsfiles)
    plot_combined_image(source, flux, var)
    write_combined_fits(source, path, fitsfiles, flux, var)