# Fitting Quasar time delay

In [None]:
import matplotlib.pyplot as plt

plt.plot()  # Required to reset the rcParams for some reason
plt.style.use(["default", "./araa-gps.mplstyle"])
plt.close()

In [None]:
import warnings
from functools import partial

import arviz as az
import corner
import jax
import jax.numpy as jnp
import jaxopt
import numpyro
import tinygp
import matplotlib as mpl
from astropy.table import Table
from numpyro import distributions as dist
from numpyro import infer
from tinygp import GaussianProcess, kernels, transforms

from paths import data, figures

Start by setting some configuration variables:

In [None]:
warnings.filterwarnings("ignore", category=FutureWarning)
numpyro.set_host_device_count(2)

Next, load the data that was digitized from the manuscript on ADS: https://ui.adsabs.harvard.edu/abs/1989A%26A...215....1V/abstract

In [None]:
data = Table.read(data / "quasar.csv")

Set up the custom GP kernel.
More discussion of this kernel can be found on the tinygp docs here: https://tinygp.readthedocs.io/en/stable/tutorials/quasisep-custom.html#multivariate-quasiseparable-kernels
The basic idea is that each input coordinate is a tuple `(t, band)` where `t` is the time and `band` is the band index (as an integer).
We use the band index to select the kernel amplitude for that data point.

In [None]:
@tinygp.helpers.dataclass
class Multiband(kernels.quasisep.Wrapper):
    amplitudes: jnp.ndarray

    def coord_to_sortable(self, X):
        return X[0]

    def observation_model(self, X):
        return self.amplitudes[X[1]] * self.kernel.observation_model(X[0])

Now we also define functions for the mean and time delay operations.
These also depend on the `band` index, as above.

In [None]:
def time_delay_transform(lag, X):
    t, band = X
    return t - lag * band


def mean_func(means, X):
    t, band = X
    return means[band]

