In [None]:
%run notebook_setup

# Scalable Gaussian processes in PyMC3

PyMC3 has support for [Gaussian Processes (GPs)](https://docs.pymc.io/gp.html), but this implementation is too slow for many applications in time series astrophysics.
So *exoplanet* comes with an implementation of scalable GPs powered by [celerite](https://celerite.readthedocs.io/).
More information about the algorithm can be found in the [celerite docs](https://celerite.readthedocs.io/) and in the papers ([Paper 1](https://arxiv.org/abs/1703.09710) and [Paper 2](https://arxiv.org/abs/1801.10156)), but this tutorial will give a hands on demo of how to use celerite in PyMC3.

## A simple demo

Let's start with the quickstart demo from the [celerite docs](https://celerite.readthedocs.io/en/stable/tutorials/first/).
We'll fit the following simulated dataset using the sum of two :class:`exoplanet.gp.terms.SHOTerm` objects.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)

t = np.sort(np.append(
    np.random.uniform(0, 3.8, 57),
    np.random.uniform(5.5, 10, 68),
))  # The input coordinates must be sorted
yerr = np.random.uniform(0.08, 0.22, len(t))
y = 0.2 * (t-5) + np.sin(3*t + 0.1*(t-5)**2) + yerr * np.random.randn(len(t))

true_t = np.linspace(0, 10, 5000)
true_y = 0.2 * (true_t-5) + np.sin(3*true_t + 0.1*(true_t-5)**2)

plt.plot(true_t, true_y, "k", lw=1.5, alpha=0.3)
plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0)
plt.xlabel("t")
plt.ylabel("y")
plt.xlim(0, 10)
plt.ylim(-2.5, 2.5);

In [None]:
import pymc3 as pm
import theano.tensor as tt
from exoplanet.gp import terms, GP
from exoplanet.sampling import TuningSchedule

schedule = TuningSchedule()

with pm.Model() as model:
    
    logS0_1 = pm.Normal("logS01", mu=0.0, sd=15.0, testval=np.log(np.var(y)))
    logw0_1 = pm.Normal("logw01", mu=0.0, sd=15.0, testval=np.log(3.0))
    logS0_2 = pm.Normal("logS02", mu=0.0, sd=15.0, testval=np.log(np.var(y)))
    logw0_2 = pm.Normal("logw02", mu=0.0, sd=15.0, testval=np.log(3.0))
    logQ_2 = pm.Normal("logQ2", mu=0.0, sd=15.0, testval=0)
    
    # Set up the kernel an GP
    kernel = terms.SHOTerm(log_S0=logS0_1, log_w0=logw0_1, Q=1.0/np.sqrt(2))
    kernel += terms.SHOTerm(log_S0=logS0_2, log_w0=logw0_2, log_Q=logQ_2)
    gp = GP(kernel, t, yerr**2)
    
    # Add a custom "potential" (log probability function) with the GP likelihood
    pm.Potential("gp", gp.log_likelihood(y))
    
    # Run the sampler
    map_soln = pm.find_MAP(start=model.test_point)
    burnin = schedule.tune(tune=1000, chains=2, start=map_soln)
    trace = schedule.sample(draws=2000, chains=2)

In [None]:
import corner
samples = pm.trace_to_dataframe(trace)
corner.corner(samples);

In [None]:
from exoplanet.utils import eval_in_model

plt.plot(true_t, true_y, "k", lw=1.5, alpha=0.3)
plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0)

with model:
    pred = gp.predict(true_t)
    for i in range(50):
        chain_idx = np.random.randint(len(trace.chains))
        sample_idx = np.random.randint(len(trace))
        chain = trace._straces[chain_idx]
        sample = chain[sample_idx]
        plt.plot(true_t, eval_in_model(pred, sample), color="C1", alpha=0.1)

In [None]:
pm.summary(trace)

In [None]:
from astropy.io import fits

url = "https://archive.stsci.edu/missions/kepler/lightcurves/0058/005809890/kplr005809890-2012179063303_llc.fits"
with fits.open(url) as hdus:
    data = hdus[1].data
    
x = data["TIME"]
y = data["PDCSAP_FLUX"]
yerr = data["PDCSAP_FLUX_ERR"]
m = (data["SAP_QUALITY"] == 0) & np.isfinite(x) & np.isfinite(y)

x = np.ascontiguousarray(x[m], dtype=np.float64)
y = np.ascontiguousarray(y[m], dtype=np.float64)
yerr = np.ascontiguousarray(yerr[m], dtype=np.float64)
mu = np.mean(y)
y = (y / mu - 1) * 1e3
yerr = yerr * 1e3 / mu

plt.plot(x, y, "k");

In [None]:
schedule = TuningSchedule()

with pm.Model() as model:
    
    logs2 = pm.Normal("logs2", mu=0.0, sd=15.0, testval=np.log(np.median(yerr**2)))
    logamp = pm.Normal("logamp", mu=0.0, sd=15.0, testval=np.log(np.var(y)))
    logQ0 = pm.Normal("logQ0", mu=0.0, sd=5.0, testval=3.0)
    logdeltaQ = pm.Normal("logdeltaQ", mu=0.0, sd=5.0)
    logperiod = pm.Normal("logperiod", mu=np.log(35.0), sd=1.0)
    mix = pm.Uniform("mix", lower=0, upper=1, testval=0.5)
    
    kernel = terms.RotationTerm(
        log_amp=logamp,
        log_period=logperiod,
        log_Q0=logQ0,
        log_deltaQ=logdeltaQ,
        mix=mix
    )
    
    gp = GP(kernel, x, yerr**2 + tt.exp(logs2), J=4)
    
    # Add a custom "potential" (log probability function) with the GP likelihood
    pm.Potential("gp", gp.log_likelihood(y))
    pm.Deterministic("pred", gp.predict())
    
    # Run the sampler
    map_soln = pm.find_MAP(start=model.test_point)
    burnin = schedule.tune(tune=1000, chains=2, start=map_soln)
    trace = schedule.sample(draws=2000, chains=2)

In [None]:
samples = pm.trace_to_dataframe(trace, varnames=["mix", "logperiod", "logQ0", "logdeltaQ", "logamp", "logs2"])
corner.corner(samples);

In [None]:
plt.hist(np.exp(trace["logperiod"]), np.linspace(20, 40, 25), color="k", histtype="step")
plt.xlabel("period [days]")
plt.yticks([]);

In [None]:
pm.summary(trace, varnames=["mix", "logperiod", "logQ0", "logdeltaQ", "logamp", "logs2"])

In [None]:
plt.plot(x, y, ".k")
plt.plot(x, np.median(trace["pred"], axis=0))

In [None]:
with model:
    K = kernel.value(x - x[0])
    plt.plot(x - x[0], eval_in_model(K, map_soln))