# 2D PSF Fitting
This is a notebook to determine the feasibility of an optimal 1D spectral extraction routine that fits monochromatic 2D PSFs to the traces in a SOSS frame.

In [1]:
from copy import copy
from bokeh.plotting import show, figure
from bokeh.io import output_notebook
from hotsoss import plotting as plt
import numpy as np
# from specialsoss import SossExposure, SimExposure, CV3Exposure, sosstrace
from specialsoss import fitpsf
import webbpsf
output_notebook()

In [2]:
# Make model
ns = webbpsf.NIRISS()
ns.filter = 'CLEAR'
ns.pupil_mask = 'GR700XD'
psf_input = ns.calc_psf(oversample=4)[0].data.T

In [3]:
def psf2D(amplitude=1, psf=psf_input, col_cutoff=0.025, row_cutoff=0.0001, snr=10, fill=0, flat=False, plot=False):
    """
    A function that scales the linearized 2D PSF to the given amplitude
    """
    if psf is None:
        
        # Make model
        ns = webbpsf.NIRISS()
        ns.filter = 'CLEAR'
        ns.pupil_mask = 'GR700XD'
        psf = ns.calc_psf(oversample=1)[0].data.T
        
    # Plot 1D PSF in each column
    if plot:
        fig1d = figure()
    psf_mask = copy(psf.T)
    dx, dy = psf_mask.shape
    signal_cols = []
    for n, col in enumerate(psf.T):
        total = np.nansum(col)
        if total > col_cutoff:
            signal, = np.where(col > row_cutoff)
            start, end = signal[0], signal[-1]
            psf_mask[n, :start] = fill
            psf_mask[n, end:] = fill
            signal_cols.append(n)
            if plot:
                fig1d.line(np.arange(dy), psf_mask[n])

        else:
            psf_mask[n] = fill
    
    psf_mask *= np.random.normal(loc=psf_mask, scale=psf_mask / snr)

    # Plot 2D psf
    if plot:
        fig2d = figure()
        fig2d.image(image=[psf_mask.T], dw=psf_mask.shape[0], dh=psf_mask.shape[1], x=0, y=0)
        show(fig1d)
        show(fig2d)

    # Get 2D PSF width
    signal_width = max(signal_cols) - min(signal_cols)
        
    psf_flat = psf_mask.flatten()
    
    return psf_flat if flat else psf_mask.T * amplitude
    
psf_temp = psf2D(plot=True)

In [4]:
# Make 10 overlapping PSFs to see how complex it is
amplitudes = np.array([10, 9, 8, 7, 9, 8, 7, 6, 5, 4]) ** 3
N = len(amplitudes)
tx, ty = psf_temp.shape
trace = np.zeros((ty, tx + (N * 2)))
for n, amp in enumerate(amplitudes):
    trace[:, n:n + tx] += psf2D(amplitude=amp)
    
tracefig = figure()
tracefig.image(image=[trace], x=0, y=0, dw=trace.shape[0], dh=ty)
show(tracefig)

In [5]:
# See how complicated 7 overlapping flat PSFs are
flat_psfs = [psf2D(amp, flat=True) for amp in amplitudes]
flat_total = np.sum(flat_psfs, axis=0)

flatfig = figure()
flatfig.line(np.arange(len(p)), flat_total, color='black', width=2)
for p in flat_psfs:
    flatfig.line(np.arange(len(p)), p, alpha=0.2)
show(flatfig)

NameError: name 'p' is not defined