### 1D Hallway dataset
Optimizes hyperparameters and constructs three GPs based on the optimized hyperparameters
1. A pure PD GP
2. A pure SE GP
3. A PD+SE GP

In [None]:
import jax 
import jax.numpy as jnp
import scipy.io as sio
import gpjax as gpx
import optax as ox
jax.config.update("jax_enable_x64", True)
%matplotlib widget
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
%load_ext autoreload
%autoreload 2
sns.set()

### Load data and construct dataset

In [None]:
data = sio.loadmat("data/hallway.mat")
X, y = data['x'], data['y_norm']
y = jnp.linalg.norm(data['y'], axis=0)[:, None]
ymu = y.mean()
D = gpx.Dataset(X, y - ymu)

### MLL Optimization

In [None]:
import kernels
Q = 20
mu = jnp.arange(0, Q+1) * 1/12
pd_kernel = kernels.PeriodicPD(q=mu.shape[0], d=1, mu=mu).replace_trainable(mu=False)
boundary = jnp.array([D.X.min(axis=0), D.X.max(axis=0)]).T
L = 0.6 * jnp.diff(boundary, axis=1).squeeze()
center = D.X.mean(axis=0)
basis = kernels.LaplaceBF(num_bfs=[200], center=center, L = L)
kernel = kernels.SumKernel(kernels=[pd_kernel, kernels.RBF(lengthscale=.4, variance=1., basis=basis)])
mean_function = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=mean_function, kernel=kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=jnp.array(1.))
posterior = prior * likelihood
import objectives
import gps
gp = gps.BFGP(likelihood=likelihood, prior=prior)
pgp = gp.update_with_batch(D)

### Optimize hyperparameters

In [None]:
key = jax.random.PRNGKey(13)
optimizer = ox.adam(learning_rate=1e-1)
gp_full, history = gpx.fit(
                            model=pgp,
                            objective=objectives.BF_MLL(negative=True, compute_bf=False),
                            train_data=D,
                            optim=optimizer,
                            num_iters=150,
                            key=key,
)

In [None]:
plt.figure(figsize=(8, 3))
plt.plot(history)
plt.show()

### Predictions of the three GPs (PD, SE and PD+SE)
Constructs the three different GPs.
Then predicts on the entire dataset, conditioned only on the first 300 observations.

In [None]:
pd_post_gp = gps.BFGP(likelihood=gp_full.likelihood, prior=prior.replace(kernel=gp_full.prior.kernel.kernels[0]))
se_post_gp = gps.BFGP(likelihood=gp_full.likelihood, prior=prior.replace(kernel=gp_full.prior.kernel.kernels[1]))
full_post_gp = gps.BFGP(likelihood=gp_full.likelihood, prior=prior.replace(kernel=gp_full.prior.kernel))
D1 = gpx.Dataset(D.X[:300], D.y[:300])
full_yhat = full_post_gp(D.X, extra_data=D1)
pd_yhat = pd_post_gp(D.X, extra_data=D1)
se_yhat = se_post_gp(D.X, extra_data=D1)

In [None]:
def pl_conf(post, Z, ax=None, beta=3, **kwargs):
    mu = post.mean()
    std = jnp.sqrt(post.covariance().diagonal()) #+ gp_full.likelihood.obs_stddev
    if ax is None:
        ax = plt.gca()
    l = ax.plot(Z, mu+ymu, '--', **kwargs)[0]
    ax.fill_between(Z, mu+ymu - beta*std, mu+ymu + beta*std, alpha=.2, color=l.get_color())

plt.close("all")
plt.figure()
plt.plot(D.X[:,0], D.y+ymu, 'k')
plt.plot(D1.X[:,0], D1.y+ymu, color='tab:blue')
# pl_conf(mll_yhat, D.X[:,1], color='tab:orange', label='MLL')
pl_conf(pd_yhat, D.X[:,0], color='tab:orange', label='Periodic')
pl_conf(se_yhat, D.X[:,0], color='tab:purple', label='SE')
pl_conf(full_yhat, D.X[:,0], color='tab:green', label='Full')
plt.legend()
plt.show()

### Save to .csv for tikz-plots

In [None]:
import pandas as pd
with open("hallway_results.csv", "w") as file:
    pd.DataFrame(dict(x=D.X.squeeze(), 
                      y=D.y.squeeze() + ymu, 
                      pdmu=pd_yhat.mean() + ymu, 
                      pdstd=jnp.sqrt(pd_yhat.covariance().diagonal()),
                      semu=se_yhat.mean() + ymu, 
                      sestd=jnp.sqrt(se_yhat.covariance().diagonal()),
                      summu=full_yhat.mean() + ymu, 
                      sumstd=jnp.sqrt(full_yhat.covariance().diagonal()))).to_csv(file, index=False)

