In [None]:
import batoid
import os
import yaml
import numpy as np
from IPython.display import display
from ipywidgets import interact, interactive_output
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
try:
    import galsim
except:
    has_galsim = False
else:
    has_galsim = True

In [None]:
DECAM_fn = os.path.join(batoid.datadir, "DECam", "DECam.yaml")
config = yaml.load(open(DECAM_fn))
fiducial_telescope = batoid.parse.parse_optic(config['opticalSystem'])

In [None]:
def spotPlot(telescope, wavelength, theta_x, theta_y, logscale, ax):
    xcos = np.sin(theta_x*np.pi/180)
    ycos = np.sin(theta_y*np.pi/180)
    zcos = -np.sqrt(1.0 - xcos**2 - ycos**2)
    rays = batoid.circularGrid(
        telescope.dist, telescope.pupil_size/2, telescope.pupil_size/2*telescope.pupil_obscuration, 
        xcos, ycos, zcos,
        48, 192, wavelength*1e-9, telescope.inMedium
    )

    rs, _ = telescope.trace(rays)
    rs = batoid.trimVignetted(rs)
    spots = np.vstack([rs.x, rs.y])
    spots = np.array(spots)
    spots -= np.mean(spots, axis=1)[:,None]
    spots *= 1e6 # meters -> microns

    ax.scatter(spots[0], spots[1], s=1, alpha=0.5)
    ax.set_xlim(-10**logscale, 10**logscale)
    ax.set_ylim(-10**logscale, 10**logscale)
    ax.set_title(r"$\theta_x = {:4.2f}\,,\theta_y = {:4.2f}$".format(theta_x, theta_y))
    ax.set_xlabel("microns")
    ax.set_ylabel("microns")

In [None]:
def wavefrontPlot(telescope, wavelength, theta_x, theta_y, ax):
    nx = 128
    wfplot = ax.imshow(
        batoid.wavefront(
            telescope, wavelength*1e-9, 
            theta_x=theta_x, theta_y=theta_y,
            nx=nx).reshape(nx, nx),
        extent=np.r_[-1,1,-1,1]*telescope.pupil_size/2
    )
    plt.colorbar(wfplot, ax=ax)

In [None]:
def fftPSFPlot(telescope, wavelength, theta_x, theta_y, ax):
    nx=32
    scale, fftPSF = batoid.fftPSF(telescope, wavelength*1e-9, theta_x, theta_y, nx=nx)
    nxout = fftPSF.shape[0]
    fftPSF /= np.sum(fftPSF)
    fftplot = ax.imshow(
        fftPSF,
        extent=np.r_[-1,1,-1,1]*scale*nxout/2*206265*10/0.2
    )
    ax.set_title("FFT PSF")
    plt.colorbar(fftplot, ax=ax)

In [None]:
def huygensPSFPlot(telescope, wavelength, theta_x, theta_y, ax):
    nx=32
    scale = wavelength*1e-9/(telescope.pupil_size*2)
    xcos = np.sin(theta_x*np.pi/180)
    ycos = np.sin(theta_y*np.pi/180)
    zcos = -np.sqrt(1.0 - xcos**2 - ycos**2)
    orig_rays = batoid.circularGrid(
        telescope.dist, telescope.pupil_size/2, telescope.pupil_size/2*telescope.pupil_obscuration, 
        xcos, ycos, zcos,
        24, 48, wavelength*1e-9, telescope.inMedium
    )
    traced_rays, _ = telescope.trace(orig_rays)    
    goodRays = batoid.trimVignetted(traced_rays)
    xmean = np.mean(goodRays.x)
    ymean = np.mean(goodRays.y)
    L = 2*scale*nx*206265 # arcsec
    L *= 15/0.27 # microns
    dx = L / nx

    xs = np.linspace(xmean-L/2*1e-6, xmean+L/2*1e-6, nx*2)
    ys = np.linspace(ymean-L/2*1e-6, ymean+L/2*1e-6, nx*2)
    xs, ys = np.meshgrid(xs, ys)
    xs -= dx*1e-6/4
    ys -= dx*1e-6/4

    huygensPSF = batoid.huygensPSF(telescope, xs=xs, ys=ys, zs=None, rays=orig_rays, saveRays=False)
    huygensPSF /= np.sum(huygensPSF)
    
    huygensplot = plt.imshow(
        huygensPSF,
        extent=np.r_[-1,1,-1,1]*L/2
    )
    ax.set_title("Huygens PSF")
    plt.colorbar(huygensplot, ax=ax)    

