In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import yaml

import astropy.io.fits as fits
import batoid
import galsim
import ipywidgets
import matplotlib.pyplot as plt
import numpy as np

import batoid_rubin

%matplotlib widget

In [None]:
def colorbar(mappable):
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    import matplotlib.pyplot as plt
    last_axes = plt.gca()
    ax = mappable.axes
    fig = ax.figure
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(mappable, cax=cax)
    plt.sca(last_axes)
    return cbar

In [None]:
def ptt(arr, eps=0.0):
    x = np.linspace(0, 2, arr.shape[0])
    y = np.linspace(0, 2, arr.shape[1])
    x, y = np.meshgrid(x, y)
    x -= np.mean(x)
    y -= np.mean(y)
    w = ~arr.mask
    
    zbasis = galsim.zernike.zernikeBasis(
        3, x[w], y[w], R_inner=eps, R_outer=1.0
    )
    coefs, *_ = np.linalg.lstsq(zbasis.T, arr[w], rcond=None)
    arr = arr - galsim.zernike.Zernike(
        coefs[:4], R_inner=eps, R_outer=1.0
    )(x, y)
    return arr

In [None]:
# Read field XY
with open("fieldXY.yaml") as f:
    data = yaml.safe_load(f)
field_x = np.array(data['x'])
field_y = np.array(data['y'])

In [None]:
fiducial = batoid.Optic.fromYaml("LSST_g_500.yaml")
wavelength = 500e-9

In [None]:
@ipywidgets.interact(    
    zr=ipywidgets.Dropdown(
        options=[
            "0, 0, False",
            "0, 0, True",
            "45, 0, False",
            "45, 45, False",
            "30, -30, False",
            "30, -30, True"
        ],
        index=0,
    ),
    subsys=ipywidgets.Dropdown(
        options=['M1M3', 'M2', 'Cam', 'All'], index=0
    ),
    ifield=ipywidgets.BoundedIntText(value=0, min=0, max=34),
    scale=ipywidgets.BoundedFloatText(value=0, min=-10, max=10),
    dscale=ipywidgets.BoundedFloatText(value=0, min=-6, max=6),
    doptt=ipywidgets.Checkbox()
)
def f(zr, subsys, ifield, scale, dscale, doptt):
    zenith_angle, rotation_angle, doT = zr.replace(",", "").split()
    zenith_angle = int(zenith_angle)
    rotation_angle = int(rotation_angle)
    doT = doT == "True"
    name = f"z{zenith_angle}_r{rotation_angle}_T_{doT}_{subsys}"

    ts_phosim_opd = fits.getdata(f"phosim/{name}/opd_{name}_{ifield}.fits.gz")
    ts_phosim_opd0 = fits.getdata(f"../nominal/phosim/opd_nominal_field_{ifield}.fits.gz")
    mask = ts_phosim_opd == 0.0
    ts_phosim_opd = np.ma.masked_array(ts_phosim_opd-ts_phosim_opd0, mask=ts_phosim_opd==0.0)
    
    builder = batoid_rubin.LSSTBuilder(
        fiducial, 
        fea_dir="/Users/josh/src/batoid_rubin/scripts/fea/", 
        bend_dir="/Users/josh/src/batoid_rubin/scripts/bend_legacy/"
    )
    if subsys in ['M1M3', 'All']:
        builder = (
            builder
            .with_m1m3_gravity(np.deg2rad(zenith_angle))
            .with_m1m3_lut(np.deg2rad(zenith_angle))
        )
        if doT:
            builder = builder.with_m1m3_temperature(
                m1m3_TBulk=0.0902, 
                m1m3_TxGrad=-0.0894,
                m1m3_TyGrad=-0.1973,
                m1m3_TzGrad=-0.0316,
                m1m3_TrGrad=0.0187
            )
    if subsys in ['M2', 'All']:
        builder = builder.with_m2_gravity(np.deg2rad(zenith_angle))
        if doT:
            builder = builder.with_m2_temperature(
                m2_TzGrad=-0.0675,
                m2_TrGrad=-0.1416
            )
    if subsys in ['Cam', 'All']:
        builder = (
            builder
            .with_camera_gravity(
                np.deg2rad(zenith_angle),
                np.deg2rad(rotation_angle)
            )
            .with_camera_temperature(camera_TBulk=6.5650)
        )
    telescope = builder.build()

    # Convert from batoid -> phosim.
    # Implies flipping input theta_x and fliplr the output image
    batoid_opd = batoid.wavefront(
        telescope,
        -np.deg2rad(field_x[ifield]),
        np.deg2rad(field_y[ifield]),
        wavelength, nx=255, 
    ).array
    batoid_opd0 = batoid.wavefront(
        fiducial,
        -np.deg2rad(field_x[ifield]),
        np.deg2rad(field_y[ifield]),
        wavelength, nx=255, 
    ).array
    batoid_opd -= batoid_opd0
    batoid_opd *= -1
    batoid_opd = np.fliplr(batoid_opd)
    # batoid in waves => microns
    batoid_opd *= wavelength*1e6

    if doptt:
        batoid_opd = ptt(batoid_opd, 0.61)
        ts_phosim_opd = ptt(ts_phosim_opd, 0.61)
    
    vmax = np.quantile(np.abs(batoid_opd[~batoid_opd.mask]), 0.99)
    vmax *= 2**scale
    fig, axes = plt.subplots(ncols=3, figsize=(8, 3), sharex=True, sharey=True)
    colorbar(axes[0].imshow(ts_phosim_opd, vmin=-vmax, vmax=vmax, cmap='seismic'))
    axes[0].set_title("ts_phosim")

    colorbar(axes[1].imshow(batoid_opd, vmin=-vmax, vmax=vmax, cmap='seismic'))
    axes[1].set_title("batoid")
    
    vmax *= 2**dscale
    colorbar(axes[2].imshow(batoid_opd - ts_phosim_opd, vmin=-0.01*vmax, vmax=0.01*vmax, cmap='seismic'))
    axes[2].set_title("b - ph")

    for ax in axes:
        ax.set_aspect('equal')
    fig.tight_layout()
    plt.show()