In [None]:
from IPython.display import display, HTML

def setup_display(width=95, fontsize=18):
    """
    Sets window width and markdown fontsize for the Jupyter notebook. Width is % of window.
    """
    display(HTML("<style>.container { width:"+str(width)+"% !important; }</style>"))
    display(HTML("<style>.rendered_html { font-size: "+str(fontsize)+"px; }</style>"))
    return None

setup_display()

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patheffects
import numpy as np
from astropy.io import fits
from astropy import units as u
from spaceKLIP import database
import glob
import webbpsf
from copy import copy, deepcopy
from jwst.coron import imageregistration
from astropy.convolution import Gaussian2DKernel
from scipy import ndimage
from specutils import Spectrum1D
import synphot
from webbpsf_ext import image_manip
import lmfit

def quick_implot(im, clim=None, clim_perc=[1.0, 99.0], cmap=None,
                 show_ticks=False, lims=None, ylims=None,
                 norm=mpl.colors.Normalize, norm_kwargs={},
                 figsize=None, panelsize=[5,5], fig_and_ax=None, extent=None,
                 show=True, tight_layout=True, alpha=1.0,
                 cbar=False, cbar_orientation='vertical',
                 cbar_kwargs={}, cbar_label=None,
                 interpolation = None, sharex=True, sharey=True,
                 save_name=None, save_kwargs={}):
    """
    Takes either a single im as "im", or a list/array and plots the images in the corresponding shape.
    e.g.
        im = [[im1,im2],
              [im3,im4],
              [im5,im6]]
              
    generates a 2 column, 3 row figure.
    clim defines the upper and lower limits of the color stretch for the plot.
    If clim is a string, it should contain a comma separating
    two entries. These entries should be one of:
    a) interpretable as a float, in which case they serve as the 
    corresponding entry in the utilized clim, b) they should contain a
    % symbol, in which case they are used as a percentile bound;
    e.g., clim='0, 99.9%' will yield an image with a color
    stretch spanning [0, np.nanpercentile(im, 99.9)], or c) they
    should contain a '*' symbol, separating either of the 
    aforementioned options, in which case they will be multiplied 
    thusly; e.g., clim='0.01*99.9%, 99.9%' would yield a plot with 
    colormapping spanning two decades (i.e., maybe appropriate for
    a logarithmic norm): 
    [0.01*np.nanpercentile(im, 99.9), np.nanpercentile(im, 99.9)].
    If clim is None, clim_perc is used to compute a clim instead. If
    clim_perc contains two values, these are the lower and upper limit
    percentiles. If only a single value, P, is given, a symmetric clim is
    generated spanning plus and minus the P-percentile of the absolute
    value of im (best used with a diverging/symmetric colormap, such as
    'coolwarm'). 
    """   
    if isinstance(clim, str):
        s_clim = [i.strip() for i in clim.split(',')]
        clim = []
        for s in s_clim:
            if s.isdigit():
                clim.append(float(s))
            elif '%' in s:
                if '*' in s:
                    svals = []
                    for si in s.split('*'):
                        if '%' in si:
                            svals.append(np.nanpercentile(im, float(si.replace('%',''))))
                        else:
                            svals.append(float(si))
                    clim.append(np.prod(svals))
                else:
                    clim.append(np.nanpercentile(im, float(s.replace('%',''))))
            else:
                raise ValueError(
                    """
                    If clim is a string, it should contain a comma separating
                    two entries. These entries should be one of:
                    a) interpretable as a float, in which case they serve as the 
                    corresponding entry in the utilized clim, b) they should contain a
                    % symbol, in which case they are used as a percentile bound;
                    e.g., clim='0, 99.9%' will yield an image with a color
                    stretch spanning [0, np.nanpercentile(im, 99.9)], or c) they
                    should contain a '*' symbol, separating either of the 
                    aforementioned options, in which case they will be multiplied.
                    """)
            
    elif isNone(clim):
        if np.isscalar(clim_perc) or len(clim_perc) == 1:
            clim = symmetric_clim_percentile(im, clim_perc)
        else:
            clim = np.nanpercentile(np.unique(im), clim_perc)
        
    if isNone(ylims):
        ylims = lims
        
    normalization = norm(vmin=clim[0], vmax=clim[1], **norm_kwargs)
    imshape = np.shape(im)
    if isNone(fig_and_ax):
        if len(imshape) == 2:
            nrows = ncols = 1
        elif len(imshape) == 3:
            nrows, ncols = 1, imshape[0]
        elif len(imshape) == 4:
            nrows, ncols = imshape[0], imshape[1]
        else:
            raise ValueError("Argument 'im' must be a 2, 3, or 4 dimensional array")
        n_ims = nrows * ncols
        if isNone(figsize):
            figsize = np.array([ncols,nrows])*np.asarray(panelsize)
        fig, ax = plt.subplots(nrows, ncols, figsize=figsize, sharex=sharex, sharey=sharey)
    else:
        if len(imshape) == 2:
            nrows = ncols = 1
        elif len(imshape) == 3:
            nrows, ncols = 1, imshape[0]
        elif len(imshape) == 4:
            nrows, ncols = imshape[0], imshape[1]
        else:
            raise ValueError("Argument 'im' must be a 2, 3, or 4 dimensional array")
        n_ims = nrows * ncols
        fig, ax = fig_and_ax
    if n_ims == 1:
        ax, im = [ax], [im]
    else:
        im = np.asarray(im).reshape((np.prod(imshape[0:-2]), imshape[-2], imshape[-1]))
        ax = np.asarray(ax).flatten()
    for ax_i, im_i in zip(ax, im):
        implot = ax_i.imshow(im_i, origin='lower', cmap=cmap, norm=normalization, extent=extent, alpha=alpha, interpolation=interpolation)
        if not show_ticks:
            ax_i.set(xticks=[], yticks=[])
        ax_i.set(xlim=lims, ylim=ylims)
    if tight_layout:
        fig.tight_layout()
    if cbar:
        cbar = fig.colorbar(implot, ax=ax, orientation=cbar_orientation, **cbar_kwargs)
        cbar.set_label(cbar_label)
    if not isNone(save_name):
        plt.savefig(save_name, bbox_inches='tight', **save_kwargs)
    if show:
        plt.show()
        return None
    if n_ims == 1:
        ax = ax[0]
    return fig, ax

