In [4]:
import os

# Set CPU count for numpyro multi-chain multi-thread
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'

import pickle
import itertools

import jax
# Enable x64 for JAX
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as np
import jax.random as jr
from jax import jit, lax

import numpy as onp
from matplotlib import pyplot as plt
from matplotlib import colors
from astropy import units as u
from astropy.io import fits
from scipy.ndimage import binary_dilation

import zodiax as zdx
import dLux as dl
from dLux.utils import arcsec2rad as toRad
from dLux.utils import rad2arcsec as toArcsec
from dLuxWebbpsf.utils import grid_search
from dLuxWebbpsf import NIRCam

from synphot import SourceSpectrum
from synphot.models import BlackBodyNorm1D

import numpyro as npy
import numpyro.distributions as dist

import chainconsumer as cc

def norm(data):
    return data/data.sum()

In [3]:
#plt.rcParams['image.cmap'] = 'inferno'

plt.rcParams["text.usetex"] = 'true'
plt.rcParams["font.family"] = "serif"
plt.rcParams['figure.dpi'] = 120

cmap_inf = plt.colormaps["inferno"].copy()
cmap_inf.set_bad('black', 1.)
cmap_inf.set_under('black')
cmap_inf.set_over('black')

cmap_sei = plt.colormaps["seismic"].copy()
cmap_sei.set_bad('black', 1.)
cmap_sei.set_under('black')
cmap_sei.set_over('black')


plt.ioff()

<contextlib.ExitStack at 0x7f4414747010>

### Get data from FITS files

In [37]:
def get_mask(data, x, y, size, *, dilate = False):
    xx, yy = np.array(np.meshgrid(np.arange(data.shape[0]), np.arange(data.shape[1])))

    mask_x = (xx >= (x - size)) & (xx <= (x + size))
    mask_y = (yy >= (y - size)) & (yy <= (y + size))

    data_mask = np.where((mask_x & mask_y), 0, 1)

    if dilate:
        data_mask = np.logical_not(binary_dilation(np.logical_not(data_mask))).astype(data_mask.dtype)

    return data_mask

In [38]:
def read_fits(path, masks = None):
    hdul = fits.open(path)
    #hdul.info()

    webbpsf_header = hdul[0].header
    webbpsf_data_star = np.array(sum(hdul[1].data))
    webbpsf_data_err = np.array(sum(hdul[2].data))
    webbpsf_data_dq = np.array(sum(hdul[3].data))

    hdul.close()

    # cleanup data
    webbpsf_data_star = np.where(np.isnan(webbpsf_data_star), 0, webbpsf_data_star)
    xx, yy = np.array(np.meshgrid(np.arange(webbpsf_data_star.shape[0]), np.arange(webbpsf_data_star.shape[1])))

    bri_y, bri_x = np.unravel_index(webbpsf_data_star.argmax(), webbpsf_data_star.shape)

    webbpsf_data_mask = get_mask(webbpsf_data_star, bri_x, bri_y, 2)

    if masks is not None:
        for (m_x, m_y, m_s) in masks:
            webbpsf_data_mask *= get_mask(webbpsf_data_star, m_x, m_y, m_s)

    webbpsf_data_mask = np.logical_not(binary_dilation(np.logical_not(webbpsf_data_mask))).astype(webbpsf_data_mask.dtype)
    webbpsf_data_mask *= (webbpsf_data_dq == 0) & ~((xx == 9) & (yy == 31)) # 9,31 is a hot pixel


    return webbpsf_header, webbpsf_data_star, webbpsf_data_mask, webbpsf_data_err

### Create model

In [5]:
zernike_terms=3
n_mirrors = 18
n_coeffs = n_mirrors * zernike_terms