Finally, this is the usual tinyp inference setup (see the [tinygp docs](https://tinygp.readthedocs.io) for more details):

In [None]:
N = len(data)
X = jnp.concatenate((data["jd"].value, data["jd"].value)), jnp.concatenate(
    (jnp.zeros(N, dtype=int), jnp.ones(N, dtype=int))
)
y = jnp.concatenate((data["a_mag"].value, data["b_mag"].value))
diag = jnp.concatenate((data["a_mag_err"].value, data["b_mag_err"].value)) ** 2


def build_gp(params, X, diag):
    band = X[1]
    t = time_delay_transform(params["lag"], X)
    inds = jnp.argsort(t)
    kernel = Multiband(
        amplitudes=params["amps"],
        kernel=kernels.quasisep.Matern32(jnp.exp(params["log_ell"])),
    )
    mean = partial(mean_func, params["means"])
    return (
        GaussianProcess(kernel, (t[inds], band[inds]), diag=diag[inds], mean=mean),
        inds,
    )


@jax.jit
def loss(params):
    gp, inds = build_gp(params, X, diag)
    return -gp.log_probability(y[inds])


true_params = {
    "lag": 420.0,
    "log_ell": jnp.log(100.0),
    "amps": jnp.array([0.08, 0.12]),
    "means": jnp.array([17.43, 17.53]),
}
gp, inds = build_gp(true_params, X, diag)
y = jnp.empty_like(y)
y = y.at[inds].set(gp.sample(jax.random.PRNGKey(10)))
plt.plot(X[0][X[1] == 0], y[X[1] == 0], ".", label="a")
plt.plot(X[0][X[1] == 1], y[X[1] == 1], ".", label="b")

To find an initial guess for the time lag, do a set of optimizations starting from a grid of potential lags and select the minimum loss result.
This is similar to the approach used by: https://ui.adsabs.harvard.edu/abs/1992ApJ...385..404P/abstract

In [None]:
opt = jaxopt.ScipyMinimize(fun=loss)

init = dict(true_params)
minimum = loss(init), init
lags = []
vals = []
for lag in jnp.linspace(0, 1000, 100):
    init["lag"] = lag
    soln = opt.run(init)
    lags.append(soln.params["lag"])
    vals.append(soln.state.fun_val)
    if soln.state.fun_val < minimum[0]:
        minimum = soln.state.fun_val, soln.params
init = minimum[1]

plt.plot(lags, vals, ".", alpha=0.2)
plt.xlabel("lag [days]")
plt.ylabel("loss minimized over other parameters");

Based on the best fit lag, define a grid of times where we'll evaluate the GP's predictive distribution.

In [None]:
t_lagged = X[0] - minimum[1]["lag"] * X[1]
t_grid = jnp.linspace(t_lagged.min() - 200, t_lagged.max() + 200, 1000)

Now, set up the model in NumPyro and run MCMC:

In [None]:
def model(X, diag, y):
    lag = numpyro.sample("lag", dist.Uniform(0.0, 1000.0))
    log_ell = numpyro.sample("log_ell", dist.Uniform(jnp.log(10), jnp.log(1000.0)))
    amps = numpyro.sample("amps", dist.Uniform(-5.0, 5.0), sample_shape=(2,))
    mean_a = numpyro.sample("mean_a", dist.Uniform(17.0, 18.0))
    delta_mean = numpyro.sample("delta_mean", dist.Uniform(-2.0, 2.0))
    means = jnp.stack((mean_a, mean_a + delta_mean))

    params = {
        "lag": lag,
        "log_ell": log_ell,
        "amps": amps,
        "means": means,
    }
    gp, inds = build_gp(params, X, diag)
    numpyro.sample("y", gp.numpyro_dist(), obs=y[inds])

    numpyro.deterministic(
        "pred_a",
        gp.condition(y[inds], (t_grid, jnp.zeros_like(t_grid, dtype=int))).gp.loc,
    )
    numpyro.deterministic(
        "pred_b",
        gp.condition(y[inds], (t_grid, jnp.ones_like(t_grid, dtype=int))).gp.loc,
    )


init_params = dict(minimum[1])
init_params["mean_a"] = init_params["means"][0]
init_params["delta_mean"] = init_params["means"][1] - init_params["means"][0]
sampler = infer.MCMC(
    infer.NUTS(
        model,
        dense_mass=True,
        target_accept_prob=0.9,
        init_strategy=infer.init_to_value(values=init_params),
    ),
    num_warmup=1000,
    num_samples=5000,
    num_chains=2,
    progress_bar=True,
)
%time sampler.run(jax.random.PRNGKey(12), X, diag, y)

Check convergence using ArviZ. The `ess_bulk` gives an estimate of the effective sample size and `r_hat` should be close to 1.

In [None]:
inf_data = az.from_numpyro(sampler)
az.summary(inf_data, var_names=["lag", "delta_mean"])

Plot the inferred parameters and overplot the lag measured by https://ui.adsabs.harvard.edu/abs/1992ApJ...385..404P/abstract

In [None]:
with mpl.rc_context({"font.size": 14}):
    fig = corner.corner(
        inf_data,
        var_names=["lag", "delta_mean"],
        labels=["time delay [days]", "mean magnitude offset"],
        truths=[true_params["lag"], jnp.diff(true_params["means"])[0]],
    )
    fig.savefig(figures / "quasar_posteriors.pdf", bbox_inches="tight")

Overplot the predictive distributions on the shifted data:

In [None]:
samples = sampler.get_samples()
lag = jnp.median(samples["lag"])
pred_a = samples["pred_a"]
pred_b = samples["pred_b"]
inds = jax.random.randint(jax.random.PRNGKey(134), (12,), 0, len(pred_a))

offset = 0.3

plt.figure(figsize=(5, 3.5))
plt.plot(t_grid + lag, pred_a[inds, :].T, c="C0", alpha=0.3, lw=0.5)
plt.plot(t_grid + lag, pred_b[inds, :].T + offset, c="C1", alpha=0.3, lw=0.5)

a = X[1] == 0
plt.errorbar(
    X[0][a] + lag,
    y[a],
    yerr=jnp.sqrt(diag[a]),
    fmt="oC0",
    label="A",
    markersize=4,
    linewidth=1,
)
b = X[1] == 1
plt.errorbar(
    X[0][b],
    y[b] + offset,
    yerr=jnp.sqrt(diag[b]),
    fmt="oC1",
    label="B",
    markerfacecolor="white",
    markersize=4,
    linewidth=1,
)
plt.ylim(plt.ylim()[::-1])
plt.xlabel(f"time [days; A + {lag:.0f} days]")
plt.ylabel(f"magnitude [B + {offset}]")
plt.xlim(t_grid.min() + lag, t_grid.max() + lag)
plt.legend(loc="lower right", fontsize=10)
plt.savefig(figures / "quasar.pdf", bbox_inches="tight")