# Jupyter Notebook for Illumination Flat Field
Authors: Tony Sohn, Jo Taylor

This notebook performs analysis for NIS-011a and creates a delta flat file. 

The input is either uncal files, which are then calibrated, or cal files which are directly used for analysis. Sources are identified in the cal files, then photometry is performed on each source in each dither pointing. Sources between each pointing are matched and the differences in their flux measurements are used to derive the delta flat. 2D fits are performed on flux ratios and the evaluated 2D fit is output as the delta flat file.

# Imports 

In [None]:
# Import all needed packages
import os
import copy
import glob
import collections
import numpy as np

import jwst
from jwst import datamodels
from jwst.pipeline import Detector1Pipeline
from jwst.pipeline import Image2Pipeline
from jwst.pipeline import Image3Pipeline
import asdf

import astropy
from astropy import units as u
from astropy.nddata import Cutout2D, NDData
from astropy.stats import gaussian_sigma_to_fwhm, sigma_clipped_stats
from astropy.table import Table, hstack
from astropy.modeling import models, fitting
from astropy.modeling.fitting import LevMarLSQFitter
from astropy.wcs.utils import pixel_to_skycoord
from astropy.coordinates import SkyCoord, match_coordinates_sky
from astropy.visualization import simple_norm
from astropy.io import fits
from astropy.wcs import WCS
from astropy.io import ascii

from scipy.stats import norm
import scipy.optimize as opt
from sklearn.metrics import mean_squared_error
from astropy.stats import sigma_clip

import photutils
from photutils import EPSFBuilder, find_peaks
from photutils.aperture import aperture_photometry, CircularAperture, CircularAnnulus, RectangularAperture, RectangularAnnulus
from photutils.detection import DAOStarFinder, IRAFStarFinder
from photutils.psf import DAOGroup, IntegratedGaussianPRF, extract_stars, IterativelySubtractedPSFPhotometry
from photutils.background import MMMBackground, SExtractorBackground, ModeEstimatorBackground, MedianBackground
from photutils.background import MADStdBackgroundRMS
from photutils.centroids import centroid_2dg
from photutils.utils import calc_total_error

%matplotlib inline
from matplotlib import style, pyplot as plt
import matplotlib.patches as patches
import matplotlib.ticker as ticker
from matplotlib.gridspec import GridSpec
params={'legend.fontsize':'18','axes.labelsize':'18',
        'axes.titlesize':'18','xtick.labelsize':'18',
        'ytick.labelsize':'18','lines.linewidth':2,
        'axes.linewidth':2,'animation.html': 'html5',
        'figure.figsize':(15,15)}
plt.rcParams.update(params)
# Colorblind-safe palette below
colors = ["#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e", "#a6cee3", "#1f78b4", "#b2df8a", "#33a02c", "#fb9a99", "#e31a1c", "#fdbf6f", "#ff7f00", "#cab2d6", "#6a3d9a", "#ffff99", "#b15928"]

In [None]:
# Print package versions
print(f'jwst: {jwst.__version__}')
print(f'astropy: {astropy.__version__}')
print(f'numpy: {np.__version__}')
print(f'photutils: {photutils.__version__}')

# Read input files and calibrate if needed

<div class="alert alert-block alert-warning">
<u><b>USER INPUT</b></u>
    
You need to define the following variables:
* `testing`: If True, only use 3 pointings of the F150W filter. If False, full suite of files will be used.
* `start_uncal`: True if starting from uncal files and you want to calibrate them within jupyter (this will take a LONG time!), False if starting from cal files. 
* `data_dir`: Directory that holds either uncal or cal files.
* `outdir`: Directory to write output files to.
</div>

In [None]:
testing = False # should be either True or False

start_uncal = False # should be either True or False

data_dir = "/ifs/jwst/wit/niriss/cap_simulations/nis011a/out_24May21/"
outdir = os.path.join(os.getcwd(), "out")

In [None]:
# Create output directory if it does not exist yet
if not os.path.isdir(outdir):
    os.mkdir(outdir)

# Instantiate the ref images dictionary
ref_ims = {}

### Starting from uncal (can still be run if starting from cal)
This may take a while depending on the number of files.