In [None]:
import kernels
Q = 20
mu = jnp.arange(1, Q+1) * 1/12
# nu = jnp.ones((Q, 1)) * 0.2
# pd_kernel = kernels.PD(q=Q, mu=mu, nu=nu)#.replace_trainable(mu=False)
pd_kernel = kernels.PeriodicPD(q=Q, mu=mu)#.replace_trainable(mu=False)
pd_kernel = kernels.FourierPD(q=Q, fundamental_frequency=1/12)#.replace_trainable(mu=False)
# pd_kernel = pd_kernel.replace_trainable(mu=False, nu=False)
# kernel = gpx.kernels.SumKernel(kernels=[pd_kernel, gpx.kernels.RBF()])
kernel = pd_kernel
mean_function = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=mean_function, kernel=kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=1.)
posterior = prior * likelihood

In [None]:
key = jax.random.PRNGKey(13)
optimizer = ox.adam(learning_rate=1e-1)

pd_full, history = gpx.fit(
                            model=posterior,
                            objective=gpx.objectives.ConjugateMLL(negative=True),
                            train_data=D,
                            optim=optimizer,
                            num_iters=100,
                            key=key,
)

In [None]:
plt.figure(figsize=(8, 3))
plt.plot(history)
plt.show()

##### Re-optimize with SE kernel as well

In [None]:
kernel = pd_full.prior.kernel
prior = gpx.gps.Prior(mean_function=mean_function, kernel=gpx.kernels.SumKernel(kernels=[kernel, gpx.kernels.RBF()]))
posterior = prior * pd_full.likelihood

In [None]:
sum_full, history = gpx.fit(
                            model=posterior,
                            objective=gpx.objectives.ConjugateMLL(negative=True),
                            train_data=D,
                            optim=optimizer,
                            num_iters=100,
                            key=key,
)

In [None]:
plt.figure(figsize=(8, 3))
plt.plot(history)
plt.show()

In [None]:
prior = gpx.gps.Prior(mean_function=mean_function, kernel=sum_full.prior.kernel.kernels[1])
se_full = prior * likelihood

In [None]:
D1 = gpx.Dataset(D.X[:300], D.y[:300])
xmin, xmax = X.min(), X.max()
Z = jnp.linspace(xmin, xmax, 500)
pd_yhat = pd_full(Z, D1)
sum_yhat = sum_full(Z, D1)
se_yhat = se_full(Z, D1)

In [None]:
plt.close("all")
plt.figure()
plt.plot(D.X, D.y, 'k')
plt.plot(D1.X, D1.y, color='tab:blue')
pl_conf(pd_yhat, Z, color='tab:orange', label='PD')
pl_conf(sum_yhat, Z, color='tab:green', label='Sum')
pl_conf(mll_yhat, Z, color='tab:red', label='MLL')
pl_conf(se_yhat, Z, color='tab:purple', label='SE')
plt.legend()
plt.show()

In [None]:
obj = ob.ConditionalLPD(negative=True, partitioner=ob.CompactSizePartitioner(size=300))
key = jax.random.PRNGKey(13)
optimizer = ox.adam(learning_rate=1e-1)

crlpd_full, history = fit.fit(
                            model=posterior,
                            objective=obj,
                            train_data=D,
                            optim=optimizer,
                            num_iters=200,
                            key=key,
)

In [None]:
plt.figure(figsize=(8,3))
plt.plot(history)
plt.show()

In [None]:
crlpd_yhat = crlpd_full(Z, D1)

In [None]:
plt.close("all")
plt.figure()
plt.plot(D.X, D.y, 'k.')
plt.plot(D1.X, D1.y, marker='.', color='tab:blue')
# pl_conf(sum_yhat, Z, color='tab:green', label='Sum')
pl_conf(mll_yhat, Z, color='tab:green', label='MLL')
# pl_conf(lpd_yhat, Z, color='tab:red', label='LPD')
pl_conf(rlpd_yhat, Z, color='tab:purple', label='RLPD')
pl_conf(crlpd_yhat, Z, color='tab:orange', label='CRLPD')
plt.legend()
plt.show()

In [None]:
import pandas as pd
models = dict(twostage=sum_yhat, fixed=lpd_yhat, random=rlpd_yhat, compact=crlpd_yhat, se=se_yhat
results =