Based on https://bambinos.github.io/bambi/notebooks/hsgp_2d.html

In [None]:
import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm

In [None]:
x1 = np.linspace(0, 10, 12)
x2 = np.linspace(0, 10, 12)
xx, yy = np.meshgrid(x1, x2)
X = np.column_stack([xx.flatten(), yy.flatten()])
X.shape

In [None]:
rng = np.random.default_rng(1234)

ell = 2
cov = 1.2 * pm.gp.cov.ExpQuad(2, ls=ell)
K = cov(X).eval()
mu = np.zeros(X.shape[0])
print(mu.shape, K.shape)

f = rng.multivariate_normal(mu, K)

fig, ax = plt.subplots()
ax.scatter(xx, yy, c=f, s=900, marker="s");

In [None]:
data = pd.DataFrame(
    {
        "x": np.tile(xx.flatten(), 1),
        "y": np.tile(yy.flatten(), 1),
        "outcome": f.flatten(),
    }
)

In [None]:
prior_hsgp = {
    "sigma": bmb.Prior("Exponential", lam=3),
    "ell": bmb.Prior("InverseGamma", mu=2, sigma=0.2),
}
priors = {
    "hsgp(x, y, c=1.5, m=10)": prior_hsgp,
    "sigma": bmb.Prior("HalfNormal", sigma=2),
}
model = bmb.Model("outcome ~ 0 + hsgp(x, y, c=1.5, m=10)", data, priors=priors)
model.set_alias({"hsgp(x, y, c=1.5, m=10)": "hsgp"})
model

In [None]:
model.build()
model.graph()

In [None]:
idata = model.fit(inference_method="numpyro_nuts", target_accept=0.9, num_chains=4)
print(idata.sample_stats.diverging.sum().item())

Above cell raises `NotImplementedError: 'numpyro_nuts' method has not been implemented`. Since all the cells seem to be using that, stopping the notebook here.