In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib widget

import os
import numpy as np
import yaml
from tqdm.notebook import tqdm
import galsim
import batoid
import wfsim
import matplotlib.pyplot as plt

from lsst.ts.wep.cwfs.Algorithm import Algorithm
from lsst.ts.wep.cwfs.CompensableImage import CompensableImage
from lsst.ts.wep.cwfs.Instrument import Instrument
from lsst.ts.wep.Utility import (
    CamType,
    DefocalType,
    getConfigDir,
    getModulePath
)

In [None]:
rng = np.random.default_rng(5772156649015328606065120900824024310421)

In [None]:
bandpass = galsim.Bandpass("LSST_r.dat", wave_type='nm')
fiducial_telescope = batoid.Optic.fromYaml("LSST_r.yaml")
factory = wfsim.SSTFactory(fiducial_telescope)
pixel_scale = 10e-6

In [None]:
# Setup observation parameters.  Making ~plausible stuff up.
observation = {
    'zenith': 30 * galsim.degrees,
    'raw_seeing': 0.7 * galsim.arcsec,  # zenith 500nm seeing
    'wavelength': bandpass.effective_wavelength,
    'exptime': 15.0,  # seconds
    'temperature': 293.,  # Kelvin
    'pressure': 69.,  #kPa
    'H2O_pressure': 1.0  #kPa
}

In [None]:
# Setup atmospheric parameters
atm_kwargs = {
    'screen_size': 819.2,
    'screen_scale': 0.1,
    'nproc': 6  # create screens in parallel using this many CPUs
}

In [None]:
dof = np.zeros(50)
# dof[40:44] = 0.2  # activate some M2 bending modes
dof = rng.normal(scale=0.1, size=50)
# but zero-out the hexafoil modes that aren't currently fit well.
dof[[28, 45, 46]] = 0
telescope = factory.get_telescope(dof=dof)  # no perturbations yet

In [None]:
# Look at some spot diagrams
fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(8, 1.5))
for ax, (thx, thy) in zip(axes, [(0,0), (-1.5, 0), (1.5, 0), (0, -1.5), (0, 1.5)]):
    sx, sy = batoid.spot(
        telescope.withGloballyShiftedOptic("Detector", [0, 0, 0.0015]), 
        np.deg2rad(thx), np.deg2rad(thy), 
        bandpass.effective_wavelength*1e-9, 
        nx=128
    )
    ax.scatter(sx/pixel_scale, sy/pixel_scale, s=0.1, alpha=0.5)
plt.tight_layout()
plt.show()

In [None]:
# Execute this line if you want to reconstruct the atmosphere below.
if 'intra_simulator' in globals():
    del intra_simulator, extra_simulator

In [None]:
# BEWARE THE HACK!!!
# HACK EXISTS TO NOT RECOMPUTE ATMOSPHERE ALL THE TIME!!!
if 'intra_simulator' not in globals():
    intra = telescope.withGloballyShiftedOptic("Detector", [0, 0, -0.0015])
    extra = telescope.withGloballyShiftedOptic("Detector", [0, 0, +0.0015])
    intra_simulator = wfsim.SimpleSimulator(
        observation,
        atm_kwargs,
        intra,
        bandpass,
        # shape=(4000, 4000),
        # shape=(256, 256),
        # offset=(0.2, 0.2),
        name="R00_SW0",
        rng=rng
    )
    extra_simulator = wfsim.SimpleSimulator(
        observation,
        atm_kwargs,
        extra,
        bandpass,
        # shape=(4000, 4000),
        # offset=(0.2, 0.2),
        name="R00_SW0",
        rng=rng
    )
else:
    intra = telescope.withGloballyShiftedOptic("Detector", [0, 0, -0.0015])
    extra = telescope.withGloballyShiftedOptic("Detector", [0, 0, +0.0015])
    intra_simulator.telescope = intra
    extra_simulator.telescope = extra
    intra_simulator.image.setZero()
    extra_simulator.image.setZero()

In [None]:
star_T = rng.uniform(4000, 10000)
sed = wfsim.BBSED(star_T)
# flux = int(rng.uniform(1_000_000, 2_000_000))
flux = 10_000_000
# flux = 500_000

In [None]:
# lets print the bounds of the sensors so we know what angles to simulate
bounds = intra_simulator.get_bounds(units=galsim.degrees)
print(f"{bounds[0, 0]:.3f} < x < {bounds[0, 1]:.3f}")
print(f"{bounds[1, 0]:.3f} < y < {bounds[1, 1]:.3f}")