In [None]:
what = dict(
    do_spot = widgets.Checkbox(value=True, description='Spot'),
    do_wavefront = widgets.Checkbox(value=True, description='Wavefront'),
    do_fftPSF = widgets.Checkbox(value=True, description='FFT PSF'),
    do_huygensPSF = widgets.Checkbox(value=True, description='Huygens PSF')
)
where = dict(
    wavelength=widgets.FloatSlider(min=300.0,max=1100.0,step=25.0,value=625.0, description="$\lambda$ (nm)"),
    theta_x=widgets.FloatSlider(min=-1.1,max=1.1,step=0.1,value=-0.5, description="$\\theta_x (deg)$"),
    theta_y=widgets.FloatSlider(min=-1.1,max=1.1,step=0.1,value=0.0, description="$\\theta_y (deg)$"),
    logscale=widgets.FloatSlider(min=1, max=3, step=0.1, value=1, description="scale")
)
perturb = dict(
    optic=widgets.Dropdown(
        options=fiducial_telescope.itemDict.keys(), 
        value='BlancoDECam.DECam'
    ),
    dx=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0, description="dx ($mm$)"),
    dy=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0, description="dy ($mm$)"),
    dz=widgets.FloatSlider(min=-100, max=100, step=1, value=0.0, description="dz ($\mu m$)"),
    dthx=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0, description="d$\phi_x$ (arcmin)"),
    dthy=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0, description="d$\phi_y$ (arcmin)"),
)

def f(do_spot, do_wavefront, do_fftPSF, do_huygensPSF,
    wavelength, theta_x, theta_y, optic, dx, dy, dz, dthx, dthy, logscale, **kwargs):

    telescope = (fiducial_telescope
            .withGloballyShiftedOptic(optic, batoid.Vec3(dx*1e-3, dy*1e-3, dz*1e-6))
            .withLocallyRotatedOptic(optic, batoid.RotX(dthx*np.pi/180/60)*batoid.RotY(dthy*np.pi/180/60))
    )
    telescope.dist = 20.0
    telescope.pupil_size = 4.1
    telescope.pupil_obscuration = 0.1
    telescope.sphereRadius = 10.5

    nplot = sum([do_spot, do_wavefront, do_fftPSF, do_huygensPSF])
    
    if nplot > 0:
        fig, axes = plt.subplots(ncols=nplot, figsize=(4*nplot, 4), squeeze=False)

        iax = 0
        if do_spot:
            ax = axes.ravel()[iax]
            spotPlot(telescope, wavelength, theta_x, theta_y, logscale, ax)
            iax += 1

        if do_wavefront:
            ax = axes.ravel()[iax]
            wavefrontPlot(telescope, wavelength, theta_x, theta_y, ax)
            iax += 1

        if do_fftPSF:
            ax = axes.ravel()[iax]
            fftPSFPlot(telescope, wavelength, theta_x, theta_y, ax)
            iax += 1

        if do_huygensPSF:
            ax = axes.ravel()[iax]
            huygensPSFPlot(telescope, wavelength, theta_x, theta_y, ax)

        fig.tight_layout()

all_widgets = {}
for d in [what, where, perturb]:
    for k in d:
        all_widgets[k] = d[k]

output = interactive_output(f, all_widgets)
display(widgets.VBox([widgets.HBox([
    widgets.VBox([v for v in what.values()]), 
    widgets.VBox([v for v in where.values()]), 
    widgets.VBox([v for v in perturb.values()])]),
    output])
)

