In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.print_figure_kwargs={'facecolor':"w"}

In [None]:
import time
from functools import lru_cache

import numpy as np
import matplotlib.pyplot as plt
import batoid
import galsim
import ipywidgets as widgets
from scipy.optimize import least_squares

import danish

In [None]:
import os
import yaml

obsc = yaml.safe_load(open(os.path.join(danish.datadir, "RubinObsc.yaml")))
factory = danish.DonutFactory(
    mask_params=obsc,
    R_outer=4.18,
    R_inner=2.5498,
    focal_length=10.31,
    pixel_scale=10e-6,
)
wavelength_dict = {
    'u': 365e-9,
    'g': 480e-9,
    'r': 625e-9,
    'i': 770e-9,
    'z': 910e-9,
    'y': 960e-9,
}

In [None]:
band_angle = dict(
    band = widgets.Dropdown(options=['u', 'g', 'r', 'i', 'z', 'y'], value='r', description='band'),
    thr = widgets.FloatSlider(value=1.67, min=0.0, max=2.0, step=0.01, description='thr (deg)'),
    ph = widgets.FloatSlider(value=0.0, min=0.0, max=360.0, step=5.0, description='phi (deg)'),
    fwhm = widgets.FloatSlider(value=0.1, min=0.0, max=1.0, step=0.01, description='FWHM (arcsec)'),
    defocus = widgets.Dropdown(options=[-0.0015, 0.0015], value=-0.0015, description='defocus (m)'),
    seed = widgets.IntText(value=57721, description='seed'),
    nphoton = widgets.IntText(value=2_000_000, step=500_000, description='nphot'),
    center_offset = widgets.FloatText(value=0.0, min=-100.0, max=100.0, step=1.0, description='center offset')
)

In [None]:
@lru_cache
def sim(
    thx, thy, fwhm, defocus, seed, nphoton, band
):
    t0 = time.time()
    print()
    print()
    print("starting simulation")

    telescope = batoid.Optic.fromYaml(f"LSST_{band}.yaml")
    telescope = telescope.withGloballyShiftedOptic("Detector", (0, 0, defocus))

    wavelength = wavelength_dict[band]
    aberrations = batoid.analysis.zernikeTA(
        telescope, thx, thy, wavelength,
        nrad=20, naz=120, jmax=66, eps=2.5498/4.18, focal_length=10.31
    ) * wavelength * 1e9  # in nm

    rng = galsim.BaseDeviate(seed)
    sensor = galsim.Sensor()

    rng = np.random.default_rng(seed)
    gsrng = galsim.BaseDeviate(seed)

    # Simulation
    # Populate pupil
    r_outer = 8.36/2
    # purposely underestimate inner radius a bit.
    # Rays that miss will just be vignetted.
    r_inner = 8.36/2*0.58
    r = np.sqrt(rng.uniform(r_inner**2, r_outer**2, nphoton))
    th = rng.uniform(0, 2*np.pi, nphoton)
    u = r*np.cos(th)
    v = r*np.sin(th)
    wavelengths = rng.normal(loc=wavelength, scale=0.15*wavelength, size=nphoton)

    kolm = galsim.Kolmogorov(fwhm=fwhm)
    pa = galsim.PhotonArray(nphoton)
    kolm._shoot(pa, gsrng)
    dku = np.deg2rad(pa.x / 3600)  # arcsec -> rad
    dkv = np.deg2rad(pa.y / 3600)  # arcsec -> rad

    dku += thx
    dkv += thy
    vx, vy, vz = batoid.utils.fieldToDirCos(dku, dkv, projection='gnomonic')
    x = u
    y = v
    zPupil = telescope["M1"].surface.sag(0, 0.5*telescope.pupilSize)
    z = np.full_like(x, zPupil)
    n = telescope.inMedium.getN(wavelengths)
    vx /= n
    vy /= n
    vz /= n
    rays = batoid.RayVector(
        x, y, z,
        vx, vy, vz,
        t=0.0,
        wavelength=wavelengths,
        flux=1.0
    )

    telescope.trace(rays)

    pa = galsim.PhotonArray(nphoton)
    pa.x = rays.x/10e-6
    pa.y = rays.y/10e-6
    pa.flux = ~rays.vignetted

    image = galsim.Image(181, 181)
    image.setCenter(
        int(np.mean(pa.x[~rays.vignetted])),
        int(np.mean(pa.y[~rays.vignetted]))
    )
    sensor.accumulate(pa, image)

    # Add background.
    image.array[:] += rng.normal(scale=np.sqrt(1000), size=(181, 181))

    t1 = time.time()
    print(f"sim time: {t1-t0:.3f} sec")
    return image, aberrations