In [None]:
if start_uncal is True:
    # Get all uncal files
    # If testing just use 3 uncal files
    if testing is True:
        uncals = [os.path.join(data_dir, x) for x in ["jw01086001001_01101_00021_nis_uncal.fits", "jw01086001001_01101_00045_nis_uncal.fits", "jw01086001001_01101_00069_nis_uncal.fits"]]
    else:
        uncals = glob.glob(os.path.join(data_dir, "*_uncal.fits"))

    # Modify Image2 parameter ref file to skip photom and resample
    step = Image2Pipeline()
    step.export_config('calwebb_image2.asdf')
    af = asdf.open("calwebb_image2.asdf")
    for i,dct in enumerate(af.tree["steps"]):
        if dct["name"] in ["photom", "resample", "background"]:
            af.tree["steps"][i]["parameters"]["skip"] = True
    new_config = os.path.join(outdir, "calwebb_image3.asdf")
    af.write_to(new_config)

    # Put all files in data dictionary
    data_d = collections.defaultdict(dict)
    for item in uncals:
        det1_out = Detector1Pipeline.call(item, save_results=True, output_dir=outdir)
        im2_out = Image2Pipeline.call(det1_out, save_results=True, config_file=new_config, output_dir=outdir)
        im2_file = os.path.join(outdir, image2[0].meta.filename)
        filt = fits.getval(item, "filter")
        pupil = fits.getval(item, "pupil")
        filt = f'{filt}_{pupil}'
        photmjsr = im2_out.meta.photometry.conversion_megajanskys
        #print(f"\tConversion factor for DN/s to MJy/Sr: {photmjsr}")
        data_cps = im2_out.data/photmjsr
        xoffset = im2_out.meta.dither.x_offset
        yoffset = im2_out.meta.dither.y_offset
        if xoffset == yoffset == 0.0:
            ref_ims[filt] = im2_file
        data_d[filt][im2_file] = {"calfile": im2_file, "cal_datamodel": im2_out, "data_cps": data_cps}

### Starting from cal files (can still be run if starting from uncal)
This may take a while depending on the number of files.

In [None]:
if start_uncal is False:
    # Get all cal files
    # If testing just use 3 cal files
    if testing is True:
        cals = [os.path.join(data_dir, x) for x in ["jw01086001001_01101_00021_nis_cal.fits", "jw01086001001_01101_00045_nis_cal.fits", "jw01086001001_01101_00069_nis_cal.fits"]]
    else:
        cals = glob.glob(os.path.join(data_dir, "*_cal.fits"))

    # Put all files in data dictionary
    data_d = collections.defaultdict(dict)
    for im2_file in cals:
        filt = fits.getval(im2_file, "filter")
        pupil = fits.getval(im2_file, "pupil")
        filt = f'{filt}_{pupil}'
        im2_out = datamodels.open(im2_file)
        photmjsr = im2_out.meta.photometry.conversion_megajanskys
        #print(f"\tConversion factor for DN/s to MJy/Sr: {photmjsr}")
        data_cps = im2_out.data/photmjsr
        xoffset = im2_out.meta.dither.x_offset
        yoffset = im2_out.meta.dither.y_offset
        if xoffset == yoffset == 0.0:
            ref_ims[filt] = im2_file
        data_d[filt][im2_file] = {"calfile": im2_file, "cal_datamodel": im2_out, "data_cps": data_cps}

### The format of the data dictionary

The data dictionary contains all necessary info for each filter and file in a nested format. Each key of `data_d` is the pupil and filter combo of the form <filt_pupil>, which is itself a dictionary. Each key of the filter dictionary is a file, which is itself a dictionary that stores all relevant info for each file such as the cal file's datamodel (`cal_datamodel`), the photometry for each aperture size (e.g. `phot_ap3`) and source catalog (`sources`). So if you wanted to see the source catalog for the F150W file `file1.fits` you would do that like so:
```
data_d['CLEAR_F150W']['file1.fits']['sources']
```

In [None]:
# Display the reference image for each filter
for filt,filt_d in data_d.items():
    fname = ref_ims[filt]
    file_d = filt_d[fname]
    data_cps = file_d["data_cps"]
    normlzd = simple_norm(data_cps, 'sqrt', percent=99.)
    plt.xlabel("X [pix]")
    plt.ylabel("Y [pix]")
    plt.imshow(data_cps, norm=normlzd, cmap="Greys", origin='lower')
    plt.title(f"{filt.upper()} reference, {os.path.basename(fname)}")
    plt.show()

# Identify Sources
This can take some time depending on number of files, up to 10s per file.

In [None]:
# Estimate background and identify stars to be used for building PSF via photutils tasks 
# The parameters below work generally well for NIRISS images

bkgrms = MADStdBackgroundRMS()
mmm_bkg = MMMBackground()

# Define parameters for each image, in case we want to tweak params for certain exposures
# Right now, just use same params for all images
for filt,filt_d in data_d.items():
    print(filt)
    for fname,file_d in filt_d.items():
        data_cps = file_d["data_cps"]
        std = bkgrms(data_cps)
        bkg = mmm_bkg(data_cps)
        starfinder = IRAFStarFinder(threshold=100*std + bkg, fwhm=2.0, minsep_fwhm=7, 
                                    roundlo=0.0, roundhi=0.6, sharplo=0.6, sharphi=1.4)
        sources = starfinder(data_cps)
        file_d["sources"] = sources
        print(fname)
        print(f"\tBackground: {bkg:.4f}, StdDev: {std:4f}, Sources: {len(sources)}")