In [None]:
if has_galsim:  
    @interact(wavelen=widgets.FloatSlider(min=300.0,max=1100.0,step=25.0,value=625.0, 
                                          description="$\lambda$ (nm)"),
              theta_x=widgets.FloatSlider(min=-1.1,max=1.1,step=0.1,value=-0.5, 
                                          description="$\\theta_x (deg)$"),
              theta_y=widgets.FloatSlider(min=-1.1,max=1.1,step=0.1,value=0.0, 
                                          description="$\\theta_y (deg)$"),
              optic=widgets.Dropdown(
                  options=fiducial_telescope.itemDict.keys(), 
                  value='BlancoDECam.DECam'
              ),
              dx=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0, 
                                     description="dx ($mm$)"),
              dy=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0, 
                                     description="dy ($mm$)"),
              dz=widgets.FloatSlider(min=-100, max=100, step=1, value=0.0, 
                                     description="dz ($\mu m$)"),
              dthx=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0, 
                                       description="d$\phi_x$ (arcmin)"),
              dthy=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0, 
                                       description="d$\phi_y$ (arcmin)"))
    def zernike(wavelen, theta_x, theta_y, optic, dx, dy, dz, dthx, dthy):
        telescope = (fiducial_telescope
                .withGloballyShiftedOptic(optic, batoid.Vec3(dx*1e-3, dy*1e-3, dz*1e-6))
                .withLocallyRotatedOptic(
                        optic, 
                        batoid.RotX(dthx*np.pi/180/60)*batoid.RotY(dthy*np.pi/180/60)
                )
        )

        telescope.dist = 20.0
        telescope.pupil_size = 4.1
        telescope.sphereRadius = 10.5

        z = batoid.zernike(telescope, wavelen*1e-9, theta_x, theta_y, jmax=22, eps=0.1, nx=128)
        for i in range(1, len(z)//2):
            print("{:6d}   {:6.3f}      {:6d}  {:6.3f}".format(i, z[i], i+12, z[i+12]))

In [None]:
if has_galsim:
    @interact(wavelen=widgets.FloatSlider(min=300.0,max=1100.0,step=25.0,value=625.0, 
                                          description="$\lambda$ (nm)"),
              zindex=widgets.IntSlider(min=4,max=22,value=9),
              optic=widgets.Dropdown(
                  options=fiducial_telescope.itemDict.keys(), 
                  value='BlancoDECam.DECam'
              ),
              dx=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0,
                                     description="dx ($mm$)"),
              dy=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0,
                                     description="dy ($mm$)"),
              dz=widgets.FloatSlider(min=-100, max=100, step=1, value=0.0,
                                     description="dz ($\mu m$)"),
              dthx=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0,
                                       description="d$\phi_x$ (arcmin)"),
              dthy=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0,
                                       description="d$\phi_y$ (arcmin)"))
    def zFoV(wavelen, zindex, optic, dx, dy, dz, dthx, dthy):
        telescope = (fiducial_telescope
                     .withGloballyShiftedOptic(optic, batoid.Vec3(dx*1e-3, dy*1e-3, dz*1e-6))
                     .withLocallyRotatedOptic(
                             optic, 
                             batoid.RotX(dthx*np.pi/180/60)*batoid.RotY(dthy*np.pi/180/60)
                     )
        )

        telescope.dist = 20.0
        telescope.pupil_size = 4.1
        telescope.sphereRadius = 10.5

        thxs = np.linspace(-1.1, 1.1, 20)
        thys = np.linspace(-1.1, 1.1, 20)

        img = np.zeros((20, 20), dtype=float)
        for ix, thx in enumerate(thxs):
            for iy, thy in enumerate(thys):
                if np.hypot(thx, thy) > 1.1: 
                    continue
                z = batoid.zernike(telescope, wavelen*1e-9, thx, thy, jmax=22, eps=0.1)
                img[ix, iy] = z[zindex]

        plt.imshow(img, vmin=-0.5, vmax=0.5, cmap='Spectral_r', extent=np.r_[-1,1,-1,1]*1.7)
        plt.title("$Z_{{{}}}$".format(zindex))
        plt.colorbar()
        plt.show()