In [2]:
import os

# Set CPU count for numpyro multi-chain multi-thread
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'

import pickle
import itertools

import jax
# Enable x64 for JAX
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as np
import jax.random as jr
from jax import jit, lax

import numpy as onp
from matplotlib import pyplot as plt
from astropy import units as u
from astropy.io import fits
from scipy.ndimage import binary_dilation

import zodiax as zdx
import dLux as dl
from dLux.utils import arcsec2rad as toRad
from dLux.utils import rad2arcsec as toArcsec
from dLuxWebbpsf.utils import grid_search, get_pixel_positions
from dLuxWebbpsf import NIRCam

import numpyro as npy
import numpyro.distributions as dist

import chainconsumer as cc

def norm(data):
    return data/data.sum()

def sepToXy(x, y, separation, angle):
    dx = separation * np.sin(angle) / 2
    dy = separation * np.cos(angle) / 2
    return (x + dx, y + dy), (x - dx, y - dy)

def xyToSep(x1, y1, x2, y2):
    x0 = (x2 + x1) / 2
    y0 = (y2 + y1) / 2

    dx = (-x2 + x1)
    dy = (-y2 + y1)

    sepr = np.sqrt(dx**2 + dy**2)
    angle = np.arctan2(dx, dy)
    
    return (x0, y0, sepr, angle)



In [3]:
plt.rcParams["text.usetex"] = 'true'
plt.rcParams["font.family"] = "serif"
plt.rcParams['figure.dpi'] = 120

plt.ioff()

<contextlib.ExitStack at 0x7f998cb848b0>

### Get data from FITS files

In [4]:
def read_fits(path):
    hdul = fits.open(path)
    #hdul.info()

    webbpsf_header = hdul[0].header
    webbpsf_data_star = np.array(sum(hdul[1].data))
    webbpsf_data_err = np.array(sum(hdul[2].data))
    webbpsf_data_dq = np.array(sum(hdul[3].data))

    hdul.close()

    # cleanup data
    webbpsf_data_star = np.where(np.isnan(webbpsf_data_star), 0, webbpsf_data_star)

    npix_x = webbpsf_data_star.shape[0]
    npix_y = webbpsf_data_star.shape[1]
    xx, yy = np.array(np.meshgrid(np.arange(npix_x), np.arange(npix_y)))

    bri_y, bri_x = np.unravel_index(webbpsf_data_star.argmax(), webbpsf_data_star.shape)

    msize = 2
    mask_x = (xx >= (bri_x - msize)) & (xx <= (bri_x + msize))
    mask_y = (yy >= bri_y - msize) & (yy <= bri_y + msize)

    webbpsf_data_mask = np.where((mask_x & mask_y), 0, 1)
    webbpsf_data_mask = np.logical_not(binary_dilation(np.logical_not(webbpsf_data_mask))).astype(webbpsf_data_mask.dtype)
    webbpsf_data_mask *= (webbpsf_data_dq == 0) & ~((xx == 9) & (yy == 31)) # 9,31 is a hot pixel

    return webbpsf_header, webbpsf_data_star, webbpsf_data_mask, webbpsf_data_err

### Create model

In [23]:
zernike_terms=3
n_mirrors = 18
n_coeffs = n_mirrors * zernike_terms

def get_telescope(npix, filter, detector, aperture, opd_date, *, pixel_mask = None, flux = None, contrast = 100):
    webb_options = {
        'jitter': None,   # jitter model name or None
        'jitter_sigma': 0.000,  # in arcsec per axis, default 0.007
        'add_distortion': False,
    }

    telescope = NIRCam(
                    filter = filter,
                    detector = detector,
                    aperture = aperture,
                    fft_oversample=1,
                    detector_oversample=4,
                    wavefront_downsample=8,
                    fov_pixels=npix,
                    options=webb_options,
                    phase_retrieval_terms = zernike_terms,
                    flux = flux,
                    load_opd_date=opd_date
                )
    
    spectrum = dl.Spectrum(telescope.filter_wavelengths, telescope.filter_weights)
    source = dl.BinarySource(telescope.filter_wavelengths, spectrum=spectrum, contrast=contrast)
    telescope = telescope.set(["source"], [source])

    if pixel_mask is not None:
        response_layer = dl.ApplyPixelResponse(pixel_mask)
        telescope.detector.layers["pixel_mask"] = response_layer

    return telescope

In [6]:
@zdx.filter_jit
def likelihood(model, x, y, separation, angle, expected_data):
    
    model_data_n = model.set(['source.position', 'source.separation', 'source.position_angle'],
                             [np.array([x, y]), separation, angle]).model()
    
    model_data = norm(model_data_n)
    residuals = expected_data - model_data
    # calculate the chi-squared value
    chi_squared = np.sum(residuals**2)
    # calculate the likelihood
    likelihood =  -0.5 * chi_squared
    return likelihood