def symmetric_clim_percentile(arr, clim_perc=98):
    clim0 = np.nanpercentile(np.abs(np.unique(arr)), clim_perc)
    return np.array([-1, 1])*clim0

def isNone(arg):
    """
    Just a quick convenience/shorthand function.
    "if isNone(x)" works for any x, whereas "if x == None"
    will sometimes cause a crash (e.g., if x is a numpy array).
    """
    return isinstance(arg, type(None))


# Convenience functions for generating model PSF images:


def get_webbpsf_model_for_stellocentric_offset(xy, inst, spectrum=None, return_oversample=True, fov_pixels=None, osamp=None,
                                               center_psf=True, psf_shift=None):
    """
    Returns a centered WebbPSF PSF model image for a source that is offset from the coronagraph center by (xy) in arcsec 
    (where a source above the coronagraph and to the right on the detector would have positive y and x respectively)
    """
    
    siaf_ap = inst.siaf[inst.aperturename]
    
    options = deepcopy(inst.options) # Save options so we can return the inst object as it was.
    
    inst.detector_position = (siaf_ap.XSciRef + xy[0]/siaf_ap.XSciScale, siaf_ap.YSciRef + xy[1]/siaf_ap.YSciScale)
    inst.options['coron_shift_x'] = -xy[0]
    inst.options['coron_shift_y'] = -xy[1]
    
    ind_out = 2 if return_oversample else 3
    psf_hdul = inst.calc_psf(source=spectrum, fov_pixels=fov_pixels, oversample=osamp)
    
    # Set inst back as it was.
    inst.options = options
    inst.detector_position = (siaf_ap.XSciRef, siaf_ap.YSciRef)
    
    if options.get('add_ipc', True) and return_oversample and "Applied detector interpixel capacitance (IPC) model" not in psf_hdul[ind_out].header['history']:
        # Last condition above is to future proof in case WebbPSF starts applying IPC to the osamp extension too.
        psf_hdul = webbpsf.detectors.apply_detector_ipc(psf_hdul, extname='OVERDIST')
    psf = psf_hdul[ind_out].data
    if isNone(osamp):
        osamp = psf_hdul[ind_out].header['OVERSAMP']
        
    if center_psf:
        if isNone(psf_shift): # Then we need to determine the requisite shift
            if isNone(inst.image_mask):
                psf_off = psf
            else:
                inst_off = deepcopy(inst)
                inst_off.image_mask = None
                psf_off = inst_off.calc_psf(source=spectrum, oversample=osamp, fov_pixels=fov_pixels)[ind_out].data
            psf_shift = get_webbpsf_model_center_offset(psf_off, osamp)
        psf = imageregistration.fourier_imshift(psf, psf_shift*osamp)
    return psf