def get_telescope(npix, filter, detector, aperture, opd_date, *, pixel_mask = None, flux = None, downsample=8, source=None):
    webb_options = {
        'add_distortion': True,
        'add_ipc': True,
        'dl_add_rotation': True,
        'dl_add_siaf': True,
        'dl_add_diffusion': True,
    }
    
    telescope = NIRCam(
                    filter = filter,
                    detector = detector,
                    aperture = aperture,
                    fft_oversample=1,
                    detector_oversample=4,
                    wavefront_downsample=downsample,
                    fov_pixels=npix,
                    options=webb_options,
                    phase_retrieval_terms = zernike_terms,
                    flux = flux,
                    load_opd_date=opd_date,
                    source = source
                )

    if pixel_mask is not None:
        response_layer = dl.ApplyPixelResponse(pixel_mask)
        telescope.detector.layers["pixel_mask"] = response_layer

    return telescope

In [6]:
def crop(data):
    s = 10
    return data[s:-s, s:-s]

## Step 2: Make an optical model for optimisation where source will be parametrised.

The solution that they have for this same problem in https://github.com/amerand/CANDID, a popular interferometry data analysis code, is to calculate the likelihood on a grid of xy with other parameters held constant, to identify the best position to start
see that page - scroll down to ‘CHI2MAP’ and ‘FITMAP’

In [7]:
@zdx.filter_jit
def likelihood(model, x, y, expected_data):
    model_data_n = model.set(['source.position'], [np.array([x, y])]).model()
    model_data = crop(norm(model_data_n))
    residuals = expected_data - model_data
    # calculate the chi-squared value
    chi_squared = np.sum(residuals**2)
    # calculate the likelihood
    likelihood =  -0.5 * chi_squared
    return likelihood

@zdx.filter_jit
def likelihood_fast(telescope, x, y, expected_data):
    likelihood_c = lambda args : likelihood(telescope, args[0], args[1], expected_data)
    args = np.stack((x, y), axis=-1)
    return lax.map(likelihood_c, args)

In [8]:
def search_grid(model, expected_data):

    def get_single(xr, yr):
        return likelihood_fast(model, xr, yr, expected_data)
    
    pri_width = toRad(1)
    pri_xr, pri_yr, tp = grid_search(get_single, 0, 0, grid_size = pri_width, grid_steps = 30, niter = 10)
    
    return pri_xr, pri_yr, tp

## Step 3: Set up HMC parameters and distributions.

In [9]:
phase_scale = 1e-6

fixed_mirrors = {
    (0,0): 0,
    (1,0): 0,
    (2,0): 0,
}

parameters = ['source.position', 'pupil.coefficients', 'source.flux']

def psf_model(data, data_err, model, *, x0, y0, coord_range, log_flux, log_flux_range):
    """
    Define the numpyro function
    """
    
    x_sample = npy.sample("x_arcsec", dist.Uniform(x0 - coord_range, x0 + coord_range))
    y_sample = npy.sample("y_arcsec", dist.Uniform(y0 - coord_range, y0 + coord_range))
    
    phases_data = [0] * n_coeffs
    for i, z in itertools.product(range(n_mirrors), range(zernike_terms)):
        coeff_index = i*zernike_terms + z
        if (z, i) in fixed_mirrors:
            phases_data[coeff_index] = fixed_mirrors[(z,i)]
        else:
            sample = npy.sample(f'p{coeff_index}', dist.Uniform(-1,1))
            phases_data[coeff_index] = sample * phase_scale

    flux_sample  = npy.sample("log_flux", dist.Uniform(log_flux - log_flux_range, log_flux + log_flux_range))

    values = [
        np.array([toRad(x_sample), toRad(y_sample)]),
        np.array(phases_data),
        np.exp(flux_sample)
    ]
    
    with npy.plate("data", len(data.flatten())):
        model_data = (crop(model.set(parameters, values).model())).flatten()
        dist_model = dist.Normal(model_data, data_err)
        npy.sample("psf", dist_model, obs=data.flatten())
        

In [10]:
def get_results(values_out):
    
    mean_vals = {k: np.mean(v) for k, v in values_out.items()}

    x_predicted = toRad(mean_vals["x_arcsec"])
    y_predicted = toRad(mean_vals["y_arcsec"])
    flux_predicted = np.exp(mean_vals["log_flux"])

    coeffs_pred_all = [0] * n_coeffs
    for i, z in itertools.product(range(n_mirrors), range(zernike_terms)):
        coeff_index = i*zernike_terms + z
        if (z, i) in fixed_mirrors:
            coeffs_pred_all[coeff_index] = fixed_mirrors[(z,i)]
        else:
            coeffs_pred_all[coeff_index] = mean_vals[f'p{coeff_index}'] * phase_scale

    coeffs_pred_all = np.array(coeffs_pred_all)

    return x_predicted, y_predicted, flux_predicted, coeffs_pred_all