In [None]:
thx = np.deg2rad(-1.12)
thy = np.deg2rad(-1.12)
# thx = np.deg2rad(0.0)
# thy = np.deg2rad(0.0)
intra_simulator.add_star(thx, thy, sed, flux, rng)
extra_simulator.add_star(thx, thy, sed, flux, rng)
# intra_simulator.add_star(0.0, 0.0, sed, flux, rng)
# extra_simulator.add_star(0.0, 0.0, sed, flux, rng)

In [None]:
intra_simulator.add_background(1000.0, rng)
extra_simulator.add_background(1000.0, rng)

In [None]:
fix, axes = plt.subplots(ncols=2, nrows=1, figsize=(6, 3), sharex=True, sharey=True)
axes[0].imshow(intra_simulator.image.array, origin="lower")
axes[1].imshow(extra_simulator.image.array, origin="lower")
plt.tight_layout()
plt.show()

In [None]:
# crop the donuts and to feed to CWFS

# intra image
x, y = intra_simulator.wcs.radecToxy(thx, thy, galsim.radians)
x = int(x - intra_simulator.image.bounds.xmin)
y = int(y - intra_simulator.image.bounds.ymin)
intra_img = intra_simulator.image.array[y-128:y+128, x-128:x+128]

# extra image
x, y = extra_simulator.wcs.radecToxy(thx, thy, galsim.radians)
x = int(x - extra_simulator.image.bounds.xmin)
y = int(y - extra_simulator.image.bounds.ymin)
extra_img = extra_simulator.image.array[y-128:y+128, x-128:x+128]

fix, axes = plt.subplots(ncols=2, nrows=1, figsize=(6, 3), sharex=True, sharey=True)
axes[0].imshow(intra_img)
axes[1].imshow(extra_img)
plt.tight_layout()
plt.show()

In [None]:
# CWFS
cwfsConfigDir = os.path.join(getConfigDir(), "cwfs")
instDir = os.path.join(cwfsConfigDir, "instData")
inst = Instrument(instDir)
algoDir = os.path.join(cwfsConfigDir, "algo")

In [None]:
fieldXY = np.array([np.rad2deg(thx), np.rad2deg(thy)])
I1 = CompensableImage()
I2 = CompensableImage()
I1.setImg(fieldXY, DefocalType.Intra, image=intra_img.copy())
I2.setImg(fieldXY, DefocalType.Extra, image=extra_img.copy())
inst.config(CamType.LsstFamCam, I1.getImgSizeInPix(), announcedDefocalDisInMm=1.5)

fftAlgo = Algorithm(algoDir)
fftAlgo.config("fft", inst)
fftAlgo.runIt(I1, I2, "offAxis", tol=1e-3)

# There's probably a reset method somewhere, but it's fast enough to just
# reconstruct these...
I1 = CompensableImage()
I2 = CompensableImage()
I1.setImg(fieldXY, DefocalType.Intra, image=intra_img.copy())
I2.setImg(fieldXY, DefocalType.Extra, image=extra_img.copy())
inst.config(CamType.LsstFamCam, I1.getImgSizeInPix(), announcedDefocalDisInMm=1.5)

expAlgo = Algorithm(algoDir)
expAlgo.config("exp", inst)
expAlgo.runIt(I1, I2, "offAxis", tol=1e-3)

In [None]:
from matplotlib.ticker import MaxNLocator

fft_zk = fftAlgo.getZer4UpInNm()
exp_zk = expAlgo.getZer4UpInNm()
bzk = batoid.zernike(telescope, 0, 0, 622e-9, eps=0.61)*622
for i in range(4, 23):
    print(f"{i:2}  {exp_zk[i-4]:8.3f} nm  {fft_zk[i-4]:8.3f} nm  {bzk[i]:8.3f} nm")

plt.figure()
plt.plot(range(4, 23), fft_zk, label='fft')
plt.plot(range(4, 23), exp_zk, label='exp')
plt.plot(range(4, 23), bzk[4:], label='truth')
plt.legend()
plt.xlabel("Noll index")
plt.ylabel("Perturbation amplitude (nm)")
plt.axhline(0, c='k')
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
plt.show()

In [None]:
wf = fftAlgo.getWavefrontMapEsti()
plt.figure()
plt.imshow(wf)
plt.show()