In [None]:
# Inspect plots like below to make sure we're picking up most of the stars while avoiding junk
# Only plot reference images
for filt,fname in ref_ims.items():
    file_d = data_d[filt][fname]
    fig, axes = plt.subplots(1, 2, figsize=(20,7))
    sources = file_d["sources"]
        
    axes[0].plot(sources['mag'], sources['sharpness'], 'o', color=colors[0], markersize=3)
    axes[0].set_xlabel('Magnitude')
    axes[0].set_ylabel('Sharpness')
    axes[0].set_title(f"{filt.upper()}, {os.path.basename(fname)}")
        
    axes[1].plot(sources['mag'], sources['roundness'], 'o', color=colors[1], markersize=3)
    axes[1].set_xlabel('Magnitude')
    axes[1].set_ylabel('Roundness')
    axes[1].set_title(f"{filt.upper()} reference, {os.path.basename(fname)}")
        
    fig.tight_layout()
    plt.show()
    fig, ax = plt.subplots(1, 1, figsize=(20,7))
    ax.plot(sources['flux'], sources['fwhm'], 'o', color=colors[2], markersize=3)
    ax.set_xlabel("FWHM")
    ax.set_ylabel("Flux")
    ax.set_title(f"{filt.upper()} reference, {os.path.basename(fname)}")
    ax.set_xlim(-100, 1000)

<div class="alert alert-block alert-warning">
<u><b>USER INPUT</b></u>
    
You need to define the following variables:
* `all_fwhm_lims`: The minimum and maximum limits for the FWHM for each filter, use diagnostic plots above.
* `all_flux_lims`: The minimum and maximum limits for the flux for each filter, use diagnostic plots above.
</div>

In [None]:
# We don't use the upper flux limit for now, so just edit lower limit
all_flux_lims = {
'F430M_CLEARP': [20, 1000],
'CLEAR_F090W': [50, 1000],
'CLEAR_F200W': [50, 1000],
'F356W_CLEARP': [50, 1000],
'F444W_CLEARP': [50, 1000],
'F277W_CLEARP': [50, 1000],
'CLEAR_F140M': [50, 1000],
'CLEAR_F115W': [50, 1000],
'F380M_CLEARP': [20, 1000],
'CLEAR_F158M': [50, 1000],
'CLEAR_F150W': [50, 1000],
'F480M_CLEARP': [20, 1000],
}

all_fwhm_lims = {
'F430M_CLEARP': [1.6, 1.9],
'CLEAR_F090W': [1.1, 1.6],
'CLEAR_F200W': [1.1, 1.6],
'F356W_CLEARP': [1.3, 1.7],
'F444W_CLEARP': [1.5, 2.0],
'F277W_CLEARP': [1.1, 1.6],
'CLEAR_F140M': [1.1, 1.6],
'CLEAR_F115W': [1.1, 1.6],
'F380M_CLEARP': [1.4, 1.8],
'CLEAR_F158M': [1.1, 1.8],
'CLEAR_F150W': [1.1, 1.6],
'F480M_CLEARP': [1.6, 2.0],
}

In [None]:
for filt,fname in ref_ims.items():
    file_d = data_d[filt][fname]
    sources = file_d["sources"]
        
    flux_lims = all_flux_lims[filt]
    fwhm_lims = all_fwhm_lims[filt]
    plt.figure(figsize=(15,7))
    plt.plot(sources['flux'], sources['fwhm'], 'o', color=colors[2], markersize=3)
    plt.xlim(0, 1000)
    plt.plot([flux_lims[0], flux_lims[0]], [fwhm_lims[0], fwhm_lims[1]], color=colors[1])
    plt.plot([flux_lims[0], 1000], [fwhm_lims[1], fwhm_lims[1]], color=colors[1])
    plt.plot([flux_lims[0], 1000], [fwhm_lims[0], fwhm_lims[0]], color=colors[1])
    plt.title(f"{filt.upper()} reference, {os.path.basename(fname)}")
    plt.xlabel("Flux")
    plt.ylabel("FWHM")
    plt.show()

