### Example 3: Closed-Loop line of sight capture

This example shows how to used lowfssim to model wavefront sensing and control in a dynamical system.  For CGI, the exported forces of the actuators are negligible, and the cross dynamics (i.e., undesirable disturbance introduced by moving the FSM) are ignored in this model.  The Z4..Z11 WFS/C channels are also ignored for purposes of this example.

The lowfs-disturbances file was generated by a controls engineer at JPL, and contains five minutes of ACS residual wander and (simplified) RWA jitter, at 1kHz temporal sampling.  The file contains a single variable `a`, which is 300k points x 4 (time, acsx, acsy, rwa).  For purposes of this notebook we pretend the RWA jitter is identical in the two Cartesian axes, which is not completely realistic.  However, the five-tone RWA model prepared for dissemination here is overly simplified anyway, for purposes of this teaching notebook.

In [None]:
import json
import pickle
from functools import partial
from pathlib import Path


import numpy as np

from scipy.io import loadmat, savemat

from lowfsc import props, control

from lowfsc.data import DesignData
from lowfsc.spectral import StellarDatabase, LOWFS_BANDPASS, ThroughputDatabase
from lowfsc.automate import flt_chop_seq

from lowfsc.emccd import EMCCD

from matplotlib import pyplot as plt

from tqdm import tqdm

In [None]:
# a utility function to generate starting points that are both
# aesthetically pleasing, and optimally sample the messy Z2/Z3
# response of the estimator
def pol2cart(r,t):
    x = r * np.cos(t)
    y = r * np.sin(t)
    return x, y
 
def chebygauss_quadrature_xy(rings, radius=1, spokes=-1, center=(0,0)):
    """Use Chebyshev-Gauss quadrature to sample a polar coordinate grid.
   
    Parameters
    ----------
    rings : `int`
        number of rings to use; degree of radial sampling
    radius : `float`
        radius of the grid, process units
    spokes : `int`, optional
        number of spokes if -1, use rings*2 + 1
    center : `tuple`
        (x,y) center point of the grid
   
    Returns
    -------
    `numpy.ndarray`
        Chebyshev-Gauss-Lobatto points (x,y)
   
    """
    # domain [0,1]
    a = 0
    b = 1
    if spokes == -1:
        spokes = 2*rings + 1
    n = rings
    r = [] # r = radial variable
    for k in range(1,n+1):
        num = 2*k - 1
        den = 2 * n
        term1 = 0.5 # 1/2 (a+b) == 0.5, fixed a,b
        # prefix to term2 also == term1
       
        # term2 == term1
        xk = term1 + term1 * np.cos((num/den) * np.pi)
        r.append(xk*radius)
        
    o_x = np.empty(spokes*len(r))
    o_y = np.empty(spokes*len(r))
    psi = (5 ** .5 + 1) / 2 # golden ratio
    lower = 0
    shift = spokes
    upper = shift
    for k, rr in enumerate(r):
        Delta = 2*np.pi / spokes
        # arange term = "j"
        # Greg forbes' theta = (j+k/psi)Delta; Delta = 2pi/J
        j = np.arange(1, spokes+1, dtype=np.float64)
        kk = k + 1
        t = (j + (kk/psi)) * Delta
        x, y = pol2cart(rr,t)
        o_x[lower:upper] = x
        o_y[lower:upper] = y
        upper += shift
        lower += shift
   
    return o_x, o_y

pt_x, pt_y = chebygauss_quadrature_xy(21, 80)
r = np.hypot(pt_x, pt_y)
mask = r > 25
pt_x = pt_x[mask]
pt_y = pt_y[mask]
pt_x = list(reversed(pt_x)) # start in the middle for aesthetic reasons, want to see a schmattering of success
pt_y = list(reversed(pt_y)) # before many failures
plt.scatter(pt_x, pt_y)
plt.gca().set(aspect='equal', xlabel='X, mas ots', ylabel='Y, mas ots')

