Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@
from botorch.optim.optimize import optimize_acqf
from botorch.sampling.base import MCSampler
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.constraints import get_outcome_constraint_transforms
from botorch.utils.containers import BotorchContainer
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
Expand Down Expand Up @@ -718,7 +717,7 @@ def construct_inputs_qLogNEI(
X_baseline=X_baseline,
prune_baseline=prune_baseline,
cache_root=cache_root,
constraint=constraints,
constraints=constraints,
eta=eta,
),
"fat": fat,
Expand Down Expand Up @@ -853,11 +852,12 @@ def construct_inputs_EHVI(
training_data: MaybeDict[SupervisedDataset],
objective_thresholds: Tensor,
objective: Optional[AnalyticMultiOutputObjective] = None,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
r"""Construct kwargs for `ExpectedHypervolumeImprovement` constructor."""
num_objectives = objective_thresholds.shape[0]
if kwargs.get("outcome_constraints") is not None:
if constraints is not None:
raise NotImplementedError("EHVI does not yet support outcome constraints.")

X = _get_dataset_field(
Expand Down Expand Up @@ -914,6 +914,7 @@ def construct_inputs_qEHVI(
training_data: MaybeDict[SupervisedDataset],
objective_thresholds: Tensor,
objective: Optional[MCMultiOutputObjective] = None,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
r"""Construct kwargs for `qExpectedHypervolumeImprovement` constructor."""
Expand All @@ -928,15 +929,10 @@ def construct_inputs_qEHVI(
# compute posterior mean (for ref point computation ref pareto frontier)
with torch.no_grad():
Y_pmean = model.posterior(X).mean

outcome_constraints = kwargs.pop("outcome_constraints", None)
# For HV-based acquisition functions we pass the constraint transform directly
if outcome_constraints is None:
cons_tfs = None
else:
cons_tfs = get_outcome_constraint_transforms(outcome_constraints)
if constraints is not None:
# Adjust `Y_pmean` to contrain feasible points only.
feas = torch.stack([c(Y_pmean) <= 0 for c in cons_tfs], dim=-1).all(dim=-1)
feas = torch.stack([c(Y_pmean) <= 0 for c in constraints], dim=-1).all(dim=-1)
Y_pmean = Y_pmean[feas]

if objective is None:
Expand All @@ -962,7 +958,7 @@ def construct_inputs_qEHVI(
add_qehvi_kwargs = {
"sampler": sampler,
"X_pending": kwargs.get("X_pending"),
"constraints": cons_tfs,
"constraints": constraints,
"eta": kwargs.get("eta", 1e-3),
}
return {**ehvi_kwargs, **add_qehvi_kwargs}
Expand All @@ -975,6 +971,7 @@ def construct_inputs_qNEHVI(
objective_thresholds: Tensor,
objective: Optional[MCMultiOutputObjective] = None,
X_baseline: Optional[Tensor] = None,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
r"""Construct kwargs for `qNoisyExpectedHypervolumeImprovement` constructor."""
Expand All @@ -991,16 +988,12 @@ def construct_inputs_qNEHVI(
if objective is None:
objective = IdentityMCMultiOutputObjective()

outcome_constraints = kwargs.pop("outcome_constraints", None)
if outcome_constraints is None:
cons_tfs = None
else:
if constraints is not None:
if isinstance(objective, RiskMeasureMCObjective):
raise UnsupportedError(
"Outcome constraints are not supported with risk measures. "
"Use a feasibility-weighted risk measure instead."
)
cons_tfs = get_outcome_constraint_transforms(outcome_constraints)

sampler = kwargs.get("sampler")
if sampler is None and isinstance(model, GPyTorchModel):
Expand All @@ -1021,7 +1014,7 @@ def construct_inputs_qNEHVI(
"X_baseline": X_baseline,
"sampler": sampler,
"objective": objective,
"constraints": cons_tfs,
"constraints": constraints,
"X_pending": kwargs.get("X_pending"),
"eta": kwargs.get("eta", 1e-3),
"prune_baseline": kwargs.get("prune_baseline", True),
Expand Down
17 changes: 10 additions & 7 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,12 @@ def compute_best_feasible_objective(
is_feasible = compute_feasibility_indicator(
constraints=constraints, samples=samples
) # sample_shape x batch_shape x q
if is_feasible.any():
obj = torch.where(is_feasible, obj, -torch.inf)
with torch.no_grad():
return obj.amax(dim=-1, keepdim=True)

if is_feasible.any(dim=-1).all():
infeasible_value = -torch.inf

elif infeasible_obj is not None:
return infeasible_obj.expand(*obj.shape[:-1], 1)
infeasible_value = infeasible_obj.item()

else:
if model is None:
Expand All @@ -323,12 +322,16 @@ def compute_best_feasible_objective(
raise ValueError(
"Must specify `X_baseline` when no feasible observation exists."
)
return _estimate_objective_lower_bound(
infeasible_value = _estimate_objective_lower_bound(
model=model,
objective=objective,
posterior_transform=posterior_transform,
X=X_baseline,
).expand(*obj.shape[:-1], 1)
).item()

obj = torch.where(is_feasible, obj, infeasible_value)
with torch.no_grad():
return obj.amax(dim=-1, keepdim=True)


def _estimate_objective_lower_bound(
Expand Down
62 changes: 44 additions & 18 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,6 @@ def test_construct_inputs_qEI(self):
self.assertTrue(torch.equal(kwargs["objective"].weights, objective.weights))
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
self.assertIsNone(kwargs["sampler"])
self.assertIsNone(kwargs["constraints"])
self.assertIsInstance(kwargs["eta"], float)
self.assertTrue(kwargs["eta"] < 1)
multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1)
Expand All @@ -406,6 +405,20 @@ def test_construct_inputs_qEI(self):
best_f=best_f_expected,
)
self.assertEqual(kwargs["best_f"], best_f_expected)
# test passing constraints
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
constraints = get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
)
kwargs = c(
model=mock_model,
training_data=self.blockX_multiY,
objective=objective,
X_pending=X_pending,
best_f=best_f_expected,
constraints=constraints,
)
self.assertIs(kwargs["constraints"], constraints)

# testing qLogEI input constructor
log_constructor = get_acqf_input_constructor(qLogExpectedImprovement)
Expand All @@ -415,6 +428,7 @@ def test_construct_inputs_qEI(self):
objective=objective,
X_pending=X_pending,
best_f=best_f_expected,
constraints=constraints,
)
# includes strict superset of kwargs tested above
self.assertTrue(kwargs.items() <= log_kwargs.items())
Expand All @@ -423,6 +437,7 @@ def test_construct_inputs_qEI(self):
self.assertEqual(log_kwargs["tau_max"], TAU_MAX)
self.assertTrue("tau_relu" in log_kwargs)
self.assertEqual(log_kwargs["tau_relu"], TAU_RELU)
self.assertIs(log_kwargs["constraints"], constraints)

def test_construct_inputs_qNEI(self):
c = get_acqf_input_constructor(qNoisyExpectedImprovement)
Expand All @@ -441,29 +456,36 @@ def test_construct_inputs_qNEI(self):
with self.assertRaisesRegex(ValueError, "Field `X` must be shared"):
c(model=mock_model, training_data=self.multiX_multiY)
X_baseline = torch.rand(2, 2)
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
constraints = get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
)
kwargs = c(
model=mock_model,
training_data=self.blockX_blockY,
X_baseline=X_baseline,
prune_baseline=False,
constraints=constraints,
)
self.assertEqual(kwargs["model"], mock_model)
self.assertIsNone(kwargs["objective"])
self.assertIsNone(kwargs["X_pending"])
self.assertIsNone(kwargs["sampler"])
self.assertFalse(kwargs["prune_baseline"])
self.assertTrue(torch.equal(kwargs["X_baseline"], X_baseline))
self.assertIsNone(kwargs["constraints"])
self.assertIsInstance(kwargs["eta"], float)
self.assertTrue(kwargs["eta"] < 1)
self.assertIs(kwargs["constraints"], constraints)

# testing qLogNEI input constructor
log_constructor = get_acqf_input_constructor(qLogNoisyExpectedImprovement)

log_kwargs = log_constructor(
model=mock_model,
training_data=self.blockX_blockY,
X_baseline=X_baseline,
prune_baseline=False,
constraints=constraints,
)
# includes strict superset of kwargs tested above
self.assertTrue(kwargs.items() <= log_kwargs.items())
Expand All @@ -472,6 +494,7 @@ def test_construct_inputs_qNEI(self):
self.assertEqual(log_kwargs["tau_max"], TAU_MAX)
self.assertTrue("tau_relu" in log_kwargs)
self.assertEqual(log_kwargs["tau_relu"], TAU_RELU)
self.assertIs(log_kwargs["constraints"], constraints)

def test_construct_inputs_qPI(self):
c = get_acqf_input_constructor(qProbabilityOfImprovement)
Expand Down Expand Up @@ -499,23 +522,28 @@ def test_construct_inputs_qPI(self):
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
self.assertIsNone(kwargs["sampler"])
self.assertEqual(kwargs["tau"], 1e-2)
self.assertIsNone(kwargs["constraints"])
self.assertIsInstance(kwargs["eta"], float)
self.assertTrue(kwargs["eta"] < 1)
multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1)
best_f_expected = objective(multi_Y).max()
self.assertEqual(kwargs["best_f"], best_f_expected)
# Check explicitly specifying `best_f`.
best_f_expected = best_f_expected - 1 # Random value.
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
constraints = get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
)
kwargs = c(
model=mock_model,
training_data=self.blockX_multiY,
objective=objective,
X_pending=X_pending,
tau=1e-2,
best_f=best_f_expected,
constraints=constraints,
)
self.assertEqual(kwargs["best_f"], best_f_expected)
self.assertIs(kwargs["constraints"], constraints)

def test_construct_inputs_qUCB(self):
c = get_acqf_input_constructor(qUpperConfidenceBound)
Expand Down Expand Up @@ -564,7 +592,7 @@ def test_construct_inputs_EHVI(self):
model=mock_model,
training_data=self.blockX_blockY,
objective_thresholds=objective_thresholds,
outcome_constraints=mock.Mock(),
constraints=mock.Mock(),
)

# test with Y_pmean supplied explicitly
Expand Down Expand Up @@ -702,13 +730,16 @@ def test_construct_inputs_qEHVI(self):
weights = torch.rand(2)
obj = WeightedMCMultiOutputObjective(weights=weights)
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
constraints = get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
)
X_pending = torch.rand(1, 2)
kwargs = c(
model=mm,
training_data=self.blockX_blockY,
objective_thresholds=objective_thresholds,
objective=obj,
outcome_constraints=outcome_constraints,
constraints=constraints,
X_pending=X_pending,
alpha=0.05,
eta=1e-2,
Expand All @@ -723,11 +754,7 @@ def test_construct_inputs_qEHVI(self):
Y_expected = mean[:1] * weights
self.assertTrue(torch.equal(partitioning._neg_Y, -Y_expected))
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
cons_tfs = kwargs["constraints"]
self.assertEqual(len(cons_tfs), 1)
cons_eval = cons_tfs[0](mean)
cons_eval_expected = torch.tensor([-0.25, 0.5])
self.assertTrue(torch.equal(cons_eval, cons_eval_expected))
self.assertIs(kwargs["constraints"], constraints)
self.assertEqual(kwargs["eta"], 1e-2)

# Test check for block designs
Expand All @@ -737,7 +764,7 @@ def test_construct_inputs_qEHVI(self):
training_data=self.multiX_multiY,
objective_thresholds=objective_thresholds,
objective=obj,
outcome_constraints=outcome_constraints,
constraints=constraints,
X_pending=X_pending,
alpha=0.05,
eta=1e-2,
Expand Down Expand Up @@ -798,6 +825,9 @@ def test_construct_inputs_qNEHVI(self):
X_baseline = torch.rand(2, 2)
sampler = IIDNormalSampler(sample_shape=torch.Size([4]))
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
constraints = get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
)
X_pending = torch.rand(1, 2)
kwargs = c(
model=mock_model,
Expand All @@ -806,7 +836,7 @@ def test_construct_inputs_qNEHVI(self):
objective=objective,
X_baseline=X_baseline,
sampler=sampler,
outcome_constraints=outcome_constraints,
constraints=constraints,
X_pending=X_pending,
eta=1e-2,
prune_baseline=True,
Expand All @@ -823,11 +853,7 @@ def test_construct_inputs_qNEHVI(self):
self.assertIsInstance(sampler_, IIDNormalSampler)
self.assertEqual(sampler_.sample_shape, torch.Size([4]))
self.assertEqual(kwargs["objective"], objective)
cons_tfs_expected = get_outcome_constraint_transforms(outcome_constraints)
cons_tfs = kwargs["constraints"]
self.assertEqual(len(cons_tfs), 1)
test_Y = torch.rand(1, 2)
self.assertTrue(torch.equal(cons_tfs[0](test_Y), cons_tfs_expected[0](test_Y)))
self.assertIs(kwargs["constraints"], constraints)
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
self.assertEqual(kwargs["eta"], 1e-2)
self.assertTrue(kwargs["prune_baseline"])
Expand All @@ -844,7 +870,7 @@ def test_construct_inputs_qNEHVI(self):
training_data=self.blockX_blockY,
objective_thresholds=objective_thresholds,
objective=MultiOutputExpectation(n_w=3),
outcome_constraints=outcome_constraints,
constraints=constraints,
)
for use_preprocessing in (True, False):
obj = MultiOutputExpectation(
Expand Down
Loading