def get_webbpsf_model_center_offset(psf_off, osamp):
    """
    Returns the detector-sampled shift required to geometrically center psf_off
    """
    psf_gauss = Gaussian2DKernel(x_stddev=1*osamp, y_stddev=2*osamp).array
    psf_gauss *= psf_off.max() / psf_gauss.max()
    psf_crop = pad_or_crop_image(psf_off, psf_gauss.shape, cval0=0)
    psf_reg_result = imageregistration.align_array(psf_gauss, psf_crop)
    shift = -psf_reg_result[1][:-1]/osamp
    return shift


def pad_or_crop_image(im, new_size, cent=None, new_cent=None, cval0=np.nan, nan_prop_threshold=0., zero_prop_threshold=0., prefilter=True):
    new_size = np.asarray(new_size)
    im_size = np.array(im.shape)
    ny, nx = im_size
    if isNone(cent):
        cent = (np.array([nx,ny])-1.)/2.
        
    if isNone(new_cent):
        new_cent = (np.array([new_size[1],new_size[0]])-1.)/2.
        
    if np.all([new_size == im_size, cent == new_cent]):
        return im.copy()
    
    if np.all([float(i).is_integer() for i in [*cent, *new_cent]]):
        # No need to treat nans/zeros differently if both centers are integers.
        out_im = pad_or_crop_about_pos(im, cent, new_size, new_cent=new_cent, cval=cval0, prefilter=False, order=0)
        
    else:    
        nans = np.isnan(im)
        zeros = im == 0.
        any_zeros = np.any(zeros)
        any_nans = np.any(nans)
        if any_nans:
            out_im = pad_or_crop_about_pos(np.where(nans, 0., im), cent, new_size, new_cent=new_cent, cval=cval0, prefilter=prefilter)
        else:
            out_im = pad_or_crop_about_pos(im, cent, new_size, new_cent=new_cent, cval=cval0, prefilter=prefilter)
        if any_zeros:
            out_zeros = pad_or_crop_about_pos(zeros.astype(float), cent, new_size, new_cent=new_cent, prefilter=False)
            out_im = np.where(out_zeros>zero_prop_threshold, 0., out_im)
        if any_nans:
            out_nans = pad_or_crop_about_pos(nans.astype(float), cent, new_size, new_cent=new_cent, prefilter=False)
            out_im = np.where(out_nans>nan_prop_threshold, np.nan, out_im)
    return out_im


def pad_or_crop_about_pos(im, pos, new_size, new_cent=None, cval=np.nan, order=3, mode='constant', prefilter=True):
    ny, nx = im.shape[-2:]
    ny_new, nx_new = new_size
    if isNone(new_cent):
        new_cent = (np.array([nx_new,ny_new])-1.)/2.
        
    nd = np.ndim(im)
    xg, yg = np.meshgrid(np.arange(nx_new, dtype=np.float64), np.arange(ny_new, dtype=np.float64))
    
    xg -= (new_cent[0]-pos[0])
    yg -= (new_cent[1]-pos[1])
    if nd == 2:
        im_out = ndimage.map_coordinates(im, np.array([yg, xg]), order=order, mode=mode, cval=cval, prefilter=prefilter)
    else:
        nI = np.prod(im.shape[:-2])
        im_reshaped = im.reshape((nI, ny, nx))
        im_out = np.zeros((nI, ny, nx), dtype=im.dtype)
        for i in range(nI):
            im_out[i] = ndimage.map_coordinates(im_reshaped[i], np.array([yg, xg]), order=order, mode=mode, cval=cval, prefilter=prefilter)
        im_out = im_out.reshape((*im.shape[:-2], ny, nx))
    return im_out