In [None]:
from prysm.conf import config
from prysm.fttools import mdft

mdft.clear()

import cupy as cp
from cupyx.scipy import fft as cpfft

cp.cuda.runtime.setDevice(0)

from prysm.mathops import np as pnp, fft as pfft
pfft._srcmodule = cpfft
pnp._srcmodule = cp
config.precision = 32

root = Path('~/proj/wfirst/lowfs/data2').expanduser()

mas2nm = 2.87

# an array w
disturbances = loadmat(root/'lowfs-disturbances.mat')['a'] * mas2nm

plt.style.use('bmh')

mode='hlc' # or spec or wfov
dd = DesignData.hlc_design(root)
sd = StellarDatabase.bijan_data(root)
td = ThroughputDatabase.bijan_data(root)
dd.seed_zernikes(range(2,12))
wt = cp.zeros(10)

In [None]:
plt.plot(disturbances[...,-1])

In [None]:
def run_capture_study(weights5, GAIN5, disturbances, Rtarg):
    # a lot of the arguments to this function are carryovers from the original purpose,
    # and the notebook has been re-written for a more general user.  Please forgive the naming.
    logs = {
        'startx': [], # mas
        'starty': [], # mas
        'endx': [],   # mas
        'endy': [],   # mas
        'trajx': [],  # mas
        'trajy': [],  # mas
        'fbkx': [],   # nm
        'fbky': [],   # nm
        'estx': [],   # nm
        'esty': [],   # nm
        'tickstoclose': [],
        'closed': []
    }
    # the return is a log, which has many variables for each of the
    # trials run.  Some are arrays (traj[ectory], fbk (feedback), est (estimates))
    # while others are scalars; start, end, ticks to close, and closed
    d = disturbances
    cam.em_gain = GAIN5
    mas2nm = 2.87
    # the real CGI controllers are redacted here, and these no-ops do nothing.  Replace with your own
    Z2ctl = control.NoOpFilter()
    z3ctl = control.NoOpFilter()
    cntr = 0
    for px, py in tqdm(zip(pt_x, pt_y), total=len(pt_x)):
        Z2ctl.x[:] = 0
        Z3ctl.x[:] = 0
        logs['startx'].append(px)
        logs['starty'].append(py)
        px = px * mas2nm
        py = py * mas2nm

        refx = px
        refy = py
        # 1 kHz, give up to 10s to close = 10,000 samples each
        trajx = []
        trajy = []
        estx = []
        esty = []
        fbkx = []
        fbky = []
        z2fbk = 0
        z3fbk = 0
        wt_lcl = wt.copy()
        for idx in range(10_000):
            _, acsx, acsy, rwa = d[idx]
            
            posx = refx + z2fbk + acsx + rwa
            posy = refy + z3fbk + acsy + rwa
            
            wt_lcl[0] = posx
            wt_lcl[1] = posy
            im = props.polychromatic(wvl, weights5, dd, wt_lcl)
            im = cam.expose(im)
            est = Rtarg.estimate(im).get()
            z2 = est[1]
            z3 = est[2]
            z2fbk = Z2ctl.update(-z2)  # implicit: setpt = 0; arg is (setpt) - (meas)
            z3fbk = Z3ctl.update(-z3)

            fbkx.append(z2fbk/mas2nm)
            fbky.append(z3fbk/mas2nm)
            
            estx.append(z2/mas2nm)
            esty.append(z3/mas2nm)
            
            trajx.append(posx/mas2nm)
            trajy.append(posy/mas2nm)
            
            # this is a capture study, if we get near the origin stop
            # and declare success, and if we get extremely far away give up
            posr = np.hypot(posx, posy)
            if posr < 5*mas2nm:
                break
            
            if posr > 500*mas2nm:
                break

        endx = posx / mas2nm
        endy = posy / mas2nm
        logs['endx'].append(endx)
        logs['endy'].append(endy)
        logs['trajx'].append(trajx)
        logs['trajy'].append(trajy)
        logs['estx'].append(estx)
        logs['esty'].append(estx)
        logs['fbkx'].append(fbkx)
        logs['fbky'].append(fbky)
        logs['tickstoclose'].append(idx)
        logs['closed'].append(posr < 5 * mas2nm)
        cntr += 1
    
    return logs

