In [None]:
import matplotlib.pyplot as plt
import torch as t
import math
import aotools
import scipy.interpolate as interp

if get_ipython().__class__.__name__ == "ZMQInteractiveShell" :
    %matplotlib widget
    from tqdm.notebook import tqdm
elif get_ipython().__class__.__name__ == "TerminalInteractiveShell":
    plt.ion()
    from tqdm import tqdm
else:
    raise RuntimeError("Unknown shell name: {}\nExpecting one of:\n"
                       "\tZMQInteractiveShell\n"
                       "\tTerminalInteractiveShell\n")

In [None]:

t.pi = math.pi

t.set_num_threads(10)

device = "cuda"

pup_diam_p = 64   # diameter of pupil in pixels
pup_diam_c = 8.0   # diameter of pupil in metres
fft_width = 128    # width of fourier support (must be at least 2*pup_diam_p)

as2rad = t.pi/180/3600

pup = t.tensor(aotools.phaseFromZernikes([1],pup_diam_p),dtype=t.float32,device=device) # circular pupil (borrowing zernike piston mode)

max_theta = 15*2**0.5 # maximum field angle, used to define meta-pupil widths

max_zern = 100 # maximum order of zernike function to use
zerns = t.tensor(aotools.zernikeArray(max_zern,pup_diam_p),dtype=t.float32,device=device) # array of zernike functions

def phase_from_zerns(x: t.tensor):
    """get phase from zernike coefficients for a full pupil.

    Args:
        x (np.ndarray): numpy array of coefficients shape of (max_zerns,)

    Returns:
        np.ndarray: phase (pup_diam,pup_diam)
    """
    if len(x.shape)==1:
        return t.einsum("ijk,i->jk",zerns,x)
    else:
        return t.einsum("ijk,...i->...jk",zerns,x)

x_ell = t.randn(max_zern,device=device)
plt.matshow(phase_from_zerns(x_ell).cpu())
plt.colorbar()

In [None]:
# build phase to zernike projector
z = zerns[:,pup==1].T
p2z = t.linalg.solve(z.T @ z, z.T)
plt.figure()
plt.plot(x_ell.cpu(),"bx-",label="true zernikes values")
plt.plot((p2z @ phase_from_zerns(x_ell)[pup==1]).cpu(),"r+-",label="projected zernike values")
plt.legend()
plt.xlabel("Mode number")
plt.ylabel("Mode value")


In [None]:
alts = t.linspace(-10000.0,10000,5,device=device) # conjugation altitudes (can be negative)
num_alts = alts.shape[0] # number of layers
x_alts = t.randn(*[num_alts,max_zern],device=device)*4e-3 # random initialisation of aberrations over all layers

def phase_at_layer(x: t.tensor, theta_x: float, theta_y: float, alt: float):
    """compute the phase at a particular layer projected onto the pupil in a given direction
    
    Args:
        x (np.ndarray): numpy array of coefficients for the layer (max_zerns,)
        theta_x (float): x field position on-sky (arcseconds)
        theta_y (float): y field position on-sky (arcseconds)
        alt (float): layer altitude (metres)

    Returns:
        np.ndarray: phase (pup_diam,pup_diam)
    """
    import numpy as np
    # get phase of whole meta-pupil
    phi = phase_from_zerns(x).cpu().numpy()
    # define coordinates (in pixels) to sample that phase (projected from pupil)
    yy,xx = np.mgrid[:pup_diam_p,:pup_diam_p]*1.0
    xx = t.tensor(xx.flatten(),device=device)
    yy = t.tensor(yy.flatten(),device=device)
    # scale down the sampling coordinates to correspond to the meta-pupil sub-region
    w = 1+2*t.abs(alt)*as2rad*max_theta/pup_diam_c
    xx /= w
    yy /= w
    # shift the coordinates according to altitude and field position
    sc = pup_diam_p*as2rad/pup_diam_c/w
    xx += (max_theta*t.abs(alt)+theta_x*alt)*sc
    yy += (max_theta*t.abs(alt)+theta_y*alt)*sc
    xx = xx.cpu().numpy()
    yy = yy.cpu().numpy()
    # original coordinates (in pixels)
    yy_og,xx_og = t.arange(pup_diam_p),t.arange(pup_diam_p)
    # build interpolator function
    interp_func = interp.RegularGridInterpolator([xx_og,yy_og],phi,bounds_error=True)
    # evaluate interpolator function at projected coordinates
    try:
        return t.tensor(interp_func(np.array([xx,yy]).T).reshape([pup_diam_p,pup_diam_p]),dtype=t.float32,device=device)
    except ValueError:
        print(xx.min(),xx.max(),yy.min(),yy.max())
        raise ValueError("Requested phase outside of defined metapupils")