# Functions to use during optimization


def c_to_c_osamp(c, osamp):
    return np.asarray(c)*osamp + 0.5*(osamp-1)


def model_rescale_factor(A, B, sig=None, mask=None):
    """
    Determines the value of scalar c such that:
        chi^2 = sum [ (A-c*B)^2 / sig^2 ]
    is minimized.
    
    Parameters
    ----------
    A : numpy.ndarray
        Array of measurements
    B : numpy.ndarray
        Array of model values. Shape must match A and B
    sig : numpy.ndarray, optional
        The 1 sigma uncertainty for the measurements of A.
    mask : numpy.ndarray, optional
        A boolean mask with False for entries of A, B, and sig not to be
        utilized, and True for entries that are. Defaults to None.
    Returns
    -------
    c : float
        The scaling factor to multiply the model (B) by to achieve the minimum chi^2
        for measurements (A) having the given uncertainties (sig).
    """
    if np.shape(A) != np.shape(B):
        raise ValueError("A and B must be arrays of the same shape!")
    if not isNone(sig):
        if np.shape(A) != np.shape(sig):
            raise ValueError("A, B, and sig must be arrays of the same shape if sig is specified!")
    else:
        sig = 1
    if isNone(mask):
        c = np.nansum(A * B / (sig ** 2)) / np.nansum((B ** 2) / (sig ** 2))
    elif np.shape(mask)[-2:] != np.shape(A)[-2:]:
        raise ValueError("If provided, mask's shape must match the final axes of A, B, and sig!")
    else:
        Amsk, Bmsk = A[..., mask], B[..., mask]
        if np.ndim(sig) != 0:
            Smsk = sig[..., mask]
        else:
            Smsk = sig
        c = np.nansum(Amsk * Bmsk / (Smsk ** 2)) / np.nansum((Bmsk ** 2) / (Smsk ** 2))
    return c


def dist_to_pt(pt, nx=201, ny=201, dtype=np.float32):
    """
    Returns a distance array of size (ny,nx), 
    where each pixel corresponds to the euclidean 
    distance of that pixel from pt=(x,y).
    """
    xaxis = np.arange(0, nx, dtype=dtype)-pt[0]
    yaxis = np.arange(0, ny, dtype=dtype)-pt[1]
    return np.sqrt(xaxis**2 + yaxis[:, np.newaxis]**2)

In [None]:
# Set up a source spectrum (for the reference star) to pass to WebbPSF however you prefer (likely minimal impact either way)
# Note: only the shape matters, as we're rescaling the brightness to minimize residuals with the data anyway.

specfile = 'bt-nextgen_teff_4800_logg_4.0_feh_0.0_spec.dat'
swave, sflux = np.loadtxt(specfile).T
spec1d = Spectrum1D(spectral_axis=(swave*u.micron), flux=sflux*(u.Watt / u.m**2 / u.micron))
spectrum = synphot.spectrum.SourceSpectrum.from_spectrum1d(spec1d)
spectrum.meta['name'] = 'source'

In [None]:
# Path to directory containing the fits files to use.
wdir = '/Users/kdlawso1/jwst/hd141569a/nircam/new_data_231108/stage2/'

fitsfiles = np.sort(glob.glob(wdir+'*_calints.fits'))

# Make a spaceklip database for these files
Database = database.Database(wdir)
Database.read_jwst_s012_data(datapaths=fitsfiles,
                             psflibpaths=None,
                             bgpaths=None)

print('\nConcatenations:\n')

for key in Database.obs:
    display(Database.obs[key])

print('\nKeys:\n')
print(list(Database.obs.keys()))

In [None]:
# Set which concatenation to work with and proceed
concats = list(Database.obs.keys())

concat_key = concats[0]
db_tab = Database.obs[concat_key]

expmask = db_tab['TYPE'] == 'REF'

expfiles = db_tab['FITSFILE'][expmask]

h0, h1 = fits.getheader(expfiles[0]), fits.getheader(expfiles[0], ext=1)

