In [None]:
## Pyro GP tutorial used as starting point:
## https://pyro.ai/examples/gp.html

import matplotlib.pyplot as plt
import numpy as np
import torch
import pyro
import pyro.contrib.gp as gp
import arviz

# Partition observations
X = np.asarray([x / 29 for x in range(1, 31)])
np.random.shuffle(X)
Y = 6 * np.square(X) - np.square(np.sin(6 * np.pi * X)) - 5 * np.power(X, 4) + 3 / 2 + np.random.normal(0.0, 0.1, 30)
Xtrain, Xtest, Ytrain, Ytest = X[10:], X[:10], Y[10:], Y[:10]

### Using NUTS

In [None]:
# Model is GP model from pyro
W = 100 # Number of warmup steps
C = 1 # Number of chains
S = 500 # Number of samples used in prediction

model = None # Should be GP model
nuts_kernel = pyro.infer.NUTS(model, jit_compile=True)
mcmc = pyro.infer.MCMC(nuts_kernel, S, num_chains=C, warmup_steps=W)
mcmc.run(X, Y)

#### Checking quality of samples using arviz

In [None]:
posterior_samples = mcmc.get_samples()
data = arviz.from_pyro(mcmc)
summary = arviz.summary(data)
print(summary)
arviz.plot_trace(data)
plt.show()
# Maybe use this: arviz.rcParams['plot.max_subplots'] = 18
arviz.plot_posterior(data, var_names=['w3', 'b3']) # TODO: Change var names