In [None]:
Gref = 20
Gtarg = 20
Cref = 'b3v'
Ctarg = 'b3v'
Vref = 2.25
Vtarg = 2.25
ref = (Vref, Cref, Gref)
targ = (Vtarg, Ctarg, Gtarg)
chops = cp.eye(10)*5
ref_z = cp.zeros(10)
gains = cp.diag(chops)
out = flt_chop_seq(wvl, dd, ref_z, chops, gains, ref=ref, targ=targ, cam=cam)
wtarg = out['wtarg']
Rtarg = out['R']

logs = run_capture_study(wtarg, Gtarg, disturbances, Rtarg)

d = {
    'logs': logs,
    'params': {
        'Gref': Gref,
        'Gtarg': Gtarg,
        'Cref': Cref,
        'Ctarg': Ctarg,
        'Vref': Vref,
        'Vtarg': Vtarg,
    }
}

matfn = f'G={Gref}|{Gtarg}-C={Cref}|{Ctarg}-V={Vref}|{Vtarg}-nocut.mat'
with open(matfn.replace('.mat', '.pkl'), 'wb') as f:
    pickle.dump(d, f, protocol=-1)

In [None]:
from matplotlib.colors import ListedColormap

# pts = []
lim = 90
fig, ax = plt.subplots()
cmap = ListedColormap(list(reversed(['#16DB54', '#DB1915'])))
ax.scatter(logs['startx'], logs['starty'], c=logs['closed'], cmap=cmap, s=30, vmin=0, vmax=1)
ax.set(aspect='equal',
       xlabel='X, mas ots', xlim=(-lim,lim),
       ylabel='Y, mas ots', ylim=(-lim,lim))
# fig.colorbar(lb)


In [None]:
lim = 90

fig, ax = plt.subplots()
nddata = logs['trajx']
tickstoclose = np.array([len(trial) for trial in nddata])
# sc = ax.scatter(logs['startx'], logs['starty'], c=np.array(logs['tickstoclose'])/400, cmap='inferno', vmin=0, vmax=1.5, s=105)
sc = ax.scatter(logs['startx'], logs['starty'], c=tickstoclose/1000, cmap='inferno', vmin=0, vmax=1.0)
fig.colorbar(sc, label='seconds to reach 5 mas radial')
ax.set(aspect='equal',
       xlabel='X, mas ots', xlim=(-lim,lim),
       ylabel='Y, mas ots', ylim=(-lim,lim))


In [None]:
fig, ax = plt.subplots()
lowidx = 0
highidx = 559
for tx, ty in zip(logs['trajx'][lowidx:highidx], logs['trajy'][lowidx:highidx]):
# for tx, ty in zip(logs['estx'], logs['esty']):
    ax.plot(tx, ty, lw=1)

lim = 80
    
ax.set(aspect='equal',
       xlabel='X, mas ots', xlim=(-lim,lim),
       ylabel='Y, mas ots', ylim=(-lim,lim))


In [None]:
t = disturbances[...,0] / mas2nm
fig, axs = plt.subplots(figsize=(10,5), nrows=2, sharey=True, sharex=True)
for lx, ly in zip(logs['trajx'], logs['trajy']):
    axs[0].plot(t[:len(lx)], lx, lw=0.5)
    axs[1].plot(t[:len(ly)], ly, lw=0.5)

axs[0].set(ylabel='Z2, mas ots')
axs[1].set(ylabel='Z3, mas ots', xlabel='time, sec')
# axs[1].set(ylim=(-25,25))