@zdx.filter_jit
def likelihood_fast(telescope, x, y, separation, angle, expected_data):
    likelihood_c = lambda args : likelihood(telescope, args[0], args[1], args[2], args[3], expected_data)
    args = np.stack((x, y, separation, angle), axis=-1)
    return lax.map(likelihood_c, args)

In [40]:
def search_grid(model, data, *, iters = 20,
                pri_xr = None, pri_yr = None, sec_xr = None, sec_yr = None,
                pri_width = toRad(1), sec_width = toRad(1), plot = False
):

    @zdx.filter_jit
    def get_single(xr, yr):
        return likelihood_fast(model, xr, yr, np.zeros_like(xr), np.zeros_like(xr), data)

    @zdx.filter_jit
    def get_primary(xr, yr):
        #if sec_xr is None or sec_yr is None:
        #    return likelihood_fast(model, xr, yr, np.zeros_like(xr), np.zeros_like(xr), data)
        x, y, s, a =  xyToSep(xr, yr, sec_xr, sec_yr)
        return likelihood_fast(model, x, y, s, a, data)

    @zdx.filter_jit
    def get_secondary(xr, yr):
        x, y, s, a =  xyToSep(pri_xr, pri_yr, xr, yr)
        return likelihood_fast(model, x, y, s, a, data)
    
    steps = 30

    i = 0

    pri_xr, pri_yr, tp = grid_search(get_single, pri_xr, pri_yr, pri_width, steps, niter = 3)

    if sec_xr is None:
        sec_xr = pri_xr
        
    if sec_yr is None:
        sec_yr = pri_yr

    sec_xr, sec_yr, ts = grid_search(get_secondary, sec_xr, sec_yr, sec_width, steps, niter = 1)

    plt.show()

    while i < iters:
        print(f'Grid search step {i}')

        pri_xr, pri_yr, tp = grid_search(get_primary, pri_xr, pri_yr, pri_width, steps, niter = 1)
        
        print(f'Grid search step {i} - primary')
        sec_xr, sec_yr, ts = grid_search(get_secondary, sec_xr, sec_yr, sec_width, steps, niter = 1)
        
        print(f'Grid search step {i} - secondary')

        if plot:
            plt.subplot(1, 2, 1)
            plt.imshow(tp[0]['likelihoods'], origin='lower')
            plt.subplot(1, 2, 2)
            plt.imshow(ts[0]['likelihoods'], origin='lower')
            plt.show()
        
        pri_width = pri_width * 2. / 3.
        sec_width = sec_width * 4. / 5.
        
        i += 1

    return pri_xr, pri_yr, sec_xr, sec_yr

## Step 3: Set up HMC parameters and distributions.

In [8]:
phase_scale = 1e-6

fixed_mirrors = {
    (0,0): 0,
    (1,0): 0,
    (2,0): 0,
}

parameters = ['source.position',
              'source.separation',
              'source.position_angle',
              'source.mean_flux',
              'source.contrast',
              'pupil.coefficients'
             ]

def psf_model(data, data_err, model, *, x0, y0, s0, a0,
              coord_range, smin, log_flux, log_flux_range):
    """
    Define the numpyro function
    """
    
    x_sample = npy.sample("x_arcsec", dist.Uniform(x0 - coord_range, x0 + coord_range))
    y_sample = npy.sample("y_arcsec", dist.Uniform(y0 - coord_range, y0 + coord_range))

    separation_sample = npy.sample("sep_arcsec", dist.Uniform(smin, s0 + coord_range))
    angle_sample = npy.sample("angle", dist.Uniform(a0 - np.pi/8, a0 + np.pi/8))

    phases_data = [0] * n_coeffs
    for i, z in itertools.product(range(n_mirrors), range(zernike_terms)):
        coeff_index = i*zernike_terms + z
        if (z, i) in fixed_mirrors:
            phases_data[coeff_index] = fixed_mirrors[(z,i)]
        else:
            sample = npy.sample(f'p{coeff_index}', dist.Uniform(-1,1))
            phases_data[coeff_index] = sample * phase_scale
                            
    flux_sample  = npy.sample("log_flux", dist.Uniform(log_flux - log_flux_range, log_flux + log_flux_range))
    contrast_sample = npy.sample("contrast", dist.Uniform(100, 10000))

    values = [
        np.array([toRad(x_sample), toRad(y_sample)]),
        toRad(separation_sample),
        angle_sample,
        np.exp(flux_sample),
        contrast_sample,
        np.array(phases_data)
    ]
    
    with npy.plate("data", len(data.flatten())):
        model_data = (model.set(parameters, values).model()).flatten()        
        dist_model = dist.Normal(model_data, data_err)
        npy.sample("psf", dist_model, obs=data.flatten())
        

