# CalibrationAdversary
An attempt to make a concrete example to demonstrate the issues with which the author is obsessed.

## Author
- **David W. Hogg** (NYU)

## To-do
- Write the peak finder (for the LFC lines) and the CCF peak finder (for the spectrum).
- Plot the velocity offset vs the PSF asymmetry.
- Make a version of the CCF peak finder for the LSF lines.
- Show that it improves the situation a lot?

## Bugs
- There should be a SNR setting and a noise level, maybe?
- Should I get the LSF modes in exactly the same place? I think it's possible.
- Binary mask construction is a hack. Binary mask *usage* is even more of a hack.

In [None]:
import numpy as np
import jax.numpy as jnp
from jax import vmap
import jaxopt
import pylab as plt

In [None]:
# todo: set JAX to work at high precision

In [None]:
# define functions that will make an asymmetric line-spread function

def gaussian_1d(xs, mean, sigma):
    return jnp.exp(-0.5 * (xs - mean) ** 2 / sigma ** 2) / jnp.sqrt(2. * jnp.pi * sigma ** 2)

def linespread_function(xs, sigma, A2, A3):
    assert A2 > 0
    assert A3 > 0
    norm = 1. / (1. + A2 + A3)
    return norm * (gaussian_1d(xs, 0., sigma) +
                   A2 * gaussian_1d(xs, sigma, 2. * sigma) +
                   A3 * gaussian_1d(xs, -sigma, 2. * sigma))

In [None]:
# create a list of LSF parameters that we will use for plotting
# Every LSF has an associated "asymmetry" value.
sigma = 1.5
A2list = jnp.arange(0.1, 1.2, 0.3)
A3list = 1.2 - A2list
asymmetrylist = (A2list - A3list) / (A2list + A3list)
labellist = ["${:+6.3f}$".format(a) for a in asymmetrylist]
print(A2list, labellist)

In [None]:
# show me some LSFs, labeled by asymmetry
tiny = 0.001
plotxs = jnp.arange(-10. + 0.5 * tiny, 10., tiny)
for A2, A3, ll in zip(A2list, A3list, labellist):
    plt.plot(plotxs, linespread_function(plotxs, sigma, A2, A3), label=ll)
plt.legend()
plt.xlabel("x offset")
plt.ylabel("(pixel-convolved) PSF value")
plt.title("Example PSFs with different asymmetries")
plt.savefig("psfs.png")

In [None]:
# define functions that make a laser frequency comb spectrum

lfc_linelist = jnp.arange(10., 1000., 25. * jnp.pi / 3.)

def lfc_spectrum(xs, sigma, A2, A3):
    def foo(x):
        return linespread_function(xs - x, sigma, A2, A3)
    return sum(map(foo, lfc_linelist))

In [None]:
# show me an LFC spectrum at one LSF
tiny = 1.0
lfc_xs = jnp.arange(0. + 0.5 * tiny, 1000., tiny)
lfc_ys = lfc_spectrum(lfc_xs, sigma, A2, A3)
plt.step(lfc_xs, lfc_ys)
plt.xlim(200, 400)
plt.axhline(0, alpha=0.25)
plt.xlabel("x")
plt.ylabel("flux")
plt.title("Example LFC spectrum (detail)")
plt.savefig("lfc.png")

In [None]:
# define functions that will make me a fake stellar spectrum

rng = np.random.default_rng(17)
star_linelist = jnp.array(1000. * rng.uniform(size=1000))
star_amplist = jnp.array(np.exp(rng.normal(size=1000) - 3.) ** 3)

def star_spectrum(xs, sigma, A2, A3):
    def foo(A, x):
        return A * linespread_function(xs - x, sigma, A2, A3)
    return jnp.exp(0. - sum(map(foo, star_amplist, star_linelist)))

In [None]:
# show me a stellar spectrum at one LSF
star_xs = 1.0 * lfc_xs
star_ys = star_spectrum(star_xs, sigma, A2, A3)
plt.step(star_xs, star_ys)
# plt.xlim(200, 400)
plt.axhline(1., alpha=0.25)
plt.xlabel("x")
plt.ylabel("flux")
plt.title("Example stellar spectrum")
plt.savefig("star.png")

In [None]:
# make a binary mask
# This is a hack.
idx = jnp.argsort(star_amplist)[-1:-16:-1]
mask_linelist = star_linelist[idx]
mask_amplist = star_amplist[idx]
mask_halfwidthlist = jnp.zeros_like(mask_linelist) + sigma
print(mask_linelist)

In [None]:
# define functions that perform cross-correlation in the traditional EPRV way
# Ish.

def nn_interp(x, xp, fp):
    """
    Nearest neighbor interpolation.

    Parameters:
        x: The x-coordinates where you want to interpolate.
        xp: The x-coordinates of the known data points.
        fp: The y-coordinates of the known data points.

    Returns:
        The interpolated y-values.

    Author:
        Google AI (ugh).
    """
    indices = jnp.argmin(jnp.abs(xp[:, None] - x[None, :]), axis=0)
    return fp[indices]

