# Full Pupil NIRISS for Ben

In [3]:
import webbpsf
import jax.numpy as np
import jax.random as jr
from jax.scipy.ndimage import map_coordinates
from jax import Array
import dLux as dl
import dLuxWebbpsf as dlW
from dLux.utils import deg_to_rad as d2r
import matplotlib.pyplot as plt

from dLux.detector_layers import DetectorLayer

dLux: Jax is running in 32-bit, to enable 64-bit visit: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision


## SIAF distortion detector layer
A version of this is built into dLuxWebbpsf I believe however haven't yet made the move over.

In [None]:
class DistortionFromSiaf:
    def __new__(cls, aperture, oversample=4):
        degree = aperture.Sci2IdlDeg + 1
        coeffs_dict = aperture.get_polynomial_coefficients()
        coeffs = np.array([coeffs_dict['Sci2IdlX'], coeffs_dict['Sci2IdlY']])
        sci_refs = np.array([aperture.XSciRef, aperture.YSciRef])
        sci_cens = np.array([aperture.XSciRef, aperture.YSciRef])  # Note this may not be foolproof
        pixelscale = 4*0.0164  # np.array([aperture.XSciScale, aperture.YSciScale]).mean()  # this may not be foolproof
        return ApplySiafDistortion(degree, coeffs, sci_refs, sci_cens, pixelscale / oversample, oversample)


class ApplySiafDistortion(DetectorLayer):
    """
    Applies Science to Ideal distortion following webbpsf/pysaif
    """
    degree: int
    Sci2Idl: float
    SciRef: float
    sci_cen: float
    pixel_scale: float
    oversample: int
    xpows: Array
    ypows: Array

    def __init__(self,
                 degree,
                 Sci2Idl,
                 SciRef,
                 sci_cen,
                 pixel_scale,
                 oversample):
        super().__init__()
        self.degree = int(degree)
        self.Sci2Idl = np.array(Sci2Idl, dtype=float)
        self.SciRef = np.array(SciRef, dtype=float)
        self.sci_cen = np.array(sci_cen, dtype=float)
        self.pixel_scale = np.array(pixel_scale, dtype=float)
        self.oversample = int(oversample)
        self.xpows, self.ypows = self.get_pows()

    def get_pows(self):
        n = self.triangular_number(self.degree)
        vals = np.arange(n)

        # Ypows
        tris = self.triangular_number(np.arange(self.degree))
        ydiffs = np.repeat(tris, np.arange(1, self.degree + 1))
        ypows = vals - ydiffs

        # Xpows
        tris = self.triangular_number(np.arange(1, self.degree + 1))
        xdiffs = np.repeat(n - np.flip(tris), np.arange(self.degree, 0, -1))
        xpows = np.flip(vals - xdiffs)

        return xpows, ypows

    def __call__(self, image):
        """

        """
        new_image = self.apply_Sci2Idl_distortion(image.image)
        return image.set('image', new_image)
        # return image_out

    def apply_Sci2Idl_distortion(self, image):
        """
        Applies the distortion from the science (i.e. images) frame to the idealised telescope frame
        """

        # Convert sci cen to idl frame
        xidl_cen, yidl_cen = self.distort_coords(self.Sci2Idl[0],
                                                 self.Sci2Idl[1],
                                                 self.sci_cen[0] - self.SciRef[0],
                                                 self.sci_cen[1] - self.SciRef[1])

        # Get paraxial pixel coordinates and detector properties.
        nx, ny = image.shape
        nx_half, ny_half = ((nx - 1) / 2., (ny - 1) / 2.)
        xlin = np.linspace(-1 * nx_half, nx_half, nx)
        ylin = np.linspace(-1 * ny_half, ny_half, ny)
        xarr, yarr = np.meshgrid(xlin, ylin)

        # Scale and shift coordinate arrays to 'sci' frame
        xnew = xarr / self.oversample + self.sci_cen[0]
        ynew = yarr / self.oversample + self.sci_cen[1]

        # Convert requested coordinates to 'idl' coordinates
        xnew_idl, ynew_idl = self.distort_coords(self.Sci2Idl[0],
                                                 self.Sci2Idl[1],
                                                 xnew - self.SciRef[0],
                                                 ynew - self.SciRef[1])

        # Create interpolation coordinates
        centre = (xnew_idl.shape[0] - 1) / 2

        coords_distort = (np.array([ynew_idl - yidl_cen,
                                    xnew_idl - xidl_cen])
                          / self.pixel_scale) + centre

        # Apply distortion
        return map_coordinates(image, coords_distort, order=1)

    def triangular_number(self, n):
        # TODO: Add to utils/math.py
        return n * (n + 1) / 2

    def distort_coords(self, A, B, X, Y, ):
        """
        Applts the distortion to the coordinates
        """

        # Promote shapes for float inputs
        X = np.atleast_2d(X)
        Y = np.atleast_2d(Y)

        # Exponentiate
        Xpow = X[None, :, :] ** self.xpows[:, None, None]
        Ypow = Y[None, :, :] ** self.ypows[:, None, None]

        # Calcaulate new coordinates
        Xnew = np.sum(A[:, None, None] * Xpow * Ypow, axis=0)
        Ynew = np.sum(B[:, None, None] * Xpow * Ypow, axis=0)

        return Xnew, Ynew

