In [None]:
# Required to reset the rcParams for some reason
import matplotlib.pyplot as plt

plt.plot()
plt.style.use(["default", "./araa-gps.mplstyle"])
plt.close()

In [None]:
import warnings
from functools import partial

import corner
import emcee
import jax
import jax.numpy as jnp
import jaxopt
import matplotlib.pyplot as plt
import numpy as np
from jax.flatten_util import ravel_pytree
from statsmodels.datasets import co2
from tinygp import GaussianProcess, kernels

from paths import data, figures

In [None]:
warnings.filterwarnings("ignore", category=FutureWarning)
jax.config.update("jax_enable_x64", True)

figsize = (4, 3)

In [None]:
fig = plt.figure(figsize=figsize)
plt.plot([-0.5, 0, 0, 1, 1, 1.5], [0, 0, 1, 1, 0, 0], "k")
plt.xlim(-0.1, 1.1)
plt.xlabel(r"$\theta$")
plt.ylabel(r"$p(\theta)$")
fig.savefig(figures / "workflow-hyperprior.pdf", dpi=300, bbox_inches="tight")

In [None]:
data = co2.load_pandas().data
t = 2000 + (np.array(data.index.to_julian_date()) - 2451545.0) / 365.25
y = np.array(data.co2)
m = np.isfinite(t) & np.isfinite(y) & (t < 1996)
t, y = t[m][::4], y[m][::4]

x = np.linspace(max(t), 2025, 2000)

fig = plt.figure(figsize=figsize)
plt.plot(t, y, ".k")
plt.xlim(t.min(), t.max())
plt.xlabel("year")
plt.ylabel("CO$_2$ in ppm")

fig.savefig(figures / "workflow-data.pdf", dpi=300, bbox_inches="tight")

In [None]:
def build_gp(theta, X):

    # We want most of our parameters to be positive so we take the `exp` here
    # Note that we're using `jnp` instead of `np`
    amps = jnp.exp(theta["log_amps"])
    scales = jnp.exp(theta["log_scales"])

    # Construct the kernel by multiplying and adding `Kernel` objects
    k1 = amps[0] * kernels.ExpSquared(scales[0])
    k2 = (
        amps[1]
        * kernels.ExpSquared(scales[1])
        * kernels.ExpSineSquared(
            scale=jnp.exp(theta["log_period"]),
            gamma=jnp.exp(theta["log_gamma"]),
        )
    )
    k3 = amps[2] * kernels.RationalQuadratic(
        alpha=jnp.exp(theta["log_alpha"]), scale=scales[2]
    )
    k4 = amps[3] * kernels.ExpSquared(scales[3])
    kernel = k1 + k2 + k3 + k4

    return GaussianProcess(
        kernel, X, diag=jnp.exp(theta["log_diag"]), mean=theta["mean"]
    )


def neg_log_likelihood(theta, X, y):
    gp = build_gp(theta, X)
    return -gp.log_probability(y)


theta_init = {
    "mean": np.float64(340.0),
    "log_diag": np.log(0.19),
    "log_amps": np.log([66.0, 2.4, 0.66, 0.18]),
    "log_scales": np.log([67.0, 90.0, 0.78, 1.6]),
    "log_period": np.float64(0.0),
    "log_gamma": np.log(4.3),
    "log_alpha": np.log(1.2),
}

# `jax` can be used to differentiate functions, and also note that we're calling
# `jax.jit` for the best performance.
obj = jax.jit(jax.value_and_grad(neg_log_likelihood))

print(f"Initial negative log likelihood: {obj(theta_init, t, y)[0]}")
print(
    f"Gradient of the negative log likelihood, wrt the parameters:\n{obj(theta_init, t, y)[1]}"
)

solver = jaxopt.ScipyMinimize(fun=neg_log_likelihood)
soln = solver.run(theta_init, X=t, y=y)
print(f"Final negative log likelihood: {soln.state.fun_val}")