plt.matshow(phase_at_layer(x_alts[4],theta_x=0,theta_y=0.0,alt=alts[4]).cpu())
plt.colorbar()

In [None]:
def phase_in_direction(x: t.tensor, theta_x: float, theta_y: float):
    """calculate integral of phase from all layers in a particular direction
    
    Args:
        x (np.ndarray): coefficients for all layers (num_alts,max_zern)
        theta_x (float): x field position on-sky (arcseconds)
        theta_y (float): y field position on-sky (arcseconds)
    
    Returns:
        np.ndarray: phase (pup_diam,pup_diam)
    """
    return t.cat([phase_at_layer(xi,theta_x,theta_y,alt)[None,...] for xi,alt in zip(x,alts)]).sum(dim=0)*pup

plt.matshow(phase_in_direction(x_alts,theta_x=10.0,theta_y=0.0).cpu())
plt.colorbar()

In [None]:
# build tomographic system
n_src_x = 8
n_src = n_src_x**2
pos_x,pos_y = t.meshgrid(t.arange(n_src_x),t.arange(n_src_x),indexing="ij")
pos_x = (pos_x-(n_src_x-1)/2)/(n_src_x-1)*30
pos_y = (pos_y-(n_src_x-1)/2)/(n_src_x-1)*30
pos_x = pos_x.flatten()
pos_y = pos_y.flatten()

plt.figure()
plt.plot(pos_x,pos_y,"rx",label="sources")
plt.axis("square")
plt.xlabel("x-field Position - [arcsec]")
plt.ylabel("y-field Position - [arcsec]")

# build tomographic projector from modes in alt to modes in directions
alt_to_dir = []
x_tmp = t.zeros([num_alts,max_zern],device=device)
for px,py in tqdm(list(zip(pos_x,pos_y))):
    tar_tmp = []
    for ell in range(num_alts):
        for i in range(max_zern):
            x_tmp *= 0.0
            x_tmp[ell,i] = 1.0
            tar_tmp.append(p2z @ phase_in_direction(x_tmp,px,py)[pup==1])
    alt_to_dir.append(t.cat([ti[None,...] for ti in tar_tmp]))
alt_to_dir = t.cat([ti[None,...] for ti in alt_to_dir])
print(alt_to_dir.shape)