In [None]:
# From the plot above, you can find the FWHM and flux limits of good sources
# Apply these filters (or masks) and overplot the detected sources on the 
# image to make sure we're dealing with actual stars
for filt,filt_d in data_d.items():
    for fname,file_d in filt_d.items():
        flux_lims = all_flux_lims[filt]
        fwhm_lims = all_fwhm_lims[filt]
        sources = file_d["sources"]
        mask = (sources['fwhm'] > fwhm_lims[0]) & (sources['fwhm'] < fwhm_lims[1]) & (sources['flux'] > flux_lims[0]) 
        sources_masked = sources[mask]
        file_d["sources_masked"] = sources_masked
        positions = np.transpose((sources_masked['xcentroid'], sources_masked['ycentroid']))
        apertures = CircularAperture(positions, r=10)
        normlzd = simple_norm(file_d["data_cps"], 'sqrt', percent=99.)
        
        if fname == ref_ims[filt]:
            plt.imshow(file_d["data_cps"], norm=normlzd, cmap='Greys', origin='lower')
            apertures.plot(color='blue', lw=1.5, alpha=0.5)
            plt.xlabel("X [pix]")
            plt.ylabel("Y [pix]")
            plt.title(f"{filt.upper()} reference: {os.path.basename(fname)}")
        
            plt.show()

# Perform aperture photometry
This can take some time, depending on number of aperture sizes and files.

<div class="alert alert-block alert-warning">
<u><b>USER INPUT</b></u>
    
You need to define the following variables:
* `aperture_sizes`: The list of aperture sizes to try, e.g. `[3, 5, 7, 10]`
* `annulus_offset`: The offset from the aperture size to use for the inner and outer radius of the sky annulus, e.g. `[+3, +7]`
    
Not yet implemented: use isolated bright, but non-saturated, stars to determine the best aperture size.
</div>

In [None]:
# Define aperture and annulus sizes
aperture_sizes = [3]
annulus_offset = [+3, +7]

In [None]:
# Now perform aperture photometry on the stars
for filt,filt_d in data_d.items():
    for fname,file_d in filt_d.items():
        sources_masked = file_d["sources_masked"]
        positions = np.transpose((sources_masked['xcentroid'], sources_masked['ycentroid']))
        for ap_size in aperture_sizes:
            apertures = CircularAperture(positions, r=ap_size)
            annulus_apertures = CircularAnnulus(positions, 
                                                r_in=ap_size+annulus_offset[0], 
                                                r_out=ap_size+annulus_offset[1])
            annulus_masks = annulus_apertures.to_mask(method='center')
            bkg_median = []
            for mask in annulus_masks:
                annulus_data = mask.multiply(file_d["data_cps"])
                annulus_data_1d = annulus_data[mask.data > 0]
                _, median_sigclip, _ = sigma_clipped_stats(annulus_data_1d)
                bkg_median.append(median_sigclip)
            bkg_median = np.array(bkg_median)
            phot = aperture_photometry(file_d["data_cps"], apertures)
            phot['annulus_median'] = bkg_median
            phot['aper_bkg'] = bkg_median * apertures.area
            phot['aper_sum_bkgsub'] = phot['aperture_sum'] - phot['aper_bkg']
            file_d[f"phot_ap{ap_size}"] = phot

In [None]:
# From Kevin Volk's email sent to Tony on Mar 26, 2021
# Assumed A0V spectral template: Bohlin Sirius (2020)

niriss_info = """
#                        ADU/sec           Jy          Microns    (MJy/ster)/(ADU/s)      Jy/(ADU/s)
ins    filter         count_rate     mean_f_nu        pivot_wl      phot_mjsr          phot_jy
NIRISS F090W          2.804222e+11   8.371777e+03     0.902458      1.673899e-10       2.985419e-08
NIRISS F115W          2.347116e+11   6.508594e+03     1.149542      1.554807e-10       2.773018e-08
NIRISS F140M          7.340864e+10   4.757672e+03     1.404035      3.633885e-10       6.481080e-08
NIRISS F150W          1.391529e+11   4.346497e+03     1.493457      1.751341e-10       3.123539e-08
NIRISS F158M          6.487924e+10   3.939552e+03     1.582011      3.404590e-10       6.072129e-08
NIRISS F200W          9.570604e+10   2.755850e+03     1.992961      1.614508e-10       2.879494e-08
NIRISS F277W          5.573678e+10   1.575708e+03     2.764281      1.585104e-10       2.827052e-08
NIRISS F356W          3.667514e+10   9.791757e+02     3.593004      1.496969e-10       2.669862e-08
NIRISS F380M          6.607656e+09   8.561057e+02     3.825227      7.264468e-10       1.295627e-07
NIRISS F430M          4.290003e+09   6.923615e+02     4.283827      9.048971e-10       1.613895e-07
NIRISS F444W          2.242869e+10   6.615412e+02     4.427699      1.653776e-10       2.949531e-08
NIRISS F480M          3.839979e+09   5.582376e+02     4.815243      8.151061e-10       1.453752e-07
Guider 1              1.379334e+12   4.478894e+03     2.501078      1.684581e-11       3.247142e-09
Guider 2              1.604525e+12   4.291017e+03     2.591652      1.386584e-11       2.674322e-09
"""

t = ascii.read(niriss_info, data_start=1)

