In [None]:
"""bo_loop.ipynbhttps://botorch.org/tutorials/turbo_1
This notebook is a simply copy of the TuRBO loop from BoTorch with 
a Vecchia GP swapped in. See the link for details:
https://botorch.org/tutorials/turbo_1
""""

In [1]:
# standard imports
import os
import math
from dataclasses import dataclass
import torch
import numpy as np

# botorch imports
from botorch.acquisition import qExpectedImprovement
from botorch.generation import MaxPosteriorSampling
from botorch.optim import optimize_acqf
from botorch.test_functions import Ackley
from botorch.utils.transforms import unnormalize
from torch.quasirandom import SobolEngine

# gpytorch imports
import gpytorch
from gpytorch.constraints import Interval
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ZeroMean

# pyvecch imports
from pyvecch.nbrs import ExactOracle, ApproximateOracle
from pyvecch.models import RFVecchia
from pyvecch.prediction import IndependentRF, VarianceCalibration
from pyvecch.training import fit_model
from pyvecch.input_transforms import Warping, Scaling, Identity

tkwargs = {"dtype":torch.float, "device":"cpu"}

In [2]:

fun = Ackley(dim=20, negate=True).to(**tkwargs)
fun.bounds[0, :].fill_(-5)
fun.bounds[1, :].fill_(10)
dim = fun.dim
lb, ub = fun.bounds

batch_size = 4
n_init = 2 * dim

def eval_objective(x):
    """This is a helper function we use to unnormalize and evalaute a point"""
    return fun(unnormalize(x, fun.bounds))

def get_initial_points(dim, n_pts, seed=0):
    sobol = SobolEngine(dimension=dim, scramble=True, seed=seed)
    X_init = sobol.draw(n=n_pts).to(**tkwargs)
    return X_init


def generate_batch(
    state,
    model,  # GP model
    X,  # Evaluated points on the domain [0, 1]^d
    Y,  # Function values
    batch_size,
    n_candidates=None,  # Number of candidates for Thompson sampling
    num_restarts=10,
    raw_samples=512,
    acqf="ts",  # "ei" or "ts"
):
    assert acqf in ("ts", "ei")
    assert X.min() >= 0.0 and X.max() <= 1.0 and torch.all(torch.isfinite(Y))
    if n_candidates is None:
        n_candidates = min(5000, max(2000, 200 * X.shape[-1]))

    # Scale the TR to be proportional to the lengthscales
    x_center = X[Y.argmax(), :].clone()
    weights = model.covar_module.base_kernel.lengthscale.squeeze().detach()
    weights = weights / weights.mean()
    weights = weights / torch.prod(weights.pow(1.0 / len(weights)))
    tr_lb = torch.clamp(x_center - weights * state.length / 2.0, 0.0, 1.0)
    tr_ub = torch.clamp(x_center + weights * state.length / 2.0, 0.0, 1.0)

    if acqf == "ts":
        dim = X.shape[-1]
        sobol = SobolEngine(dim, scramble=True)
        pert = sobol.draw(n_candidates).to(**tkwargs)
        pert = tr_lb + (tr_ub - tr_lb) * pert

        # Create a perturbation mask
        prob_perturb = min(20.0 / dim, 1.0)
        mask = (
            torch.rand(n_candidates, dim, **tkwargs)
            <= prob_perturb
        )
        ind = torch.where(mask.sum(dim=1) == 0)[0]
        mask[ind, torch.randint(0, dim - 1, size=(len(ind),), device=tkwargs['device'])] = 1

        # Create candidate points from the perturbations and the mask        
        X_cand = x_center.expand(n_candidates, dim).clone()
        X_cand[mask] = pert[mask]

        # Sample on the candidate points
        thompson_sampling = MaxPosteriorSampling(model=model, replacement=False)
        with torch.no_grad():  # We don't need gradients when using TS
            X_next = thompson_sampling(X_cand, num_samples=batch_size)

    elif acqf == "ei":
        ei = qExpectedImprovement(model, train_Y.max(), maximize=True)
        X_next, acq_value = optimize_acqf(
            ei,
            bounds=torch.stack([tr_lb, tr_ub]),
            q=batch_size,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
        )

    return X_next

