In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
import matplotlib.pyplot as plt
from functools import partial

In [None]:
import sys

sys.path.append("..")

In [None]:
# Import the various experiments we need for the notebook
from sva.experiments import Simple2d

# Plotting quality of life utils
from sva.mpl_utils import set_mpl_defaults

In [None]:
set_mpl_defaults()

# A simple 2d example

In [None]:
experiment = Simple2d()
x = experiment.get_dense_coordinates(ppd=100)
y = experiment(x)
extent = experiment.get_domain_mpl_extent()

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(2, 2), sharex=True, sharey=True)

ax = axs
ax.imshow(
    y.reshape(100, 100).T,
    extent=extent,
    interpolation="nearest",
    origin="lower",
    cmap="viridis",
)

plt.show()

In [None]:
# Import the campaign
from sva.campaign import Campaign

# Import the standard fixed policy
from sva.policy import FixedPolicy

# Model to use
from sva.models.gp import EasySingleTaskGP

# Default fitting function
from sva.models.gp import fit_EasyGP_mll

In [None]:
policy = FixedPolicy(
    n_max=100,
    prime_kwargs={"protocol": "cold_start"},
    model_factory=partial(EasySingleTaskGP.from_default),
    model_fitting_function=partial(fit_EasyGP_mll),
    optimize_kwargs={"q": 1, "num_restarts": 20, "raw_samples": 100},
    acquisition_function="UCB-10",
    save_model=True,
)

campaign = Campaign(seed=123, experiment=Simple2d(), policy=policy)
campaign.run()

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(2, 2), sharex=True, sharey=True)

ax = axs
ax.imshow(
    y.reshape(100, 100).T,
    extent=extent,
    interpolation="nearest",
    origin="lower",
    cmap="viridis",
)

X = campaign.data.X
ax.scatter(X[:, 0], X[:, 1], s=0.5, color="black")

plt.show()