In [None]:
# Calculate the star magnitudes and coordinates for each aperture size
for filt,filt_d in data_d.items():
    filt0, pup0 = filt.split("_")
    if filt0 in ["CLEAR", "CLEARP"]:
        filtkey = pup0
    else:
        filtkey = filt0
    # Magnitude zero points defined using the table above.
    count_rate = t[t['filter'] == filtkey.upper()]['count_rate'][0]
    zpt = 2.5 * np.log10(count_rate)
    for i,(fname,file_d) in enumerate(filt_d.items()):
        for ap_size in aperture_sizes:
            phot = file_d[f"phot_ap{ap_size}"]
        #phot = file_d[phot_ap]
            phot['mag'] = -2.5 * np.log10(phot['aper_sum_bkgsub']) + zpt - 1.401 # where does 1.401 come from?
            phot['mag_err'] = 1.0857 * np.sqrt(phot['aper_sum_bkgsub']) / phot['aper_sum_bkgsub'] # where does 1.0857 come from?

            # Convert the detected positions into sky coordinates (RA, Dec) in degrees
            detector_to_world = file_d["cal_datamodel"].meta.wcs.get_transform('detector', 'world')
            ra, dec = detector_to_world(phot["xcenter"].value, phot["ycenter"]  .value)
            phot["ra_deg"] = ra
            phot["dec_deg"] = dec
            coords = SkyCoord(ra=ra*u.degree, dec=dec*u.degree, frame='icrs')
            phot['coords'] = coords

            # At this stage, let's save the resulting photometry table to files so we don't have to 
            # repeat all the steps above in case something goes wrong with this session
            photfile0 = os.path.basename(fname).replace(".fits", f"_photutils_ap{ap_size}.fits")
            photfile = os.path.join(outdir, photfile0)
            phot.write(photfile, format="fits", overwrite=True)
            #print(f"Wrote {photfile}")

        # Let's take a look at how the sky coordinates look like for each catalog
        plt.plot(ra, dec, '<', color=colors[i], alpha=0.5, 
                 label=os.path.basename(fname))
        plt.title(filt.upper())
        plt.legend(bbox_to_anchor=(-.15, 1.35), loc="upper left")
    plt.show()

<div class="alert alert-block alert-warning">
<u><b>USER INPUT</b></u>
    
You need to define the following variables:
* `ap_use`: The aperture size to use for all photometry going forward.
    
Not yet implemented: use isolated bright, but non-saturated, stars to determine the best aperture size.
</div>

In [None]:
# To-do: Need to let this vary as a function of filter
ap_use = 3 # 3 5 7 10
phot_ap = f'phot_ap{ap_use}'

# Match stars in each image pair

In [None]:
# Now let's do the coordinate-based crossmatching
# For this, we use the simple match_coordinates_sky method from astropy.coordinates
# Just produce one group of plots per filter, since 8 sets is a little much!

for filt,filt_d in data_d.items():
    ref_im = ref_ims[filt]
    phot0 = filt_d[ref_im][phot_ap]
    filt_d[ref_im]['source_catalog'] = np.nan
    plotted = False

    for i,(fname,file_d) in enumerate(filt_d.items()):
        if fname == ref_im:
            continue
        # Note: idx is the index of c1
        photi = file_d[phot_ap]
        idx, d2d, d3d = match_coordinates_sky(photi['coords'], phot0['coords'])
        
        d0 = {'ra_deg_ref': phot0['ra_deg'][idx], 'dec_deg_ref': phot0['dec_deg'][idx], 
              'x_ref': phot0['xcenter'][idx], 'y_ref': phot0['ycenter'][idx],
              'mag_ref': phot0['mag'][idx], 'magerr_ref': phot0['mag_err'][idx], 
              'aper_sum_bkgsub_ref': phot0['aper_sum_bkgsub'][idx]}
        t0 = Table(d0)
                
        di = {'ra_deg': photi['ra_deg'], 'dec_deg': photi['dec_deg'], 
              'x': photi['xcenter'], 'y': photi['ycenter'],
              'mag': photi['mag'], 'magerr': photi['mag_err'],
              'aper_sum_bkgsub': photi['aper_sum_bkgsub'],
              'sep2d_arcsec': d2d.arcsecond}
        ti = Table(di)
        ti0 = hstack([t0, ti])
        file_d['source_catalog'] = ti0
        
        if plotted is True:
            continue
            
        print(filt)
        print(f"Ref:        {ref_im}")
        print(f"Comparison: {fname}")
        # Inspect the distance between matched sources to find out how to select good sources.
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))
        axes[0].hist(d2d.arcsecond, histtype='bar', facecolor=colors[0], linewidth=1.2, bins=50, range=(-0.1,1))
        axes[0].set_xlabel("Separation (arcsec)")
        axes[0].set_ylabel("N")
    
        # Also inspect difference in magnitudes.
        dmag = photi['mag'] - phot0['mag'][idx]
        axes[1].hist(dmag, histtype='bar', facecolor=colors[0], linewidth=1.2, bins=50, range=(-2,2))
        axes[1].set_xlabel("Delta Mag")
        axes[1].set_ylabel("N")
        plt.show()
        plotted = True