## Step 4: Make plots.

In [11]:
def get_aber(coeffs, basis):
    s_coeffs = np.asarray(coeffs.reshape(coeffs.shape[0], 1, 1), dtype=float)
    s_basis = np.asarray(basis, dtype=float)
    result = s_coeffs * s_basis
    return result.sum(0)

In [12]:
def double_plot(a, b, title, filename):
    fig = plt.figure(figsize=(8, 3))
    plt.suptitle(title)
    
    plt.subplot(1, 2, 1)
    plt.imshow(a, norm='log')
    plt.colorbar()

    plt.subplot(1, 2, 2)
    plt.imshow(b)
    plt.colorbar()

    plt.savefig(filename)
    plt.close(fig)

In [13]:
def plot_residuals(a, b, title, filename, points = None, text = None, pixscale = 1):

    centre_a = len(a)/2
    centre_b = len(b)/2
    extent = [-centre_a*pixscale, centre_a*pixscale, -centre_b*pixscale, centre_b*pixscale]

    fig = plt.figure(figsize=(12, 10))

    plt.suptitle(title)

    ax1 = plt.subplot(2, 2, 1)
    #ax1.set(xticks=np.linspace(0, 120, 7), xticklabels=np.arange(0, 121, 20), yticks=[0, 10], yticklabels=[0, 1])
    #ax1.set(xticks=xvals, xticklabels=np.arange(0, len(a), 20), yticks=xvals)

    plt.imshow(a, norm='log', cmap=cmap_inf, extent=extent)
    plt.xlabel("$\Delta$x, arcsec", fontsize=20)
    plt.ylabel("$\Delta$y, arcsec", fontsize=20)
    plt.colorbar().set_label("Normalized flux")
    plt.title("Data")

    if (points is not None):
        for x, y, s in points:
            plt.plot(x, y, marker=s, color="red")

    ax2 = plt.subplot(2, 2, 2)

    plt.imshow(b, norm='log', cmap=cmap_inf, extent=extent)
    plt.xlabel("$\Delta$x, arcsec", fontsize=20)
    plt.ylabel("$\Delta$y, arcsec", fontsize=20)
    plt.colorbar().set_label("Normalized flux")
    plt.title("Model")

    if (points is not None):
        for x, y, s in points:
            plt.plot(x, y, marker=s, color="red")

    delta = a - b
    vlim = np.nanmax(np.abs(delta))

    ax3 = plt.subplot(2, 2, 3)
    plt.imshow(a - b, vmin=-vlim, vmax=vlim, cmap=cmap_sei, extent=extent)
    plt.xlabel("$\Delta$x, arcsec", fontsize=20)
    plt.ylabel("$\Delta$y, arcsec", fontsize=20)
    plt.colorbar()
    plt.title("Residual")

    if (points is not None):
        for x, y, s in points:
            plt.plot(x, y, marker=s, color="red")
    
    residual = a - b
    #rel_resid = residual / a**0.5
    rel_resid = residual / b**0.5

    
    vlim_r = np.nanmax(np.abs(rel_resid))

    ax4 = plt.subplot(2, 2, 4)
    #plt.imshow(rel_resid, cmap=cmap)
    #plt.imshow(rel_resid, vmin=-(np.nanmax(rel_resid)*2), cmap=cmap_sei)
    plt.imshow(rel_resid, vmin=-vlim_r, vmax=vlim_r, cmap=cmap_sei, extent=extent)
    plt.xlabel("$\Delta$x, arcsec", fontsize=20)
    plt.ylabel("$\Delta$y, arcsec", fontsize=20)
    plt.colorbar()
    plt.title("Relative residual")
    #plt.title("Relative residual ((data - model)/data**0.5)")
    
    if points is not None:
        for x, y, s in points:
            plt.plot(x, y, marker=s, color="red")

    if text is not None:
        plt.figtext(0.25, 0.01, text)

    plt.savefig(filename)
    plt.close(fig)