## Aberrated Primary Mirror
Here we have the `JWSTAberratedPrimary` class which is a child of `JWSTPrimary` which adds the functionality to store a Hexike basis and coefficients.

In [None]:
from dLuxWebbpsf import JWSTPrimary
from abberations import generate_jwst_hexike_basis, generate_jwst_secondary_basis

class JWSTAberratedPrimary(JWSTPrimary, dl.optical_layers.BasisLayer):
    """
    Child class of JWSTPrimary which adds the functionality to store a Hexike basis and coefficients.
    """

    def __init__(
            self,
            transmission: Array,
            opd: Array,
            coefficients: Array | list = None,
            radial_orders: Array | list = None,
            noll_indices: Array | list = None,
            secondary_coefficients: Array | list = None,
            secondary_radial_orders: Array | list = None,
            secondary_noll_indices: Array | list = None,
            AMI: bool = False,
            mask: bool = False,
    ):
        """
        Parameters
        ----------
        transmission: Array
            The Array of transmission values to be applied to the input
            wavefront.
        opd : Array
            The Array of OPD values to be applied to the input wavefront.
        radial_orders : Array
            The radial orders of the zernike polynomials to be used for the
            aberrations. Input of [0, 1] would give [Piston, Tilt X, Tilt Y],
            [1, 2] would be [Tilt X, Tilt Y, Defocus, Astig X, Astig Y], etc.
            The order must be increasing but does not have to be consecutive.
            If you want to specify specific zernikes across radial orders the
            noll_indices argument should be used instead.
        noll_indices : Array
            The zernike noll indices to be used for the aberrations. [1, 2, 3]
            would give [Piston, Tilt X, Tilt Y], [2, 3, 4] would be [Tilt X,
            Tilt Y, Defocus].
        coefficients : Array
            The coefficients to be applied to the Hexike basis vectors.
        AMI : bool
            Whether to use the AMI segments or not.
        mask : bool
            Whether to apodise the basis with the AMI mask or not. Recommended is False.
        """
        npix: int = transmission.shape[0]
        super().__init__(transmission=transmission, opd=opd)

        # Dealing with the radial_orders and noll_indices arguments
        if radial_orders is not None and noll_indices is not None:
            print("Warning: Both radial_orders and noll_indices provided. Using noll_indices.")
            radial_orders = None

        primary_basis = generate_jwst_hexike_basis(
            radial_orders=radial_orders,
            noll_indices=noll_indices,
            npix=npix,
            AMI=AMI,
            mask=mask,
        )

        if secondary_radial_orders is not None and secondary_noll_indices is not None:
            print("Warning: Both secondary_radial_orders and secondary_noll_indices provided. Using "
                  "secondary_noll_indices.")
            secondary_radial_orders = None

        if secondary_coefficients is not None:
            secondary_basis = generate_jwst_secondary_basis(
                radial_orders=secondary_radial_orders,
                noll_indices=secondary_noll_indices,
                npix=npix,
            )
            self.coefficients = {'primary': coefficients, 'secondary': secondary_coefficients}
            self.basis = {'primary': primary_basis, 'secondary': secondary_basis}

        else:
            self.coefficients = np.array(coefficients)
            self.basis = np.array(primary_basis)

    @property
    def basis_opd(self):
        """
        Returns the OPD calculated from the basis and coefficients.
        """

        outputs = jtu.tree_map(lambda b, c: self.calculate(b, c), (self.basis,), (self.coefficients,))
        return np.array(jtu.tree_flatten(outputs)[0]).sum(0)

    def __call__(self, wavefront):
        # Apply transmission and normalise
        amplitude = wavefront.amplitude * self.transmission
        amplitude /= np.linalg.norm(amplitude)

        total_opd = self.opd + self.basis_opd

        # Apply phase
        phase = wavefront.phase + wavefront.wavenumber * total_opd

        # Update and return
        return wavefront.set(["amplitude", "phase"], [amplitude, phase])