In [3]:
@dataclass
class TurboState:
    dim: int
    batch_size: int
    length: float = 0.8
    length_min: float = 0.5 ** 7
    length_max: float = 1.6
    failure_counter: int = 0
    failure_tolerance: int = float("nan")  # Note: Post-initialized
    success_counter: int = 0
    success_tolerance: int = 10  # Note: The original paper uses 3
    best_value: float = -float("inf")
    restart_triggered: bool = False

    def __post_init__(self):
        self.failure_tolerance = math.ceil(
            max([4.0 / self.batch_size, float(self.dim) / self.batch_size])
        )


def update_state(state, Y_next):
    if max(Y_next) > state.best_value + 1e-3 * math.fabs(state.best_value):
        state.success_counter += 1
        state.failure_counter = 0
    else:
        state.success_counter = 0
        state.failure_counter += 1

    if state.success_counter == state.success_tolerance:  # Expand trust region
        state.length = min(2.0 * state.length, state.length_max)
        state.success_counter = 0
    elif state.failure_counter == state.failure_tolerance:  # Shrink trust region
        state.length /= 2.0
        state.failure_counter = 0

    state.best_value = max(state.best_value, max(Y_next).item())
    if state.length < state.length_min:
        state.restart_triggered = True
    return state

In [4]:
x = get_initial_points(dim, n_init)
y = torch.tensor(
    [eval_objective(x_) for x_ in x], **tkwargs
)

state = TurboState(dim, batch_size=batch_size)

In [5]:
training_settings = {
    "n_window":50, 
    "maxiter":100, 
    "rel_tol":5e-3
}

In [None]:


NUM_RESTARTS = 10 
RAW_SAMPLES = 512 
N_CANDIDATES = min(5000, max(2000, 200 * dim)) 


while not state.restart_triggered:  # Run until TuRBO converges

    n = x.shape[0]
    m = int(7.2 * np.log10(n) ** 2)
    train_batch_size = np.min([n, 128])

    z = (y - y.mean()) / y.std()
    # kernel, mean and likelihood can be swapped out, but prediction only 
    # expects a zero mean. 
    covar_module = ScaleKernel(MaternKernel(ard_num_dims = dim))
    mean_module = ZeroMean()
    likelihood = GaussianLikelihood(noise_constraint=Interval(1e-8, 1e-3))

    # We can get an approximate oracle with 
    # neighbor_oracle = ApproximateOracle(x,z,m,n_list = 100, n_probe = 75) 
    neighbor_oracle = ExactOracle(x,z,m)
    # We can get variance inflation with 
    #prediction_stategy = VarianceCalibration(IndependentRF(), num_p = 2) 
    prediction_stategy = IndependentRF()
    # We can get warping (or scaling) with 
    # input_transform = Warping(d = dim)
    #input_transform = Scaling(d = dim)
    input_transform = Identity(d = dim)
    model = RFVecchia(covar_module, mean_module, likelihood, 
        neighbor_oracle, prediction_stategy, input_transform)


    fit_model(
        model,
        train_batch_size = train_batch_size, 
        **training_settings
    )
    model.update_transform()
    model.eval()
    model.likelihood.eval()

    # Create a batch
    x_next = generate_batch(
        state=state,
        model=model,
        X=x,
        Y=z,
        batch_size=batch_size,
        n_candidates=N_CANDIDATES,
        num_restarts=NUM_RESTARTS,
        raw_samples=RAW_SAMPLES,
        acqf="ts",
    ).squeeze(0)

    y_next = torch.tensor(
        [eval_objective(x_) for x_ in x_next], **tkwargs
    )

    # Update state
    state = update_state(state=state, Y_next=y_next)

    # Append data
    x = torch.cat((x, x_next), dim=0)
    y = torch.cat((y, y_next), dim=0)

    # Print current status
    print(
        f"{len(x)}) Best value: {state.best_value:.2e}, TR length: {state.length:.2e}"
    )