In [9]:
def get_results(values_out):
    
    mean_vals = {k: np.mean(v) for k, v in values_out.items()}

    x_pred = toRad(mean_vals["x_arcsec"])
    y_pred = toRad(mean_vals["y_arcsec"])
    sep_pred = toRad(mean_vals["sep_arcsec"])
    angle_pred = mean_vals["angle"]
    contrast_pred = mean_vals["contrast"]
    flux_pred = np.exp(mean_vals["log_flux"])

    coeffs_pred_all = [0] * n_coeffs
    for i, z in itertools.product(range(n_mirrors), range(zernike_terms)):
        coeff_index = i*zernike_terms + z
        if (z, i) in fixed_mirrors:
            coeffs_pred_all[coeff_index] = fixed_mirrors[(z,i)]
        else:
            coeffs_pred_all[coeff_index] = mean_vals[f'p{coeff_index}'] * phase_scale

    coeffs_pred_all = np.array(coeffs_pred_all)

    return x_pred, y_pred, sep_pred, angle_pred, flux_pred, contrast_pred, coeffs_pred_all

## Step 4: Make plots.

In [10]:
def get_aber(coeffs, basis):
    s_coeffs = np.asarray(coeffs.reshape(coeffs.shape[0], 1, 1), dtype=float)
    s_basis = np.asarray(basis, dtype=float)
    result = s_coeffs * s_basis
    return result.sum(0)

In [11]:
def double_plot(a, b, title, filename):
    fig = plt.figure(figsize=(8, 3))
    plt.suptitle(title)
    
    plt.subplot(1, 2, 1)
    plt.imshow(a, norm='log')
    plt.colorbar()

    plt.subplot(1, 2, 2)
    plt.imshow(b)
    plt.colorbar()

    plt.savefig(filename)
    plt.close(fig)

In [63]:
def plot_residuals(a, b, title, filename, points = None, text = None):
    fig = plt.figure(figsize=(12, 10))

    plt.suptitle(title)

    plt.subplot(2, 2, 1)

    plt.imshow(a, norm='log')
    plt.colorbar()
    plt.title("Data")

    if (points is not None):
        for x, y, s in points:
            plt.plot(x, y, marker=s, color="red")

    plt.subplot(2, 2, 2)

    plt.imshow(b, norm='log')
    plt.colorbar()
    plt.title("Model")

    if (points is not None):
        for x, y, s in points:
            plt.plot(x, y, marker=s, color="red")

    delta = a - b
    vlim = np.nanmax(np.abs(delta))

    plt.subplot(2, 2, 3)
    plt.imshow(a - b, vmin=-vlim, vmax=vlim)
    plt.colorbar()
    plt.title("Residual")

    if (points is not None):
        for x, y, s in points:
            plt.plot(x, y, marker=s, color="red")
    
    residual = a - b
    rel_resid = residual / a**0.5

    plt.subplot(2, 2, 4)
    plt.imshow(rel_resid)
    plt.imshow(rel_resid, vmin=-(np.nanmax(rel_resid)*2))
    plt.colorbar()
    plt.title("Relative residual ((data - model)/data**0.5)")
    
    if points is not None:
        for x, y, s in points:
            plt.plot(x, y, marker=s, color="red")

    if text is not None:
        plt.figtext(0.25, 0.01, text)

    plt.savefig(filename)
    plt.close(fig)

In [15]:
def plot_chains(values, *, keys = None, filename):

    if keys is not None:
        values = {k: values[k] for k in keys}

    chain = cc.ChainConsumer()
    chain.add_chain(values)
    chain.configure_truth(color='r', ls=":", alpha=0.8)
    chain.configure(serif=True, shade=True, bar_shade=True, shade_alpha=0.2, spacing=1., max_ticks=3)

    fig = chain.plotter.plot()
    fig.set_size_inches((24, 24))

    plt.savefig(filename)
    plt.close(fig)

## Run All

In [42]:
fits_path = '../data/JWSTGO1902/calints/jw01902002001_0210e_00001_nrcb1_calints.fits'
ouput_path = '~/JWST/results_binary/'
pickle_dir = ouput_path

os.makedirs(ouput_path, exist_ok=True)
os.makedirs(pickle_dir, exist_ok=True)

webbpsf_header, webbpsf_data_star, webbpsf_data_mask, webbpsf_data_err = read_fits(fits_path)

webbpsf_data_flux = np.sum(webbpsf_data_star)

webbpsf_data_masked = webbpsf_data_star * webbpsf_data_mask

npix = webbpsf_data_star.shape[0]


key_names = ["TARGNAME", "DETECTOR", "FILTER", "PUPIL" ]
target_key = ' '.join(list([webbpsf_header[k] for k in key_names]))

print(target_key)

double_plot(webbpsf_data_masked, webbpsf_data_mask, target_key, os.path.join(ouput_path, target_key + ' 1 DATA'))

print('Creating model...', end="")

telescope = get_telescope(npix,
                        webbpsf_header['FILTER'],
                        webbpsf_header['DETECTOR'].replace('LONG', '5'),
                        webbpsf_header['APERNAME'],
                        webbpsf_header['DATE-BEG'],
                        pixel_mask = webbpsf_data_mask,
                        flux = webbpsf_data_flux,
                        contrast = 1000)