In [14]:
def plot_chains(values, *, keys = None, filename):

    if keys is not None:
        values = {k: values[k] for k in keys}

    chain = cc.ChainConsumer()
    chain.add_chain(values)
    chain.configure_truth(color='r', ls=":", alpha=0.8)
    chain.configure(serif=True, shade=True, bar_shade=True, shade_alpha=0.2, spacing=1., max_ticks=3)

    fig = chain.plotter.plot()
    fig.set_size_inches((24, 24))

    plt.savefig(filename)
    plt.close(fig)

## Run All

In [40]:
base_path = '~/JWST/data/JWSTGO1902/calints/'
ouput_path = '~/JWST/results_single/'
pickle_dir = ouput_path

os.makedirs(ouput_path, exist_ok=True)
os.makedirs(pickle_dir, exist_ok=True)


files_all = [
    (('HD 135067', 'NRCB1', 'F212N'), '../data/JWSTGO1902/calints/jw01902002001_0210e_00001_nrcb1_calints.fits'),
    (('HD 135067', 'NRCB1', 'F187N'), '../data/JWSTGO1902/calints/jw01902002001_02108_00001_nrcb1_calints.fits'),
    (('HD 135067', 'NRCBLONG', 'F444W'), '../data/JWSTGO1902/calints/jw01902002001_02108_00001_nrcblong_calints.fits'),
    (('HD 135067', 'NRCBLONG', 'F444W'), '../data/JWSTGO1902/calints/jw01902002001_0210e_00001_nrcblong_calints.fits'),
    (('HD 135067', 'NRCBLONG', 'F322W2'), '../data/JWSTGO1902/calints/jw01902002001_02102_00001_nrcblong_calints.fits'),
    (('HD 135067', 'NRCB1', 'F150W2'), '../data/JWSTGO1902/calints/jw01902002001_02102_00001_nrcb1_calints.fits'),
    (('HD 136164', 'NRCB1', 'F212N'), '../data/JWSTGO1902/calints/jw01902001001_0210e_00001_nrcb1_calints.fits'),
    (('HD 136164', 'NRCB1', 'F187N'), '../data/JWSTGO1902/calints/jw01902001001_02108_00001_nrcb1_calints.fits'),
    (('HD 136164', 'NRCBLONG', 'F444W'), '../data/JWSTGO1902/calints/jw01902001001_0210e_00001_nrcblong_calints.fits'),
    (('HD 136164', 'NRCBLONG', 'F444W'), '../data/JWSTGO1902/calints/jw01902001001_02108_00001_nrcblong_calints.fits'),
    (('HD 136164', 'NRCBLONG', 'F322W2'), '../data/JWSTGO1902/calints/jw01902001001_02102_00001_nrcblong_calints.fits'),
    (('HD 136164', 'NRCB1', 'F150W2'), '../data/JWSTGO1902/calints/jw01902001001_02102_00001_nrcb1_calints.fits'),
]

files = [f for k, f in files_all]


In [41]:
plt.rcParams.update({'font.size': 10})

In [45]:
for fits_path in files:

    webbpsf_header, webbpsf_data_star, webbpsf_data_mask, webbpsf_data_err = read_fits(fits_path)

    filter_name = webbpsf_header["FILTER"] if webbpsf_header["PUPIL"] == "CLEAR" else webbpsf_header["PUPIL"]

    key_names = ["TARGNAME", "DETECTOR", "FILTER", "PUPIL" ]
    target_key = ' '.join(list([webbpsf_header[k] for k in key_names]))

    print(target_key, filter_name)

