# The value of the multi-exposure visit

## author:
- **David W. Hogg** (Flatiron)

## projects:
- Figure out how splitting an exposure into sub-exposures helps or hurts.
- Look at information gain per unit time on the mean function and (parameters of) the variance function.
- Use the magic of `jax` to take derivatives of, or optimize, strategies.

## notes:
- Using SI units, of course; not using strict units checking, but maybe should be?
- Equations have not been carefully checked. There could be factors of 2 or pi, and maybe also just bugs. I have only done trivial checks.
- I need to check that derivatives can be taken and visualized.
- See https://claude.ai/share/4b0d0ccb-bfe6-44ee-b1a9-677bf4e4b3de for relevant discussion.

In [None]:
import jax
jax.config.update("jax_enable_x64", True)
from jax import vmap
import jax.numpy as jnp
import pylab as plt

In [None]:
# set up the assumptions of the method
READ_TIME = 30. # seconds; optimistic
SLEW_TIME = 120. # seconds; very optimistic
PMODE_AMPLITUDE = 3.0 # m^2 s^{-2} (variance)
PMODE_NUMAX = 1. / (5.2 * 60.) # Hz; making this up
PMODE_QEFF = 8. # making this up
WHITE_NOISE_AMPLITUDE = 40. # m^2 s^{-1} (variance times time)
READ_NOISE_TIME = 10. # seconds; integration time at which white noise equals read noise
SATURATION_TIME = 1200. # s; exposure time at which saturation starts to matter

In [None]:
# make a frequency grid on which all integrals will be done
day = 86400 # seconds
DF = 1. / (5. * day) # in Hz; fine enough to resolve the individual modes maybe
fmax = 0.01 # Hz
FS = jnp.arange(DF, fmax, DF)
print(len(FS))

In [None]:
def oned_gaussian(xs, mu, V):
    return jnp.exp(-0.5 * (xs - mu) ** 2 / V) / jnp.sqrt(2. * jnp.pi * V)

def pmodes(fs, A = PMODE_AMPLITUDE):
    """
    - `A` has units of RV variance times time.
    """
    numax = PMODE_NUMAX
    nuw = numax / PMODE_QEFF
    return A * oned_gaussian(fs, numax, nuw * nuw)

In [None]:
# BUG: I just typed these from memory and they are probably wrong!

def one_exposure_variance(Texp):
    power = pmodes(FS)
    return jnp.sum(DF * power
                * (jnp.sin(jnp.pi * FS * Texp) / (jnp.pi * FS * Texp)) ** 2) \
           + WHITE_NOISE_AMPLITUDE / Texp \
           + WHITE_NOISE_AMPLITUDE * READ_NOISE_TIME / Texp ** 2

def two_exposure_covariance(Texp1, Texp2, DeltaT):
    """
    ## notes:
    - Assumes (requires) `DeltaT > Texp1`.
    
    ## bugs:
    - Because this is `vmap`ed below, it can't `assert` if `DeltaT < Texp1`.
    """
    return jnp.sum(DF * pmodes(FS)
                * (jnp.sin(jnp.pi * FS * Texp1) / (jnp.pi * FS * Texp1))
                * (jnp.sin(jnp.pi * FS * Texp2) / (jnp.pi * FS * Texp2))
                * jnp.cos(2. * jnp.pi * FS * DeltaT))

In [None]:
def exptime(T, N):
    T1 = T - max(READ_TIME, SLEW_TIME)
    return (T1 + READ_TIME) / N - READ_TIME

def one_visit_variance(T, N):
    Texp = exptime(T, N)
    DeltaT = Texp + READ_TIME
    covar = jnp.diag(jnp.zeros(N) + one_exposure_variance(Texp))
    for dd in range(1,N):
        DeltaT = dd * (Texp + READ_TIME)
        covarij = two_exposure_covariance(Texp, Texp, DeltaT)
        for i in range(N):
            j = i + dd
            if j < N:
                covar = covar.at[i, j].set(covarij)
                covar = covar.at[j, i].set(covarij)
    return covar

In [None]:
def one_visit_information(T, N):
    return jnp.dot(jnp.ones(N), jnp.linalg.solve(one_visit_variance(T, N), jnp.ones(N)))

In [None]:
Ns = [1, 2, 3, 4, 5, 7, 10]
times = jnp.exp(jnp.arange(jnp.log(SLEW_TIME + 50.), jnp.log(7300.), 0.01))
allinfos = ()
for N in Ns:
    vectorized_visit_information = vmap(one_visit_information, (0, None), 0)
    infos = vectorized_visit_information(times, N)
    vectorized_exptime = vmap(exptime, (0, None), 0)
    exptimes = vectorized_exptime(times, N)
    infos = infos.at[jnp.where(exptimes < 1.)].set(jnp.nan) # don't trust < 20 sec exposures
    infos = infos.at[jnp.where(exptimes > SATURATION_TIME)].set(jnp.nan) # don't trust < 20 sec exposures
    allinfos += (infos, )

In [None]:
shorttimes = times - max(SLEW_TIME, READ_TIME)
for i, N in enumerate(Ns):
    sigmas = 1. / jnp.sqrt(allinfos[i])
    plt.plot(shorttimes, sigmas, label=N)
plt.loglog()
plt.legend()
plt.axvline(1. / PMODE_NUMAX, color="k", lw=0.5, alpha=0.5)
white_noise_only = jnp.sqrt(WHITE_NOISE_AMPLITUDE / shorttimes)
plt.plot(shorttimes, white_noise_only, "k-", lw=0.5, alpha=0.5)
plt.xlim(jnp.min(shorttimes), 7200.)
plt.ylim(jnp.min(white_noise_only), 3. * jnp.max(white_noise_only))
plt.xlabel("total time (excluding final read or slew) (s)")
plt.ylabel("expected uncertainty on the mean (m/s)")
plt.savefig("sigma_v_time.png")

In [None]:
maxrate = 0.
for i, N in enumerate(Ns):
    info_rates = allinfos[i] / times
    if jnp.nanmax(info_rates) > maxrate:
        maxrate = jnp.nanmax(info_rates)
    plt.plot(times, info_rates, label=N)
plt.legend()
plt.semilogy()
plt.axhline(1. / WHITE_NOISE_AMPLITUDE, color="k", lw=0.5, alpha=0.5)
plt.axvline(1. / PMODE_NUMAX + max(READ_TIME, SLEW_TIME), color="k", lw=0.5, alpha=0.5)
plt.xlim(0., 7200.)
plt.ylim(0.11 / WHITE_NOISE_AMPLITUDE, 1.1 / WHITE_NOISE_AMPLITUDE)
plt.xlabel("total wall-clock time (including final read or slew) (s)")
plt.ylabel("information per unit wall-clock time (s / m^2)")
plt.savefig("information_rate_v_time.png")