In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

In [None]:
mpl.rcParams["mathtext.fontset"] = "stix"
mpl.rcParams["font.family"] = "STIXGeneral"
mpl.rcParams["text.usetex"] = False
plt.rc("xtick", labelsize=12)
plt.rc("ytick", labelsize=12)
plt.rc("axes", labelsize=12)
mpl.rcParams["figure.dpi"] = 300

In [None]:
import sys

sys.path.append("..")

In [None]:
# Import the various experiments we need for the notebook
from sva.experiments import Simple2d, PolicyPerformanceEvaluator

# Import the seeding function for reproducibility
from sva.utils import seed_everything

# A simple 2d example

The `Simple2d` example has a maximum at $x=(2, -4).$ Below, we initialize it, and plot the ground truth function as a 2d heatmap, along with the `n` randomly initialized data points and a red "x" for the location of the true maximum.

In [None]:
seed_everything(1234)
experiment = Simple2d()
experiment.initialize_data(n=5, protocol="random")

In [None]:
x = experiment.get_dense_coordinates(ppd=100)
y, _ = experiment(x)
extent = experiment.get_experimental_domain_mpl_extent()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 2))

X, Y = experiment.data.X, experiment.data.Y
ax.imshow(
    y.reshape(100, 100).T,
    extent=extent,
    interpolation="nearest",
    origin="lower",
)
ax.scatter(X[:, 0], X[:, 1], color="black", s=0.5)
ax.scatter(2, -4, s=10, color="red", marker="x")

plt.show()

We would expect that with relatively few initial samples, the policy performance evaluator will show that more exploratory acquisition functions are preferable.

In [None]:
seed_everything(1234)
experiment = Simple2d()
experiment.initialize_data(n=5, protocol="random")

In [None]:
acquisition_function_list = ["EI", "UCB", "UCB", "UCB"]
acquisition_function_kwargs_list = [
    None,
    {"beta": 2.0},
    {"beta": 20.0},
    {"beta": 100.0},
]

In [None]:
policy_evaluator = PolicyPerformanceEvaluator(
    experiment, checkpoint_dir="checkpoints/simple2d/n5"
)
policy_evaluator.run(
    10,
    50,
    acquisition_function_list,
    acquisition_function_kwargs_list,
    n_jobs=12,
)

In [None]:
policy_results = policy_evaluator.process_results()

In [None]:
plot_kwargs = {
    "linewidth": 1.0,
    "marker": "s",
    "ms": 1.0,
    "capthick": 0.3,
    "capsize": 2.0,
    "elinewidth": 0.3,
}

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2), sharex=True, sharey=True)

for ii, (key, res) in enumerate(policy_results.items()):
    x = np.arange(res.shape[1])
    y = np.median(res, axis=0)
    percentiles = np.percentile(res, q=[25, 75], axis=0)
    ax.errorbar(x, y, yerr=percentiles, label=key, **plot_kwargs)
ax.legend(frameon=False, bbox_to_anchor=(1, 0.5), loc="center left")
ax.set_yscale("log")
plt.show()

It does appear to be the case that the overall opportunity cost is lowest using the most exploratory acquisition function (in the long term). If we start with far more initial points though this might not be the same.

In [None]:
seed_everything(1234)
experiment = Simple2d()
experiment.initialize_data(n=50, protocol="random")

In [None]:
acquisition_function_list = ["EI", "UCB", "UCB", "UCB"]
acquisition_function_kwargs_list = [
    None,
    {"beta": 2.0},
    {"beta": 20.0},
    {"beta": 100.0},
]

In [None]:
policy_evaluator = PolicyPerformanceEvaluator(
    experiment, checkpoint_dir="checkpoints/simple2d/n50"
)
policy_evaluator.run(
    60,
    50,
    acquisition_function_list,
    acquisition_function_kwargs_list,
    n_jobs=12,
)

In [None]:
policy_results = policy_evaluator.process_results()

In [None]:
plot_kwargs = {
    "linewidth": 1.0,
    "marker": "s",
    "ms": 1.0,
    "capthick": 0.3,
    "capsize": 2.0,
    "elinewidth": 0.3,
}

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2), sharex=True, sharey=True)

for ii, (key, res) in enumerate(policy_results.items()):
    x = np.arange(res.shape[1])
    y = np.median(res, axis=0)
    percentiles = np.percentile(res, q=[25, 75], axis=0)
    ax.errorbar(x, y, yerr=percentiles, label=key, **plot_kwargs)
ax.legend(frameon=False, bbox_to_anchor=(1, 0.5), loc="center left")
ax.set_yscale("log")
plt.show()

It does appear that the best policy here is the more exploitative UCB with $\beta=2.$