In [None]:
import batoid
import os
import yaml
import numpy as np
from IPython.display import display
from ipywidgets import interact, interactive_output
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
try:
    import galsim
except:
    has_galsim = False
else:
    has_galsim = True

In [None]:
HSC_fn = os.path.join(batoid.datadir, "HSC", "HSC.yaml")
config = yaml.load(open(HSC_fn))
fiducial_telescope = batoid.parse.parse_optic(config['opticalSystem'])

In [None]:
def spotPlot(telescope, wavelength, theta_x, theta_y, logscale, ax):
    dirCos = batoid.utils.gnomicToDirCos(np.deg2rad(theta_x), np.deg2rad(theta_y))
    rays = batoid.circularGrid(
        telescope.dist, telescope.pupilSize/2, telescope.pupilSize/2*telescope.pupilObscuration, 
        dirCos[0], dirCos[1], -dirCos[2],
        48, 192, wavelength*1e-9, 1.0, telescope.inMedium
    )

    telescope.traceInPlace(rays)
    rays.trimVignettedInPlace()
    spots = np.vstack([rays.x, rays.y])
    spots -= np.mean(spots, axis=1)[:,None]
    spots *= 1e6 # meters -> microns

    ax.scatter(spots[0], spots[1], s=1, alpha=0.5)
    ax.set_xlim(-1.5*10**logscale, 1.5*10**logscale)
    ax.set_ylim(-1.5*10**logscale, 1.5*10**logscale)
    ax.set_title(r"$\theta_x = {:4.2f}\,,\theta_y = {:4.2f}$".format(theta_x, theta_y))
    ax.set_xlabel("microns")
    ax.set_ylabel("microns")

In [None]:
def wavefrontPlot(telescope, wavelength, theta_x, theta_y, ax):
    nx = 128
    wf = batoid.wavefront(
        telescope, np.deg2rad(theta_x), np.deg2rad(theta_y), wavelength*1e-9, 
        nx=nx
    )    
    wfplot = ax.imshow(
        wf.array,
        extent=np.r_[-1,1,-1,1]*telescope.pupilSize/2
    )
    ax.set_xlabel("meters")
    ax.set_ylabel("meters")
    plt.colorbar(wfplot, ax=ax)

In [None]:
def fftPSFPlot(telescope, wavelength, theta_x, theta_y, ax):
    nx=32
    fft = batoid.fftPSF(
        telescope, np.deg2rad(theta_x), np.deg2rad(theta_y), wavelength*1e-9, nx=nx
    )
    # We should be very close to primitive vectors that are a multiple of
    # [1,0] and [0,1].  If the multiplier is negative though, then this will
    # make it look like our PSF is upside-down.  So we check for this here and 
    # invert if necessary.  This will make it easier to compare with the spot 
    # diagram, for instance
    if fft.primitiveVectors[0,0] < 0:
        fft.array = fft.array[::-1,::-1]

    scale = np.sqrt(np.abs(np.linalg.det(fft.primitiveVectors)))
    nxout = fft.array.shape[0]
    fft.array /= np.sum(fft.array)
    fftplot = ax.imshow(
        fft.array,
        extent=np.r_[-1,1,-1,1]*scale*nxout/2*1e6
    )
    ax.set_title("FFT PSF")
    ax.set_xlabel("micron")
    ax.set_ylabel("micron")    
    plt.colorbar(fftplot, ax=ax)

In [None]:
def huygensPSFPlot(telescope, wavelength, theta_x, theta_y, ax):
    nx=32
    huygensPSF = batoid.huygensPSF(telescope, np.deg2rad(theta_x), np.deg2rad(theta_y),
                                   wavelength*1e-9, nx=nx)
    # We should be very close to primitive vectors that are a multiple of
    # [1,0] and [0,1].  If the multiplier is negative though, then this will
    # make it look like our PSF is upside-down.  So we check for this here and 
    # invert if necessary.  This will make it easier to compare with the spot 
    # diagram, for instance
    if huygensPSF.primitiveVectors[0,0] < 0:
        huygensPSF.array = huygensPSF.array[::-1,::-1]

    huygensPSF.array /= np.sum(huygensPSF.array)    
    scale = np.sqrt(np.abs(np.linalg.det(huygensPSF.primitiveVectors)))
    nxout = huygensPSF.array.shape[0]
    
    huygensplot = plt.imshow(
        huygensPSF.array,
        extent=np.r_[-1,1,-1,1]*scale*nxout/2*1e6
    )
    ax.set_title("Huygens PSF")
    ax.set_xlabel("micron")
    ax.set_ylabel("micron")    
    plt.colorbar(huygensplot, ax=ax)