def make_fine_mask_grid():
    tiny = 0.1
    ds = jnp.arange(-1. + tiny, 1., tiny)
    xs = mask_linelist[:, None] + mask_halfwidthlist[:, None] * ds[None, :]
    ys = 0. * xs + mask_amplist[:, None]
    xs, ys = xs.flatten(), ys.flatten()
    idx = jnp.argsort(xs)
    return xs[idx], ys[idx]

def binary_ccf_one(dx, data_xs, data_ys, mask_xs, mask_ys):
    idys = nn_interp(mask_xs - dx, data_xs, data_ys)
    return jnp.nansum(mask_ys * idys)

binary_ccf = vmap(binary_ccf_one, in_axes=(0, None, None, None, None))

def fit_1d_gaussian(xs, ys):
    def objective(pars):
        A0, A1, mu, sigma = pars
        return jnp.sum((ys - A0 - A1 * gaussian_1d(xs, mu, sigma)) ** 2)
    pars0 = jnp.array((jnp.median(ys), ys[len(ys) // 2] - jnp.median(ys), jnp.mean(xs), 1.0))
    solver = jaxopt.LBFGS(fun=objective)
    pars, state = solver.run(pars0)
    return pars

def fit_binary_ccf(dxs, bccfs):
    return fit_1d_gaussian(dxs, bccfs)

In [None]:
# show me the binary mask on an example, plus best-fit dx
dxs = jnp.arange(-3.0, 3.1, 0.5)
mask_xs, mask_ys = make_fine_mask_grid()
bccfs = binary_ccf(dxs, star_xs, star_ys, mask_xs, mask_ys)
A0, A1, mu, sigma = fit_binary_ccf(dxs, bccfs)
plt.plot(dxs, bccfs)
plt.plot(dxs, A0 + A1 * gaussian_1d(dxs, mu, sigma))
plt.axvline(mu)
plt.xlabel("Doppler shift delta-x")
plt.ylabel("CCF value")
plt.title("Example of a Gaussian fit to a binary-mask CCF")
plt.savefig("ccf.png")

In [None]:
# todo: define functions to centroid and combine LFC lines

def fit_one_lfc_line(xs, ys):
    return fit_1d_gaussian(xs, ys)

def fit_lfc_lines(xs, ys):
    """
    ## bugs:
    - Doesn't deal with edge cases (idx close to or beyond the edge).
    """
    idx = jnp.argmin(jnp.abs(xs[None, :] - lfc_linelist[:, None]), axis=1)
    idxs = idx[:, None] + jnp.arange(-2, 3)[None, :]
    return vmap(fit_one_lfc_line)(xs[idxs], ys[idxs])[:, 2] - lfc_linelist

In [None]:
# show me some test of the LFC offsets
measured_lfc_dxs = fit_lfc_lines(lfc_xs, lfc_ys)
print(measured_lfc_dxs)
print(jnp.mean(measured_lfc_dxs))

In [None]:
# do a loop over LSFs, measuring everything we want
# WARNING: THERE IS A MEDIAN HERE WHICH IS PROBABLY WRONG.
A2list = jnp.arange(0.1, 1.05, 0.1)
A3list = 1.1 - A2list

lfc_xs = jnp.arange(0. + 0.5 * tiny, 1000., tiny)
star_xs = 1.0 * lfc_xs
mask_xs, mask_ys = make_fine_mask_grid()
dxs = jnp.arange(-3.0, 3.1, 0.5)

lfc_shiftlist = np.array(0. * A2list)
star_shiftlist = np.array(0. * A2list)
for i, (A2, A3) in enumerate(zip(A2list, A3list)):
    print(i, A2, A3)
    lfc_ys = lfc_spectrum(lfc_xs, sigma, A2, A3)
    star_ys = star_spectrum(star_xs, sigma, A2, A3)
    lfc_shiftlist[i] = jnp.median(fit_lfc_lines(lfc_xs, lfc_ys))
    star_shiftlist[i] = fit_binary_ccf(dxs, binary_ccf(dxs, star_xs, star_ys, mask_xs, mask_ys))[2]
print(lfc_shiftlist)
print(star_shiftlist)

In [None]:
# show that there is a bias!
asymlist = (A2list - A3list) / (A2list + A3list)
plt.plot(asymlist, lfc_shiftlist, "o", label="LFC")
plt.plot(asymlist, star_shiftlist, "o", label="star")
plt.legend()
plt.xlabel("PSF asymmetry")
plt.ylabel("apparent Doppler shift")
plt.title("Dependence of apparent Doppler shifts on PSF asymmetry")
plt.savefig("apparent_shifts.png")