In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.print_figure_kwargs={'facecolor':"w"}

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Circle
import matplotlib
import ipywidgets as widgets
import galsim

import danish

In [None]:
def f(
    th, ph, show_M1, show_obscurations, show_flux, **kwargs
):
    # Setup for Rubin observatory
    R_outer = 4.18
    R_inner = R_outer * 0.61
    wavelength = 750e-9
    focal_length = 10.31
    pixel_scale = 50e-6  # 5x larger-than-life
    npix = 37
    no2 = (npix-1)//2

    obsc_radii = {
        'M1_inner': 2.5580033095346875,
        'M2_outer': 4.502721059044802,
        'M2_inner': 2.3698531889709487,
        'M3_outer': 5.4353949343626216,
        'M3_inner': 1.1919725733251365,
        'L1_entrance': 7.692939426566589,
        'L1_exit': 8.103064894823262,
        'L2_entrance': 10.746925431763076,
        'L2_exit': 11.548732622162085,
        'Filter_entrance': 28.06952057721957,
        'Filter_exit': 30.895257933242576,
        'L3_entrance': 54.5631834759912,
        'L3_exit': 114.76715786850136
    }
    obsc_motion = {
        'M1_inner': 0.1517605552388959,
        'M2_outer': 16.818667026561727,
        'M2_inner': 16.818667026561727,
        'M3_outer': 53.2113063872138,
        'M3_inner': 53.2113063872138,
        'L1_entrance': 131.69949884635324,
        'L1_exit': 137.51151184228345,
        'L2_entrance': 225.63931108752732,
        'L2_exit': 236.8641351903567,
        'Filter_entrance': 801.6598843836333,
        'Filter_exit': 879.4647343264201,
        'L3_entrance': 1594.7432961792515,
        'L3_exit': 3328.637595923783
    }
    
    coef = np.zeros(23)
    for i in range(4, 23):
        coef[i] = kwargs[f"z{i}"]
    coef *= wavelength
    Z = galsim.zernike.Zernike(coef, R_outer=R_outer, R_inner=R_inner)
    
    # Create axes
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 8))
    for ax in axes:
        ax.set_aspect('equal')
    axes[0].set_title("Pixel space")
    axes[1].set_title("Pupil space")
    axes[0].set_xlabel("Pixels")
    axes[0].set_ylabel("Pixels")
    axes[1].set_xlabel("Meters")
    axes[1].set_ylabel("Meters")
    axes[0].set_xlim(-no2-4, no2+4)
    axes[0].set_ylim(-no2-4, no2+4)
    axes[1].set_xlim(-5.5, 5.5)
    axes[1].set_ylim(-5.5, 5.5)
    
    # Pixel grid
    xgrid = []
    ygrid = []
    for x in np.arange(-no2-0.5, no2+0.5+0.1):
        xgrid.append(np.linspace(-no2-0.5, no2+0.5, 50))
        ygrid.append(np.ones(50)*x)
        xgrid.append(np.ones(50)*x)
        ygrid.append(np.linspace(-no2-0.5, no2+0.5, 50))
    xgrid = np.array(xgrid)
    ygrid = np.array(ygrid)
    axes[0].plot(xgrid.T, ygrid.T, c='k', lw=1)

    # Project to pupil grid.
    ugrid, vgrid = danish.factory._focal_to_pupil(
        xgrid*pixel_scale, ygrid*pixel_scale, 
        Z, focal_length=focal_length
    )
    axes[1].plot(ugrid.T, vgrid.T, c='k', lw=1)

    # M1
    th = np.deg2rad(th)
    ph = np.deg2rad(ph)
    thx, thy = th*np.cos(ph), th*np.sin(ph)
    if show_M1:
        # M1 outer
        uobsc = 4.18*np.cos(np.linspace(0, 2*np.pi, 500))
        vobsc = 4.18*np.sin(np.linspace(0, 2*np.pi, 500))
        axes[1].plot(uobsc, vobsc, c='m', lw=2)
        xobsc, yobsc = danish.factory._pupil_to_focal(
            uobsc, vobsc, Z, focal_length=focal_length
        )
        xobsc /= pixel_scale
        yobsc /= pixel_scale
        axes[0].plot(xobsc, yobsc, c='m', lw=2)
        axes[0].annotate('M1_outer', (xobsc[30], yobsc[30]))
        axes[1].annotate('M1_outer', (uobsc[30], vobsc[30]))
        
        # M1 inner
        uobsc = 0.61*4.18*np.cos(np.linspace(0, 2*np.pi, 500))
        vobsc = 0.61*4.18*np.sin(np.linspace(0, 2*np.pi, 500))
        axes[1].plot(uobsc, vobsc, c='r', lw=2)
        xobsc, yobsc = danish.factory._pupil_to_focal(
            uobsc, vobsc, Z, focal_length=focal_length
        )
        xobsc /= pixel_scale
        yobsc /= pixel_scale
        axes[0].plot(xobsc, yobsc, c='r', lw=2)
        axes[0].annotate('M1_inner', (xobsc[30], yobsc[30]))
        axes[1].annotate('M1_inner', (uobsc[30], vobsc[30]))

    # Obscurations
    if show_obscurations:
        for k in obsc_radii:
            if show_M1 and k == 'M1_inner':
                continue
            radius, motion = obsc_radii[k], obsc_motion[k]
            u0 = -motion*thx
            v0 = -motion*thy
            uobsc = u0+radius*np.cos(np.linspace(0, 2*np.pi, 500))
            vobsc = v0+radius*np.sin(np.linspace(0, 2*np.pi, 500))
            color = 'r' if 'inner' in k else 'm'
            axes[1].plot(uobsc, vobsc, c=color, lw=1)
            xobsc, yobsc = danish.factory._pupil_to_focal(
                uobsc, vobsc, Z, focal_length=focal_length
            )
            xobsc /= pixel_scale
            yobsc /= pixel_scale
            axes[0].plot(xobsc, yobsc, c=color, lw=1)
            axes[0].annotate(k, (xobsc[30], yobsc[30]))
            axes[1].annotate(k, (uobsc[30], vobsc[30]))
    
    # Flux
    if show_flux:
        x = np.arange(-no2, no2+0.1)
        x, y = np.meshgrid(x, x)

        if show_obscurations:
            factory = danish.DonutFactory(
                R_outer=R_outer, 
                R_inner=R_inner, 
                obsc_radii=obsc_radii, 
                obsc_motion=obsc_motion, 
                focal_length=focal_length, 
                pixel_scale=pixel_scale
            )
            f = factory.image(Z, thx=thx, thy=thy, npix=npix)
        elif show_M1:
            obsc_rad = {'M1_inner':obsc_radii['M1_inner']}
            obsc_mot = {'M1_inner':obsc_motion['M1_inner']}
            factory = danish.DonutFactory(
                R_outer=R_outer, 
                R_inner=R_inner, 
                obsc_radii=obsc_rad, 
                obsc_motion=obsc_mot, 
                focal_length=focal_length, 
                pixel_scale=pixel_scale
            )
            f = factory.image(Z, thx=thx, thy=thy, npix=npix)
            
        else:
            u, v = danish.factory._focal_to_pupil(
                x*pixel_scale, y*pixel_scale, Z, focal_length=focal_length
            )
            f = 1/Z.hessian(u, v)
            f /= np.max(f)
        
        patches = []
        colors = []
        for x_, y_, f_ in zip(x.ravel(), y.ravel(), f.ravel()):
            circle = Circle((x_, y_), 0.4)
            colors.append(f_)
            patches.append(circle)

        p = PatchCollection(patches, cmap=matplotlib.cm.Purples)
        p.set_array(np.array(colors))
        p.set_clim([0.0, 1.5])    
        axes[0].add_collection(p)

    return None

