Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion botorch/acquisition/knowledge_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@
from botorch.exceptions.errors import UnsupportedError
from botorch.models.model import Model
from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler
from botorch.utils.transforms import match_batch_shape, t_batch_mode_transform
from botorch.utils.transforms import (
match_batch_shape,
t_batch_mode_transform,
concatenate_pending_points,
)
from torch import Tensor


Expand Down Expand Up @@ -183,6 +187,65 @@ def forward(self, X: Tensor) -> Tensor:
# return average over the fantasy samples
return values.mean(dim=0)

@concatenate_pending_points
@t_batch_mode_transform()
def evaluate(self, X_actual: Tensor, bounds: Tensor, **kwargs: Any) -> Tensor:
r"""Evaluate qKnowledgeGradient on the candidate set `X_actual` by
solving the inner optimization problem.

Args:
X_actual: A `b x q x d` Tensor with `b` t-batches of `q` design points
each. Unlike `forward()`, this does not include solutions of the
inner optimization problem.
bounds: A `2 x d` tensor of lower and upper bounds for each column of
the solutions to the inner problem.
kwargs: Additional keyword arguments. This includes the options for
optimization of the inner problem, i.e. `num_restarts`, `raw_samples`
and an `options` dictionary to be passed on to the optimization helpers.

Returns:
A Tensor of shape `b`. For t-batch b, the q-KG value of the design
`X_actual[b]` is averaged across the fantasy models.
NOTE: If `current_value` is not provided, then this is not the
true KG value of `X_actual[b]`.
"""
# construct the fantasy model of shape `num_fantasies x b`
fantasy_model = self.model.fantasize(
X=X_actual, sampler=self.sampler, observation_noise=True
)

# get the value function
value_function = _get_value_function(
model=fantasy_model, objective=self.objective, sampler=self.inner_sampler
)

# optimize the inner problem
from botorch.optim.initializers import gen_value_function_initial_conditions
from botorch.generation.gen import gen_candidates_scipy

initial_conditions = gen_value_function_initial_conditions(
acq_function=value_function,
bounds=bounds,
num_restarts=kwargs.get("num_restarts", 20),
raw_samples=kwargs.get("raw_samples", 1024),
current_model=self.model,
options=kwargs.get("options"),
)
_, values = gen_candidates_scipy(
initial_conditions=initial_conditions,
acquisition_function=value_function,
lower_bounds=bounds[0],
upper_bounds=bounds[1],
options=kwargs.get("options"),
)
# get the maximizer for each batch
values, _ = torch.max(values, dim=0)
if self.current_value is not None:
values = values - self.current_value

# return average over the fantasy samples
return values.mean(dim=0)

def get_augmented_q_batch_size(self, q: int) -> int:
r"""Get augmented q batch size for one-shot optimzation.

Expand Down
162 changes: 149 additions & 13 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
_get_value_function,
qKnowledgeGradient,
)
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
from botorch.acquisition.utils import is_nonnegative
from botorch.models.model import Model
from botorch.exceptions.warnings import BadInitialCandidatesWarning, SamplingWarning
from botorch.utils.sampling import draw_sobol_samples, manual_seed
from botorch.utils.sampling import draw_sobol_samples, manual_seed, batched_multinomial
from botorch.utils.transforms import standardize
from torch import Tensor
from torch.quasirandom import SobolEngine
Expand Down Expand Up @@ -226,6 +228,126 @@ def gen_one_shot_kg_initial_conditions(
return ics


def gen_value_function_initial_conditions(
acq_function: AcquisitionFunction,
bounds: Tensor,
num_restarts: int,
raw_samples: int,
current_model: Model,
options: Optional[Dict[str, Union[bool, float, int]]] = None,
) -> Tensor:
r"""Generate a batch of smart initializations for optimizing
the value function of qKnowledgeGradient.

This function generates initial conditions for optimizing the inner problem of
KG, i.e. its value function, using the maximizer of the posterior objective.
Intutively, the maximizer of the fantasized posterior will often be close to a
maximizer of the current posterior. This function uses that fact to generate the
initital conditions for the fantasy points. Specifically, a fraction of `1 -
frac_random` (see options) of raw samples is generated by sampling from the set of
maximizers of the posterior objective (obtained via random restart optimization)
according to a softmax transformation of their respective values. This means that
this initialization strategy internally solves an acquisition function
maximization problem. The remaining raw samples are generated using
`draw_sobol_samples`. All raw samples are then evaluated, and the initial
conditions are selected according to the standard initialization strategy in
'initialize_q_batch' individually for each inner problem.

Args:
acq_function: The value function instance to be optimized.
bounds: A `2 x d` tensor of lower and upper bounds for each column of
task features.
num_restarts: The number of starting points for multistart acquisition
function optimization.
raw_samples: The number of raw samples to consider in the initialization
heuristic.
current_model: The model of the KG acquisition function that was used to
generate the fantasy model of the value function.
options: Options for initial condition generation. These contain all
settings for the standard heuristic initialization from
`gen_batch_initial_conditions`. In addition, they contain
`frac_random` (the fraction of fully random fantasy points),
`num_inner_restarts` and `raw_inner_samples` (the number of random
restarts and raw samples for solving the posterior objective
maximization problem, respectively) and `eta` (temperature parameter
for sampling heuristic from posterior objective maximizers).

