In [1]:
import webbpsf
import dLux as dl
import dLuxWebbpsf as dlW

import jax.numpy as np
import jax.random as jr

from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import pandas as pd

import optax
import zodiax as zdx
import jax

from dLux.utils import deg_to_rad as d2r
import dLux.utils as dlu
from astropy.io import fits
from detector_layers import DistortionFromSiaf

plt.rcParams['image.origin'] = 'lower'



dLux: Jax is running in 32-bit, to enable 64-bit visit: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision


# Building Model

In [None]:
# Primary mirror - note this class automatically flips about the y-axis
webbpsfobj = webbpsf.NIRISS()
webbpsfobj.calc_psf()  # calculating fits files
webbpsfobj.pupil_mask = "MASK_NRM"
NIS_CEN_aperture = webbpsfobj.siaf.apertures["NIS_CEN"]
webbpsf_osys = webbpsfobj.get_optical_system()
planes = webbpsf_osys.planes

In [None]:
radial_orders = np.array([0, 1, 2], dtype=int)
hexike_shape = (7, int(np.sum(np.array([dl.utils.triangular_number(i+1) - dl.utils.triangular_number(i) for i in radial_orders]))))

true_flux = 1e8
true_coeffs = 1e-7 * jr.normal(jr.PRNGKey(0), hexike_shape)

In [None]:
npix = 1024
oversample = 4
pscale = (planes[-1].pixelscale).to("arcsec/pix").value
pupil_plane = planes[-2]

osys = dl.LayeredOptics(
    wf_npixels=1024,
    diameter=planes[0].pixelscale.to("m/pix").value * planes[0].npix,
    layers=[
        (dlW.optical_layers.JWSTAberratedPrimary(
            planes[0].amplitude,
            planes[0].opd,
            radial_orders=radial_orders,
            coefficients=true_coeffs,
            AMI=True,
        ), "Pupil"),
        (dl.Flip(0), "InvertY"),
        (dl.Optic(pupil_plane.amplitude), "Mask"),
        (dlW.MFT(npixels=oversample * 64, oversample=oversample, pixel_scale=pscale), "Propagator"),
    ]
)

src = dl.PointSource(flux=true_flux, **dict(np.load("filter_configs/F480M.npz")))
detector = dl.LayeredDetector(
    [
        dlW.detector_layers.Rotate(-d2r(getattr(NIS_CEN_aperture, "V3IdlYAngle"))),
        DistortionFromSiaf(
            aperture=NIS_CEN_aperture
        ),  # TODO implement dLuxWebbpsf version
        dl.IntegerDownsample(kernel_size=4),  # Downsample to detector pixel scale
    ]
)

instrument = dl.Instrument(sources=[src], detector=detector, optics=osys)

# Reading Data

In [None]:
data_path = 'data/calints/'
info_path = 'data/comm_1093_exposure_info_epoch2.txt'

def uncal_to_calints(filename):
    return filename[:-10] + 'calints.fits'

In [None]:
# DATA
# reading table to pandas dataframe
df = pd.read_table(info_path, sep='\s+', header=0, skiprows=[1])

# selecting HD-37093 target
HD37093 = df[df['TARGPROP'] == 'HD-37093']

file_idx = 0  # selecting first file
HDUList = fits.open(data_path + uncal_to_calints(list(HD37093['DISKNAME'])[file_idx]))

print(HDUList.info())

filt = list(HD37093['FILTER'])[file_idx]

data = np.array(HDUList['SCI'].data[0])

In [None]:
plt.imshow(data, cmap='magma')
plt.title('JWST AMI Frame')
plt.colorbar(label='Counts')
plt.show()

plt.imshow(np.array(HDUList['DQ'].data[0]), cmap='gray')
plt.title('JWST AMI Frame')
plt.colorbar(label='Bad Pixels')
plt.show()