Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 139 additions & 4 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@
qKnowledgeGradient,
)
from botorch.acquisition.utils import is_nonnegative
from botorch.exceptions.warnings import BadInitialCandidatesWarning, SamplingWarning
from botorch.exceptions.warnings import (
BadInitialCandidatesWarning,
SamplingWarning,
)
from botorch.models.model import Model
from botorch.optim.utils import fix_features
from botorch.optim.utils import fix_features, get_X_baseline
from botorch.utils.multi_objective.pareto import is_non_dominated
from botorch.utils.sampling import (
batched_multinomial,
draw_sobol_samples,
manual_seed,
get_polytope_samples,
)
from botorch.utils.transforms import standardize
from botorch.utils.transforms import unnormalize, normalize, standardize
from torch import Tensor
from torch.distributions import Normal
from torch.quasirandom import SobolEngine


Expand Down Expand Up @@ -137,7 +142,22 @@ def gen_batch_initial_conditions(
.view(n, q, -1)
.cpu()
)

# sample points around best
if options.get("sample_around_best", False):
X_best_rnd = sample_points_around_best(
acq_function=acq_function,
n_discrete_points=n * q,
sigma=options.get("sample_around_best_std", 1e-3),
bounds=bounds,
)
if X_best_rnd is not None:
X_rnd = torch.cat(
[
X_rnd,
X_best_rnd.view(n, q, bounds.shape[-1]).cpu(),
],
dim=0,
)
X_rnd = fix_features(X_rnd, fixed_features=fixed_features)
with torch.no_grad():
if batch_limit is None:
Expand Down Expand Up @@ -566,3 +586,118 @@ def initialize_q_batch_nonneg(
if max_idx not in idcs:
idcs[-1] = max_idx
return X[idcs]


def sample_points_around_best(
acq_function: AcquisitionFunction,
n_discrete_points: int,
sigma: float,
bounds: Tensor,
best_pct: float = 5.0,
) -> Optional[Tensor]:
r"""Find best points and sample nearby points.

Args:
acq_function: The acquisition function.
n_discrete_points: The number of points to sample.
sigma: The standard deviation of the additive gaussian noise for
perturbing the best points.
bounds: A `2 x d`-dim tensor containing the bounds.
best_pct: The percentage of best points to perturb.

Returns:
An optional `n_discrete_points x d`-dim tensor containing the
sampled points. This is None if no baseline points are found.
"""
X = get_X_baseline(acq_function=acq_function)
if X is None:
return
with torch.no_grad():
posterior = acq_function.model.posterior(X)
mean = posterior.mean
while mean.ndim > 2:
# take average over batch dims
mean = mean.mean(dim=0)

f_pred = acq_function.objective(mean)
try:
# handle constraints for EHVI-based acquisition functions
constraints = acq_function.constraints
if constraints is not None:
neg_violation = -torch.stack(
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
).sum(dim=-1)
feas = neg_violation == 0
if feas.any():
f_pred[~feas] = float("-inf")
else:
# set objective equal to negative violation
f_pred = neg_violation
except AttributeError:
pass
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
# multi-objective
# find pareto set
is_pareto = is_non_dominated(f_pred)
best_X = X[is_pareto]
else:
n_best = max(1, round(X.shape[0] * best_pct / 100))
best_idcs = torch.topk(f_pred, n_best).indices
best_X = X[best_idcs]
return sample_truncated_normal_perturbations(
X=best_X,
n_discrete_points=n_discrete_points,
sigma=sigma,
bounds=bounds,
)


def sample_truncated_normal_perturbations(
X: Tensor,
n_discrete_points: int,
sigma: float,
bounds: Tensor,
qmc: bool = True,
) -> Tensor:
r"""Sample points around `X`.

Sample perturbed points around `X` such that the added perturbations
are sampled from N(0, sigma^2 I) and truncated to be within [0,1]^d.

Args:
X: A `n x d`-dim tensor starting points.
n_discrete_points: The number of points to sample.
sigma: The standard deviation of the additive gaussian noise for
perturbing the points.
bounds: A `2 x d`-dim tensor containing the bounds.
qmc: A boolean indicating whether to use qmc.

Returns:
A `n_discrete_points x d`-dim tensor containing the sampled points.
"""
X = normalize(X, bounds=bounds)
d = X.shape[1]
# sample points from N(X_center, sigma^2 I), truncated to be within
# [0, 1]^d.
if X.shape[0] > 1:
rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device)
X = X[rand_indices]
if qmc:
std_bounds = torch.zeros(2, d, dtype=X.dtype, device=X.device)
std_bounds[1] = 1
u = draw_sobol_samples(bounds=std_bounds, n=n_discrete_points, q=1).squeeze(1)
else:
u = torch.rand((n_discrete_points, d), dtype=X.dtype, device=X.device)
# compute bounds to sample from
a = -X
b = 1 - X
# compute z-score of bounds
alpha = a / sigma
beta = b / sigma
normal = Normal(0, 1)
cdf_alpha = normal.cdf(alpha)
# use inverse transform
perturbation = normal.icdf(cdf_alpha + u * (normal.cdf(beta) - cdf_alpha)) * sigma
# add perturbation and clip points that are still outside
perturbed_X = (X + perturbation).clamp(0.0, 1.0)
return unnormalize(perturbed_X, bounds=bounds)
14 changes: 9 additions & 5 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
from botorch.optim.stopping import ExpMAStoppingCriterion
from torch import Tensor

INIT_OPTION_KEYS = (
"init_batch_limit",
"batch_limit",
"nonnegative",
"sample_around_best",
"sample_around_best_std",
)


def optimize_acqf(
acq_function: AcquisitionFunction,
Expand Down Expand Up @@ -186,11 +194,7 @@ def optimize_acqf(
acquisition_function=acq_function,
lower_bounds=bounds[0],
upper_bounds=bounds[1],
options={
k: v
for k, v in options.items()
if k not in ("init_batch_limit", "batch_limit", "nonnegative")
},
options={k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_features=fixed_features,
Expand Down
44 changes: 43 additions & 1 deletion botorch/optim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

import numpy as np
import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.exceptions.errors import BotorchError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.gpytorch import ModelListGPyTorchModel, GPyTorchModel
from botorch.optim.numpy_converter import TorchAttr, set_params_with_array
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
Expand Down Expand Up @@ -241,3 +242,44 @@ def _handle_numerical_errors(
):
return float("nan"), np.full_like(x, "nan")
raise error # pragma: nocover


def get_X_baseline(acq_function: AcquisitionFunction) -> Optional[Tensor]:
r"""Extract X_baseline from an acquisition function.

This tries to find the baseline set of points. First, this checks if the
acquisition function has an `X_baseline` attribute. If it does not,
then this method attempts to use the model's `train_inputs` as `X_baseline`.

Args:
acq_function: The acquisition function.

Returns
An optional `n x d`-dim tensor of baseline points. This is None if no
baseline points are found.
"""
try:
X = acq_function.X_baseline
# if there are no baseline points, use training points
if X.shape[0] == 0:
raise BotorchError
except (BotorchError, AttributeError):
try:
# for entropy MOO methods
model = acq_function.mo_model
except AttributeError:
model = acq_function.model
try:
# make sure input transforms are not applied
model.train()
if isinstance(model, ModelListGPyTorchModel):
X = model.models[0].train_inputs[0]
else:
X = model.train_inputs[0]
except (BotorchError, AttributeError):
warnings.warn("Failed to extract X_baseline.", BotorchWarning)
return
# just use one batch
while X.ndim > 2:
X = X[0]
return X
Loading