In [None]:
what = dict(
    do_spot = widgets.Checkbox(value=True, description='Spot'),
    do_wavefront = widgets.Checkbox(value=True, description='Wavefront'),
    do_fftPSF = widgets.Checkbox(value=True, description='FFT PSF'),
    do_huygensPSF = widgets.Checkbox(value=True, description='Huygens PSF')
)
where = dict(
    wavelength=widgets.FloatSlider(min=300.0,max=1100.0,step=25.0,value=625.0, description="$\lambda$ (nm)"),
    theta_x=widgets.FloatSlider(min=-0.75,max=0.75,step=0.15,value=-0.45, description="$\\theta_x (deg)$"),
    theta_y=widgets.FloatSlider(min=-0.75,max=0.75,step=0.15,value=0.0, description="$\\theta_y (deg)$"),
    logscale=widgets.FloatSlider(min=1, max=3, step=0.1, value=1, description="scale")
)
perturb = dict(
    optic=widgets.Dropdown(
        options=fiducial_telescope.itemDict.keys(), 
        value='SubaruHSC.HSC'
    ),
    dx=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0, description="dx ($mm$)"),
    dy=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0, description="dy ($mm$)"),
    dz=widgets.FloatSlider(min=-100, max=100, step=1, value=0.0, description="dz ($\mu m$)"),
    dthx=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0, description="d$\phi_x$ (arcmin)"),
    dthy=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0, description="d$\phi_y$ (arcmin)"),
)

def f(do_spot, do_wavefront, do_fftPSF, do_huygensPSF,
    wavelength, theta_x, theta_y, optic, dx, dy, dz, dthx, dthy, logscale, **kwargs):

    telescope = (fiducial_telescope
            .withGloballyShiftedOptic(optic, [dx*1e-3, dy*1e-3, dz*1e-6])
            .withLocallyRotatedOptic(optic, batoid.RotX(dthx*np.pi/180/60).dot(batoid.RotY(dthy*np.pi/180/60)))
    )
    nplot = sum([do_spot, do_wavefront, do_fftPSF, do_huygensPSF])
    
    if nplot > 0:
        fig, axes = plt.subplots(ncols=nplot, figsize=(4*nplot, 4), squeeze=False)

        iax = 0
        if do_spot:
            ax = axes.ravel()[iax]
            spotPlot(telescope, wavelength, theta_x, theta_y, logscale, ax)
            iax += 1

        if do_wavefront:
            ax = axes.ravel()[iax]
            wavefrontPlot(telescope, wavelength, theta_x, theta_y, ax)
            iax += 1

        if do_fftPSF:
            ax = axes.ravel()[iax]
            fftPSFPlot(telescope, wavelength, theta_x, theta_y, ax)
            iax += 1

        if do_huygensPSF:
            ax = axes.ravel()[iax]
            huygensPSFPlot(telescope, wavelength, theta_x, theta_y, ax)

        fig.tight_layout()

all_widgets = {}
for d in [what, where, perturb]:
    for k in d:
        all_widgets[k] = d[k]

output = interactive_output(f, all_widgets)
display(widgets.VBox([widgets.HBox([
    widgets.VBox([v for v in what.values()]), 
    widgets.VBox([v for v in where.values()]), 
    widgets.VBox([v for v in perturb.values()])]),
    output])
)