pixscale = toRad(telescope.psf_pixel_scale)

def toPix(value):
    ret = (value / pixscale) + npix/2 - 0.5
    return ret

def fromPix(value):
    ret = ((value + 0.5) - npix/2) * pixscale
    return ret

# double_plot(telescope.model(), webbpsf_data_mask, target_key + " SAMPLE", os.path.join(ouput_path, target_key + ' 9 SAMPLE MODEL'))

print("ok")

HD 135067 NRCB1 F212N CLEAR
Creating model...
MAST OPD query around UTC: 2023-02-19T15:18:08.320
                        MJD: 59994.6375962963

OPD immediately preceding the given datetime:
	URI:	 mast:JWST/product/R2023021903-NRCA3_FP1-1.fits
	Date (MJD):	 59994.0235
	Delta time:	 -0.6141 days

OPD immediately following the given datetime:
	URI:	 mast:JWST/product/R2023022103-NRCA3_FP1-1.fits
	Date (MJD):	 59995.9079
	Delta time:	 1.2703 days
User requested choosing OPD time closest in time to 2023-02-19T15:18:08.320, which is R2023021903-NRCA3_FP1-1.fits, delta time -0.614 days
Importing and format-converting OPD from /root/JWST/webbpsf-data/MAST_JWST_WSS_OPDs/R2023021903-NRCA3_FP1-1.fits
Backing out SI WFE and OTE field dependence at the WF sensing field point
ok


In [None]:

print('Grid search...', end="")

pri_xr, pri_yr, sec_xr, sec_yr = search_grid(telescope, norm(webbpsf_data_masked),
                                             pri_width = 32 * pixscale,
                                             sec_xr = fromPix(48), sec_yr=(-fromPix(38)),
                                             sec_width = 10 * pixscale,
                                             iters = 10, plot=True)

print("ok")

In [48]:
mean_xr, mean_yr, separation_rec, angle_rec =  xyToSep(pri_xr, pri_yr, sec_xr, sec_yr)

recovered_data = telescope.set(['source.position', 'source.separation', 'source.position_angle'],
                                [np.array([mean_xr, mean_yr]), separation_rec, angle_rec]).model()

star_points = [
    (toPix(pri_xr), toPix(-pri_yr), '*'),
    (toPix(sec_xr), toPix(-sec_yr), '.')
]

plot_residuals(norm(webbpsf_data_masked), norm(recovered_data),
                target_key + " GRID SEARCH RESIDUALS", os.path.join(ouput_path, target_key + ' 2 GRID SEARCH'),
                points = star_points,
                text = f'separation: {toArcsec(separation_rec)} arcsec, angle:{angle_rec} rad')

In [49]:
print('HMC...')
    
std = np.sqrt(webbpsf_data_err * webbpsf_data_mask).flatten()
std = np.where((std == 0) | np.isnan(std), np.nanmean(std), std)

coord_range = 3 * telescope.psf_pixel_scale
log_flux_range = 1

x0 = toArcsec(mean_xr)
y0 = toArcsec(mean_yr)
s0 = toArcsec(separation_rec)
log_flux = np.log(webbpsf_data_flux)

smin = max(0, s0 - coord_range)

print(f'x0: {x0}')
print(f'y0: {y0}')
print(f's0: {s0}')
print(f'a0: {angle_rec}')
print(f'coord_range: {coord_range}')
print(f'log_flux: {log_flux}')


HMC...
x0: 0.29017584883069636
y0: -0.1665159754268638
s0: 0.4341662467154193
a0: -1.2956704184970418
coord_range: 0.09189801
log_flux: 18.76310157775879


In [50]:

initial_values = {
    "x_arcsec": x0,
    "y_arcsec": y0,
    "sep_arcsec": s0,
    "angle": angle_rec,
    "log_flux": np.float64(log_flux)
}

for z, i in itertools.product(range(zernike_terms), range(n_mirrors)):
    if (z, i) not in fixed_mirrors:
        initial_values[f'p{i*zernike_terms + z}'] = 0

sampler = npy.infer.MCMC(
    npy.infer.NUTS(psf_model, init_strategy=npy.infer.init_to_value(values=initial_values), dense_mass=True),
    num_warmup=2000,
    num_samples=2000,
    # num_chains=jax.device_count(),
    progress_bar=True
)

sampler.run(jr.PRNGKey(0), webbpsf_data_masked, std, telescope,
            x0 = x0, y0 = y0, s0 = s0, a0 = angle_rec, log_flux = log_flux,
            coord_range = coord_range, smin = smin, log_flux_range = log_flux_range)

#sampler.run(jr.PRNGKey(0), webbpsf_data_star, std, telescope)

#sampler.print_summary()
values_out = sampler.get_samples()

print("ok")



sample: 100%|██████████| 4000/4000 [3:34:46<00:00,  3.22s/it, 31 steps of size 1.11e-02. acc. prob=0.90]    


