# Wavefront Estimation
*David Thomas and Emily Li 2019/07/11*

**Abstract:** We implement the full forward-model-powered wavefront estimation.

**Table of Contents:**
- [Starter Code](#Starter-Code)
- [Research Questions](#Research-Questions)

## Starter Code
Our goal is to estimate the zernike coefficients of the wavefront, which characterize the state telescope, from donut images of stars. We develop a pipeline to realize this goal. Most of the components have been introduced through the previous notebooks. Now we must synthesize them. Here are the steps of the pipeline:
- Create telescope (potentially with a perturbation).
- Raytrace a donut (at the field center (0,0) for now).
- Find the zernikes by optimizing the forward model.
    - We use an off-the-shelf optimizer to minimize a cost function.
    - The cost function is the difference between forward modelled donut and raytraced donut.
- Evaluate the zernikes to get a wavefront image.
- Use the opd to get the true wavefront image.
- Examine the difference between the true and predicted wavefront.

#### Basic Script
The script below captures these steps and highlights the abstractions and functionality we will need.

In [None]:
telescope = loadTelescope()
donut = raytraceDonut(telescope)

predZ = findZernikes(donut)
predW = zernikesToWavefront(zernikes)
trueW = opd(telescope)

err = np.abs(trueW - predW)

Great! What do we need to build so that this code will actually work?

#### Support Functions
Let's start by importing the modules we will need.

In [None]:
%matplotlib inline

from scipy.interpolate import interp2d
from scipy.signal import correlate
import matplotlib.pyplot as plt
import numpy as np
import os
import batoid
import yaml
from batoid.utils import fieldToDirCos
import galsim

Below we copy over the functions we developed in previous notebooks. We will be able to use these without change.

In [None]:
def lsstPixels(lattice, pix=10e-6):
    """
    Pixelizes the lattice so that pixels correspond to LSST pixels.
    
    Parameters
    ----------
    lattice: batoid.Lattice
        The psf lattice to pixelize.
    pix: float
        The size in meters of an LSST pixel.
        
    Returns
    -------
    lattice: batoid.Lattice
        The re-pixelized lattice.
    
    Note
    ----
    We assume lattice1 has cartesian grid vectors.
    """
    nx, ny = lattice.array.shape
    dx, dy = lattice.primitiveVectors[0,0], lattice.primitiveVectors[1,1]
    z = lattice.array
    if dx < 0:
        z = z[::-1,:]
        dx = np.abs(dx)
    if dy < 0:
        z = z[:,::-1]
        dy = np.abs(dy)
    x = np.arange(nx) * dx
    y = np.arange(ny) * dy
    
    interp = interp2d(x, y, z)
    xprime = np.arange(nx * dx // pix) * pix
    yprime = np.arange(ny * dy // pix) * pix

    arr = interp(xprime, yprime)
    primitiveVectors = np.eye(2) * pix
    return batoid.Lattice(arr, primitiveVectors)

def compare(donutLattice, psfLattice):
    """
    Re-pixelizes, normalizes, and centers lattices so that they can be compared on the same footing.
    
    Parameters
    ----------
    donutLattice: batoid.Lattice
        The donut lattice.
    psfLattice: batoid.Lattice
        The psf lattice.
        
    Returns
    -------
    numpy.ndarray, numpy.ndarray
        The psf and donut images.
    
    Notes
    -----
    Assume psf is larger.
    """
    psfprime = lsstPixels(psfLattice)
    
    psf = psfprime.array
    donut = donutLattice.array
    
    # first pad smaller array so that both are the same size
    psf_nx = psf.shape[0]
    donut_nx = donut.shape[0]
    nx = max(psf_nx, donut_nx)
    larger = np.zeros((nx,nx))
    if psf_nx > donut_nx:
        larger[:donut_nx,:donut_nx] = donut
        donut = larger
    else:
        larger[:psf_nx,:psf_nx] = psf
        psf = larger
        
    # smart way to center them
    center = nx // 2
    corr = correlate(psf, donut, mode='same')
    idx = np.argmax(corr)
    xmatch = (idx // nx)
    ymatch = (idx % nx)
    dx = center - xmatch
    dy = center - ymatch
    psf = np.roll(np.roll(psf, dx, axis=0), dy, axis=1)
    
    donut /= donut.sum()
    psf /= psf.sum()
    
    return donut, psf

def loadTelescope(offset=1.5e-3):
    """
    Convenience function for loading a telescope.
    
    Parameters
    ----------
    offset: float
        The offset of the detector. Defaults to 1.5e-3 = 1.5mm.
    
    Returns
    -------
    batoid.Optic
        The loaded telescope.
    """
    LSST_g_fn = os.path.join(batoid.datadir, "LSST", "LSST_g.yaml")
    config = yaml.safe_load(open(LSST_g_fn))
    telescope = batoid.parse.parse_optic(config['opticalSystem'])
    telescope = telescope.withGloballyShiftedOptic('LSST.LSSTCamera.Detector', [0, 0, offset])
    return telescope

def opd(telescope, theta_x=0, theta_y=0, wavelength=500e-9, nx=256, projection='zemax', lattice=False):
    """
    Computes the optical path difference, or wavefront, of a telescope.
    
    Parameters
    ----------
    telescope: batoid.Optic
        The telescope for which to compute wavefront.
    theta_x, theta_y : float
        Field of incoming rays (gnomonic projection). Default to 0.
    wavelength: float
        The Wavelength of light. Defaults to 500e-9 = 500nm.
    nx: int
        Number of grid points in each dimension. Defaults to 2048.
    projection : {'postel', 'zemax', 'gnomonic', 'stereographic', 'lambert', 'orthographic'}
        Projection used to convert field angle to direction cosines.
    lattice : bool, optional
        If true, then decenter the grid so it spans (-N/2, N/2+1), as appropriate
        for Fourier transforms.
        
    Returns
    -------
    batoid.Lattice
        The wavefront, or optical path difference, of the telescope.
    """
    flux = 1
    dirCos = fieldToDirCos(theta_x, theta_y, projection=projection)
    rays = batoid.rayGrid(
        telescope.dist/dirCos[2], telescope.pupilSize,
        dirCos[0], dirCos[1], -dirCos[2],
        nx, wavelength, flux, telescope.inMedium, lattice=lattice)
    
    # chief ray index.  works if lattice=True and nx is even,
    # or if lattice=False and nx is odd
    cridx = (nx//2)*nx+nx//2
    
    telescope.traceInPlace(rays, outCoordSys=batoid.globalCoordSys)
    spherePoint = rays[cridx].r
    
    # We want to place the vertex of the reference sphere one radius length away from the
    # intersection point.  So transform our rays into that coordinate system.
    radius = np.hypot(telescope.sphereRadius, np.hypot(spherePoint[0], spherePoint[1]))
    transform = batoid.CoordTransform(
            batoid.globalCoordSys, batoid.CoordSys(spherePoint+np.array([0,0,radius])))
    transform.applyForwardInPlace(rays)

    sphere = batoid.Sphere(-radius)
    sphere.intersectInPlace(rays) 
    t0 = rays[cridx].t
    arr = np.ma.masked_array(t0-rays.t, mask=rays.vignetted).reshape(nx, nx)

    primitiveVectors = np.vstack([[telescope.pupilSize/nx, 0], [0, telescope.pupilSize/nx]])
    return batoid.Lattice(arr, primitiveVectors)

def raytraceDonut(telescope, nphot=int(1e6), theta_x=0, theta_y=0, wavelength=500e-9):
    """Simulate a donut image by raytracing photons through telescope.
    
    Parameters
    ----------
    telescope: batoid.Optic
        The telescope to raytrace through.
    nphot: int
        The number of photons to raytrace. Defaults to 1 million.
    theta_x: float
        The x field position. Defaults to 0.
    theta_y: float
        The y field position. Defaults to 0.
    wavelength: float
        The wavelength of light to use. Defaults to 500e-9 = 500 nm.
    
    Returns
    -------
    batoid.Lattice
        The donut image.
    """
    flux = 1
    xcos, ycos, zcos = batoid.utils.gnomonicToDirCos(theta_x, theta_y)
    rays = batoid.uniformCircularGrid(
        telescope.dist, 
        telescope.pupilSize/2, 
        telescope.pupilSize*telescope.pupilObscuration/2,
        xcos, ycos, -zcos,
        nphot, wavelength, flux,
        telescope.inMedium)
    telescope.traceInPlace(rays)
    rays.trimVignettedInPlace()
    
    xcent, ycent = np.median(rays.x), np.median(rays.y)
    pix = 10e-6
    width = 192 * pix
    
    xedges = np.arange(xcent-width/2, xcent+width/2, pix)
    yedges = np.arange(ycent-width/2, ycent+width/2, pix)
    # flip here because 1st dimension corresponds to y-dimension in bitmap image
    result, _, _ = np.histogram2d(rays.y, rays.x, bins=[yedges, xedges])
    
    primitiveX = np.array([[pix,0],[0,pix]])
    return batoid.Lattice(result, primitiveX)

In [None]:
1.5mm / 10.31 

We will need to write three new support functions. I have implemented the cost function for you. You will need to implement 'fftPSF'. This function will be almost identical to the function with the same name in the [ForwardModelingDonuts.ipynb](https://github.com/davidthomas5412/ForwardModelingLSSTDonuts/blob/master/notebooks/ForwardModelingDonuts.ipynb) notebook, but there are two differences described in the comments below. 

You will also need to fill in 'zernikesToWavefront'. This function takes a vector of zernike coefficients as input and produces the corresponding image. This will be similar to the work you did in [ZernikePolynomials.ipynb](https://github.com/davidthomas5412/ForwardModelingLSSTDonuts/blob/master/notebooks/ZernikePolynomials.ipynb). Remember to use the annular form of the zernike polynomials. Also to get the correct size of the image - number of pixels to use etc. - a number of geometric parameters are needed. Note that the size of a donut is (focalLength / diameter) * (absOffset / pixSize).

**Problem 1:** Why does this give the diameter of a donut?

**Problem 2:** Fill in the code below.

In [None]:
def cost(donut1, donut2):
    """
    The L2 norm of the two donut images.
    
    Parameters
    ----------
    donut1: numpy.ndarray
        The image of donut 1.
    donut2: numpy.ndarray
        The image of donut 2.
        
    Returns
    -------
    float
        The L2 difference.
    """
    return np.sum((donut1 - donut2) ** 2)

def fftPSF(zernikes, theta_x=0, theta_y=0, wavelength=500e-9, nx=2500, projection='zemax', pad_factor=1):
    # TODO: your code here.
    
    # Note: This function will be a little different than the one in ForwardModelingDonuts.ipynb.
    # First it will be a function of zernikes as opposed to the telescope. Then you can evaluate 
    # the zernikes to get the wavefront ('wf' in the other code) and use the nominal state of the telescope 
    # with batoid.psf.dkdu. 
    
    
def zernikesToWavefront(zernikes, absOffset=1.5e-3, pixSize=10e-6, focalLength=10.31, diameter=8.31):
    """
    Produces the image of these zernikes.
    
    Parameters
    ----------
    zernikes: numpy.ndarray
        Vector of zernike coefficients.
    absOffset: float
        The absolute value of the offset from focus in meters. Defaults to 1.5e-3 = 1.5mm.
    pixSize: float
        The size of a LSST pixel in meters. Defaults to 10e-6 = 10um.
    focalLength: float
        The LSST focal length in meters. Defaults to 10.31m.
    diameter: float
        The outer diameter of the primary mirror M1 in meters. Defaults to 8.31m.

    Returns
    -------
    numpy.ndarray
        The corresponding image.
    """
    # TODO: your code here
    
    # Note the diameter of a donut is (diameter / focalLength ) * (absOffset / pixSize)

#### Wavefront Estimation

Here is how we put these pieces together to find the zernike coefficients from a raytraced donut image.

In [None]:
from scipy.optimize import minimize

def findZernikes(raytracedDonut):
    """
    Estimates the zernikes of the wavefront from a donut.
    
    Parameters
    ----------
    raytracedDonut: numpy.ndarray
        The donut image.
    
    Returns
    -------
    numpy.ndarray
        The vector of zernike coefficients.
    """
    # function to optimize
    def func(zernikes):
        psfLattice = fftPSF(zernikes)
        donut, estimate = compare(raytracedDonut, psfLattice)
        return cost(donut, estimate)
    
    w0 = np.zeros(22)
    res = minimize(w0, func)
    
    # raise Error if optimization fails
    if not res.success:
        print('Message: ', res.message)
        print('Status: ', res.status)
        raise RuntimeError('Optimizer Failed')
        
    zernikes = res.x
    return zernikes

**Problem 3:** Write and debug code to get the 'basic script' to work.

In [None]:
telescope = loadTelescope()
donut = raytraceDonut(telescope)

predZ = findZernikes(donut)
predW = zernikesToWavefront(zernikes)
trueW = opd(telescope)

err = np.abs(trueW - predW)

## Research Questions
Now that we have finished the pipeline we can embark on many research questions. These fall into a couple themes.
- Runtime: What is the runtime of the different pieces of the pipeline? How can we make the optimization faster? Can we decrease the number of photons?
- Accuracy: How accurate is the pipeline? Is it more accurate on certain perturbations? How does the number of photons used in the raytracing impact accuracy? What if we add poisson background noise to the donut image?
- Optimization: What does the optimization surface look like? Which scipy minimizer is best? What are the trade-offs? Can we trade accuracy for better runtime?
- Artifacts: How do fourier transform artifacts (the rings in the fftPSF) impact the pipeline? Can we remove them?
- Visualization: How can we visualize the optimization surface? What diagnostic metrics can we use to assess the optimization (ex. hessian matrix)?
- Generalization: How can we generalize this technique to other field positions? How does it compare to other approaches?

We can start with the most immediate questions, and progress to more open ended questions. Congratulations on completing the basic training! Welcome to the journey that is research!