### Basement dataset

In [None]:
import jax 
import jax.numpy as jnp
import scipy.io as sio
import gpjax as gpx
import optax as ox
jax.config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt
%matplotlib widget
import seaborn as sns
import pandas as pd
from tqdm.notebook import tqdm
%load_ext autoreload
%autoreload 2
sns.set()

### Load and visualize the data

In [None]:
data = sio.loadmat("data/basement.mat")
X = data['p']
Y = data['y']
Da = gpx.Dataset(X, Y)

l2 = 4
west = - 50 - l2
east = 50 + l2
south = - 40 - l2
north = 40 + l2
Lx = (jnp.abs(east) + jnp.abs(west)).squeeze()/2
Ly = (jnp.abs(south) + jnp.abs(north)).squeeze()/2
Lz = l2

plt.close("all")
plt.figure()
plt.scatter(X[:,0],X[:,1],c=jnp.linalg.norm(Y,axis=1))
plt.show()

### Constructs the curl-free sum GP and conditions on the dataset

In [None]:
import kernels
import gps
import objectives
Q = 4
fx, fy, fz = jnp.meshgrid(jnp.arange(0, Q+1) * 1/6, jnp.arange(0, Q+1) * 1/15, jnp.arange(0, Q+1) * 1/2, indexing='ij')
mu = jnp.array([fx.flatten(), fy.flatten(), fz.flatten()]).T
pd_kernel = kernels.PotentialPeriodicPD(q=mu.shape[0], d=3, mu=mu).replace_trainable(mu=False)

bf = kernels.PotentialLaplaceBF(num_bfs=[40, 40, 3], L = [Lx, Ly, Lz])
kernel = kernels.SumKernel(kernels=[pd_kernel, kernels.RBF(basis=bf, lengthscale=2., variance=5.)])
mean_function = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=mean_function, kernel=kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=Da.n, obs_stddev=5.)
posterior = prior * likelihood

gp = gps.PotentialBFGP(likelihood=likelihood, prior=prior)
pgp = gp.update_with_batch(Da)

#### Optimize hyperparameters (may take some time)

In [None]:
key = jax.random.PRNGKey(13)
optimizer = ox.adam(learning_rate=1e-1)

gp_full, history = gpx.fit(
                            model=pgp,
                            objective=objectives.BF_MLL(negative=True, compute_bf=False),
                            train_data=Da,
                            optim=optimizer,
                            num_iters=150,
                            key=key,
)

In [None]:
plt.figure(figsize=(8,3))
plt.plot(history)
plt.show()

##### Extract the SE component to construct an SE GP as well

In [None]:
import utils
dQ = gp_full.prior.kernel.kernels[0].q * 2
dparams = gp_full.parameters.replace(alpha=gp_full.alpha[dQ:], B=gp_full.B[dQ:, dQ:])
se_full = gp_full.replace(parameters=dparams, prior=gp_full.prior.replace(kernel=gp_full.prior.kernel.kernels[1]))

### Compute predictions on a fine grid

In [None]:
import gp_utils as gpu

step = 0.25
x = jnp.arange(west + l2, east - l2 + step, step)
y = jnp.arange(south + l2, north - l2 + step, step)
X, Y = jnp.meshgrid(x, y)
X = jnp.array([X.flatten(), Y.flatten()]).T
xtest = jnp.hstack([X, jnp.zeros((X.shape[0],1))])
def pred_on_grid(gp):
    mux, covx = [], []
    muy, covy = [], []
    muz, covz = [], []
    m, S = gp.mean_parameters
    for i in tqdm(range(0, X.shape[0], 1000)):
        mu, V = gpu.predict(m, S, gp.prior.kernel, xtest[i:i+1000])
        mux.append(mu.T[0])
        muy.append(mu.T[1])
        muz.append(mu.T[2])
        covx.append(V.T[0].diagonal())
        covy.append(V.T[1].diagonal())
        covz.append(V.T[2].diagonal())
    mux, covx = jnp.concatenate(mux), jnp.concatenate(covx)
    muy, covy = jnp.concatenate(muy), jnp.concatenate(covy)
    muz, covz = jnp.concatenate(muz), jnp.concatenate(covz)
    return (mux, covx), (muy, covy), (muz, covz)
def norm_pred(xmu, xcov, ymu, ycov, zmu, zcov):
    base_y = data['y_raw'].mean(axis=0)
    norm = jnp.sqrt((xmu+base_y[0])**2 + (ymu+base_y[1])**2 + (zmu+base_y[2])**2)
    var = xcov + ycov + zcov
    return norm, var

##### This may take some time

In [None]:
(se_x, se_xcov), (se_y, se_ycov), (se_z, se_zcov) = pred_on_grid(se_full)
(sum_x, sum_xcov), (sum_y, sum_ycov), (sum_z, sum_zcov) = pred_on_grid(gp_full)

In [None]:
se_norm, se_var = norm_pred(se_x, se_xcov, se_y, se_ycov, se_z, se_zcov)
sum_norm, sum_var = norm_pred(sum_x, sum_xcov, sum_y, sum_ycov, sum_z, sum_zcov)

##### Plotting function