In [None]:
def my_interact(f, controls, compute_button):
    """Modification of `ipywidgets.interaction.interactive_output` 
    to add Compute button.
    """
    out = widgets.Output()
    def observer(change):
        kwargs = {k:v.value for k,v in controls.items()}
        widgets.interaction.show_inline_matplotlib_plots()
        with out:
            widgets.interaction.clear_output(wait=True)
            f(**kwargs)
            widgets.interaction.show_inline_matplotlib_plots()
    compute_button.on_click(observer)
    widgets.interaction.show_inline_matplotlib_plots()
    observer(None)
    return out

In [None]:
th = widgets.FloatSlider(
    value=1.67, min=0.0, max=2.0, step=0.01, description='Field radius', 
    layout={'width':'250px'}
)
ph = widgets.FloatSlider(
    value=0.0, min=0.0, max=360.0, step=5.0, description='Field azimuth', 
    layout={'width':'250px'}
)
show_M1 = widgets.Checkbox(
    description='show M1', 
    layout={'width':'250px'}
)
show_obscurations = widgets.Checkbox(
    description='show obscurations', 
    layout={'width':'250px'}
)
show_flux = widgets.Checkbox(
    description='show flux', 
    layout={'width':'250px'}
)
compute_button = widgets.Button(description="Compute")

zernikes = [
    widgets.BoundedFloatText(
        value=38.5, min=-40.0, max=40.0, step=0.1, description="Z4", 
        layout={'width':"150px"}
    )
]
for i in range(5, 23):
    zernikes.append(
        widgets.BoundedFloatText(
            value=0.0, min=-3.0, max=3.0, step=0.1, description=f"Z{i}", 
            layout={'width':'150px'}
        )
    )

all_widgets = dict(
    th=th, ph=ph,
    show_M1=show_M1,
    show_obscurations=show_obscurations,
    show_flux=show_flux, 
)

for i in range(4, 23):
    all_widgets[f"z{i}"] = zernikes[i-4]

output = my_interact(f, all_widgets, compute_button)

In [None]:
widgets.VBox([
    widgets.HBox([
        widgets.VBox(
            [th, ph, show_M1, show_obscurations, show_flux, compute_button], 
        ), 
        widgets.VBox(zernikes[:5]),
        widgets.VBox(zernikes[5:10]),
        widgets.VBox(zernikes[10:14]),
        widgets.VBox(zernikes[14:])
    ]),
    output
])