Returns:
A `num_restarts x batch_shape x q x d` tensor that can be used as initial
conditions for `optimize_acqf()`. Here `batch_shape` is the
`_input_batch_shape` of value function model.

Example:
>>> fant_X = torch.rand(5, 1, 2)
>>> fantasy_model = model.fantasize(fant_X, SobolQMCNormalSampler(16))
>>> value_function = PosteriorMean(fantasy_model)
>>> bounds = torch.tensor([[0., 0.], [1., 1.]])
>>> Xinit = gen_value_function_initial_conditions(
>>> value_function, bounds, num_restarts=10, raw_samples=512,
>>> options={"frac_random": 0.25},
>>> )
"""
options = options or {}
seed: Optional[int] = options.get("seed")
frac_random: float = options.get("frac_random", 0.6)
if not 0 < frac_random < 1:
raise ValueError(
f"frac_random must take on values in (0,1). Value: {frac_random}"
)

# compute maximizer of the current value function
value_function = _get_value_function(
model=current_model,
objective=acq_function.objective,
sampler=acq_function.sampler
if isinstance(acq_function, MCAcquisitionFunction)
else None,
)
from botorch.optim.optimize import optimize_acqf

fantasy_cands, fantasy_vals = optimize_acqf(
acq_function=value_function,
bounds=bounds,
q=1,
num_restarts=options.get("num_inner_restarts", 20),
raw_samples=options.get("raw_inner_samples", 1024),
return_best_only=False,
options=options,
)

batch_shape = acq_function.model._input_batch_shape
# sampling from the optimizers
n_value = int((1 - frac_random) * raw_samples) # number of non-random ICs
if n_value > 0:
eta = options.get("eta", 2.0)
weights = torch.exp(eta * standardize(fantasy_vals))
idx = batched_multinomial(
weights=weights.expand(*batch_shape, -1),
num_samples=n_value,
replacement=True,
).permute(-1, *range(len(batch_shape)))
resampled = fantasy_cands[idx]
else:
resampled = torch.empty(
0, *batch_shape, 1, bounds.shape[-1], dtype=bounds.dtype
)
# add qMC samples
randomized = draw_sobol_samples(
bounds=bounds, n=raw_samples - n_value, q=1, batch_shape=batch_shape, seed=seed,
)
# full set of raw samples
X_rnd = torch.cat([resampled, randomized], dim=0)

# evaluate the raw samples
with torch.no_grad():
Y_rnd = acq_function(X_rnd)

# select the restart points using the heuristic
return initialize_q_batch(
X=X_rnd, Y=Y_rnd, n=num_restarts, eta=options.get("eta", 2.0)
)


def initialize_q_batch(X: Tensor, Y: Tensor, n: int, eta: float = 1.0) -> Tensor:
r"""Heuristic for selecting initial conditions for candidate generation.

Expand All @@ -238,15 +360,18 @@ def initialize_q_batch(X: Tensor, Y: Tensor, n: int, eta: float = 1.0) -> Tensor
`initialize_q_batch_nonneg` instead.