ok


In [52]:

print('Pickling...', end="")

pickle_file = target_key + ".bin"
pickle_path = os.path.join(pickle_dir, pickle_file)

with open(pickle_path, "wb") as f:
    pickle.dump(values_out, f)

print("ok")

print('Plotting chains...', end="")

plot_chains(values_out, filename = os.path.join(ouput_path, target_key + " 4 CHAINS"))
plot_chains(values_out, keys = ['x_arcsec', 'y_arcsec', 'log_flux', 'contrast'], filename = os.path.join(ouput_path, target_key + " 5 CHAINS SHORT"))

print("ok")



Pickling...ok
Plotting chains...



ok
Getting residuals...ok
Getting aberrations...ok


In [64]:
print('Getting residuals...', end="")

x_pred, y_pred, sep_pred, angle_pred, flux_pred, contrast_pred, coeffs_pred_all = get_results(values_out)

psf_found = telescope.set(
    [
        'source.position',
        'source.separation',
        'source.position_angle',
        'source.mean_flux',
        'source.contrast',
        'pupil.coefficients',
    ],
    [
        np.array([x_pred, y_pred]),
        sep_pred,
        angle_pred,
        flux_pred,
        contrast_pred,
        coeffs_pred_all
    ]).model()

(pri_xp, pri_yp), (sec_xp, sec_yp) = sepToXy(x_pred, y_pred, sep_pred, angle_pred)

star_points = [
    (toPix(pri_xp), toPix(-pri_yp), '*'),
    (toPix(sec_xp), toPix(-sec_yp), '.')
]

plot_residuals(webbpsf_data_masked, psf_found,
                target_key + " HMC RESIDUALS", os.path.join(ouput_path, target_key + ' 3 RESIDUALS'),
                points = star_points,
                text = f'separation: {toArcsec(sep_pred)} arcsec, angle:{angle_pred} rad, contrast: {contrast_pred}')

print("ok")

print('Getting aberrations...', end="")

rec_aberrations = get_aber(coeffs_pred_all, telescope.pupil.basis)

fig = plt.figure()
plt.title("Zernike OPD")
plt.imshow((rec_aberrations)*telescope.pupil.transmission)

plt.colorbar()

plt.savefig(os.path.join(ouput_path, target_key + " 6 ABERRATIONS"))
plt.close(fig)

print("ok")

Getting residuals...ok
Getting aberrations...ok


In [13]:
base_path = '../data/JWSTGO1902/calints/' #concat/'
ouput_path = './results_binary/'
pickle_dir = ouput_path

os.makedirs(ouput_path, exist_ok=True)
os.makedirs(pickle_dir, exist_ok=True)


files_all = [
    # (('HD 136164', 'NRCB1', 'F212N'), '../data/JWSTGO1902/calints/jw01902001001_0210e_00001_nrcb1_calints.fits'),
    (('HD 136164', 'NRCB1', 'F187N'), '../data/JWSTGO1902/calints/jw01902001001_02108_00001_nrcb1_calints.fits'),
    (('HD 136164', 'NRCBLONG', 'F444W'), '../data/JWSTGO1902/calints/jw01902001001_0210e_00001_nrcblong_calints.fits'),
    (('HD 136164', 'NRCBLONG', 'F444W'), '../data/JWSTGO1902/calints/jw01902001001_02108_00001_nrcblong_calints.fits'),
    (('HD 136164', 'NRCBLONG', 'F322W2'), '../data/JWSTGO1902/calints/jw01902001001_02102_00001_nrcblong_calints.fits'),
    #(('HD 136164', 'NRCB1', 'F150W2'), '../data/JWSTGO1902/calints/jw01902001001_02102_00001_nrcb1_calints.fits'),
    (('HD 135067', 'NRCB1', 'F187N'), '../data/JWSTGO1902/calints/jw01902002001_02108_00001_nrcb1_calints.fits'),
    (('HD 135067', 'NRCB1', 'F212N'), '../data/JWSTGO1902/calints/jw01902002001_0210e_00001_nrcb1_calints.fits'),
    (('HD 135067', 'NRCBLONG', 'F444W'), '../data/JWSTGO1902/calints/jw01902002001_02108_00001_nrcblong_calints.fits'),
    (('HD 135067', 'NRCBLONG', 'F444W'), '../data/JWSTGO1902/calints/jw01902002001_0210e_00001_nrcblong_calints.fits'),
    (('HD 135067', 'NRCBLONG', 'F322W2'), '../data/JWSTGO1902/calints/jw01902002001_02102_00001_nrcblong_calints.fits'),
    #(('HD 135067', 'NRCB1', 'F150W2'), '../data/JWSTGO1902/calints/jw01902002001_02102_00001_nrcb1_calints.fits'),
]

files = [f for k, f in files_all]