sim.cache_clear()

In [None]:
def demo(
    thr, ph, fwhm, defocus, seed, nphoton, band, center_offset=0.0
):
    thx = np.deg2rad(thr)*np.cos(np.deg2rad(ph))
    thy = np.deg2rad(thr)*np.sin(np.deg2rad(ph))

    image, aberrations = sim(
        thx, thy, fwhm, defocus, seed, nphoton, band,
    )
    flux = np.sum(image.array)

    dx = 0.0
    dy = 0.0
    if center_offset != 0.0:
        dx = np.cos(np.deg2rad(ph))*center_offset
        dy = np.sin(np.deg2rad(ph))*center_offset

        # aberrations = np.array(aberrations)
        # aberrations[2] += dx * 1e3  # µm -> nm
        # aberrations[3] += dy * 1e3  # µm -> nm

    fitter = danish.SingleDonutModel(
        factory,
        z_ref=aberrations * 1e-9,  # nm -> m
        z_terms=np.arange(4, 29),
        thx=thx, thy=thy,
        npix=181, bkg_order=0
    )

    t0 = time.time()
    im2 = fitter.model(flux, dx, dy, fwhm, np.zeros(25))
    t1 = time.time()
    print(f"geo time: {t1-t0:.3f} sec")

    t2 = time.time()
    result = least_squares(
        fitter.chi, [flux,0,0,1.0]+[0]*25+[0], jac=fitter.jac,
        ftol=1e-3, xtol=1e-3, gtol=1e-3,
        max_nfev=20,
        args=(image.array, sky_level:=1000)
    )
    result = fitter.unpack_params(result.x)
    im3 = fitter.model(**result)
    t3 = time.time()
    print(f"fit time: {t3-t2:.3f} sec")

    out = ""
    for j0 in range(4, 10):
        for j in range(j0, 29, 6):
            v = result["z_fit"][j-4]
            zstr = f"Z{j}"
            out += f"{zstr:>4s}: {v*1e9:7.1f} nm    "
        out += "\n"

    # for j, v in enumerate(result.x[3:]):
    #     zstr = f"Z{j+4}"
    #     out += f"{zstr:>4s}: {v*1e9:7.1f} nm  "
    #     if (j+1) % 4 == 0:
    #         out += "\n"
    print("Fitted aberrations:")
    print(out)


    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(12, 8))
    axes[0, 0].imshow(image.array)
    axes[0, 1].imshow(im2)
    axes[0, 2].imshow(im2 - image.array)
    axes[1, 0].imshow(image.array)
    axes[1, 1].imshow(im3)
    axes[1, 2].imshow(im3 - image.array)
    axes[0, 0].set_title("Simulated")
    axes[0, 1].set_title("Initial Model")
    axes[0, 2].set_title("Initial Residual")
    axes[1, 0].set_title("Simulated")
    axes[1, 1].set_title("Fitted Model")
    axes[1, 2].set_title("Fitted Residual")
    plt.show()

In [None]:
all_widgets = {}
for d in [band_angle]:
    for k in d:
        all_widgets[k] = d[k]

output = widgets.interactive_output(demo, all_widgets)
display(widgets.VBox([
    widgets.HBox([
        widgets.VBox([v for v in band_angle.values()]),
    ]),
    output
]))