In [1]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.print_figure_kwargs={'facecolor':"w"}

In [2]:
import time
from functools import lru_cache

import numpy as np
import matplotlib.pyplot as plt
import batoid
import galsim
import ipywidgets as widgets

import danish

In [3]:
import os
import yaml

obsc = yaml.safe_load(open(os.path.join(danish.datadir, "RubinObsc.yaml")))
factory = danish.DonutFactory(obsc_radii=obsc['radii'], obsc_centers=obsc['centers'], obsc_th_mins=obsc['th_mins'])

In [4]:
band_angle = dict(
    thr = widgets.FloatSlider(value=1.67, min=0.0, max=2.0, step=0.01, description='thr (deg)'),
    ph = widgets.FloatSlider(value=0.0, min=0.0, max=360.0, step=5.0, description='phi (deg)'),
    defocus = widgets.Dropdown(options=[-0.0015, 0.0015]),
    seed = widgets.IntText(value=57721, description='seed'),
    nphoton = widgets.IntText(value=2_000_000, step=500_000, description='nphot')
)

In [5]:
@lru_cache
def sim(
    thx, thy, ph, defocus, seed, nphoton,
):
    t0 = time.time()
    print()
    print()
    print("starting simulation")

    telescope = batoid.Optic.fromYaml("LSST_i.yaml")
    telescope = telescope.withGloballyShiftedOptic("Detector", (0, 0, defocus))
    
    aberrations = batoid.analysis.zernikeTA(
        telescope, thx, thy, 750e-9,
        nrad=20, naz=120, jmax=66, eps=0.61
    )
    
    rng = galsim.BaseDeviate(seed)
    silicon = batoid.TableMedium.fromTxt("silicon_dispersion.txt")
    sensor = galsim.Sensor()

    rng = np.random.default_rng(seed)
    gsrng = galsim.BaseDeviate(seed)
    
    # Simulation
    # Populate pupil
    r_outer = 8.36/2
    # purposely underestimate inner radius a bit.
    # Rays that miss will just be vignetted.
    r_inner = 8.36/2*0.58
    r = np.sqrt(rng.uniform(r_inner**2, r_outer**2, nphoton))
    th = rng.uniform(0, 2*np.pi, nphoton)
    u = r*np.cos(th)
    v = r*np.sin(th)
    wavelengths = np.full(u.shape, 750e-9)
    
    kolm = galsim.Kolmogorov(fwhm=0.1)
    pa = galsim.PhotonArray(nphoton)
    kolm._shoot(pa, gsrng)
    dku = np.deg2rad(pa.x / 3600)
    dkv = np.deg2rad(pa.y / 3600)

    dku += thx
    dkv += thy
    vx, vy, vz = batoid.utils.fieldToDirCos(dku, dkv, projection='gnomonic')
    x = u
    y = v
    zPupil = telescope["M1"].surface.sag(0, 0.5*telescope.pupilSize)
    z = np.zeros_like(x)+zPupil
    n = telescope.inMedium.getN(wavelengths)
    vx /= n
    vy /= n
    vz /= n
    rays = batoid.RayVector(
        x, y, z,
        vx, vy, vz,
        t=0.0,
        wavelength=wavelengths,
        flux=1.0
    )

    telescope.trace(rays)
    
    pa = galsim.PhotonArray(nphoton)
    pa.x = rays.x/10e-6
    pa.y = rays.y/10e-6
    pa.flux = ~rays.vignetted

    image = galsim.Image(181, 181)
    image.setCenter(
        int(np.mean(pa.x[~rays.vignetted])),
        int(np.mean(pa.y[~rays.vignetted]))
    )
    sensor.accumulate(pa, image)

    # Add background.
    image.array[:] += rng.normal(scale=np.sqrt(1000), size=(181, 181))
    
    t1 = time.time()    
    print(f"sim time: {t1-t0:.3f} sec")    
    return image, aberrations

sim.cache_clear()

In [6]:
def demo(
    thr, ph, defocus, seed, nphoton,
):
    thx = np.deg2rad(thr)*np.cos(np.deg2rad(ph))
    thy = np.deg2rad(thr)*np.sin(np.deg2rad(ph))

    image, aberrations = sim(
        thx, thy, ph, defocus, seed, nphoton,
    )

    t0 = time.time()
    im2 = factory.image(
        aberrations=aberrations*750e-9, 
        thx=thx, thy=thy
    )
    t1 = time.time()
    print(f"geo time: {t1-t0:.3f} sec")

    fig, axes = plt.subplots(ncols=3, figsize=(10, 5))
    axes[0].imshow(image.array)
    axes[1].imshow(im2[::-1,::-1])
    axes[2].imshow(im2[::-1,::-1]/np.sum(im2)*np.sum(image.array) - image.array)
    plt.show()

In [7]:
all_widgets = {}
for d in [band_angle]:
    for k in d:
        all_widgets[k] = d[k]

output = widgets.interactive_output(demo, all_widgets)
display(widgets.VBox([
    widgets.HBox([
        widgets.VBox([v for v in band_angle.values()]),
    ]),
    output
]))

VBox(children=(HBox(children=(VBox(children=(FloatSlider(value=1.67, description='thr (deg)', max=2.0, step=0.…