imcube = []
errcube = []
for i,f in enumerate(expfiles):
    imints, errints = fits.getdata(f, ext=1), fits.getdata(f, ext=2)
    n = np.sum(np.isfinite(imints), axis=0)
    sig_med_over_sig_mean = np.sqrt(np.pi*(2*n+1)/(4*n))
    sig_mean = np.sqrt(np.nansum(errints**2, axis=0))/n
    err_im_med = sig_mean * sig_med_over_sig_mean
    im_med = np.nanmedian(imints, axis=0)
    imcube.append(im_med)
    errcube.append(err_im_med)
    
imcube = np.asarray(imcube)
errcube = np.asarray(errcube)

filt = db_tab[expmask]['FILTER'][0] 
pxscale = db_tab[expmask]['PIXSCALE'][0]*u.arcsec/u.pixel

lam = db_tab[expmask]['CWAVEL'][0]
d_eff = 5.2*u.meter
fwhm = ((np.rad2deg(lam/(d_eff.to(u.micron).value))*u.deg).to(u.arcsec)/pxscale).value # pixels

dithers = np.asarray([db_tab[expmask]['XOFFSET'], db_tab[expmask]['YOFFSET']]).T

db_tab[expmask]

In [None]:
# Initialize the WebbPSF instrument object and pre-compute the PSF center offset to save time during optimization
inst = webbpsf.setup_sim_to_match_file(expfiles[0])

inst.options['coron_shift_x'] = 0.
inst.options['coron_shift_y'] = 0.

if 'S_IPC' in h0 and h0['S_IPC'] == 'COMPLETE':
    inst.options['add_ipc'] = False

# Compute synthetic PSF center offset and plot to verify
inst_off = deepcopy(inst)
inst_off.image_mask = None
psf_off = inst_off.calc_psf(source=None, oversample=4, fov_pixels=35)[2].data
psf_shift = get_webbpsf_model_center_offset(psf_off, osamp=4)

import json
import spaceKLIP

with open(spaceKLIP.__path__[0]+'/resources/crpix_jarron.json', 'r') as file:
    crpix_jarron = json.load(file)
    
with open(spaceKLIP.__path__[0]+'/resources/filter_shifts_jarron.json', 'r') as file:
    filt_offsets = json.load(file)
    
c_coron_ta = np.array(crpix_jarron[inst.aperturename])-1 # Coronagraph position for the TA filter

c_coron = c_coron_ta + np.array(filt_offsets[filt]) # Coronagraph position adjusted for filter dependent offset

c = (np.array(psf_off.shape[::-1])-1.)/2.

fig,axes = quick_implot([psf_off, imageregistration.fourier_imshift(psf_off, psf_shift*4)], show=False, clim_perc=[1,99.995])

labels = ['Not Centered', 'Centered']
for i,ax in enumerate(axes):
    ax.scatter(*c, marker='+', c='red', s=150)
    ax.set_title(labels[i])
    
plt.show()

Next, we define our objective function:

- takes an LMFit Parameters object, as well as necessary arguments and keyword arguments


- generates the corresponding model


- returns the uncertainty weighted residuals between the model and data (or returns the model itself if return_model=True)


In [None]:
def webbpsf_model_objective(p, imcube, dithers, inst, fov_pix, spectrum, c_coron, mask, sig=None, osamp=2, return_model=False, psf_shift=None):
    if isNone(psf_shift):
        psf_shift = 0.
        
    siaf_ap = inst.siaf[inst.aperturename]
    
    inst_opt = deepcopy(inst)
    inst_opt.options['pupil_shift_x'] = p['xshear'].value
    inst_opt.options['pupil_shift_y'] = p['yshear'].value
    inst_opt.options['defocus_waves'] = p['defocus'].value
    inst_opt.options['pupil_rotation'] = p['pupil_rotation'].value
    
    xy0 = np.array([p['xsourceoffset'].value, p['ysourceoffset'].value])

    nI, ny, nx = imcube.shape
    
    modelcube = np.zeros_like(imcube)
    for i in range(nI):
        xy = xy0 + dithers[i]
        psf0 = get_webbpsf_model_for_stellocentric_offset(xy, inst_opt, spectrum=spectrum, 
                                                          return_oversample=True,
                                                          fov_pixels=fov_pix, osamp=osamp,
                                                          center_psf=False) # To minimize interpolations: center_psf=False, then apply center shift when padding below.

        xy_px = c_coron + xy/np.array([siaf_ap.XSciScale, siaf_ap.YSciScale])
    
        psf_osamp = pad_or_crop_image(psf0, new_size=np.array([ny,nx])*osamp, new_cent=c_to_c_osamp(xy_px+psf_shift, osamp), cval0=0.,
                                      nan_prop_threshold=1e-8, zero_prop_threshold=1e-8)
        
        modelcube[i] = image_manip.frebin(psf_osamp, scale=1./osamp)
    
    modelcube *= model_rescale_factor(imcube, modelcube, mask=mask, sig=sig)
    
    if return_model:
        return modelcube
    
    global counter
    counter += 1
    print('Models evaluated: {0: <16}'.format(counter), end='\r')
    
    if isNone(sig):
        sig = np.ones_like(imcube)
    
    res = ((imcube - np.nan_to_num(modelcube))/sig)[..., mask]
    
    return np.abs(res)

