### Visionen dataset
Data from an unstructured environment -- should not be beneficial with the PD component.
Optimizes hyperparameters and produces plots for 
1. PD GP
2. SE GP
3. Sum GP

In [None]:
import jax 
import jax.numpy as jnp
import scipy.io as sio
import gpjax as gpx
import optax as ox
jax.config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

### Load and visualize the data

Subsampled to approximately 900 datapoints.

In [None]:
data = sio.loadmat("data/visionen.mat")
inds = jnp.arange(0, data['y_mag_norm'].shape[-1], step=100)
test_inds = inds[:-1] + 50
X = jnp.array(data['p'].T[inds])/1000
X = X.at[:,-1].set(0)
XT = jnp.array(data['p'].T[test_inds])/1000
XT = XT.at[:,-1].set(0)
YT = jnp.array(data['y_mag_norm'].T[test_inds])
D = gpx.Dataset(X, jnp.array(data['y_mag_norm'].T[inds]))
%matplotlib widget
plt.close("all")
plt.figure()
plt.scatter(data['p'][0, inds]/1000, data['p'][1,inds]/1000, c=data['y_mag_norm'][:,inds].flatten())#, antialiased=True)
plt.colorbar()
plt.show()

## Optimize hyperparameters of the sum kernel

In [None]:
import kernels
mean_function = gpx.mean_functions.Zero()
pd_kernel = kernels.PeriodicPD(q=3)
sum_kernel = gpx.kernels.SumKernel(kernels=[pd_kernel, gpx.kernels.RBF()])
prior = gpx.gps.Prior(mean_function=mean_function, kernel=sum_kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=1.)
posterior = prior * likelihood

In [None]:
key = jax.random.PRNGKey(13)
optimizer = ox.adam(learning_rate=1e-1)

sum_full, history = gpx.fit(
                model=posterior,
                objective=gpx.objectives.ConjugateMLL(negative=True),
                train_data=D,
                optim=optimizer,
                num_iters=150,
                key=key,
)

In [None]:
plt.figure(figsize=(8,3))
plt.plot(history)
plt.show()

### SE kernel
Full GP with standard squared-exponential kernel. Optimises hyperparameters via negative MLL optimisation.

In [None]:
kernel = gpx.kernels.RBF()
prior = gpx.gps.Prior(mean_function=mean_function, kernel=kernel)
posterior = prior * likelihood
key = jax.random.PRNGKey(13)
optimizer = ox.adam(learning_rate=1e-1)

se_full, history = gpx.fit(
                model=posterior,
                objective=gpx.objectives.ConjugateMLL(negative=True),
                train_data=D,
                optim=optimizer,
                num_iters=150,
                key=key,
)

In [None]:
plt.figure(figsize=(8,3))
plt.plot(history)
plt.show()

### Predictions on a fine test grid (note -- this may take a few minutes)

In [None]:
import numpy as np
ntest = 200
x = jnp.linspace(D.X[:,0].min(), D.X[:,0].max(), ntest)
y = jnp.linspace(D.X[:,1].min(), D.X[:,1].max(), ntest)
xtest = np.array([x.flatten() for x in jnp.meshgrid(x, y)])
xtest = jnp.vstack([xtest, jnp.zeros((1,xtest.shape[-1]))]).T

import gp_utils as gpu
from gpjax.distributions import GaussianDistribution
import cola
from tqdm.notebook import tqdm
def pred_on_grid(gp):
    mu, cov = [], []
    for i in tqdm(range(0, xtest.shape[0], 2000)):
        f = gp(xtest[i:i+2000], D)
        mu.append(f.mean())
        cov.append(f.covariance().diagonal())
    return jnp.concatenate(mu), jnp.concatenate(cov)

In [None]:
se = sum_full.replace(prior=sum_full.prior.replace(kernel=sum_full.prior.kernel.kernels[1]))
pd = sum_full.replace(prior=sum_full.prior.replace(kernel=sum_full.prior.kernel.kernels[0]))
pd_mu, pd_var = pred_on_grid(pd)
se_mu, se_var = pred_on_grid(se)
sum_mu, sum_var = pred_on_grid(sum_full)

In [None]:
pmin, pmax = D.X.min(axis=0)[:2], D.X.max(axis=0)[:2]
ext = 0.02*np.diff(np.array([pmin, pmax]).T, axis=1)
pmin -= ext[0]
pmax += ext[1]
xlim, ylim = np.array([pmin, pmax]).T + np.kron(np.array([-1, 1]), ext)

##### Save raw measurements

In [None]:
import matplotlib
import tikzplotlib
from utils import bitmappify
plt.close("all")
vmin, vmax = D.y.min(), D.y.max()
cmap = sns.color_palette("viridis", as_cmap=True)
with sns.axes_style("whitegrid"):
    matplotlib.rcParams['xtick.direction'] ='in'
    matplotlib.rcParams['ytick.direction'] ='in'
    matplotlib.rcParams['xtick.bottom'] = True
    matplotlib.rcParams['ytick.left'] = True
    fig = plt.figure(figsize=(4, 4))
    g = plt.scatter(D.X[:,0], D.X[:,1], c=D.y.flatten(), marker='s', vmin=vmin, vmax=vmax, cmap=cmap)
    bitmappify(plt.gca(), dpi=300)
    plt.xlabel(r"$x_1~[m]$")
    plt.ylabel(r"$x_2~[m]$")
    plt.xlim(xlim)
    plt.ylim(ylim)
    tikzplotlib.save("raw_measurements.tex",
                     axis_width="3.5cm",
                     extra_axis_parameters=["scale only axis",
                                           "xlabel style={yshift=.25cm}",
                                            "ylabel style={yshift=-.25cm}",
                                           "title style={yshift=-.25cm}",],
                     override_externals=True,
                    standalone=False)