In [None]:
def image_from_phase(phi: t.tensor, wavelength: float = 0.55, window_size=32):
    """compute images/PSFs given phases at pupil
    
    Args:
        phi (np.ndarray): phase in microns (n_phase,pup_diam_p,pup_diam_p)
        wavelength (float) or array-like: wavelength in microns
        
    Returns:
        np.ndarray: image (fft_width,fft_width)
    """
    only_one = False
    if len(phi.shape)==2:
        phi = phi[None,:,:]
        only_one = True
    im = t.abs(t.fft.fftshift(t.fft.fft2(
            pup[None,:,:]*t.exp(1j*2*t.pi*phi/wavelength),
            s=[fft_width,fft_width],norm="ortho",dim=[1,2]),dim=[1,2]
         ))**2
    im = im[:,im.shape[1]//2-window_size//2:im.shape[1]//2+window_size//2,
             im.shape[1]//2-window_size//2:im.shape[1]//2+window_size//2]
    if only_one == True:
        return im[0]
    else:
        return im

# scaling factor to obtain PSF in "strehl" units
psf_scaling = image_from_phase(pup*0.0).max()

plt.matshow((image_from_phase(phase_in_direction(x_alts,theta_x=10.0,theta_y=0.0))/psf_scaling).cpu())
plt.colorbar()

In [None]:
# compute images at each target
offset = x_alts*0.0
offset[alts==0,3] = 1e-1 # defocus on ground layer only
dir_modes = t.cat([
    t.einsum("ijk,j->ik",alt_to_dir,x_alts.flatten()),
    t.einsum("ijk,j->ik",alt_to_dir,(x_alts+offset).flatten())
])
imgs_target = image_from_phase(phase_from_zerns(dir_modes))/psf_scaling

fig,ax = plt.subplots(n_src_x*2,n_src_x,figsize=[8,16])
for i,im in enumerate(imgs_target):
    ax.flatten()[i].imshow(im.cpu())
    ax.flatten()[i].set_xticks([])
    ax.flatten()[i].set_yticks([])
plt.tight_layout()

print(imgs_target.shape)

In [None]:
def cost(x,target):
    """evaluate cost function at guess 'x'
    
    """
    # compute images at each target
    dir_modes = t.cat([
      t.einsum("ijk,j->ik",alt_to_dir,x.flatten()),
      t.einsum("ijk,j->ik",alt_to_dir,(x+offset).flatten())
    ])
    imgs = image_from_phase(phase_from_zerns(dir_modes))/psf_scaling
    return t.cat([(imgs - target).flatten(),x.flatten()/10])

start = t.cuda.Event(enable_timing=True)
end = t.cuda.Event(enable_timing=True)
start.record()
for _ in range(1000):
  cost(x_alts,imgs_target)
# whatever you are timing goes here
end.record()
# Waits for everything to finish running
t.cuda.synchronize()
print(f"cost func takes approx {start.elapsed_time(end)/1000:0.3f} ms to eval")


In [None]:
from torch import optim

x_opt = t.zeros(x_alts.shape,requires_grad=True,device="cuda")
#optimizer = optim.SGD([x_opt], lr=1e-5)
optimizer = optim.Adam([x_opt], lr=1e-3)

fig,ax = plt.subplots(1,1)
from tqdm.notebook import tqdm
for e in tqdm(range(5000)):
  loss = (t.square(cost(x_opt,imgs_target))).sum()
  ax.plot(e,loss.detach().cpu().numpy(),"r.")
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  dir_modes = t.einsum("ijk,j->ik",alt_to_dir,(x_alts-x_opt).flatten())
  sr = t.mean((image_from_phase(phase_from_zerns(dir_modes))/psf_scaling).max(dim=1)[0].max(dim=1)[0]).detach().cpu().numpy()
  print("sr est: {:1.5f}".format(sr))
  if(sr>0.999):
    break
  
x_opt = x_opt.detach() 
ax.set_yscale("log")

In [None]:
fig,ax = plt.subplots(n_src_x,n_src_x,figsize=[8,8])

sr = t.zeros([0],device=device)
# compute estimated images at each target (to evaluate convergence quality)
for i in range(n_src):
    x,y = pos_x[i],pos_y[i]
    im = image_from_phase(phase_in_direction(x_opt.detach(),theta_x=x,theta_y=y))/psf_scaling
    sr = t.cat([sr,im.max()[None]])
    ax.flatten()[i].imshow(im.cpu())
    ax.flatten()[i].set_xticks([])
    ax.flatten()[i].set_yticks([])
print(f"mean strehl: {t.mean(sr).cpu().numpy():0.5f}")
plt.tight_layout()

In [None]:
fig,ax = plt.subplots(n_src_x,n_src_x,figsize=[8,8])

# compute residual images at each target (to evaluate convergence quality)
sr = t.zeros([0],device=device)
for i in range(n_src):
    x,y = pos_x[i],pos_y[i]
    im = image_from_phase(phase_in_direction(x_alts-x_opt.detach(),theta_x=x,theta_y=y))/psf_scaling
    sr = t.cat([sr,im.max()[None]])
    ax.flatten()[i].imshow(im.cpu())
    ax.flatten()[i].set_xticks([])
    ax.flatten()[i].set_yticks([])
print(f"mean strehl: {t.mean(sr).cpu().numpy():0.5f}")
plt.tight_layout()

In [None]:
plt.figure()
plt.plot(x_alts.flatten().cpu())
plt.plot(x_opt.flatten().cpu())
plt.plot((x_alts-x_opt).flatten().cpu())