Args:
X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim.
feature space. Typically, these are generated using qMC sampling.
Y: A tensor of `b` outcomes associated with the samples. Typically, this
is the value of the batch acquisition function to be maximized.
X: A `b x batch_shape x q x d` tensor of `b` - `batch_shape` samples of
`q`-batches from a d`-dim feature space. Typically, these are generated
using qMC sampling.
Y: A tensor of `b x batch_shape` outcomes associated with the samples.
Typically, this is the value of the batch acquisition function to be
maximized.
n: The number of initial condition to be generated. Must be less than `b`.
eta: Temperature parameter for weighting samples.

Returns:
A `n x q x d` tensor of `n` `q`-batch initial conditions.
A `n x batch_shape x q x d` tensor of `n` - `batch_shape` `q`-batch initial
conditions, where each batch of `n x q x d` samples is selected independently.

Example:
>>> # To get `n=10` starting points of q-batch size `q=3`
Expand All @@ -256,6 +381,7 @@ def initialize_q_batch(X: Tensor, Y: Tensor, n: int, eta: float = 1.0) -> Tensor
>>> Xinit = initialize_q_batch(Xrnd, qUCB(Xrnd), 10)
"""
n_samples = X.shape[0]
batch_shape = X.shape[1:-2] or torch.Size()
if n > n_samples:
raise RuntimeError(
f"n ({n}) cannot be larger than the number of "
Expand All @@ -264,27 +390,37 @@ def initialize_q_batch(X: Tensor, Y: Tensor, n: int, eta: float = 1.0) -> Tensor
elif n == n_samples:
return X

Ystd = Y.std()
if Ystd == 0:
Ystd = Y.std(dim=0)
if torch.any(Ystd == 0):
warnings.warn(
"All acqusition values for raw samples points are the same. "
"Choosing initial conditions at random.",
"All acquisition values for raw samples points are the same for "
"at least one batch. Choosing initial conditions at random.",
BadInitialCandidatesWarning,
)
return X[torch.randperm(n=n_samples, device=X.device)][:n]

max_val, max_idx = torch.max(Y, dim=0)
Z = (Y - Y.mean()) / Ystd
Z = (Y - Y.mean(dim=0)) / Ystd
etaZ = eta * Z
weights = torch.exp(etaZ)
while torch.isinf(weights).any():
etaZ *= 0.5
weights = torch.exp(etaZ)
idcs = torch.multinomial(weights, n)
if batch_shape == torch.Size():
idcs = torch.multinomial(weights, n)
else:
idcs = batched_multinomial(
weights=weights.permute(*range(1, len(batch_shape) + 1), 0), num_samples=n
).permute(-1, *range(len(batch_shape)))
# make sure we get the maximum
if max_idx not in idcs:
idcs[-1] = max_idx
return X[idcs]
if batch_shape == torch.Size():
return X[idcs]
else:
return X.gather(
dim=0, index=idcs.view(*idcs.shape, 1, 1).expand(n, *X.shape[1:])
)


def initialize_q_batch_nonneg(
Expand Down
4 changes: 2 additions & 2 deletions botorch/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def t_batch_mode_transform(

def decorator(method: Callable[[Any, Tensor], Any]) -> Callable[[Any, Tensor], Any]:
@wraps(method)
def decorated(cls: Any, X: Tensor) -> Any:
def decorated(cls: Any, X: Tensor, **kwargs: Any) -> Any:
if X.dim() < 2:
raise ValueError(
f"{type(cls).__name__} requires X to have at least 2 dimensions,"
Expand All @@ -166,7 +166,7 @@ def decorated(cls: Any, X: Tensor) -> Any:
f" got X with shape {X.shape}."
)
X = X if X.dim() > 2 else X.unsqueeze(0)
return method(cls, X)
return method(cls, X, **kwargs)

return decorated

Expand Down
56 changes: 54 additions & 2 deletions test/acquisition/test_knowledge_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

from unittest import mock

import torch
from botorch.acquisition.analytic import PosteriorMean
from botorch.acquisition.cost_aware import GenericCostAwareUtility
Expand All @@ -17,13 +16,13 @@
)
from botorch.acquisition.monte_carlo import qSimpleRegret
from botorch.acquisition.objective import GenericMCObjective, ScalarizedObjective
from botorch.models import SingleTaskGP
from botorch.exceptions.errors import UnsupportedError
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
from gpytorch.distributions import MultitaskMultivariateNormal


NO = "botorch.utils.testing.MockModel.num_outputs"


Expand Down Expand Up @@ -213,6 +212,59 @@ def test_evaluate_q_knowledge_gradient(self):
val_expected = (mean * weights).sum(-1).mean(0)
self.assertTrue(torch.allclose(val, val_expected))

def test_evaluate_kg(self):
# a thorough test using real model and dtype double
d = 2
dtype = torch.double
bounds = torch.tensor([[0], [1]], device=self.device, dtype=dtype).repeat(1, d)
train_X = torch.rand(3, d, device=self.device, dtype=dtype)
train_Y = torch.rand(3, 1, device=self.device, dtype=dtype)
model = SingleTaskGP(train_X, train_Y)
qKG = qKnowledgeGradient(
model=model,
num_fantasies=2,
objective=None,
X_pending=torch.rand(2, d, device=self.device, dtype=dtype),
current_value=torch.rand(1, device=self.device, dtype=dtype),
)
X = torch.rand(4, 3, d, device=self.device, dtype=dtype)
options = {
"num_inner_restarts": 2,
"raw_inner_samples": 3,
}
val = qKG.evaluate(
X, bounds=bounds, num_restarts=2, raw_samples=3, options=options,
)
# verify output shape
self.assertEqual(val.size(), torch.Size([4]))
# verify dtype
self.assertEqual(val.dtype, dtype)

# test i) no dimension is squeezed out, ii) dtype float, iii) MC objective,
# and iv) t_batch_mode_transform
dtype = torch.float
bounds = torch.tensor([[0], [1]], device=self.device, dtype=dtype)
train_X = torch.rand(1, 1, device=self.device, dtype=dtype)
train_Y = torch.rand(1, 1, device=self.device, dtype=dtype)
model = SingleTaskGP(train_X, train_Y)
qKG = qKnowledgeGradient(
model=model,
num_fantasies=1,
objective=GenericMCObjective(objective=lambda Y: Y.norm(dim=-1)),
)
X = torch.rand(1, 1, device=self.device, dtype=dtype)
options = {
"num_inner_restarts": 1,
"raw_inner_samples": 1,
}
val = qKG.evaluate(
X, bounds=bounds, num_restarts=1, raw_samples=1, options=options,
)
# verify output shape
self.assertEqual(val.size(), torch.Size([1]))
# verify dtype
self.assertEqual(val.dtype, dtype)


class TestQMultiFidelityKnowledgeGradient(BotorchTestCase):
def test_initialize_q_multi_fidelity_knowledge_gradient(self):
Expand Down
Loading