# multi-exposure visit

## author:
- **David W. Hogg**

## project:
- Figure out how splitting an exposure into sub-exposures helps or hurts.

## notes:
- Using SI units, of course.
- Doesn't deal with readnoise correctly. It's a bit of an issue.

In [None]:
import jax
jax.config.update("jax_enable_x64", True)
from jax import vmap
import jax.numpy as jnp
import pylab as plt

In [None]:
# set up the assumptions of the method
READ_TIME = 45 # seconds
SLEW_TIME = 120 # seconds
PMODE_AMPLITUDE = 1.0e4 # some units I don't understand
PMODE_NUMAX = 1. / (5.2 * 60.) # Hz; making this up
PMODE_QEFF = 10. # making this up
WHITE_NOISE_AMPLITUDE = 0.5e3 # some units I don't understand

In [None]:
# make a frequency grid on which all integrals will be done
day = 86400 # seconds
DF = 1. / (5. * day) # in Hz; fine enough to resolve the individual modes maybe
fmax = 1.0 # Hz
FS = jnp.arange(DF, fmax, DF)
print(len(FS))

In [None]:
def white_noise(fs, S_0 = WHITE_NOISE_AMPLITUDE):
    """
    - `S_0` has units of RV variance in one second, maybe?
    """
    return S_0 * jnp.ones_like(fs)

def pmodes(fs, A = PMODE_AMPLITUDE):
    """
    - `A` has units of RV variance, maybe? No it is going to have to do with `df`.
    """
    numax = PMODE_NUMAX
    nuw = numax / PMODE_QEFF
    return A * jnp.exp(-0.5 * (fs - numax) ** 2 / nuw ** 2)

In [None]:
# BUG: I just typed these from memory and they are probably wrong!

def one_exposure_variance(Texp):
    power = white_noise(FS) + pmodes(FS)
    return jnp.sum(DF * power
                   * (jnp.sin(jnp.pi * FS * Texp) / (jnp.pi * FS * Texp)) ** 2)

def two_exposure_covariance(Texp1, Texp2, DeltaT):
    return jnp.sum(DF * pmodes(FS)
                   * (jnp.sin(jnp.pi * FS * Texp1) / (jnp.pi * FS * Texp1))
                   * (jnp.sin(jnp.pi * FS * Texp2) / (jnp.pi * FS * Texp2))
                   * jnp.cos(2. * jnp.pi * FS * DeltaT))

In [None]:
def exptime(T, N):
    T1 = T - max(READ_TIME, SLEW_TIME)
    return T1 / N - READ_TIME

def one_visit_variance(T, N):
    Texp = exptime(T, N)
    DeltaT = Texp + READ_TIME
    covar = jnp.diag(jnp.zeros(N) + one_exposure_variance(Texp))
    for dd in range(1,N):
        DeltaT = dd * (Texp + READ_TIME)
        covarij = two_exposure_covariance(Texp, Texp, DeltaT)
        for i in range(N):
            j = i + dd
            if j < N:
                covar = covar.at[i, j].set(covarij)
                covar = covar.at[j, i].set(covarij)
    return covar

In [None]:
def one_visit_information(T, N):
    return jnp.dot(jnp.ones(N), jnp.linalg.solve(one_visit_variance(T, N), jnp.ones(N)))

In [None]:
times = jnp.arange(200., 3600., 10.)
for N in [1, 2, 3, 4, 5]:
    vectorized_visit_information = vmap(one_visit_information, (0, None), 0)
    infos = vectorized_visit_information(times, N)
    vectorized_exptime = vmap(exptime, (0, None), 0)
    infos = infos.at[jnp.where(vectorized_exptime(times, N) < 10.)].set(0.) # don't trust < 20 sec exposures
    sigmas = 1. / jnp.sqrt(infos)
    plt.plot(times - max(SLEW_TIME, READ_TIME), sigmas, label=N)
plt.loglog()
plt.legend()