In [None]:
try:
    import tinygp
except ImportError:
    !pip install -q tinygp
    
try:
    import numpyro
except ImportError:
    !pip uninstall -y jax jaxlib
    !pip install -q numpyro jax jaxlib
    
try:
    import arviz
except ImportError:
    !pip install arviz

# Alternative likelihoods

In [None]:
import numpy as np
import matplotlib.pyplot as plt

random = np.random.default_rng(203618)
x = np.linspace(-3, 3, 20)
true_log_rate = 2 * np.cos(2 * x)
y = random.poisson(np.exp(true_log_rate))
plt.plot(x, y, ".k", label="data")
plt.plot(x, np.exp(true_log_rate), "C1", label="true rate")
plt.legend(loc=2)
plt.xlabel("x")
_ = plt.ylabel("counts")

In [None]:
from jax.config import config

config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

from tinygp import kernels, GaussianProcess


def model(x, y=None):
    mean = numpyro.sample("mean", dist.Normal(0.0, 2.0))
    sigma = numpyro.sample("sigma", dist.HalfNormal(3.0))
    rho = numpyro.sample("rho", dist.HalfNormal(10.0))
    kernel = sigma ** 2 * kernels.Matern32(rho)
    gp = GaussianProcess(kernel, x, diag=1e-5, mean=mean)
    log_rate = numpyro.sample(
        "log_rate",
        dist.MultivariateNormal(loc=gp.loc, scale_tril=gp.scale_tril),
    )
    numpyro.sample("obs", dist.Poisson(jnp.exp(log_rate)), obs=y)


nuts_kernel = NUTS(model, target_accept_prob=0.9)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=False,
)
rng_key = random.PRNGKey(55873)

In [None]:
%%time
mcmc.run(rng_key, x, y=y)
samples = mcmc.get_samples()
_ = samples["log_rate"].block_until_ready()

In [None]:
import arviz as az

data = az.from_numpyro(mcmc)
az.summary(
    data, var_names=[v for v in data.posterior.data_vars if v != "log_rate"]
)

In [None]:
q = np.percentile(samples["log_rate"], [5, 25, 50, 75, 95], axis=0)
plt.plot(x, y, ".k", label="data")
plt.plot(x, np.exp(true_log_rate), color="C1", label="true rate")
plt.plot(x, np.exp(q[2]), color="C0", label="inferred rate")
plt.fill_between(x, np.exp(q[0]), np.exp(q[-1]), alpha=0.3, lw=0, color="C0")
plt.fill_between(x, np.exp(q[1]), np.exp(q[-2]), alpha=0.3, lw=0, color="C0")
plt.legend(loc=2)
plt.xlabel("x")
_ = plt.ylabel("counts")