##### Save predictions

In [None]:
def save_pred(mu, var, filename=None, extra_axis_params=[], standalone=False):
    plt.close("all")
    a = var
    a = 1 - a/a.max()
    with sns.axes_style("whitegrid"):
        matplotlib.rcParams['xtick.direction'] ='in'
        matplotlib.rcParams['ytick.direction'] ='in'
        matplotlib.rcParams['xtick.bottom'] = True
        matplotlib.rcParams['ytick.left'] = True
        fig = plt.figure(figsize=(4, 4))
        # plt.scatter(xtest[:,0], xtest[:,1], c=preds.mean(), marker='s', vmin=vmin, vmax=vmax, cmap=cmap, alpha=a)#, edgecolors='none')
        c = mu
        g = plt.imshow(c.reshape(ntest, ntest), 
                       alpha=a.reshape(ntest, ntest),
                       cmap=cmap,
                       vmin=vmin,
                       vmax=vmax,
                       origin='lower',
                       extent=(xtest[:,0].min(), xtest[:,0].max(), xtest[:,1].min(), xtest[:,1].max()),
                      interpolation='bicubic')
        plt.xlim(xlim)
        plt.ylim(ylim)
        bitmappify(plt.gca(), dpi=300)
        extra_axis_parameters = ["scale only axis",
                               "xlabel style={yshift=.25cm}",
                                "ylabel style={yshift=-.25cm}",
                                "title style={yshift=-.25cm}"] + extra_axis_params
        if filename is not None:
            tikzplotlib.save(filename,
                         axis_width="4.0cm",
                         extra_axis_parameters=extra_axis_parameters,
                         override_externals=True,
                        standalone=standalone)

In [None]:
save_pred(pd_mu, pd_var, filename="pd_pred.tex", extra_axis_params=["ylabel={$x_2~[m]$}", "xticklabels={}"], standalone=False)

In [None]:
save_pred(se_mu, se_var, filename="se_pred.tex", extra_axis_params=["yticklabels={}", "xticklabels={}"], standalone=False)

In [None]:
save_pred(sum_mu, sum_var, filename="sum_pred.tex", extra_axis_params=["xlabel={$x_1~[m]$}", "yticklabels={}"], standalone=False)

In [None]:
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
def nlpd(gp):
    pred = gp(XT, D)
    mu, cov = pred.mean(), pred.covariance().diagonal()
    cov += gp.likelihood.obs_stddev ** 2
    return -jax.vmap(lambda loc, scale, y: tfd.Normal(loc=loc, scale=scale).log_prob(y), (0, 0, 0), 0)(mu, cov, YT).mean()
nlpd(sum_full), nlpd(se), nlpd(pd)

In [None]:
filenames = ['pd', 'se', 'sum', 'meas']
for file in filenames:
    fig, ax = plt.subplots(figsize=(6, 6))
    with open(file + '.csv', "r") as f:
        d = pd.read_csv(f)
    g = plt.scatter(d.x1, d.x2, c=d.y, vmin=0, vmax=2.5, cmap=sns.color_palette("viridis", as_cmap=True), alpha=d.a)
    ax.set(xticklabels=[], yticklabels=[])
    fig.savefig(file + '.pdf', bbox_inches='tight')
    if 'meas' in file:
        fig, ax = plt.subplots()
        cbar = plt.colorbar(g, ax=ax, orientation='horizontal')
        cbar.ax.xaxis.set_ticks_position('top')
        ax.remove()
        fig.savefig('visionen_cbar.pdf', bbox_inches='tight')

In [None]:
vmin, vmax = D.y.min(), D.y.max()

plt.close("all")
fig, ax = plt.subplots(2, 2, figsize=(8, 8))
g = ax[1,0].scatter(D.X[:,0],D.X[:,1], c=D.y.flatten(), vmin=vmin, vmax=vmax, cmap=sns.color_palette("viridis", as_cmap=True))
ax[1,0].set_title("Raw measurements")
apd = yh_pd.covariance().diagonal()
apd = apd/apd.max()
ax[0,0].scatter(xtest[:,0], xtest[:,1], c=yh_pd.mean(), vmin=vmin, vmax=vmax, cmap=sns.color_palette("viridis", as_cmap=True), alpha=1-apd)
ax[0,0].set_title("PD kernel")
ase = yhgp.covariance().diagonal()
ase = ase/ase.max()
ax[0,1].scatter(xtest[:,0], xtest[:,1], c=yhgp.mean(), vmin=vmin, vmax=vmax, cmap=sns.color_palette("viridis", as_cmap=True), alpha=1-ase)
ax[0,1].set_title("SE kernel")
asum = yh_sum.covariance().diagonal()
asum = asum/asum.max()
g = ax[1,1].scatter(xtest[:,0], xtest[:,1], c=yh_sum.mean(), vmin=vmin, vmax=vmax, cmap=sns.color_palette("viridis", as_cmap=True), alpha=1-asum)
ax[1,1].set_title("PD + SE kernel")
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(g, cax=cbar_ax)

for axi in ax.flatten():
    axi.set_box_aspect(1)
plt.show()

In [None]:
yh_pd.kl_divergence(yhgp)
yh_sum.kl_divergence(yhgp)