In [None]:
for f in files:

    fits_path = f # base_path + 'jw01902002001_0210e_00001_nrcblong_calints.fits'

    webbpsf_header, webbpsf_data_star, webbpsf_data_mask, webbpsf_data_err = read_fits(fits_path)

    webbpsf_data_flux = np.sum(webbpsf_data_star)
    
    webbpsf_data_masked = webbpsf_data_star * webbpsf_data_mask

    npix = webbpsf_data_star.shape[0]
    

    key_names = ["TARGNAME", "DETECTOR", "FILTER", "PUPIL" ]
    target_key = ' '.join(list([webbpsf_header[k] for k in key_names]))
    
    grid_search_filename = os.path.join(pickle_dir, target_key + " GRID SEARCH DATA.bin")
    hmc_filename = os.path.join(pickle_dir, target_key + " HMC DATA.bin")
    pickle_file = os.path.join(pickle_dir, target_key + ".bin")

    print(target_key)

    double_plot(webbpsf_data_masked, webbpsf_data_mask, target_key, os.path.join(ouput_path, target_key + ' 1 DATA'))

    print('Creating model...', end="")

    telescope = get_telescope(npix,
                            webbpsf_header['FILTER'],
                            webbpsf_header['DETECTOR'].replace('LONG', '5'),
                            webbpsf_header['APERNAME'],
                            webbpsf_header['DATE-BEG'],
                            pixel_mask = webbpsf_data_mask,
                            flux = webbpsf_data_flux)
    
    pixscale = toRad(telescope.psf_pixel_scale)

    def toPix(value):
        ret = (value / pixscale) + npix/2 - 0.5
        return ret

    double_plot(telescope.model(), webbpsf_data_mask, target_key + " SAMPLE", os.path.join(ouput_path, target_key + ' 9 SAMPLE MODEL'))

    print("ok")

    print('Grid search...', end="")
    grid_data = None

    if os.path.isfile(grid_search_filename):
        with open(grid_search_filename, "rb") as f:
            grid_data = pickle.load(f)
    else:

        pri_xr, pri_yr, sec_xr, sec_yr = search_grid(telescope, norm(webbpsf_data_masked))
        mean_xr, mean_yr, separation_rec, angle_rec =  xyToSep(pri_xr, pri_yr, sec_xr, sec_yr)
        
        recovered_data = telescope.set(['source.position', 'source.separation', 'source.position_angle'],
                                   [np.array([mean_xr, mean_yr]), separation_rec, angle_rec]).model()

        grid_data = {
            'data': norm(webbpsf_data_masked),
            'model': norm(telescope.set(["source.position"], [np.array([pri_xr, pri_yr])]).model()),
            'x_found': mean_xr,
            'y_found': mean_yr,
            'sep_found': separation_rec,
            'angle_found': angle_rec
        }

        with open(os.path.join(pickle_dir, target_key + " GRID SEARCH DATA.bin"), "wb") as f:
            pickle.dump(grid_data, f)


    (pri_xr, pri_yr), (sec_xr, sec_yr) =  sepToXy(grid_data['x_found'], grid_data['x_found'], grid_data['sep_found'], grid_data['angle_found'])
     
    star_points = [
        (toPix(pri_xr), toPix(-pri_yr), '*'),
        (toPix(sec_xr), toPix(-sec_yr), '.')
    ]

    plot_residuals(norm(webbpsf_data_masked), norm(recovered_data),
                target_key + " GRID SEARCH RESIDUALS", os.path.join(ouput_path, target_key + ' 2 GRID SEARCH'),
                points = star_points,
                text = f'separation: {toArcsec(separation_rec)} arcsec, angle:{angle_rec} rad')

    print("ok")


    print('HMC...')

    values_out = None

    if os.path.isfile(pickle_file):
        with open(pickle_file, "rb") as f:
            values_out = pickle.load(f)
    else:
        std = np.sqrt(webbpsf_data_err * webbpsf_data_mask).flatten()
        std = np.where((std == 0) | np.isnan(std), np.nanmean(std), std)

        coord_range = 3 * telescope.psf_pixel_scale
        log_flux_range = 1

        x0 = toArcsec(mean_xr)
        y0 = toArcsec(mean_yr)
        s0 = toArcsec(separation_rec)
        log_flux = np.log(webbpsf_data_flux)
        
        smin = max(0, s0 - coord_range)

        print(f'x0: {x0}')
        print(f'y0: {y0}')
        print(f's0: {s0}')
        print(f'a0: {angle_rec}')
        print(f'coord_range: {coord_range}')
        print(f'log_flux: {log_flux}')

        initial_values = {
            "x_arcsec": x0,
            "y_arcsec": y0,
            "sep_arcsec": s0,
            "angle": angle_rec,
            "log_flux": np.float64(log_flux)
        }

        for z, i in itertools.product(range(zernike_terms), range(n_mirrors)):
            if (z, i) not in fixed_mirrors:
                initial_values[f'p{i*zernike_terms + z}'] = 0

        sampler = npy.infer.MCMC(
            npy.infer.NUTS(psf_model, init_strategy=npy.infer.init_to_value(values=initial_values), dense_mass=True),
            num_warmup=2000,
            num_samples=2000,
            # num_chains=jax.device_count(),
            progress_bar=True
        )

        sampler.run(jr.PRNGKey(0), webbpsf_data_masked, std, telescope,
                    x0 = x0, y0 = y0, s0 = s0, a0 = angle_rec, log_flux = log_flux,
                    coord_range = coord_range, smin = smin, log_flux_range = log_flux_range)
        
        #sampler.run(jr.PRNGKey(0), webbpsf_data_star, std, telescope)

        #sampler.print_summary()
        values_out = sampler.get_samples()

        with open(pickle_file, "wb") as f:
            pickle.dump(values_out, f)

        print("ok")

        print('Plotting chains...', end="")

        plot_chains(values_out, filename = os.path.join(ouput_path, target_key + " 4 CHAINS"))
        plot_chains(values_out, keys = ['x_arcsec', 'y_arcsec', 'sep_arcsec', 'angle', 'log_flux', 'contrast'], filename = os.path.join(ouput_path, target_key + " 5 CHAINS SHORT"))

        print("ok")
    
    
    print('Getting residuals...', end="")

    x_pred, y_pred, sep_pred, angle_pred, flux_pred, contrast_pred, coeffs_pred_all = get_results(values_out)

    psf_found = telescope.set(
        [
            'source.position',
            'source.separation',
            'source.position_angle',
            'source.mean_flux',
            'source.contrast',
            'pupil.coefficients',
        ],
        [
            np.array([x_pred, y_pred]),
            sep_pred,
            angle_pred,
            flux_pred,
            contrast_pred,
            coeffs_pred_all
        ]).model()

    (pri_xp, pri_yp), (sec_xp, sec_yp) = sepToXy(x_pred, y_pred, sep_pred, angle_pred)

    star_points = [
        (toPix(pri_xp), toPix(-pri_yp), '*'),
        (toPix(sec_xp), toPix(-sec_yp), '.')
    ]

    plot_residuals(webbpsf_data_masked, psf_found,
                   target_key + " HMC RESIDUALS", os.path.join(ouput_path, target_key + ' 3 RESIDUALS'),
                   points = star_points,
                   text = f'separation: {toArcsec(sep_pred)} arcsec, angle:{angle_pred} rad, contrast: {contrast_pred}')

    print("ok")

    print('Getting aberrations...', end="")

    rec_aberrations = get_aber(coeffs_pred_all, telescope.pupil.basis)

    fig = plt.figure()
    plt.title("Zernike OPD")
    plt.imshow((rec_aberrations)*telescope.pupil.transmission)

    plt.colorbar()

    plt.savefig(os.path.join(ouput_path, target_key + " 6 ABERRATIONS"))
    plt.close(fig)

    print("ok")




