In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
import torch as t
import aotools
import scipy.interpolate as interp

pup_diam_p = 64   # diameter of pupil in pixels
pup_diam_c = 8.0   # diameter of pupil in metres
fft_width = 128    # width of fourier support (must be at least 2*pup_diam_p)

as2rad = t.pi/180/3600

pup = t.tensor(aotools.phaseFromZernikes([1],pup_diam_p),dtype=t.float32) # circular pupil (borrowing zernike piston mode)

max_theta = 15*2**0.5 # maximum field angle, used to define meta-pupil widths

max_zern = 30 # maximum order of zernike function to use
zerns = t.tensor(aotools.zernikeArray(max_zern,pup_diam_p),dtype=t.float32) # array of zernike functions

def phase_from_zerns(x: t.tensor):
    """get phase from zernike coefficients for a full pupil.

    Args:
        x (np.ndarray): numpy array of coefficients shape of (max_zerns,)

    Returns:
        np.ndarray: phase (pup_diam,pup_diam)
    """
    return t.einsum("ijk,i->jk",zerns,x)

x_ell = t.randn(max_zern)
plt.matshow(phase_from_zerns(x_ell))
plt.colorbar()

In [None]:
# build phase to zernike projector
z = zerns[:,pup==1].T
p2z = t.linalg.solve(z.T @ z, z.T)
plt.figure()
plt.plot(x_ell,"bx-",label="true zernikes values")
plt.plot(p2z @ phase_from_zerns(x_ell)[pup==1],"r+-",label="projected zernike values")
plt.legend()
plt.xlabel("Mode number")
plt.ylabel("Mode value")


In [None]:
alts = t.tensor([-10000.0,-5000.0,0.0,5000.0,10000.0]) # conjugation altitudes (can be negative)
num_alts = alts.shape[0] # number of layers
x_alts = t.randn(*[num_alts,max_zern])*1e-2 # random initialisation of aberrations over all layers

def phase_at_layer(x: t.tensor, theta_x: float, theta_y: float, alt: float):
    """compute the phase at a particular layer projected onto the pupil in a given direction
    
    Args:
        x (np.ndarray): numpy array of coefficients for the layer (max_zerns,)
        theta_x (float): x field position on-sky (arcseconds)
        theta_y (float): y field position on-sky (arcseconds)
        alt (float): layer altitude (metres)

    Returns:
        np.ndarray: phase (pup_diam,pup_diam)
    """
    import numpy as np
    # get phase of whole meta-pupil
    phi = phase_from_zerns(x).numpy()
    # define coordinates (in pixels) to sample that phase (projected from pupil)
    yy,xx = np.mgrid[:pup_diam_p,:pup_diam_p]*1.0
    xx = xx.flatten()
    yy = yy.flatten()
    # scale down the sampling coordinates to correspond to the meta-pupil sub-region
    w = 1+2*t.abs(alt)*as2rad*max_theta/pup_diam_c
    xx /= w
    yy /= w
    # shift the coordinates according to altitude and field position
    sc = pup_diam_p*as2rad/pup_diam_c/w
    xx += (max_theta*t.abs(alt)+theta_x*alt)*sc
    yy += (max_theta*t.abs(alt)+theta_y*alt)*sc
    xx = xx.numpy()
    yy = yy.numpy()
    # original coordinates (in pixels)
    yy_og,xx_og = t.arange(pup_diam_p),t.arange(pup_diam_p)
    # build interpolator function
    interp_func = interp.RegularGridInterpolator([xx_og,yy_og],phi,bounds_error=True)
    # evaluate interpolator function at projected coordinates
    try:
        return t.tensor(interp_func(np.array([xx,yy]).T).reshape([pup_diam_p,pup_diam_p]),dtype=t.float32)
    except ValueError:
        print(xx.min(),xx.max(),yy.min(),yy.max())
        raise ValueError("Requested phase outside of defined metapupils")

plt.matshow(phase_at_layer(x_alts[4],theta_x=0,theta_y=0.0,alt=alts[4]))
plt.colorbar()

In [None]:
def phase_in_direction(x: t.tensor, theta_x: float, theta_y: float):
    """calculate integral of phase from all layers in a particular direction
    
    Args:
        x (np.ndarray): coefficients for all layers (num_alts,max_zern)
        theta_x (float): x field position on-sky (arcseconds)
        theta_y (float): y field position on-sky (arcseconds)
    
    Returns:
        np.ndarray: phase (pup_diam,pup_diam)
    """
    return t.cat([phase_at_layer(xi,theta_x,theta_y,alt)[None,...] for xi,alt in zip(x,alts)]).sum(dim=0)*pup

plt.matshow(phase_in_direction(x_alts,theta_x=10.0,theta_y=0.0))
plt.colorbar()