In [None]:
from utils import bitmappify
import matplotlib
import tikzplotlib
def plot_preds(norm, var, basevar, filename=None, extra_params=[], standalone=False):
    alpha = 1 - var/basevar
    with sns.axes_style("white"):
        matplotlib.rcParams['xtick.direction'] ='in'
        matplotlib.rcParams['ytick.direction'] ='in'
        matplotlib.rcParams['xtick.bottom'] = True
        matplotlib.rcParams['ytick.left'] = True
        plt.close("all")
        fig = plt.figure(figsize=(8, 8))
        g = plt.imshow(norm.reshape(Y.shape), 
                       alpha=alpha.reshape(Y.shape),
                       cmap=sns.color_palette("viridis", as_cmap=True),
                       vmin=14,
                       vmax=23,
                       origin='lower',
                       extent=(xtest[:,0].min(), xtest[:,0].max(), xtest[:,1].min(), xtest[:,1].max()),
                      interpolation='bicubic')
        
        bitmappify(plt.gca(), dpi=300)
        plt.xlabel(r'$p_1~[m]$')
        plt.ylabel(r'$p_2~[m]$')
        extra_params.extend(["scale only axis",
                        "axis lines = left",
                       "xlabel style={yshift=.25cm}",
                        "ylabel style={yshift=-.25cm}"])
        fig.colorbar(g, ax=plt.gca(), location='top', label=r'$\mu T$')
        if filename is not None:
            tikzplotlib.save(filename,
                         axis_width="4cm",
                         extra_axis_parameters=extra_params,
                        override_externals=True,
                        standalone=standalone)
        plt.show()

##### Plot SE predictions

In [None]:
extra_params = ["colormap/viridis",
                "colorbar horizontal",
                "colorbar style={\n\
at={(0.5, 1.025)},\n\
anchor=south,\n\
point meta min=14,\n\
point meta max=24,\n\
xticklabel pos=upper,\n\
xlabel = {$\si{\micro\tesla}$},\n\
height = .15cm,\n\
}"]
basevar = se_var.max()
plot_preds(se_norm, se_var, basevar, filename="se_warehouse_predictions.tex", extra_params=extra_params, standalone=False)

##### Plot sum predictions

In [None]:
basevar = gp_full.prior.kernel.kernels[0].variance*jnp.exp(gp_full.prior.kernel.kernels[0].a).sum() + gp_full.prior.kernel.kernels[1].variance
plot_preds(sum_norm, sum_var, basevar, filename="sum_warehouse_predictions.tex", standalone=False)

##### Save hyperparams -- for MC runs in MATLAB

In [None]:
sio.savemat("hyperparams.mat", dict(obs_stddev=gp_full.likelihood.obs_stddev,
     mixture_weights=jnp.exp(gp_full.prior.kernel.kernels[0].a) * gp_full.prior.kernel.kernels[0].variance,
     frequencies=gp_full.prior.kernel.kernels[0].mu,
     se_lengthscale=gp_full.prior.kernel.kernels[1].lengthscale, 
     se_variance=gp_full.prior.kernel.kernels[1].variance,
     se_num_bfs=gp_full.prior.kernel.kernels[1].basis.num_bfs, 
     se_domain_boundary=gp_full.prior.kernel.kernels[1].basis.L))

In [None]:
filenames = ['pd', 'se', 'sum', 'meas']
for file in filenames:
    fig, ax = plt.subplots(figsize=(6, 6))
    with open(file + '.csv', "r") as f:
        d = pd.read_csv(f)
    g = plt.scatter(d.x1, d.x2, c=d.y, vmin=0, vmax=2.5, cmap=sns.color_palette("viridis", as_cmap=True), alpha=d.a)
    ax.set(xticklabels=[], yticklabels=[])
    fig.savefig(file + '.pdf', bbox_inches='tight')
    if 'meas' in file:
        fig, ax = plt.subplots()
        cbar = plt.colorbar(g, ax=ax, orientation='horizontal')
        cbar.ax.xaxis.set_ticks_position('top')
        ax.remove()
        fig.savefig('visionen_cbar.pdf', bbox_inches='tight')

In [None]:
vmin, vmax = D.y.min(), D.y.max()

plt.close("all")
fig, ax = plt.subplots(2, 2, figsize=(8, 8))
g = ax[1,0].scatter(D.X[:,0],D.X[:,1], c=D.y.flatten(), vmin=vmin, vmax=vmax, cmap=sns.color_palette("viridis", as_cmap=True))
ax[1,0].set_title("Raw measurements")
apd = yh_pd.covariance().diagonal()
apd = apd/apd.max()
ax[0,0].scatter(xtest[:,0], xtest[:,1], c=yh_pd.mean(), vmin=vmin, vmax=vmax, cmap=sns.color_palette("viridis", as_cmap=True), alpha=1-apd)
ax[0,0].set_title("PD kernel")
ase = yhgp.covariance().diagonal()
ase = ase/ase.max()
ax[0,1].scatter(xtest[:,0], xtest[:,1], c=yhgp.mean(), vmin=vmin, vmax=vmax, cmap=sns.color_palette("viridis", as_cmap=True), alpha=1-ase)
ax[0,1].set_title("SE kernel")
asum = yh_sum.covariance().diagonal()
asum = asum/asum.max()
g = ax[1,1].scatter(xtest[:,0], xtest[:,1], c=yh_sum.mean(), vmin=vmin, vmax=vmax, cmap=sns.color_palette("viridis", as_cmap=True), alpha=1-asum)
ax[1,1].set_title("PD + SE kernel")
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(g, cax=cbar_ax)

for axi in ax.flatten():
    axi.set_box_aspect(1)
plt.show()

In [None]:
yh_pd.kl_divergence(yhgp)
yh_sum.kl_divergence(yhgp)