# Building dLux model

Here's the stuff we need from WebbPSF...

In [2]:
# Primary mirror - note this class automatically flips about the y-axis
webbpsfobj = webbpsf.NIRISS()
webbpsfobj.calc_psf()  # calculating fits files
webbpsfobj.pupil_mask = "MASK_NRM"
NIS_CEN_aperture = webbpsfobj.siaf.apertures["NIS_CEN"]
webbpsf_osys = webbpsfobj.get_optical_system()
planes = webbpsf_osys.planes

OSError: Environment variable $WEBBPSF_PATH is not set!

 ***********  ERROR  ******  ERROR  ******  ERROR  ******  ERROR  ***********
 *                                                                          *
 *  WebbPSF requires several data files to operate.                         *
 *  These files could not be located automatically at this time, or this    *
 *  version of the software requires a newer set of reference files than    *
 *  you have installed.  For more details see:                              *
 *                                                                          *
 *        https://webbpsf.readthedocs.io/en/stable/installation.html        *
 *                                                                          *
 *  under "Installing the Required Data Files".                             *
 *  WebbPSF will not be able to function properly until the appropriate     *
 *  reference files have been downloaded to your machine and installed.     *
 *                                                                          *
 ****************************************************************************


Just initialising parameters of the model, including

In [None]:
radial_orders = np.array([0, 1, 2], dtype=int)
hexike_shape = (7, int(np.sum(np.array([dl.utils.triangular_number(i+1) - dl.utils.triangular_number(i) for i in radial_orders]))))

true_flux = 1e6
true_coeffs = 1e-7 * jr.normal(jr.PRNGKey(0), hexike_shape)

In [None]:
npix = 1024
oversample = 4
pscale = (planes[-1].pixelscale).to("arcsec/pix").value
pupil_plane = planes[-2]

osys = dl.LayeredOptics(
    wf_npixels=1024,
    diameter=planes[0].pixelscale.to("m/pix").value * planes[0].npix,
    layers=[
        (JWSTAberratedPrimary(
            planes[0].amplitude,
            planes[0].opd,
            radial_orders=radial_orders,
            coefficients=true_coeffs,
            AMI=False,  # FALSE FOR FULL PUPILL
        ), "Pupil"),
        (dl.Flip(0), "InvertY"),
        (dl.Optic(pupil_plane.amplitude), "Mask"),
        (dlW.MFT(npixels=oversample * 64, oversample=oversample, pixel_scale=pscale), "Propagator"),
    ]
)

src = dl.PointSource(flux=true_flux, **dict(np.load("filter_configs/F480M.npz")))
detector = dl.LayeredDetector(
    [
        dlW.detector_layers.Rotate(-d2r(getattr(NIS_CEN_aperture, "V3IdlYAngle"))),
        DistortionFromSiaf(
            aperture=NIS_CEN_aperture
        ),  # TODO implement dLuxWebbpsf version
        dl.IntegerDownsample(kernel_size=4),  # Downsample to detector pixel scale
    ]
)

instrument = dl.Instrument(sources=[src], detector=detector, optics=osys)

## PSF time
surely

In [None]:
plt.imshow(instrument.model())