In [None]:
if has_galsim:
    @interact(wavelen=widgets.FloatSlider(min=300.0,max=1100.0,step=25.0,value=625.0,
                                          description="$\lambda$ (nm)"),
              theta_x=widgets.FloatSlider(min=-0.75,max=0.75,step=0.15,value=-0.45,
                                          description="$\\theta_x (deg)$"),
              theta_y=widgets.FloatSlider(min=-0.75,max=0.75,step=0.15,value=0.0,
                                          description="$\\theta_y (deg)$"),
              optic=widgets.Dropdown(
                  options=fiducial_telescope.itemDict.keys(), 
                  value='SubaruHSC.HSC'
              ),
              dx=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0,
                                     description="dx ($mm$)"),
              dy=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0,
                                     description="dy ($mm$)"),
              dz=widgets.FloatSlider(min=-100, max=100, step=1, value=0.0,
                                     description="dz ($\mu m$)"),
              dthx=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0,
                                       description="d$\phi_x$ (arcmin)"),
              dthy=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0,
                                       description="d$\phi_y$ (arcmin)"))
    def zernike(wavelen, theta_x, theta_y, optic, dx, dy, dz, dthx, dthy):
        telescope = (fiducial_telescope
                     .withGloballyShiftedOptic(optic, [dx*1e-3, dy*1e-3, dz*1e-6])
                     .withLocallyRotatedOptic(
                             optic,
                             batoid.RotX(dthx*np.pi/180/60).dot(batoid.RotY(dthy*np.pi/180/60))
                     )
        )
        z = batoid.zernike(telescope, np.deg2rad(theta_x), np.deg2rad(theta_y), wavelen*1e-9,
                           jmax=22, eps=0.1, nx=128)
        for i in range(1, len(z)//2+1):
            print("{:6d}   {:7.3f}      {:6d}  {:7.3f}".format(i, z[i], i+11, z[i+11]))

In [None]:
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec
from matplotlib.pyplot import subplots

def Zpyramid(xs, ys, zs, figsize=(13, 8), vmin=-1, vmax=1, vdim=True,
             s=5, title=None, filename=None, **kwargs):
    jmax = zs.shape[0]+3
    nmax, _ = galsim.zernike.noll_to_zern(jmax)
    
    nrow = nmax - 1
    ncol = nrow + 2
    gridspec = GridSpec(nrow, ncol)

    def shift(pos, amt):
        return [pos.x0+amt, pos.y0, pos.width, pos.height]

    def shiftAxes(axes, amt):
        for ax in axes:
            ax.set_position(shift(ax.get_position(), amt))

    fig = plt.figure(figsize=figsize, **kwargs)
    axes = {}
    shiftLeft = []
    shiftRight = []
    for j in range(4, jmax+1):
        n, m = galsim.zernike.noll_to_zern(j)
        if n%2 == 0:
            row, col = n-2, m//2 + ncol//2
        else:
            row, col = n-2, (m-1)//2 + ncol//2
        subplotspec = gridspec.new_subplotspec((row, col))
        axes[j] = fig.add_subplot(subplotspec)
        axes[j].set_aspect('equal')
        if nrow%2==0 and n%2==0:
            shiftLeft.append(axes[j])
        if nrow%2==1 and n%2==1:
            shiftRight.append(axes[j])
            
    cbar = {}
    for j, ax in axes.items():        
        n, _ = galsim.zernike.noll_to_zern(j)
        ax.set_title("Z{}".format(j))
        if vdim:
            _vmin = vmin/n
            _vmax = vmax/n
        else:
            _vmin = vmin
            _vmax = vmax
        scat = ax.scatter(xs, ys, c=zs[j-4], s=s, linewidths=0.5, cmap='Spectral_r',
                          rasterized=True, vmin=_vmin, vmax=_vmax)
        cbar[j] = fig.colorbar(scat, ax=ax)
        ax.set_xticks([])
        ax.set_yticks([])

    if title:
        fig.suptitle(title, x=0.1)

    fig.tight_layout()
    amt = 0.5*(axes[4].get_position().x0 - axes[5].get_position().x0)
    shiftAxes(shiftLeft, -amt)
    shiftAxes(shiftRight, amt)

    shiftAxes([cbar[j].ax for j in cbar.keys() if axes[j] in shiftLeft], -amt)
    shiftAxes([cbar[j].ax for j in cbar.keys() if axes[j] in shiftRight], amt)
    
    if filename:
        fig.savefig(filename)    
    
    fig.show()

In [None]:
if has_galsim:
    @interact(wavelen=widgets.FloatSlider(min=300.0,max=1100.0,step=25.0,value=625.0,
                                          description="$\lambda$ (nm)"),
              optic=widgets.Dropdown(
                  options=fiducial_telescope.itemDict.keys(), 
                  value='SubaruHSC.HSC'
              ),
              dx=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0,
                                     description="dx ($mm$)"),
              dy=widgets.FloatSlider(min=-0.2, max=0.2, step=0.05, value=0.0,
                                     description="dy ($mm$)"),
              dz=widgets.FloatSlider(min=-500, max=500, step=10, value=0.0,
                                     description="dz ($\mu m$)"),
              dthx=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0,
                                       description="d$\phi_x$ (arcmin)"),
              dthy=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.0,
                                       description="d$\phi_y$ (arcmin)"),
              do_resid=widgets.Checkbox(value=False, description="residual?"))
    def zFoV(wavelen, optic, dx, dy, dz, dthx, dthy, do_resid):
        telescope = (fiducial_telescope
                .withGloballyShiftedOptic(optic, [dx*1e-3, dy*1e-3, dz*1e-6])
                .withLocallyRotatedOptic(
                        optic,
                        batoid.RotX(dthx*np.pi/180/60).dot(batoid.RotY(dthy*np.pi/180/60))
                )
        )

        thxs = np.linspace(-0.75, 0.75, 15)
        thys = np.linspace(-0.75, 0.75, 15)

        img = np.zeros((15, 15), dtype=float)
        vmin = -0.3
        vmax = 0.3
        zs = []
        thxplot = []
        thyplot = []
        for ix, thx in enumerate(thxs):
            for iy, thy in enumerate(thys):
                if np.hypot(thx, thy) > 0.74: 
                    continue
                z = batoid.zernike(telescope, np.deg2rad(thx), np.deg2rad(thy), wavelen*1e-9,
                                   jmax=21, eps=0.231, nx=16)
                thxplot.append(thx)
                thyplot.append(thy)
                if do_resid:
                    vmin = -0.05
                    vmax = 0.05
                    z -= batoid.zernike(fiducial_telescope, np.deg2rad(thx), np.deg2rad(thy), wavelen*1e-9,
                                        jmax=21, eps=0.231, nx=16)
                zs.append(z)
        zs = np.array(zs).T
        thxplot = np.array(thxplot)
        thyplot = np.array(thyplot)
        Zpyramid(thxplot, thyplot, zs[4:], vmin=vmin, vmax=vmax)
        plt.show()