In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import aotools
import scipy.interpolate as interp

pup_diam_p = 64   # diameter of pupil in pixels
pup_diam_c = 8.0   # diameter of pupil in metres
fft_width = 128    # width of fourier support (must be at least 2*pup_diam_p)

as2rad = np.pi/180/3600

pup = aotools.phaseFromZernikes([1],pup_diam_p) # circular pupil (borrowing zernike piston mode)

max_theta = 15*2**0.5 # maximum field angle, used to define meta-pupil widths

max_zern = 30 # maximum order of zernike function to use
zerns = aotools.zernikeArray(max_zern,pup_diam_p) # array of zernike functions

def phase_from_zerns(x: np.ndarray):
    """get phase from zernike coefficients for a full pupil.

    Args:
        x (np.ndarray): numpy array of coefficients shape of (max_zerns,)

    Returns:
        np.ndarray: phase (pup_diam,pup_diam)
    """
    return np.einsum("ijk,i->jk",zerns,x)

plt.matshow(phase_from_zerns(np.random.randn(max_zern)))
plt.colorbar()

In [None]:
alts = np.r_[-10000.0,-5000.0,0.0,5000.0,20000.0] # conjugation altitudes (can be negative)
num_alts = alts.shape[0] # number of layers
x_alts = np.random.randn(*[num_alts,max_zern])*1e-2 # random initialisation of aberrations over all layers

def phase_at_layer(x: np.ndarray, theta_x: float, theta_y: float, alt: float):
    """compute the phase at a particular layer projected onto the pupil in a given direction
    
    Args:
        x (np.ndarray): numpy array of coefficients for the layer (max_zerns,)
        theta_x (float): x field position on-sky (arcseconds)
        theta_y (float): y field position on-sky (arcseconds)
        alt (float): layer altitude (metres)

    Returns:
        np.ndarray: phase (pup_diam,pup_diam)
    """
    # get phase of whole meta-pupil
    phi = phase_from_zerns(x)
    # define coordinates (in pixels) to sample that phase (projected from pupil)
    yy,xx = np.mgrid[:pup_diam_p,:pup_diam_p]
    w = 1+2*np.abs(alt)*as2rad*max_theta/pup_diam_c
    xx = xx/w
    xx += pup_diam_p*(1-1/w)/2
    xx += pup_diam_p*alt*as2rad*theta_x/pup_diam_c/w
    yy = yy/(1+2*np.abs(alt)*as2rad*max_theta/pup_diam_c)
    yy += pup_diam_p*(1-1/w)/2
    yy += pup_diam_p*alt*as2rad*theta_y/pup_diam_c/w
    # original coordinates (in pixels)
    yy_og,xx_og = np.ogrid[:pup_diam_p,:pup_diam_p]
    # build interpolator function
    interp_func = interp.RegularGridInterpolator([xx_og.flatten(),yy_og.flatten()],phi,bounds_error=True)
    # evaluate interpolator function at projected coordinates
    try:
        return interp_func(np.array([xx.flatten(),yy.flatten()]).T).reshape([pup_diam_p,pup_diam_p])
    except ValueError:
        print(xx.min(),xx.max(),yy.min(),yy.max())
        raise ValueError("Requested phase outside of defined metapupils")

plt.matshow(phase_at_layer(x_alts[4],theta_x=0,theta_y=0.0,alt=alts[3]))
plt.colorbar()

In [None]:
def phase_in_direction(x: np.ndarray, theta_x: float, theta_y: float):
    """calculate integral of phase from all layers in a particular direction
    
    Args:
        x (np.ndarray): coefficients for all layers (num_alts,max_zern)
        theta_x (float): x field position on-sky (arcseconds)
        theta_y (float): y field position on-sky (arcseconds)
    
    Returns:
        np.ndarray: phase (pup_diam,pup_diam)
    """
    return np.sum([phase_at_layer(x[i],theta_x,theta_y,alt) for i,alt in enumerate(alts)],axis=0)*pup

plt.matshow(phase_in_direction(x_alts,theta_x=10.0,theta_y=0.0))
plt.colorbar()

In [None]:
as2rad