<div class="alert alert-block alert-warning">
<u><b>USER INPUT</b></u>
    
You need to define the following variables:
* `sep_limit`: Any "matched" stars with separartion (arcsec) larger than this value will be ignored.
    
Not yet implemented: use isolated bright, but non-saturated, stars to determine the best aperture size.
</div>

In [None]:
# 0.1 is a safe bet, might not need to change
sep_limit = 0.1

In [None]:
# Now filter out the bad matches and write the crossmatched catalog to a FITS table
# Write an output file {filt}_crossmatched.fits with source measurements for ref and dither images
# Just produce one group of plots per filter, since 8 sets is a little much!

for filt,filt_d in data_d.items():
    ref_im = ref_ims[filt]
    plotted = False
    for i,(fname,file_d) in enumerate(filt_d.items()):
        if fname == ref_im:
            file_d['filtered_dmag'] = np.nan
            file_d['filtered_xref'] = np.nan
            file_d['filtered_yref'] = np.nan
            continue

        ti0 = file_d['source_catalog']
        mask = (ti0['sep2d_arcsec'] < sep_limit)
        t_filtered = ti0[mask]
        matched_file = os.path.join(outdir, f'{filt}_crossmatched_{i}.fits')
        t_filtered.write(matched_file, format='fits', overwrite=True)
        with fits.open(matched_file, mode='update') as hdulist:
            hdr0 = hdulist[0].header
            hdr0.set("ref_img", ref_im)
            hdr0.set("comp_img", fname)
        #print(f'Wrote {matched_file}')
        
        delx = t_filtered['x_ref'] - t_filtered['x']
        dely = t_filtered['y_ref'] - t_filtered['y']
        
        # Calculate magnitude difference and flux ratio
        dmag0 = t_filtered['mag_ref'] - t_filtered['mag']
        dmag = dmag0.data
        dmag = np.nan_to_num(dmag, nan=0)
        flux_ratio0 = t_filtered['aper_sum_bkgsub_ref'] / t_filtered['aper_sum_bkgsub']
        flux_ratio = flux_ratio0.data
        flux_ratio = np.nan_to_num(flux_ratio, nan=0)        
                
        dmag_sigma_mask = sigma_clip(dmag, sigma=3, maxiters=10, masked=True)
        filtered_dmag = dmag[~dmag_sigma_mask.mask]
        file_d['filtered_dmag'] = filtered_dmag
        
        sigma_mask = sigma_clip(flux_ratio, sigma=3, maxiters=10, masked=True)
        file_d['filtered_flux_ratio'] = flux_ratio[~sigma_mask.mask]
        file_d['filtered_xref'] = t_filtered['x_ref'][~sigma_mask.mask]
        file_d['filtered_yref'] = t_filtered['y_ref'][~sigma_mask.mask]
        file_d['filtered_deltax'] = t_filtered['x_ref'][~sigma_mask.mask] - t_filtered['x'][~sigma_mask.mask]
        file_d['filtered_deltay'] = t_filtered['y_ref'][~sigma_mask.mask] - t_filtered['y'][~sigma_mask.mask]
 
        if plotted is True:
            continue

        print(filt)
        print(f"Ref:        {ref_im}")
        print(f"Comparison: {fname}")
        fig, axes0 = plt.subplots(2, 2, figsize=(20,20))
        axes = axes0.flatten()

        # Plot delta y vs delta x
        axes[0].plot(delx, dely, 'o', color=colors[0], markersize=3.0)
        axes[0].set_xlabel("Delta X")
        axes[0].set_ylabel("Delta Y")

        # Plot histogram of delta mags
        axes[1].hist(dmag, color=colors[0], density=True, linewidth=1.2, bins=500, range=[-0.2,0.2], label="Data")
        # Fit Gaussian to delta mag
        mu1, std1 = norm.fit(dmag)
        x1 = np.linspace(-.2, .2, 500)
        p1 = norm.pdf(x1, mu1, std1)
        axes[1].plot(x1, p1, colors[1], linewidth=5, label="Guassian Fit")
        axes[1].legend(loc="best")
        axes[1].set_xlabel("Delta Mag")
        axes[1].set_ylabel("N")

        # Plot histogram of sigma-clipped delta mags
        axes[2].hist(filtered_dmag, bins=100, density=True, color=colors[0], label="Data")
        mu, std = norm.fit(filtered_dmag)
        xmin = np.min(filtered_dmag)
        xmax = np.max(filtered_dmag)
        x = np.linspace(xmin, xmax, 500)
        p = norm.pdf(x, mu, std)
        axes[2].plot(x, p, colors[1], linewidth=5, label="Guassian Fit")
        axes[2].axvline(0, linestyle="--", color="k")
        axes[2].legend(loc="best")
        axes[2].set_xlabel("Sigma-clipped Delta Mag")
        axes[2].set_ylabel("N")
        
        axes[3].hist(file_d['filtered_flux_ratio'], bins=100, color=colors[0])
        axes[3].set_xlim(.5, 1.5)
        axes[3].set_xlabel("Flux Ratio")
        axes[3].set_ylabel("N")

        plt.show()
        plotted = True