In [None]:
kernel = build_gp(soln.params, t).kernel
fig = plt.figure(figsize=figsize)
tau = np.linspace(0, 20, 1000)
k = kernel(tau, np.zeros(1))[:, 0]
plt.plot(tau, k, "k")
plt.xlim(tau.min(), tau.max())
plt.ylabel(r"$k(\tau)$")
plt.xlabel(r"$\tau$ [year]")

fig.savefig(figures / "workflow-kernel.pdf", dpi=300, bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(min(figsize), min(figsize)))
plt.imshow(kernel(t, t), origin="upper", cmap="Greys")
plt.xticks([])
plt.yticks([])
fig.savefig(figures / "workflow-covariance.pdf", dpi=300, bbox_inches="tight")

In [None]:
t_ = np.linspace(1960, 1975, 1000)
gp = build_gp(soln.params, t_)
samples = gp.sample(jax.random.PRNGKey(1), (5,))
fig = plt.figure(figsize=figsize)
plt.plot(
    t_,
    (
        samples
        - (np.median(samples, axis=-1) - 300 - 5 * np.arange(len(samples)))[:, None]
    ).T,
    lw=1,
)
plt.xlim(t_.min(), t_.max())
plt.xlabel("year")
plt.ylabel("CO$_2$ in ppm")
fig.savefig(figures / "workflow-prior-samples.pdf", dpi=300, bbox_inches="tight")

In [None]:
gp = build_gp(soln.params, t)
cond_gp = gp.condition(y, x).gp
mu, var = cond_gp.loc, cond_gp.variance

fig = plt.figure(figsize=figsize)
plt.plot(t, y, ".k", ms=3)
plt.fill_between(
    x,
    mu + 2 * np.sqrt(var),
    mu - 2 * np.sqrt(var),
    color="C0",
    alpha=0.5,
    edgecolor="none",
)
plt.plot(x, mu, color="C0", lw=1)

plt.xlim(t.min(), 2025)
plt.xlabel("year")
plt.ylabel("CO$_2$ in ppm")
fig.savefig(figures / "workflow-pred.pdf", dpi=300, bbox_inches="tight")

In [None]:
np_random = np.random.default_rng(0)
flat_params, unravel_fn = ravel_pytree(soln.params)


@jax.jit
def log_prob(flat):
    return -neg_log_likelihood(unravel_fn(flat), t, y)


ndim = len(flat_params)
nwalkers = 36
init_params = np_random.normal(
    flat_params, np.full_like(flat_params, 1e-4), size=(nwalkers, ndim)
)
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)

In [None]:
print("Running burn-in")
p0, _, _ = sampler.run_mcmc(init_params, 200)

print("Running production chain")
sampler.run_mcmc(p0, 200);

In [None]:
@jax.jit
def sample(key, flat):
    return build_gp(unravel_fn(flat), t).condition(y, x).gp.sample(key)


fig = plt.figure(figsize=figsize)

for i in range(50):
    # Choose a random walker and step.
    w = np_random.integers(sampler.chain.shape[0])
    n = np_random.integers(sampler.chain.shape[1])
    y_ = sample(jax.random.PRNGKey(i), sampler.chain[w, n])

    # Plot a single sample.
    plt.plot(x, y_, "C0", alpha=0.1)

plt.plot(t, y, ".k", ms=3)
plt.xlim(t.min(), 2025)
plt.xlabel("year")
plt.ylabel("CO$_2$ in ppm")
fig.savefig(figures / "workflow-posterior.pdf", dpi=300, bbox_inches="tight")

In [None]:
fig = corner.corner(
    sampler.get_chain(flat=True)[:, (0, 5, -1)],
    plot_datapoints=False,
    smooth=1,
    smooth1d=1,
    labels=[r"$\log \alpha$", r"$\log \sigma$", r"$\mu$"],
)
fig.savefig(figures / "workflow-corner.pdf", dpi=300, bbox_inches="tight")