In [None]:
def image_from_phase(phi: np.ndarray, wavelength: float = 0.55, window_size=16):
    """compute image/PSF given phase at pupil
    
    Args:
        phi (np.ndarray): phase in microns (pup_diam_p,pup_diam_p)
        wavelength (float): wavelength in microns
        
    Returns:
        np.ndarray: image (fft_width,fft_width)
    """
    im = np.abs(np.fft.fftshift(np.fft.fft2(
            pup*np.exp(1j*2*np.pi*phi/wavelength),
            s=[fft_width,fft_width],norm="ortho")
         ))**2 
    return im[im.shape[0]//2-window_size//2:im.shape[0]//2+window_size//2,
             im.shape[0]//2-window_size//2:im.shape[0]//2+window_size//2]

# scaling factor to obtain PSF in "strehl" units
psf_scaling = image_from_phase(pup*0.0).max()

plt.matshow(image_from_phase(phase_in_direction(x_alts,theta_x=10.0,theta_y=0.0))/psf_scaling)
plt.colorbar()

In [None]:
# build tomographic system
n_src_x = 5
n_src = n_src_x**2
pos_x,pos_y = (np.mgrid[:n_src_x,:n_src_x]-(n_src_x-1)/2)/(n_src_x-1)*30
pos_x = pos_x.flatten()
pos_y = pos_y.flatten()
print(f"{pos_x=}")
print(f"{pos_y=}")

plt.figure()
plt.plot(pos_x,pos_y,"rx",label="sources")
plt.axis("square")
plt.xlabel("x-field Position - [arcsec]")
plt.ylabel("y-field Position - [arcsec]")

In [None]:
fig,ax = plt.subplots(n_src_x,n_src_x,figsize=[8,8])

# compute images at each target
imgs = []
for i in range(n_src):
    x,y = pos_x[i],pos_y[i]
    im = image_from_phase(phase_in_direction(x_alts,theta_x=x,theta_y=y))/psf_scaling
    imgs.append(im)
    ax.flatten()[i].imshow(im)
    ax.flatten()[i].set_xticks([])
    ax.flatten()[i].set_yticks([])
plt.tight_layout()
imgs_target = np.array(imgs)
print(imgs_target.shape)

In [None]:
def cost(x,target):
    """evaluate cost function at guess 'x'
    
    """
    # compute images at each target
    imgs = []
    for i in range(n_src):
        xp,yp = pos_x[i],pos_y[i]
        im = image_from_phase(phase_in_direction(x,theta_x=xp,theta_y=yp))/psf_scaling
        imgs.append(im)
    imgs = np.array(imgs)
    return np.append((imgs - target).flatten(),x/10)

from scipy.optimize import least_squares

x_opt = least_squares(lambda x: cost(x.reshape([num_alts,max_zern]),imgs_target),
                      x_alts.flatten()*0.0, verbose=2)["x"].reshape([num_alts,max_zern])

In [None]:
fig,ax = plt.subplots(n_src_x,n_src_x,figsize=[8,8])

# compute estimated images at each target (to evaluate convergence quality)
for i in range(n_src):
    x,y = pos_x[i],pos_y[i]
    im = image_from_phase(phase_in_direction(x_opt,theta_x=x,theta_y=y))/psf_scaling
    ax.flatten()[i].imshow(im)
    ax.flatten()[i].set_xticks([])
    ax.flatten()[i].set_yticks([])
plt.tight_layout()

In [None]:
fig,ax = plt.subplots(n_src_x,n_src_x,figsize=[8,8])

# compute residual images at each target (to evaluate convergence quality)
for i in range(n_src):
    x,y = pos_x[i],pos_y[i]
    im = image_from_phase(phase_in_direction(x_alts-x_opt,theta_x=x,theta_y=y))/psf_scaling
    ax.flatten()[i].imshow(im)
    ax.flatten()[i].set_xticks([])
    ax.flatten()[i].set_yticks([])
plt.tight_layout()

In [None]:
plt.figure()
plt.plot(x_alts.flatten())
plt.plot(x_opt.flatten())
plt.plot((x_alts-x_opt).flatten())


In [None]:
print(np.tan(as2rad*-15))
print(as2rad*-15)