In [None]:
def image_from_phase(phi: t.tensor, wavelength: float = 0.55, window_size=16):
    """compute image/PSF given phase at pupil
    
    Args:
        phi (np.ndarray): phase in microns (pup_diam_p,pup_diam_p)
        wavelength (float): wavelength in microns
        
    Returns:
        np.ndarray: image (fft_width,fft_width)
    """
    im = t.abs(t.fft.fftshift(t.fft.fft2(
            pup*t.exp(1j*2*t.pi*phi/wavelength),
            s=[fft_width,fft_width],norm="ortho")
         ))**2 
    return im[im.shape[0]//2-window_size//2:im.shape[0]//2+window_size//2,
             im.shape[0]//2-window_size//2:im.shape[0]//2+window_size//2]

# scaling factor to obtain PSF in "strehl" units
psf_scaling = image_from_phase(pup*0.0).max()

plt.matshow(image_from_phase(phase_in_direction(x_alts,theta_x=10.0,theta_y=0.0))/psf_scaling)
plt.colorbar()

In [None]:
# build tomographic system
n_src_x = 8
n_src = n_src_x**2
pos_x,pos_y = t.meshgrid(t.arange(n_src_x),t.arange(n_src_x),indexing="ij")
pos_x = (pos_x-(n_src_x-1)/2)/(n_src_x-1)*30
pos_y = (pos_y-(n_src_x-1)/2)/(n_src_x-1)*30
pos_x = pos_x.flatten()
pos_y = pos_y.flatten()

plt.figure()
plt.plot(pos_x,pos_y,"rx",label="sources")
plt.axis("square")
plt.xlabel("x-field Position - [arcsec]")
plt.ylabel("y-field Position - [arcsec]")

# build tomographic projector from modes in alt to modes in directions
alt_to_dir = []
x_tmp = t.zeros([num_alts,max_zern])
for px,py in zip(pos_x,pos_y):
    tar_tmp = []
    for ell in range(num_alts):
        for i in range(max_zern):
            x_tmp *= 0.0
            x_tmp[ell,i] = 1.0
            tar_tmp.append(p2z @ phase_in_direction(x_tmp,px,py)[pup==1])
    alt_to_dir.append(t.cat([ti[None,...] for ti in tar_tmp]))
alt_to_dir = t.cat([ti[None,...] for ti in alt_to_dir])
print(alt_to_dir.shape)

In [None]:
fig,ax = plt.subplots(n_src_x,n_src_x,figsize=[8,8])
# compute images at each target
dir_modes = t.einsum("ijk,j->ik",alt_to_dir,x_alts.flatten())
imgs = []
for i in range(n_src):
    #x,y = pos_x[i],pos_y[i]
    #im = image_from_phase(phase_in_direction(x_alts,theta_x=x,theta_y=y))/psf_scaling
    im = image_from_phase(phase_from_zerns(dir_modes[i]))/psf_scaling
    imgs.append(im)
    ax.flatten()[i].imshow(im)
    ax.flatten()[i].set_xticks([])
    ax.flatten()[i].set_yticks([])
plt.tight_layout()

offset = x_alts*0.0
offset[alts==0,3] = 1e-1 # defocus on ground layer only
fig,ax = plt.subplots(n_src_x,n_src_x,figsize=[8,8])
# compute images at each target
dir_modes = t.einsum("ijk,j->ik",alt_to_dir,(x_alts+offset).flatten())
for i in range(n_src):
    #x,y = pos_x[i],pos_y[i]
    #im = image_from_phase(phase_in_direction(x_alts,theta_x=x,theta_y=y))/psf_scaling
    im = image_from_phase(phase_from_zerns(dir_modes[i]))/psf_scaling
    imgs.append(im)
    ax.flatten()[i].imshow(im)
    ax.flatten()[i].set_xticks([])
    ax.flatten()[i].set_yticks([])
plt.tight_layout()
imgs_target = t.cat([im[None,...] for im in imgs])
print(imgs_target.shape)

In [None]:
def cost(x,target):
    """evaluate cost function at guess 'x'
    
    """
    # compute images at each target
    dir_modes = t.einsum("ijk,j->ik",alt_to_dir,x.flatten())
    imgs = [image_from_phase(phase_from_zerns(z))/psf_scaling for z in dir_modes]
    dir_modes = t.einsum("ijk,j->ik",alt_to_dir,(x+offset).flatten())
    imgs += [image_from_phase(phase_from_zerns(z))/psf_scaling for z in dir_modes]
    imgs = t.cat([im[None,...] for im in imgs])
    return t.cat([(imgs - target).flatten(),x.flatten()/10])

from scipy.optimize import least_squares
from torch import optim

x_opt = (x_alts*0.0).clone().detach().requires_grad_(True)
optimizer = optim.SGD([x_opt], lr=1e-5)

fig,ax = plt.subplots(1,1)
for e in range(1000):
  loss = (cost(x_opt,imgs_target)**2).sum()
  print(loss)
  ax.plot(e,loss.detach().numpy(),"r.")
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
x_opt = x_opt.detach() 
ax.set_yscale("log")
#x_opt = least_squares(lambda x: cost(x.reshape([num_alts,max_zern]),imgs_target),
#                      x_alts.flatten()*0.0, verbose=2)["x"].reshape([num_alts,max_zern])

In [None]:
fig,ax = plt.subplots(n_src_x,n_src_x,figsize=[8,8])

# compute estimated images at each target (to evaluate convergence quality)
for i in range(n_src):
    x,y = pos_x[i],pos_y[i]
    im = image_from_phase(phase_in_direction(x_opt.detach(),theta_x=x,theta_y=y))/psf_scaling
    ax.flatten()[i].imshow(im)
    ax.flatten()[i].set_xticks([])
    ax.flatten()[i].set_yticks([])
plt.tight_layout()

In [None]:
fig,ax = plt.subplots(n_src_x,n_src_x,figsize=[8,8])

# compute residual images at each target (to evaluate convergence quality)
for i in range(n_src):
    x,y = pos_x[i],pos_y[i]
    im = image_from_phase(phase_in_direction(x_alts-x_opt.detach(),theta_x=x,theta_y=y))/psf_scaling
    ax.flatten()[i].imshow(im)
    ax.flatten()[i].set_xticks([])
    ax.flatten()[i].set_yticks([])
plt.tight_layout()

In [None]:
plt.figure()
plt.plot(x_alts.flatten())
plt.plot(x_opt.flatten())
plt.plot((x_alts-x_opt).flatten())
