In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import string
from sklearn.neighbors import KNeighborsRegressor

In [None]:
# Import the various experiments we need for the notebook
from sva.experiments import Simple2d, GPDream, GPDreamKNN

# Import the helper functions for Gaussian Processes
from sva.models import EasySingleTaskGP

# Other utilities
from sva.utils import seed_everything
from sva.mpl_utils import set_mpl_defaults, set_mpl_grids

In [None]:
set_mpl_defaults()

# A simple 2d example

Here's a simple example of using a GP to fit some data in a 2-dimensional space. We show the ground truth, GP fit and the error in the GP predictions.

In [None]:
seed_everything(123)

In [None]:
experiment = Simple2d()
extent = experiment.get_domain_mpl_extent()
X = experiment.get_random_coordinates(35)

In [None]:
# Get the ground truth results
x = experiment.get_dense_coordinates(ppd=100)
y, _ = experiment(x)

In [None]:
# Train a GP on the current data
_y, _ = experiment(X)
gp = EasySingleTaskGP.from_default(X, _y)
gp.fit_mll()
mu, var = gp.predict(x)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(6, 3), sharex=True, sharey=True)

ax = axs[0]
ax.imshow(
    y.reshape(100, 100).T,
    extent=extent,
    interpolation="nearest",
    origin="lower",
)
ax.scatter(X[:, 0], X[:, 1], color="black", s=0.5)
ax.set_title("Ground Truth")


ax = axs[1]
ax.imshow(
    mu.reshape(100, 100).T,
    extent=extent,
    interpolation="nearest",
    origin="lower",
)
ax.scatter(X[:, 0], X[:, 1], color="black", s=0.5)
ax.set_title("GP Prediction")


ax = axs[2]
ax.imshow(
    mu.reshape(100, 100).T - y.reshape(100, 100).T,
    extent=extent,
    interpolation="nearest",
    origin="lower",
    cmap="RdBu",
)
ax.scatter(X[:, 0], X[:, 1], color="black", s=0.5)
ax.set_title("Delta")


plt.show()

# Example 2d test functions

We can also use GPs _themselves_ as test functions. The way this is done is by using the uniformed prior, sampling from it, and then fitting another deterministic function to that data. This allows us to produce an infinite number of test functions consistent with a kernel (usually RBF) and its length scale.

In [None]:
experiments = [
    GPDream.from_default(gp_model_params={"kernel":"rbf", "lengthscale": 0.1}, d=2, seed=2),
    GPDream.from_default(gp_model_params={"kernel":"rbf", "lengthscale": 0.3}, d=2, seed=3),
    GPDream.from_default(gp_model_params={"kernel":"periodic", "lengthscale": 0.2, "period_length": 1.0}, d=2, seed=4),
    GPDream.from_default(gp_model_params={"kernel":"periodic", "lengthscale": 0.5, "period_length": 1.0}, d=2, seed=3),
    GPDream.from_default(gp_model_params={"kernel":"periodic", "lengthscale": 0.8, "period_length": 0.5}, d=2, seed=5)
]

In [None]:
optima = []
results = []
for ii, experiment in enumerate(experiments):
    x = experiment.get_dense_coordinates(ppd=100)
    y, _ = experiment(x)
    y = y.squeeze()
    extent = experiment.get_domain_mpl_extent()
    optima.append(experiment.optimum)
    results.append(y)

In [None]:
vmin = np.min(results)
vmax = np.max(results)

In [None]:
f = plt.figure(figsize=(2 * len(experiments), 2))

axs = ImageGrid(
    f,
    111,
    nrows_ncols=(1, len(experiments)),
    axes_pad=0.25,
    share_all=True,
    # cbar_location="right",
    # cbar_mode="edge",
    # cbar_size="5%",
    # cbar_pad=0.15,
    aspect=False,
)


for ii, r in enumerate(results):

    ax = axs[ii]

    ax.set_xticks([-1, 0, 1])
    ax.set_yticks([-1, 0, 1])
    
    set_mpl_grids(ax)
    
    im = ax.imshow(r.reshape(100, 100)[:, ::-1].T, extent=extent, aspect="equal", cmap="viridis")
    ax.scatter(*optima[ii][0].squeeze(), color="black", marker="x")
    l = string.ascii_lowercase[ii]
    ax.text(0.1, 0.1, f"({l})", ha="left", va="bottom", transform=ax.transAxes, color="white")

    # cbar = ax.cax.colorbar(im)
    # cbar.set_label(r"$f(\mathbf{x})$")

axs[0].set_ylabel(r"$x_2$")
axs[len(experiments) // 2].set_xlabel(r"$x_1$")

plt.show()