# Combine all flux ratio measurements as a function of ref. image X & Y

In [None]:
# Create master x_ref, y_ref, and dmag arrays
# From here on, we do not loop over any files, just filters
final_data = {}
for filt,filt_d in data_d.items():
    x_ref = np.array([])
    y_ref = np.array([])
    z = np.array([])
    final_data[filt] = {}
    ref_im = ref_ims[filt]
    for fname,file_d in filt_d.items():
        if fname == ref_im:
            continue
        x_ref = np.concatenate((x_ref, file_d['filtered_xref']))
        y_ref = np.concatenate((y_ref, file_d['filtered_yref']))
        z = np.concatenate((z, file_d['filtered_flux_ratio']))
        final_data[filt]['x_ref'] = x_ref
        final_data[filt]['y_ref'] = y_ref
        final_data[filt]['flux_ratio'] = z

In [None]:
# Make 2D plot of average ratios for diagnostic purposes only, this might help you decide which fit to use
# Bin x and y data into an image
# From: https://stackoverflow.com/questions/30764955/python-numpy-create-2d-array-of-values-based-on-coordinates
for filt,filt_d in final_data.items():
    zi, yi, xi = np.histogram2d(filt_d['y_ref'], filt_d['x_ref'], bins=(2048,2048), 
                                weights=filt_d['flux_ratio'], normed=False)
    counts, _, _ = np.histogram2d(filt_d['y_ref'], filt_d['x_ref'], bins=(2048,2048))
    zi /= counts
    ratio_im = np.nan_to_num(zi, nan=1)
    im = plt.imshow(ratio_im, origin="lower", vmin=.999, vmax=1.001)
    plt.colorbar(im, label="Flux Ratio")
    plt.title(f"{filt}, average flux ratio of all pointings")
    plt.show()

# Fit variety of models to flux ratio

<div class="alert alert-block alert-warning">
<u><b>USER INPUT</b></u>
    
You need to define the following variables:
* `init_guess`: Initial guess for the 2D Gaussian fit parameters. Corresponds to amplitude, xo, yo, sigma_x, sigma_y, theta, offset. If you're not interested in fitting a Gaussian, you can leave this as is.

**WARNING!!!** Fitting a 2D gaussian is difficult. Derived fits will be inaccurate unless you provide an initial guess for the parameters- use the plots above to estimate the initial guess.
</div>

In [None]:
# Corresponds to amplitude, xo, yo, sigma_x, sigma_y, theta, offset
init_guess = (1.2, 1250, 1250, 200, 500, 0, 0)

In [None]:
# Define a model function for a 2D Gaussian that can be fit
# From: https://stackoverflow.com/questions/21566379/fitting-a-2d-gaussian-function-using-scipy-optimize-curve-fit-valueerror-and-m
def twod_gaussian(xy, amplitude, xo, yo, sigma_x, sigma_y, theta, offset):
    x = xy[0]
    y = xy[1]
    xo = float(xo)
    yo = float(yo)    
    a = (np.cos(theta)**2)/(2*sigma_x**2) + (np.sin(theta)**2)/(2*sigma_y**2)
    b = -(np.sin(2*theta))/(4*sigma_x**2) + (np.sin(2*theta))/(4*sigma_y**2)
    c = (np.sin(theta)**2)/(2*sigma_x**2) + (np.cos(theta)**2)/(2*sigma_y**2)
    g = offset + amplitude*np.exp( - (a*((x-xo)**2) + 2*b*(x-xo)*(y-yo) + c*((y-yo)**2)))
    return g.ravel()

In [None]:
# Now fit 2D polynomial, legendre, and chebyshev functions of the orders specified below. Also fit a 2D gaussian.
orders = [1, 2, 3, 4, 5]
for filt,filt_d in final_data.items():
    actualx = filt_d['x_ref']
    actualy = filt_d['y_ref']
    actualz = filt_d['flux_ratio']
    fts = {}
    for i in orders:
        poly = models.Polynomial2D(degree=i)
        leg = models.Legendre2D(x_degree=i, y_degree=i)
        cheby = models.Chebyshev2D(x_degree=i, y_degree=i)  
        fitter = LevMarLSQFitter()
        poly_fit = fitter(poly, actualx, actualy, actualz)
        leg_fit = fitter(leg, actualx, actualy, actualz)
        cheby_fit = fitter(cheby, actualx, actualy, actualz)
        fts[f'polynomial order={i}'] = poly_fit
        fts[f'legendre order={i}'] = leg_fit
        fts[f'chebyshev order={i}'] = cheby_fit
    popt, pcov = opt.curve_fit(twod_gaussian, (actualx, actualy), actualz, p0=init_guess)
    fts['2d gaussian'] = popt
    filt_d['fts'] = fts