In [None]:
# First: fit stellar offset

# Initialize an LMFit parameter object for optimization.
p = lmfit.Parameters()

# Arguments for each param added are (respectively): 
# the parameter name, the initial value, min allowed value, max allowed value, and whether or not it will be varied.
p.add('defocus', value=0.0, min=-0.4, max=0.4, vary=True) # setting for inst.options['defocus_waves']
p.add('xshear', value=0.0, min=-0.05, max=0.05, vary=True) # setting for inst.options['pupil_shift_x']
p.add('yshear', value=0.0, min=-0.05, max=0.05, vary=True) # setting for inst.options['pupil_shift_y']
p.add('pupil_rotation', value=0., min=-5., max=5., vary=True) # setting for inst.options['pupil_rotation']

# Offset of the source center from the coronagraph center ***in arcsec*** for the center of dither pattern
# In other words: the offset of the source relative to the coronagraph center if dither == [0,0]
p.add('xsourceoffset', value=0.)
p.add('ysourceoffset', value=0.)
# Note: we assume that the dither movement is perfect. I.e., we just fit for one coronagraph offset rather than one per dither position

rmap = dist_to_pt(c_coron, imcube.shape[2], imcube.shape[1])
rmax = 40*fwhm # Max distance from the coronagraph center to consider in our goodness of fit calculation (should only matter if other sources are present in the data)
# Generate a boolean mask that is True wherever we want to consider in our goodness of fit for each iteration.
mask = (rmap < rmax) & np.all(np.isfinite(imcube), axis=0)

fov_pix = np.ceil(rmax*2+5) # Adding a few pixels beyond the mask region to be safe with interpolation.

sig = errcube

modelcube0 = webbpsf_model_objective(p, imcube, dithers, inst, fov_pix, spectrum, c_coron, mask, sig=sig, return_model=True, psf_shift=psf_shift)

xysourceoffsets = []
for i in range(len(dithers)):
    regres = imageregistration.align_array(np.nan_to_num(imcube[i]), modelcube0[[i]], mask=mask)
    xysourceoffset = -(regres[1][0, :-1]*u.pix * pxscale).value
    xysourceoffsets.append(xysourceoffset)
    
xysourceoffsets = np.array(xysourceoffsets)
xysourceoffset = np.nanmedian(xysourceoffsets, axis=0)

p.add('xsourceoffset', value=xysourceoffset[0], vary=False)
p.add('ysourceoffset', value=xysourceoffset[1], vary=False)

p

In [None]:
# Now: optimize considering only PSF wings and varying only pupil_rotation:

mask = (rmap < rmax) & (rmap > 20*fwhm) & np.all(np.isfinite(imcube), axis=0)

# 'counter' is just used to make print statements to reassure you that the code isn't hung up.
counter = 0

p['defocus'].vary = False
p['xshear'].vary = False
p['yshear'].vary = False

quick_implot(np.where(mask, imcube[0], np.nan))
# Run the optimization procedure using the Powell method. See lmfit documentation for other options
res = lmfit.minimize(webbpsf_model_objective, p, method='powell',
                     args=(imcube, dithers, inst, fov_pix, spectrum, c_coron, mask),
                     kws=dict(sig=sig, osamp=2, psf_shift=psf_shift))

