In [None]:
import matplotlib.pyplot as plt
import ipywidgets
import astropy.io.fits as fits
import numpy as np
import galsim

In [None]:
def get_zk(opd):
    xs = np.linspace(-1, 1, opd.shape[0])
    ys = np.linspace(-1, 1, opd.shape[1])
    xs, ys = np.meshgrid(xs, ys)
    w = ~opd.mask
    basis = galsim.zernike.zernikeBasis(22, xs[w], ys[w], R_inner=0.61)
    zk, *_ = np.linalg.lstsq(basis.T, opd[w], rcond=None)
    return zk

In [None]:
def sub_ptt(opd):
    xs = np.linspace(-1, 1, opd.shape[0])
    ys = np.linspace(-1, 1, opd.shape[1])
    xs, ys = np.meshgrid(xs, ys)
    zk = get_zk(opd)
    opd -= galsim.zernike.Zernike(zk[:4], R_inner=0.61)(xs, ys)
    return opd

In [None]:
def getdata(imode, ifield):
    wfsim_opd = fits.getdata(f"wfsim_opds/opd/opd_mode_{imode}_field_{ifield}.fits.gz")
    wfsim_opd = np.ma.masked_array(wfsim_opd, mask=(wfsim_opd==0))
    
    wfsim_simple_opd = fits.getdata(f"wfsim_simple_opds/opd/opd_mode_{imode}_field_{ifield}.fits.gz")
    wfsim_simple_opd = np.ma.masked_array(wfsim_simple_opd, mask=(wfsim_simple_opd==0))

    ts_phosim_opd = fits.getdata(f"ts_phosim_opds/opd/opd_mode_{imode}_field_{ifield}.fits.gz")
    ts_phosim_opd = np.ma.masked_array(ts_phosim_opd, mask=(ts_phosim_opd==0))

    wfsim_opd -= fits.getdata(f"wfsim_opds/opd/opd_nominal_field_{ifield}.fits.gz")
    wfsim_simple_opd -= fits.getdata(f"wfsim_simple_opds/opd/opd_nominal_field_{ifield}.fits.gz")
    ts_phosim_opd -= fits.getdata(f"ts_phosim_opds/opd/opd_nominal_field_{ifield}.fits.gz")
    
    return sub_ptt(wfsim_opd), sub_ptt(wfsim_simple_opd), sub_ptt(ts_phosim_opd)

In [None]:
@ipywidgets.interact(
    imode=ipywidgets.BoundedIntText(value=0, min=0, max=49),
    ifield=ipywidgets.BoundedIntText(value=0, min=0, max=34),
)
def f(imode, ifield):
    wfsim_opd, wfsim_simple_opd, ts_phosim_opd = getdata(imode, ifield)
    sensM = fits.getdata("sensM.fits")
    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(10, 6))
    vmax = np.max(np.abs(ts_phosim_opd))
    ims = []
    ims.append(axes[0,0].imshow(wfsim_opd, vmin=-vmax, vmax=vmax, cmap='seismic'))
    ims.append(axes[0,1].imshow(wfsim_simple_opd, vmin=-vmax, vmax=vmax, cmap='seismic'))
    ims.append(axes[0,2].imshow(ts_phosim_opd, vmin=-vmax, vmax=vmax, cmap='seismic'))
    
    vmax *= 0.01
    zk = sensM[ifield, :, imode]
    Z = galsim.zernike.Zernike([0]*4+zk.tolist(), R_inner=0.61)
    opd = np.zeros_like(ts_phosim_opd)
    xs = np.linspace(-1, 1, opd.shape[0])
    ys = np.linspace(-1, 1, opd.shape[1])
    xs, ys = np.meshgrid(xs, ys)
    w = ~opd.mask
    opd[w] = Z(xs[w], ys[w])
    
    ims.append(axes[1,0].imshow(wfsim_opd-opd, vmin=-vmax, vmax=vmax, cmap='seismic'))
    ims.append(axes[1,1].imshow(wfsim_simple_opd-opd, vmin=-vmax, vmax=vmax, cmap='seismic'))
    ims.append(axes[1,2].imshow(ts_phosim_opd-opd, vmin=-vmax, vmax=vmax, cmap='seismic'))

    for i, im in enumerate(ims):
        plt.colorbar(im, ax=axes.ravel()[i])
    axes[0,0].set_title("wfsim")
    axes[0,1].set_title("wfsim_simple")    
    axes[0,2].set_title("ts_phosim")    
    plt.tight_layout()
    plt.show()
    
    wfsim_zk = get_zk(wfsim_opd)
    wfsim_simple_zk = get_zk(wfsim_simple_opd)
    ts_phosim_zk = get_zk(ts_phosim_opd)
    
    print(" j        sensM          full        simple     ts_phosim      d(full)     d(simple)    d(ts_phosim)")
    print("="*100)
    for j in range(4, 23):
        out = f"{j:2d} {zk[j-4]*1e3:12.4f}  "
        out += f"{wfsim_zk[j]*1e3:12.4f}  {wfsim_simple_zk[j]*1e3:12.4f}  {ts_phosim_zk[j]*1e3:12.4f}"
        out += f" {(wfsim_zk[j]-zk[j-4])*1e3:12.4f}  {(wfsim_simple_zk[j]-zk[j-4])*1e3:12.4f}"
        out += f"    {(ts_phosim_zk[j]-zk[j-4])*1e3:12.4f}"
        print(out)
    print("="*100)
    out = "rss"
    out += f"{np.sqrt(np.sum(np.square(zk*1e3))):12.4f}"
    out += f"  {np.sqrt(np.sum(np.square(wfsim_zk[4:]*1e3))):12.4f}"
    out += f"  {np.sqrt(np.sum(np.square(wfsim_simple_zk[4:]*1e3))):12.4f}"
    out += f"  {np.sqrt(np.sum(np.square(ts_phosim_zk[4:]*1e3))):12.4f}"
    out += f" {np.sqrt(np.sum(np.square((wfsim_zk[4:]-zk)*1e3))):12.4f}"
    out += f"  {np.sqrt(np.sum(np.square((wfsim_simple_zk[4:]-zk)*1e3))):12.4f}"
    out += f"    {np.sqrt(np.sum(np.square((ts_phosim_zk[4:]-zk)*1e3))):12.4f}"
    print(out)
    out = "rssr/rss"
    out += " "*50
    out += f"{np.sqrt(np.sum(np.square(wfsim_zk[4:]-zk)))/np.sqrt(np.sum(np.square(zk))):12.4f}"
    out += f"  {np.sqrt(np.sum(np.square(wfsim_simple_zk[4:]-zk)))/np.sqrt(np.sum(np.square(zk))):12.4f}"
    out += f"    {np.sqrt(np.sum(np.square(ts_phosim_zk[4:]-zk)))/np.sqrt(np.sum(np.square(zk))):12.4f}"
    print(out)
    
    # rss = np.sqrt(np.sum(np.square(ts_phosim_zk[4:])))
    # rssr_wfsim = np.sqrt(np.sum(np.square(wfsim_zk[4:]-ts_phosim_zk[4:])))
    # rssr_wfsim_simple = np.sqrt(np.sum(np.square(wfsim_simple_zk[4:]-ts_phosim_zk[4:])))
    # print(f"wfsim fractional rss residual = {rssr_wfsim/rss:10.6f}")
    # print(f"simple fractional rss residual = {rssr_wfsim_simple/rss:10.6f}")