# SOSS aperture investigation
This notebook aims to characterize the loss of signal in the SOSS PSF as a function of wavelength and extraction aperture size. To do this, I'll load the monochromatic PSFs and sum their signal with narrower and narrower boxes in the cross-dispersion direction. Then I'll plot the signal percantage as a function of aperture size in pixels.

In [241]:
from copy import copy
from mirage.psf import soss_trace as st
import numpy as np
from astropy.io import fits
from scipy.interpolate import interp1d
import os
import itertools
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.models import Range1d
import bokeh.palettes as bpal
from hotsoss import locate_trace as lt
output_notebook()

def color_gen(colormap='viridis', key=None, n=10):
    """Color generator for Bokeh plots

    Parameters
    ----------
    colormap: str, sequence
        The name of the color map

    Returns
    -------
    generator
        A generator for the color palette
    """
    if colormap in dir(bpal):
        palette = getattr(bpal, colormap)

        if isinstance(palette, dict):
            if key is None:
                key = list(palette.keys())[0]
            palette = palette[key]

        elif callable(palette):
            palette = palette(n)

        else:
            raise TypeError("pallette must be a bokeh palette name or a sequence of color hex values.")

    elif isinstance(colormap, (list, tuple)):
        palette = colormap

    else:
        raise TypeError("pallette must be a bokeh palette name or a sequence of color hex values.")

    yield from itertools.cycle(palette)

In [236]:
# Get the file
PSF_DIR = os.path.join(os.environ['MIRAGE_DATA'], 'niriss/soss_psfs/')
file = os.path.join(PSF_DIR, 'SOSS_CLEAR_PSF.fits')

# Load the SOSS psf cube
cube = fits.getdata(file).swapaxes(-1, -2)
wave = fits.getdata(file, ext=1)

# Initilize interpolator
psfs = interp1d(wave, cube, axis=0, kind=3)

In [259]:
def wave_psf(wavelength, aper=10, cutoff=0.005, plot=False):
    """
    Get trimmed PSF for a given aperture size
    """
    psf = st.get_SOSS_psf(wavelength, filt='CLEAR', psfs=psfs, cutoff=cutoff, plot=False)
    trimmed = copy(psf)
    trimmed[min(256, 38+aper):, :] = 0
    trimmed[:max(0, 38-aper), :] = 0
    trimmed[trimmed < cutoff] = 0
    
    if plot:
        fig = figure()
        fig.image([trimmed], x=0, y=0, dw=psf.shape[0], dh=psf.shape[1])
        fig.line([0, 76], [min(256, 38+aper)]*2, color='red')
        fig.line([0, 76], [max(0, 38-aper)]*2, color='red')
        show(fig)
        
    return psf, trimmed

fl, tr = wave_psf(1, 12, plot=True)

In [260]:
def aper_sum(order=1, aper=10, nwaves=100, cutoff=0.005, plot=False):
    """
    Get signal loss as a function of wavelength for a given aperture size
    """
    # Get wavelengths
    wave = lt.trace_wavelengths(order)
    wmin = np.nanmin(wave)
    wmax = np.nanmax(wave)
    
    # Make a cube of nwaves
    waves = np.linspace(wmin, wmax, nwaves)
    result = np.empty((nwaves, 76, 76))
    fulls = np.empty((nwaves, 76, 76))
    
    for n, w in enumerate(waves):
        full, trimmed = wave_psf(w, aper, cutoff=cutoff, plot=False)
        result[n] = trimmed
        fulls[n] = full
    
    # Sum signals
    total = np.sum(fulls, axis=(1, 2))
    signals = np.sum(result, axis=(1, 2))/total
    
    if plot:
        fig = figure()
        fig.circle(waves, signals)
        show(fig)
        
    return waves, signals

waves, signals = aper_sum(aper=15, nwaves=100, plot=True)

In [264]:
def aperture_comparison(apertures=[12, 14, 16, 18, 20, 22, 24, 26], nwaves=100, cutoff=0.005, order=1):
    """
    Make a plot of the signal as a function of wavelength for a range of aperture sizes
    """
    colors = color_gen()
    fig = figure()
    for aper in apertures:
        waves, signals = aper_sum(aper=aper, nwaves=nwaves, cutoff=cutoff, plot=False)
        fig.circle(waves, signals, color=next(colors), legend_label='{}px'.format(aper))
        
    fig.line(waves, np.ones_like(waves), color='red')
    fig.legend.location = "bottom_center"
    show(fig)
    
aperture_comparison(cutoff=0.0001)