HD 136164 NRCB1 F187N CLEAR
Creating model...
MAST OPD query around UTC: 2023-02-19T13:16:12.957
                        MJD: 59994.552927743054

OPD immediately preceding the given datetime:
	URI:	 mast:JWST/product/R2023021903-NRCA3_FP1-1.fits
	Date (MJD):	 59994.0235
	Delta time:	 -0.5295 days

OPD immediately following the given datetime:
	URI:	 mast:JWST/product/R2023022103-NRCA3_FP1-1.fits
	Date (MJD):	 59995.9079
	Delta time:	 1.3550 days
User requested choosing OPD time closest in time to 2023-02-19T13:16:12.957, which is R2023021903-NRCA3_FP1-1.fits, delta time -0.529 days
Importing and format-converting OPD from /Users/uqitroit/Dev/optics/webbpsf-data/MAST_JWST_WSS_OPDs/R2023021903-NRCA3_FP1-1.fits
Backing out SI WFE and OTE field dependence at the WF sensing field point
ok
Grid search...Grid search step 0
Grid search step 0 - primary
Grid search step 0 - secondary
Grid search step 1
Grid search step 1 - primary
Grid search step 1 - secondary
Grid search step 2
Grid search st

sample: 100%|█| 4000/4000 [4:19:36<00:00,  3.89s/it, 63 steps of size 1.06e-02. 


ok
Pickling...ok
Plotting chains...



ok
Getting residuals...ok
Getting aberrations...ok
HD 136164 NRCBLONG F444W F470N


INFO:webbpsf:NIRCam aperture name updated to NRCA1_FULL
INFO:webbpsf:NIRCam pixel scale switched to 0.062909 arcsec/pixel for the long wave channel.
INFO:webbpsf:NIRCam aperture name updated to NRCA5_FULL
INFO:webbpsf:NIRCam aperture name updated to NRCB5_FULL


Creating model...
MAST OPD query around UTC: 2023-02-19T13:36:52.086
                        MJD: 59994.567269513886

OPD immediately preceding the given datetime:
	URI:	 mast:JWST/product/R2023021903-NRCA3_FP1-1.fits
	Date (MJD):	 59994.0235
	Delta time:	 -0.5438 days

OPD immediately following the given datetime:
	URI:	 mast:JWST/product/R2023022103-NRCA3_FP1-1.fits
	Date (MJD):	 59995.9079
	Delta time:	 1.3406 days
User requested choosing OPD time closest in time to 2023-02-19T13:36:52.086, which is R2023021903-NRCA3_FP1-1.fits, delta time -0.544 days
Importing and format-converting OPD from /Users/uqitroit/Dev/optics/webbpsf-data/MAST_JWST_WSS_OPDs/R2023021903-NRCA3_FP1-1.fits
Backing out SI WFE and OTE field dependence at the WF sensing field point


INFO:webbpsf:NIRCam aperture name updated to NRCA1_FULL
INFO:webbpsf:NIRCam aperture name updated to NRCA3_FULL
INFO:webbpsf:NIRCam aperture name updated to NRCA3_FP1
INFO:poppy:OPD from /Users/uqitroit/Dev/optics/webbpsf-data/NIRCam/OPD/wss_target_phase_fp1.fits: Loaded OPD from /Users/uqitroit/Dev/optics/webbpsf-data/NIRCam/OPD/wss_target_phase_fp1.fits
INFO:poppy:No info supplied on amplitude transmission; assuming uniform throughput = 1
INFO:webbpsf:Creating optical system model:
INFO:poppy:Initialized OpticalSystem: JWST+NIRCam
INFO:poppy:JWST Entrance Pupil: Loaded amplitude transmission from /Users/uqitroit/Dev/optics/webbpsf-data/jwst_pupil_RevW_npix1024.fits.gz
INFO:poppy:JWST Entrance Pupil: Loaded OPD from /Users/uqitroit/Dev/optics/webbpsf-data/JWST_OTE_OPD_cycle1_example_2022-07-30.fits
INFO:webbpsf:Loading field dependent model parameters from /Users/uqitroit/Dev/optics/webbpsf-data/NIRCam/OPD/field_dep_table_nircam.fits
INFO:webbpsf:Calculating field-dependent OTE OPD at

ok
Grid search...Grid search step 0
Grid search step 0 - primary
Grid search step 0 - secondary
Grid search step 1
Grid search step 1 - primary
Grid search step 1 - secondary
Grid search step 2
Grid search step 2 - primary
Grid search step 2 - secondary
Grid search step 3
Grid search step 3 - primary
Grid search step 3 - secondary
Grid search step 4
Grid search step 4 - primary
Grid search step 4 - secondary
Grid search step 5
Grid search step 5 - primary
Grid search step 5 - secondary
Grid search step 6
Grid search step 6 - primary
Grid search step 6 - secondary
Grid search step 7
Grid search step 7 - primary
Grid search step 7 - secondary
Grid search step 8
Grid search step 8 - primary
Grid search step 8 - secondary
Grid search step 9
Grid search step 9 - primary
Grid search step 9 - secondary
Grid search step 10
Grid search step 10 - primary
Grid search step 10 - secondary
Grid search step 11
Grid search step 11 - primary
Grid search step 11 - secondary
Grid search step 12
Grid sear

In [None]:
pickle_path = './results_binary/HD 136164 NRCB1 F212N CLEAR.bin'

with open(pickle_path, "rb") as f:
    values_out = pickle.load(f)

print('Getting residuals...', end="")

x_pred, y_pred, sep_pred, angle_pred, flux_pred, contrast_pred, coeffs_pred_all = get_results(values_out)

psf_found = telescope.set(
    [
        'source.position',
        'source.separation',
        'source.position_angle',
        'source.mean_flux',
        'source.contrast',
        'pupil.coefficients',
    ],
    [
        np.array([x_pred, y_pred]),
        sep_pred,
        angle_pred,
        flux_pred,
        contrast_pred,
        coeffs_pred_all
    ]).model()

(pri_xp, pri_yp), (sec_xp, sec_yp) = sepToXy(x_pred, y_pred, sep_pred, angle_pred)

star_points = [
    (toPix(pri_xp), toPix(-pri_yp), '*'),
    (toPix(sec_xp), toPix(-sec_yp), '.')
]

plot_residuals(webbpsf_data_masked, psf_found,
               target_key + " HMC RESIDUALS", os.path.join(ouput_path, target_key + ' 3 RESIDUALS'),
               points = star_points,
               text = f'separation: {toArcsec(sep_pred)} arcsec, angle:{angle_pred} rad, contrast: {contrast_pred}')

print("ok")

print('Getting aberrations...', end="")

rec_aberrations = get_aber(coeffs_pred_all, telescope.pupil.basis)

fig = plt.figure()
plt.title("Zernike OPD")
plt.imshow((rec_aberrations)*telescope.pupil.transmission)

plt.colorbar()

plt.savefig(os.path.join(ouput_path, target_key + " 6 ABERRATIONS"))
plt.close(fig)

print("ok")