res

In [None]:
# Now adopt best-fit pupil rotation and optimize defocus and shear (could also let the offsets vary to be safe):

p.add('pupil_rotation', value=res.params['pupil_rotation'].value, vary=False)

p['defocus'].vary = True
p['xshear'].vary = True
p['yshear'].vary = True

mask = (rmap < rmax) & np.all(np.isfinite(imcube), axis=0)

# 'counter' is just used to make print statements to reassure you that the code isn't hung up.
counter = 0

quick_implot(np.where(mask, imcube[0], np.nan))

# Run the optimization procedure using the Powell method. See lmfit documentation for other options
res = lmfit.minimize(webbpsf_model_objective, p, method='powell',
                     args=(imcube, dithers, inst, fov_pix, spectrum, c_coron, mask),
                     kws=dict(sig=sig, osamp=2, psf_shift=psf_shift))

# Re-generate the final best-fit model for inspection
modelcube = webbpsf_model_objective(res.params, imcube, dithers, inst, 351, spectrum, c_coron, mask, sig=sig, osamp=2, return_model=True, psf_shift=psf_shift)

res

In [None]:
# Plot the data vs the model for all 9 dithers

w = rmax

pim = np.where(mask, np.array([imcube/errcube, modelcube/errcube, (imcube-modelcube)/errcube]).transpose((1,0,2,3)), np.nan)

fig,axes = quick_implot(pim, lims=c_coron[[0]]+[-w,w], ylims=c_coron[[1]]+[-w,w],
                        cmap='coolwarm', clim_perc=99.9, show_ticks=True, show=False)

axes[0].set_title('Data')
axes[1].set_title('Model')
axes[2].set_title('Residuals')

for i,ax in enumerate(axes):
    ax.grid(c='k', alpha=0.25)
    
plt.show()

In [None]:
pnames = list(res.params.keys())[:-2]
pvals = np.array([res.params[key].value for key in pnames])

f_out = '{}JWST_{}_{}_{}_{}_{}_{}_WebbPSF_params.txt'.format(Database.output_dir, h0['INSTRUME'], h0['DETECTOR'], h0['filter'], h0['pupil'], h0['coronmsk'], h0['SUBARRAY'])
np.savetxt(f_out, pvals[np.newaxis], header=' '.join(pnames), fmt='%.5f')

print(f'Output written to:\n{f_out}')

# Now load the F360M data and fit the defocus, keeping all other parameters the same

In [None]:
# Set which concatenation to work with and proceed
concats = list(Database.obs.keys())

concat_key = concats[1]
db_tab = Database.obs[concat_key]

expmask = db_tab['TYPE'] == 'REF'

expfiles = db_tab['FITSFILE'][expmask]

h0, h1 = fits.getheader(expfiles[0]), fits.getheader(expfiles[0], ext=1)

imcube = []
errcube = []
for i,f in enumerate(expfiles):
    imints, errints = fits.getdata(f, ext=1), fits.getdata(f, ext=2)
    n = np.sum(np.isfinite(imints), axis=0)
    sig_med_over_sig_mean = np.sqrt(np.pi*(2*n+1)/(4*n))
    sig_mean = np.sqrt(np.nansum(errints**2, axis=0))/n
    err_im_med = sig_mean * sig_med_over_sig_mean
    im_med = np.nanmedian(imints, axis=0)
    imcube.append(im_med)
    errcube.append(err_im_med)
    
imcube = np.asarray(imcube)
errcube = np.asarray(errcube)

filt = db_tab[expmask]['FILTER'][0] 
pxscale = db_tab[expmask]['PIXSCALE'][0]*u.arcsec/u.pixel

lam = db_tab[expmask]['CWAVEL'][0]
d_eff = 5.2*u.meter
fwhm = ((np.rad2deg(lam/(d_eff.to(u.micron).value))*u.deg).to(u.arcsec)/pxscale).value # pixels

dithers = np.asarray([db_tab[expmask]['XOFFSET'], db_tab[expmask]['YOFFSET']]).T

db_tab[expmask]

In [None]:
# Initialize the WebbPSF instrument object and pre-compute the PSF center offset to save time during optimization
inst = webbpsf.setup_sim_to_match_file(expfiles[0])