HD 135067 NRCB1 F212N CLEAR F212N
HD 135067 NRCB1 F187N CLEAR F187N
HD 135067 NRCBLONG F444W F405N F405N
HD 135067 NRCBLONG F444W F470N F470N
HD 135067 NRCBLONG F322W2 F323N F323N
HD 135067 NRCB1 F150W2 F164N F164N
HD 136164 NRCB1 F212N CLEAR F212N
HD 136164 NRCB1 F187N CLEAR F187N
HD 136164 NRCBLONG F444W F470N F470N
HD 136164 NRCBLONG F444W F405N F405N
HD 136164 NRCBLONG F322W2 F323N F323N
HD 136164 NRCB1 F150W2 F164N F164N


In [46]:
for fits_path in files:

    webbpsf_header, webbpsf_data_star, webbpsf_data_mask, webbpsf_data_err = read_fits(fits_path)

    npix = webbpsf_data_star.shape[0]
    webbpsf_data_flux = np.sum(webbpsf_data_star)

    webbpsf_data_masked = crop(webbpsf_data_star * webbpsf_data_mask)

    key_names = ["TARGNAME", "DETECTOR", "FILTER", "PUPIL" ]
    target_key = ' '.join(list([webbpsf_header[k] for k in key_names]))

    grid_search_filename = os.path.join(pickle_dir, target_key + " GRID SEARCH DATA.bin")
    hmc_filename = os.path.join(pickle_dir, target_key + " HMC DATA.bin")
    pickle_file = os.path.join(pickle_dir, target_key + ".bin")

    # double_plot(webbpsf_data_masked, webbpsf_data_mask, target_key, os.path.join(ouput_path, target_key + ' 1 DATA'))
    # continue

    print('Creating model...', end="")

    spectrum = SourceSpectrum(BlackBodyNorm1D, temperature=7500)

    telescope = get_telescope(npix,
                            webbpsf_header["FILTER"] if webbpsf_header["PUPIL"] == "CLEAR" else webbpsf_header["PUPIL"],
                            webbpsf_header['DETECTOR'].replace('LONG', '5'),
                            webbpsf_header['APERNAME'],
                            webbpsf_header['DATE-BEG'],
                            pixel_mask = webbpsf_data_mask,
                            downsample = 4,
                            flux = webbpsf_data_flux,
                            source=spectrum
                            )

    print("ok")

    print('Grid search...', end="")

    grid_data = None

    if os.path.isfile(grid_search_filename):
        with open(grid_search_filename, "rb") as f:
            grid_data = pickle.load(f)
    else:

        pri_xr, pri_yr, tp = search_grid(telescope, norm(webbpsf_data_masked))

        grid_data = {
            'data': norm(webbpsf_data_masked),
            'model': norm(crop(telescope.set(["source.position"], [np.array([pri_xr, pri_yr])]).model())),
            'x_found': pri_xr,
            'y_found': pri_yr
        }

        with open(os.path.join(pickle_dir, target_key + " GRID SEARCH DATA.bin"), "wb") as f:
            pickle.dump(grid_data, f)

    plot_residuals(grid_data['data'], grid_data['model'], target_key + " GRID SEARCH RESIDUALS", os.path.join(ouput_path, target_key + ' 2 GRID SEARCH'), pixscale=telescope.psf_pixel_scale)

    print("ok")

    continue

    print('HMC...')
    
    values_out = None

    if os.path.isfile(pickle_file):
        with open(pickle_file, "rb") as f:
            values_out = pickle.load(f)
    else:
        std = np.sqrt(webbpsf_data_err * webbpsf_data_mask).flatten()
        std = np.where((std == 0) | np.isnan(std), np.nanmean(std), std)

        coord_range = 3 * telescope.psf_pixel_scale
        log_flux_range = 1

        x0 = toArcsec(pri_xr)
        y0 = toArcsec(pri_yr)
        log_flux = np.log(webbpsf_data_flux)

        print(f'x0: {x0}')
        print(f'y0: {y0}')
        print(f'coord_range: {coord_range}')
        print(f'log_flux: {log_flux}')

        initial_values = {
            "x_arcsec": x0,
            "y_arcsec": y0,
            "log_flux": np.float64(log_flux)
        }

        for z, i in itertools.product(range(zernike_terms), range(n_mirrors)):
            if (z, i) not in fixed_mirrors:
                initial_values[f'p{i*zernike_terms + z}'] = 0

        sampler = npy.infer.MCMC(
            npy.infer.NUTS(psf_model, init_strategy=npy.infer.init_to_value(values=initial_values), dense_mass=True),
            num_warmup=2000,
            num_samples=2000,
            # num_chains=jax.device_count(),
            progress_bar=True
        )

        sampler.run(jr.PRNGKey(0), webbpsf_data_masked, std, telescope, x0 = x0, y0 = y0, coord_range = coord_range, log_flux = log_flux, log_flux_range = log_flux_range)
        
        #sampler.run(jr.PRNGKey(0), webbpsf_data_star, std, telescope)

        #sampler.print_summary()
        values_out = sampler.get_samples()

        with open(pickle_file, "wb") as f:
            pickle.dump(values_out, f)

        print("ok")

        print('Plotting chains...', end="")

        plot_chains(values_out, filename = os.path.join(ouput_path, target_key + " 4 CHAINS"))
        plot_chains(values_out, keys = ['x_arcsec', 'y_arcsec', 'log_flux'], filename = os.path.join(ouput_path, target_key + " 5 CHAINS SHORT"))

        print("ok")
    
    print('Getting residuals...', end="")

    hmc_data = None

    if os.path.isfile(hmc_filename):
        with open(hmc_filename, "rb") as f:
            hmc_data = pickle.load(f)
    else:
        x_predicted, y_predicted, flux_predicted, coeffs_pred_all = get_results(values_out)

        psf_found = telescope.set(
            [
                'source.position',
                'pupil.coefficients',
                'source.flux'
            ],
            [
                np.array([x_predicted, y_predicted]),
                coeffs_pred_all,
                flux_predicted
            ]).model()


        hmc_data = {
            'data': webbpsf_data_masked,
            'model': psf_found,
            'x_found': x_predicted,
            'y_found': y_predicted,
            'flux_predicted': flux_predicted,
            'coeffs_predicted': coeffs_pred_all
        }

        with open(os.path.join(pickle_dir, target_key + " HMC DATA.bin"), "wb") as f:
            pickle.dump(hmc_data, f)

    plot_residuals(hmc_data['data'], hmc_data['model'], target_key + " HMC RESIDUALS", os.path.join(ouput_path, target_key + ' 3 RESIDUALS'))

    print("ok")

    print('Getting aberrations...', end="")

    rec_aberrations = get_aber(hmc_data['coeffs_predicted'], telescope.pupil.basis)

    fig = plt.figure()
    plt.title("Zernike OPD")
    plt.imshow((rec_aberrations)*telescope.pupil.transmission)

    plt.colorbar()

    plt.savefig(os.path.join(ouput_path, target_key + " 6 ABERRATIONS"))
    plt.close(fig)

    print("ok")
    print('---------------------------------------------------------------------')
    print()




Creating model...
MAST OPD query around UTC: 2023-02-19T15:18:08.320
                        MJD: 59994.6375962963

OPD immediately preceding the given datetime:
	URI:	 mast:JWST/product/R2023021903-NRCA3_FP1-1.fits
	Date (MJD):	 59994.0235
	Delta time:	 -0.6141 days

OPD immediately following the given datetime:
	URI:	 mast:JWST/product/R2023022103-NRCA3_FP1-1.fits
	Date (MJD):	 59995.9079
	Delta time:	 1.2703 days
User requested choosing OPD time closest in time to 2023-02-19T15:18:08.320, which is R2023021903-NRCA3_FP1-1.fits, delta time -0.614 days
Importing and format-converting OPD from /root/JWST/webbpsf-data/MAST_JWST_WSS_OPDs/R2023021903-NRCA3_FP1-1.fits
Backing out SI WFE and OTE field dependence at the WF sensing field point
ok
Grid search...ok
Creating model...
MAST OPD query around UTC: 2023-02-19T14:56:06.414
                        MJD: 59994.622296458336

OPD immediately preceding the given datetime:
	URI:	 mast:JWST/product/R2023021903-NRCA3_FP1-1.fits
	Date (MJD):	 59