Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions ax/generators/random/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ def gen(
max_draws = model_gen_options.get("max_rs_draws", DEFAULT_MAX_RS_DRAWS)
max_draws = int(assert_is_instance_of_tuple(max_draws, (int, float)))
try:
# With equality constraints, unconstrained sampling has probability
# zero of producing feasible points, so skip straight to polytope
# sampling.
if equality_constraints is not None:
raise SearchSpaceExhausted(
"Equality constraints require polytope sampling."
)
# Always rejection sample, but this only rejects if there are
# constraints or actual duplicates and deduplicate is specified.
# If rejection sampling fails, fall back to polytope sampling.
Expand Down Expand Up @@ -184,11 +191,15 @@ def gen(
num_generated = (
len(generated_points) if generated_points is not None else 0
)
interior_point = ( # A feasible point of shape `d x 1`.
torch.from_numpy(generated_points[-1].reshape((-1, 1))).double()
if generated_points is not None
else None
)
# Use a previously generated point as the interior point
# hint, but only if it's likely feasible. When equality
# constraints are present, previous points (generated
# without those constraints) won't satisfy them.
interior_point: torch.Tensor | None = None
if generated_points is not None and equality_constraints is None:
interior_point = torch.from_numpy(
generated_points[-1].reshape((-1, 1))
).double()
kwargs = {"n_burnin": 100, "n_thinning": 20}
kwargs.update(self.polytope_sampler_kwargs)
polytope_sampler: HitAndRunPolytopeSampler = HitAndRunPolytopeSampler(
Expand Down Expand Up @@ -229,6 +240,7 @@ def gen_polytope_sampler(
fixed_features=fixed_features,
rounding_func=rounding_func,
existing_points=generated_points,
equality_constraints=equality_constraints,
)
else:
raise e
Expand Down
170 changes: 161 additions & 9 deletions ax/generators/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@
AnalyticExpectedUtilityOfBestOption,
qExpectedUtilityOfBestOption,
)
from botorch.exceptions.errors import BotorchError, InputDataError
from botorch.exceptions.errors import (
BotorchError,
CandidateGenerationError,
InputDataError,
)
from botorch.generation.sampling import SamplingStrategy
from botorch.models.model import Model
from botorch.optim.optimize import (
Expand All @@ -72,7 +76,10 @@
optimize_acqf_mixed_alternating,
should_use_mixed_alternating_optimizer,
)
from botorch.optim.parameter_constraints import evaluate_feasibility
from botorch.optim.parameter_constraints import (
evaluate_feasibility,
project_to_feasible_space_via_slsqp,
)
from botorch.utils.constraints import get_outcome_constraint_transforms
from pyre_extensions import assert_is_instance, none_throws
from torch import Tensor
Expand Down Expand Up @@ -892,6 +899,7 @@ def optimize(
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_features=fixed_features,
bounds=bounds,
)
# Validate candidates before returning
validate_candidates(
Expand Down Expand Up @@ -1007,6 +1015,7 @@ def _prune_irrelevant_parameters(
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
fixed_features: dict[int, float] | None = None,
bounds: Tensor | None = None,
) -> tuple[Tensor, Tensor]:
r"""Prune irrelevant parameters from the candidates using BONSAI.

Expand Down Expand Up @@ -1042,6 +1051,11 @@ def _prune_irrelevant_parameters(
corresponds to the `l_i`-th feature of that element.
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
bounds: A `2 x d`-dim tensor of lower and upper parameter bounds.
Required when `inequality_constraints` or `equality_constraints`
are provided: pruned candidates are projected onto the feasible
set via SLSQP, and the projection needs the parameter bounds to
define the feasible region. Unused when no constraints are set.

Returns:
A two-element tuple containing an `q x d`-dim tensor of generated
Expand Down Expand Up @@ -1085,12 +1099,14 @@ def _prune_irrelevant_parameters(
# dense AF val
final_af_val = dense_af_val
# If the current incremental AF value is zero, then we skip pruning
has_constraints = bool(inequality_constraints or equality_constraints)
if dense_incremental_af_val > 0.0:
remaining_indices = set(range(candidates.shape[-1])) - excluded_indices
# remove features that are already set to target_point
remaining_indices -= set(
(candidates[i] == target_point).nonzero().view(-1).tolist()
)
initial_remaining = set(remaining_indices)
# len(remaining_indices) - 1 is used here so that we do not prune
# every dimension
for _ in range(len(remaining_indices) - 1):
Expand All @@ -1107,13 +1123,23 @@ def _prune_irrelevant_parameters(
indices=indices,
targets=target_point[indices],
)
# remove candidates that violate constraints after pruning
pruned_candidates, indices = _remove_infeasible_candidates(
candidates=pruned_candidates,
indices=indices,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)
# Project pruned candidates onto the feasible set
# (pinning the pruned dim and previously pruned dims),
# then filter any that remain infeasible.
if has_constraints:
previously_pruned = initial_remaining - remaining_indices
pruned_candidates, indices = (
_project_and_filter_pruned_candidates(
candidates=pruned_candidates,
indices=indices,
target_point=target_point,
pruned_dims=previously_pruned,
bounds=none_throws(bounds),
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_features=fixed_features,
)
)
if pruned_candidates.shape[0] == 0:
# no feasible points, continue to
# next candidate
Expand Down Expand Up @@ -1253,3 +1279,129 @@ def _remove_infeasible_candidates(
candidates = candidates[is_feasible]
indices = indices[is_feasible]
return candidates, indices


def _project_and_filter_pruned_candidates(
candidates: Tensor,
indices: Tensor,
target_point: Tensor,
pruned_dims: set[int],
bounds: Tensor,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
fixed_features: dict[int, float] | None = None,
) -> tuple[Tensor, Tensor]:
r"""Project pruned candidates onto the feasible set, then filter infeasible.

Helper for ``Acquisition._prune_irrelevant_parameters`` (BONSAI). It is
only meaningful in the context of that greedy-pruning loop and is not
intended for standalone use.

Background: BONSAI pruning evaluates a candidate dimension-by-dimension
by setting one dimension at a time to its target-point value. Each row
of ``candidates`` is one such trial -- the dense candidate with the
dimension at ``indices[i]`` swapped to ``target_point[indices[i]]``.
Under linear constraints, swapping a single dimension to the target
typically violates the constraints; rather than discarding the trial
(the prior behavior), we adjust the *other* free dimensions to recover
feasibility while keeping the swapped dimension and all previously
pruned dimensions pinned. Trials whose pins make the constraint system
infeasible -- and the rare case where projection succeeds but the
result still violates constraints -- are filtered out via the mask
returned to the caller.

Args:
candidates: A ``b x 1 x d``-dim tensor of pruned candidates (one row
per single-dimension prune attempt for the current BONSAI
iteration).
indices: A ``b``-dim tensor indicating which dimension was pruned
in each batch element.
target_point: A ``d``-dim tensor of target values for pruning.
pruned_dims: Set of dimension indices already pruned in prior
greedy iterations (to be kept pinned during projection).
bounds: A ``2 x d``-dim tensor of lower and upper bounds.
inequality_constraints: Inequality constraints in BoTorch format.
equality_constraints: Equality constraints in BoTorch format.
fixed_features: A map ``{feature_index: value}`` from the caller.
These dimensions are excluded from pruning at the outer loop and
must also be pinned during projection so SLSQP cannot adjust
them while satisfying the constraints. Without this, fixed
features could be silently altered.

Returns:
A two-element tuple of filtered ``(candidates, indices)``.
"""
# Pre-compute which dims participate in any constraint, and check whether
# any constraint is inter-point (2D index tensor). Inter-point constraints
# apply across the q-batch, but each row here is a single-candidate prune
# attempt -- ``project_to_feasible_space_via_slsqp`` cannot evaluate
# inter-point constraints on a 1 x d input. Fall back to the original
# filter-only behavior in that case.
constrained_dims: set[int] = set()
has_interpoint_constraint = False
for constraints in (inequality_constraints, equality_constraints):
if constraints is not None:
for c_indices, _, _ in constraints:
if c_indices.dim() == 1:
constrained_dims.update(c_indices.tolist())
else:
constrained_dims.update(c_indices[:, -1].tolist())
has_interpoint_constraint = True
if has_interpoint_constraint:
return _remove_infeasible_candidates(
candidates=candidates,
indices=indices,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)

# Build fixed_features for previously pruned dims and the caller's
# fixed_features (both shared across all candidates in this iteration).
prev_fixed: dict[int, float] = {k: target_point[k].item() for k in pruned_dims}
if fixed_features is not None:
prev_fixed.update(fixed_features)

feasible_mask = torch.ones(candidates.shape[0], dtype=torch.bool)
result = candidates.clone()

for i in range(candidates.shape[0]):
j: int = int(indices[i].item())
# If the pruned dim doesn't participate in any constraint,
# pruning it can't violate anything — skip projection.
if j not in constrained_dims:
continue
# Pin the currently pruned dim, all previously pruned dims, and the
# caller's fixed features.
fixed: dict[int, float | Tensor] = {
j: float(target_point[j].item()),
**prev_fixed,
}
try:
projected = project_to_feasible_space_via_slsqp(
X=candidates[i], # 1 x d
bounds=bounds,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_features=fixed,
)
result[i] = projected
except CandidateGenerationError:
# Pin makes the system infeasible — mark for removal.
# The post-projection feasibility check below is the safety net
# for any candidates that project but still violate constraints.
feasible_mask[i] = False

# Final safety-net feasibility check after projection.
if feasible_mask.any():
is_feasible = evaluate_feasibility(
X=result[feasible_mask],
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)
# Map back to the full mask.
feasible_subset_indices = feasible_mask.nonzero(as_tuple=True)[0]
for idx, feas in zip(feasible_subset_indices, is_feasible):
if not feas:
feasible_mask[idx] = False

return result[feasible_mask], indices[feasible_mask]
Loading
Loading