inst.options['coron_shift_x'] = 0.
inst.options['coron_shift_y'] = 0.

if 'S_IPC' in h0 and h0['S_IPC'] == 'COMPLETE':
    inst.options['add_ipc'] = False

# Compute PSF center offset and plot for sanity
inst_off = deepcopy(inst)
inst_off.image_mask = None
psf_off = inst_off.calc_psf(source=None, oversample=4, fov_pixels=35)[2].data
psf_shift = get_webbpsf_model_center_offset(psf_off, osamp=4)

import json
import spaceKLIP

with open(spaceKLIP.__path__[0]+'/resources/crpix_jarron.json', 'r') as file:
    crpix_jarron = json.load(file)
    
with open(spaceKLIP.__path__[0]+'/resources/filter_shifts_jarron.json', 'r') as file:
    filt_offsets = json.load(file)
    
c_coron_ta = np.array(crpix_jarron[inst.aperturename])-1

c_coron = c_coron_ta + np.array(filt_offsets[filt])

c = (np.array(psf_off.shape[::-1])-1.)/2.

fig,axes = quick_implot([psf_off, imageregistration.fourier_imshift(psf_off, psf_shift*4)], show=False, clim_perc=[1,99.995])

labels = ['Not Centered', 'Centered']
for i,ax in enumerate(axes):
    ax.scatter(*c, marker='+', c='red', s=150)
    ax.set_title(labels[i])
    
plt.show()

In [None]:
# Fitting defocus only (should be much faster)
p = deepcopy(res.params)
for key in p.keys():
    p[key].vary=False

p['defocus'].value=0.
p['defocus'].vary=True

rmap = dist_to_pt(c_coron, imcube.shape[2], imcube.shape[1])
rmax = 40*fwhm # Max distance from the coronagraph center to consider in our goodness of fit calculation (should only matter if other sources are present in the data)

# Generate a boolean mask that is True wherever we want to consider in our goodness of fit for each iteration.
mask = (rmap < rmax) & np.all(np.isfinite(imcube), axis=0)

fov_pix = np.ceil(rmax*2+5) # Adding a few pixels beyond the mask region to be safe with interpolation.

sig = errcube

mask = (rmap < rmax) & np.all(np.isfinite(imcube), axis=0)

# 'counter' is just used to make print statements to reassure you that the code isn't hung up.
counter = 0

quick_implot(np.where(mask, imcube[0], np.nan))

# Run the optimization procedure using the Powell method. See lmfit documentation for other options
res2 = lmfit.minimize(webbpsf_model_objective, p, method='powell',
                     args=(imcube, dithers, inst, fov_pix, spectrum, c_coron, mask),
                     kws=dict(sig=sig, osamp=2, psf_shift=psf_shift))

modelcube = webbpsf_model_objective(res2.params, imcube, dithers, inst, 351, spectrum, c_coron, mask, sig=sig, osamp=2, return_model=True, psf_shift=psf_shift)

res2

In [None]:
w = rmax

pim = np.where(mask, np.array([imcube/errcube, modelcube/errcube, (imcube-modelcube)/errcube]).transpose((1,0,2,3)), np.nan)

fig,axes = quick_implot(pim, lims=c_coron[[0]]+[-w,w], ylims=c_coron[[1]]+[-w,w],
                        cmap='coolwarm', clim_perc=99.9, show_ticks=True, show=False)

axes[0].set_title('Data')
axes[1].set_title('Model')
axes[2].set_title('Residuals')

for i,ax in enumerate(axes):
    ax.grid(c='k', alpha=0.25)
    
plt.show()

In [None]:
pnames = list(res2.params.keys())[:-2]
pvals = np.array([res2.params[key].value for key in pnames])

f_out = '{}JWST_{}_{}_{}_{}_{}_{}_WebbPSF_params.txt'.format(Database.output_dir, h0['INSTRUME'], h0['DETECTOR'], h0['filter'], h0['pupil'], h0['coronmsk'], h0['SUBARRAY'])

np.savetxt(f_out, pvals[np.newaxis], header=' '.join(pnames), fmt='%.5f')

print(f'Output written to:\n{f_out}')