In [None]:
# Now evaluate each fit and estimate its goodness
x = y = np.arange(2048)
for filt,filt_d in final_data.items():
    print(filt)
    fts = filt_d['fts']
    for fitname,fit in fts.items():
        if fitname == "2d gaussian":
            fit_z = twod_gaussian((actualx, actualy), *fit)
        else:
            fit_z = fit(actualx, actualy)
        rmse = mean_squared_error(actualz, fit_z, squared=False)
        chi2 = np.sum( (actualz-fit_z)**2 / fit_z )
        print(fitname)
        print(f'\tRMSE:{rmse:.5f}, Chi2: {chi2:.4f}')
    print('')

<div class="alert alert-block alert-warning">
<u><b>USER INPUT</b></u>
    
You need to define the following variables:
* `final_fitname`: The name of the fit to adopt for the delta flat. This assumes the same fit will be used for all filters.
</div>

In [None]:
final_fitname = 'polynomial order=3'

In [None]:
# Create a 3D plot of the flux ratios and overplot the fit surface
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D

xmesh, ymesh = np.meshgrid(x, y)
for filt,filt_d in final_data.items():
    print(filt)
    fts = filt_d['fts']
    final_fit = fts[final_fitname]
    # Evaluate the fit on an integer grid for plotting below    
    if final_fitname == "2d gaussian":    
        zmesh = twod_gaussian((xmesh, ymesh), *final_fit)
        zmesh = zmesh.reshape(2048, 2048)
    else:
        zmesh = final_fit(xmesh, ymesh)
    filt_d['fit_zmesh'] = zmesh
    fig = plt.figure(figsize=(15,15)) 
    ax = fig.add_subplot(111,projection='3d')
    ax.scatter(actualx, actualy, actualz, marker='o', s=15, c=colors[0])
    downsample = 10
    ax.plot_surface(xmesh, ymesh, zmesh, rcount=downsample, ccount=downsample, color=colors[1], 
                    shade=False, alpha=0.3)
    ax.set_zlim(0, 2)
    ax.set_title(filt)
    ax.set_xlabel("X pixel")
    ax.set_ylabel("Y pixel")
    ax.set_zlabel("Flux ratio")
    plt.show()

In [None]:
# You have to do this to undo the 3D plotting capabilities
%matplotlib inline
from matplotlib import style, pyplot as plt
params={'legend.fontsize':'18','axes.labelsize':'18',
        'axes.titlesize':'18','xtick.labelsize':'18',
        'ytick.labelsize':'18','lines.linewidth':2,
        'axes.linewidth':2,'animation.html': 'html5',
        'figure.figsize':(15,15)}
plt.rcParams.update(params)

In [None]:
# Now plot the fit in 2D
for filt,filt_d in final_data.items():
    mn = 1. - np.min(filt_d['fit_zmesh'].flatten())
    mx = 1. - np.max(filt_d['fit_zmesh'].flatten())
    mx_mx = np.max(np.abs([mn, mx]))
    im = plt.imshow(filt_d['fit_zmesh'], origin='lower', vmin=1.0-mx_mx, vmax=1.0+mx_mx, cmap="PiYG")
    plt.colorbar(im, label="Flux Ratio")
    plt.title(f'{filt} Delta Flat')
    plt.show()

# Create delta flat file for each filter

In [None]:
# Now package the data into a fits file. 
for filt,filt_d in final_data.items():
    im = filt_d['fit_zmesh']
    central = im[974:1075, 974:1075]
    # Compute the mean in the central 100x100 pixels and use that to normalize the image.
    #central = im[974:1075, 974:1075]
    #avg = np.average(central)
    avg = np.average(im)
    nrmlzd = im/avg

    outname0 = f'{filt}_deltaflat.fits'
    outname = os.path.join(outdir, outname0)
    
    hdr0 = fits.Header()
    filt0, pup0 = filt.split("_")
    hdr0['FILTER'] = filt0
    hdr0['PUPIL'] = pup0
    hdu0 = fits.PrimaryHDU(header=hdr0)
    hdu1 = fits.ImageHDU(nrmlzd)
    hdulist = fits.HDUList([hdu0, hdu1])
    
    hdulist.writeto(outname, overwrite=True)
    print(f'Wrote file: {outname}')