In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
import matplotlib.pyplot as plt
from functools import partial
import numpy as np
from mpl_toolkits.axes_grid1 import ImageGrid
import gpytorch

In [None]:
import sys

sys.path.append("..")

In [None]:
# Import the various experiments we need for the notebook
from sva.experiments.ggce.ggce import Peierls

# Plotting quality of life utils
from sva.mpl_utils import set_mpl_defaults

# Import the campaign
from sva.campaign import Campaign, CampaignData, FixedPolicy, FixedSVAPolicy

# Model to use
from sva.models import EasySingleTaskGP, EasyFixedNoiseGP

# Proximity penalty module
from sva.bayesian_optimization import ProximityPenalty

# Value function
from sva.value import SVF

In [None]:
set_mpl_defaults()

# GGCE example

In [None]:
SEED = 133
ppd = 100

In [None]:
experiment = Peierls(y_log=True)
x = experiment.get_dense_coordinates(ppd=ppd)
y = experiment(x)
extent = experiment.get_domain_mpl_extent()
extent[1] = 1.0  # Scale x-axis by pi

Show the ground truth function.

In [None]:
A = y.copy()
A = A.reshape(ppd, ppd).T[::-1, :]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 2))

im = ax.imshow(A, extent=extent, aspect="auto", cmap="binary")
cbar = plt.colorbar(im)
cbar.set_label(r"$\log_{10} A(k,\omega)$")

ax.set_ylabel(r"$\omega/t$")
ax.set_xlabel(r"$k/\pi$")
ax.set_xticks([0, 1])
plt.show()

In [None]:
x_grid = np.linspace(0, 2, 100)
def sigmoid(d, x0, a):
    return 1.0 / (1.0 + np.exp(-(d - x0) / a))
xig = sigmoid(x_grid, 1, 0.05)

In [None]:
plt.plot(x_grid, xig)

We'll need some other bits and pieces to effectively run a campaign!

In [None]:
N_start = 3
N_max = 225

In [None]:
data = CampaignData()
data.prime(experiment, "random", seed=SEED, n=N_start)
covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel())
model_factory = partial(EasySingleTaskGP.from_default, covar_module=covar_module)
policy = FixedSVAPolicy(
    n_max=N_max, acquisition_function="UCB-100", model_factory=model_factory, save_model=False,
    optimize_kwargs={"q": 1, "num_restarts": 20, "raw_samples": 100}
)
campaign = Campaign(data=data, experiment=experiment, policy=policy, seed=SEED)

In [None]:
campaign.run()

In [None]:
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RQKernel(ard_num_dims=2))
mean = gpytorch.means.ZeroMean()
model = EasySingleTaskGP.from_default(data.X, data.Y, mean_module=mean, covar_module=kernel)
model.fit_mll()
pred, std = model.predict(x)

In [None]:
f = plt.figure(figsize=(8, 2))

axs = ImageGrid(
    f,
    111,
    nrows_ncols=(1, 4),
    axes_pad=0.25,
    share_all=True,
    cbar_location="right",
    cbar_mode="edge",
    cbar_size="7%",
    cbar_pad=0.15,
    aspect=False,
)


ax = axs[0]
ax.imshow(A, aspect="auto", extent=extent, cmap="viridis")
ax.scatter(data.X[:N_start, 0] / np.pi, data.X[:N_start, 1], color='black')
ax.scatter(data.X[:, 0] / np.pi, data.X[:, 1], s=1, color='red')
ax.set_title("Truth w/ samples")

ax = axs[1]
ax.imshow(pred.reshape(ppd, ppd).T[::-1, :], aspect="auto", extent=extent, cmap="viridis")
ax.set_title("Pred")

ax = axs[2]
ax.imshow((y - pred.reshape(-1, 1)).reshape(ppd, ppd).T[::-1, :], cmap="RdBu", aspect="auto", extent=extent)
ax.set_title("Delta")

ax = axs[3]
ax.imshow(std.reshape(ppd, ppd).T[::-1, :], aspect="auto", extent=extent, cmap="viridis")
ax.set_title("Std")

plt.show()