From c963f34f17a02f54f187970c9f6fe39268d8b1ba Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Fri, 8 May 2026 09:32:10 -0700 Subject: [PATCH 1/2] BONSAI: pin-and-project for constrained pruning (#5180) Summary: Extend BONSAI's irrelevance pruning to handle both equality and inequality constraints via a pin-and-project approach. Previously, BONSAI simply discarded pruned candidates that violated constraints. This was overly conservative (inequality) or completely broken (equality, where almost all single-dimension prunes violate the constraint). The new approach: 1. Set x_j = target[j] (unchanged) 2. Project the other dimensions onto the feasible set via SLSQP, keeping x_j pinned (and all previously pruned dims pinned) 3. Filter any candidates that remain infeasible after projection This is strictly better than discarding: it recovers feasibility when possible by adjusting other dimensions, while infeasible pins (where no adjustment can satisfy the constraints) are still caught. Key implementation details: - `_project_and_filter_pruned_candidates`: new function that uses `project_to_feasible_space_via_slsqp` with `fixed_features` to pin the pruned dim and all previously pruned dims. - Optimization: skip projection for dims not in any constraint's index set (pruning them can't violate anything). - Handles 2D inter-point constraint indices correctly. - `_prune_irrelevant_parameters` now accepts `bounds` parameter. Reviewed By: esantorella Differential Revision: D100256483 --- .../torch/botorch_modular/acquisition.py | 170 +++++++++++++++++- ax/generators/torch/tests/test_acquisition.py | 125 +++++++++++++ 2 files changed, 286 insertions(+), 9 deletions(-) diff --git a/ax/generators/torch/botorch_modular/acquisition.py b/ax/generators/torch/botorch_modular/acquisition.py index 167598cb961..97abfb68cd6 100644 --- a/ax/generators/torch/botorch_modular/acquisition.py +++ b/ax/generators/torch/botorch_modular/acquisition.py @@ -57,7 +57,11 @@ AnalyticExpectedUtilityOfBestOption, qExpectedUtilityOfBestOption, ) -from botorch.exceptions.errors import BotorchError, InputDataError +from botorch.exceptions.errors import ( + BotorchError, + CandidateGenerationError, + InputDataError, +) from botorch.generation.sampling import SamplingStrategy from botorch.models.model import Model from botorch.optim.optimize import ( @@ -72,7 +76,10 @@ optimize_acqf_mixed_alternating, should_use_mixed_alternating_optimizer, ) -from botorch.optim.parameter_constraints import evaluate_feasibility +from botorch.optim.parameter_constraints import ( + evaluate_feasibility, + project_to_feasible_space_via_slsqp, +) from botorch.utils.constraints import get_outcome_constraint_transforms from pyre_extensions import assert_is_instance, none_throws from torch import Tensor @@ -892,6 +899,7 @@ def optimize( inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, fixed_features=fixed_features, + bounds=bounds, ) # Validate candidates before returning validate_candidates( @@ -1007,6 +1015,7 @@ def _prune_irrelevant_parameters( inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, fixed_features: dict[int, float] | None = None, + bounds: Tensor | None = None, ) -> tuple[Tensor, Tensor]: r"""Prune irrelevant parameters from the candidates using BONSAI. @@ -1042,6 +1051,11 @@ def _prune_irrelevant_parameters( corresponds to the `l_i`-th feature of that element. fixed_features: A map `{feature_index: value}` for features that should be fixed to a particular value during generation. + bounds: A `2 x d`-dim tensor of lower and upper parameter bounds. + Required when `inequality_constraints` or `equality_constraints` + are provided: pruned candidates are projected onto the feasible + set via SLSQP, and the projection needs the parameter bounds to + define the feasible region. Unused when no constraints are set. Returns: A two-element tuple containing an `q x d`-dim tensor of generated @@ -1085,12 +1099,14 @@ def _prune_irrelevant_parameters( # dense AF val final_af_val = dense_af_val # If the current incremental AF value is zero, then we skip pruning + has_constraints = bool(inequality_constraints or equality_constraints) if dense_incremental_af_val > 0.0: remaining_indices = set(range(candidates.shape[-1])) - excluded_indices # remove features that are already set to target_point remaining_indices -= set( (candidates[i] == target_point).nonzero().view(-1).tolist() ) + initial_remaining = set(remaining_indices) # len(remaining_indices) - 1 is used here so that we do not prune # every dimension for _ in range(len(remaining_indices) - 1): @@ -1107,13 +1123,23 @@ def _prune_irrelevant_parameters( indices=indices, targets=target_point[indices], ) - # remove candidates that violate constraints after pruning - pruned_candidates, indices = _remove_infeasible_candidates( - candidates=pruned_candidates, - indices=indices, - inequality_constraints=inequality_constraints, - equality_constraints=equality_constraints, - ) + # Project pruned candidates onto the feasible set + # (pinning the pruned dim and previously pruned dims), + # then filter any that remain infeasible. + if has_constraints: + previously_pruned = initial_remaining - remaining_indices + pruned_candidates, indices = ( + _project_and_filter_pruned_candidates( + candidates=pruned_candidates, + indices=indices, + target_point=target_point, + pruned_dims=previously_pruned, + bounds=none_throws(bounds), + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + fixed_features=fixed_features, + ) + ) if pruned_candidates.shape[0] == 0: # no feasible points, continue to # next candidate @@ -1253,3 +1279,129 @@ def _remove_infeasible_candidates( candidates = candidates[is_feasible] indices = indices[is_feasible] return candidates, indices + + +def _project_and_filter_pruned_candidates( + candidates: Tensor, + indices: Tensor, + target_point: Tensor, + pruned_dims: set[int], + bounds: Tensor, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + fixed_features: dict[int, float] | None = None, +) -> tuple[Tensor, Tensor]: + r"""Project pruned candidates onto the feasible set, then filter infeasible. + + Helper for ``Acquisition._prune_irrelevant_parameters`` (BONSAI). It is + only meaningful in the context of that greedy-pruning loop and is not + intended for standalone use. + + Background: BONSAI pruning evaluates a candidate dimension-by-dimension + by setting one dimension at a time to its target-point value. Each row + of ``candidates`` is one such trial -- the dense candidate with the + dimension at ``indices[i]`` swapped to ``target_point[indices[i]]``. + Under linear constraints, swapping a single dimension to the target + typically violates the constraints; rather than discarding the trial + (the prior behavior), we adjust the *other* free dimensions to recover + feasibility while keeping the swapped dimension and all previously + pruned dimensions pinned. Trials whose pins make the constraint system + infeasible -- and the rare case where projection succeeds but the + result still violates constraints -- are filtered out via the mask + returned to the caller. + + Args: + candidates: A ``b x 1 x d``-dim tensor of pruned candidates (one row + per single-dimension prune attempt for the current BONSAI + iteration). + indices: A ``b``-dim tensor indicating which dimension was pruned + in each batch element. + target_point: A ``d``-dim tensor of target values for pruning. + pruned_dims: Set of dimension indices already pruned in prior + greedy iterations (to be kept pinned during projection). + bounds: A ``2 x d``-dim tensor of lower and upper bounds. + inequality_constraints: Inequality constraints in BoTorch format. + equality_constraints: Equality constraints in BoTorch format. + fixed_features: A map ``{feature_index: value}`` from the caller. + These dimensions are excluded from pruning at the outer loop and + must also be pinned during projection so SLSQP cannot adjust + them while satisfying the constraints. Without this, fixed + features could be silently altered. + + Returns: + A two-element tuple of filtered ``(candidates, indices)``. + """ + # Pre-compute which dims participate in any constraint, and check whether + # any constraint is inter-point (2D index tensor). Inter-point constraints + # apply across the q-batch, but each row here is a single-candidate prune + # attempt -- ``project_to_feasible_space_via_slsqp`` cannot evaluate + # inter-point constraints on a 1 x d input. Fall back to the original + # filter-only behavior in that case. + constrained_dims: set[int] = set() + has_interpoint_constraint = False + for constraints in (inequality_constraints, equality_constraints): + if constraints is not None: + for c_indices, _, _ in constraints: + if c_indices.dim() == 1: + constrained_dims.update(c_indices.tolist()) + else: + constrained_dims.update(c_indices[:, -1].tolist()) + has_interpoint_constraint = True + if has_interpoint_constraint: + return _remove_infeasible_candidates( + candidates=candidates, + indices=indices, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ) + + # Build fixed_features for previously pruned dims and the caller's + # fixed_features (both shared across all candidates in this iteration). + prev_fixed: dict[int, float] = {k: target_point[k].item() for k in pruned_dims} + if fixed_features is not None: + prev_fixed.update(fixed_features) + + feasible_mask = torch.ones(candidates.shape[0], dtype=torch.bool) + result = candidates.clone() + + for i in range(candidates.shape[0]): + j: int = int(indices[i].item()) + # If the pruned dim doesn't participate in any constraint, + # pruning it can't violate anything — skip projection. + if j not in constrained_dims: + continue + # Pin the currently pruned dim, all previously pruned dims, and the + # caller's fixed features. + fixed: dict[int, float | Tensor] = { + j: float(target_point[j].item()), + **prev_fixed, + } + try: + projected = project_to_feasible_space_via_slsqp( + X=candidates[i], # 1 x d + bounds=bounds, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + fixed_features=fixed, + ) + result[i] = projected + except CandidateGenerationError: + # Pin makes the system infeasible — mark for removal. + # The post-projection feasibility check below is the safety net + # for any candidates that project but still violate constraints. + feasible_mask[i] = False + + # Final safety-net feasibility check after projection. + if feasible_mask.any(): + is_feasible = evaluate_feasibility( + X=result[feasible_mask], + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ) + # Map back to the full mask. + feasible_subset_indices = feasible_mask.nonzero(as_tuple=True)[0] + for idx, feas in zip(feasible_subset_indices, is_feasible): + if not feas: + feasible_mask[idx] = False + + return result[feasible_mask], indices[feasible_mask] diff --git a/ax/generators/torch/tests/test_acquisition.py b/ax/generators/torch/tests/test_acquisition.py index d18ed99200e..c4232c21f7f 100644 --- a/ax/generators/torch/tests/test_acquisition.py +++ b/ax/generators/torch/tests/test_acquisition.py @@ -1831,6 +1831,7 @@ def test_prune_irrelevant_parameters_with_inequality_constraints(self) -> None: candidates=candidates, search_space_digest=search_space_digest, inequality_constraints=inequality_constraints, + bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]]), ) self.assertTrue(torch.equal(pruned_candidates, torch.tensor([[0.2, 0.8]]))) self.assertTrue(torch.equal(pruned_values, torch.tensor([0.91]))) @@ -1848,6 +1849,7 @@ def test_prune_irrelevant_parameters_with_inequality_constraints(self) -> None: inequality_constraints=[ (torch.tensor([0, 1]), torch.tensor([1.0, 1.0]), 1.5) ], + bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]]), ) # No pruning: setting either dim to 0.2 gives sum=1.0 < 1.5 (infeasible) self.assertTrue(torch.equal(pruned_candidates, torch.tensor([[0.8, 0.8]]))) @@ -2055,6 +2057,7 @@ def test_prune_irrelevant_parameters_with_constraints_exact_values(self) -> None 1.0, ) ], + bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]]), ) # Only dimension 0 should be pruned @@ -2062,6 +2065,128 @@ def test_prune_irrelevant_parameters_with_constraints_exact_values(self) -> None self.assertTrue(torch.equal(pruned_candidates, expected_candidate)) self.assertTrue(torch.equal(pruned_values, torch.tensor([1.0]))) + def test_prune_irrelevant_parameters_with_equality_constraints(self) -> None: + # Test pruning with an equality constraint (x1 + x2 + x3 = 1). + # When a dimension is pruned to its target, the remaining dims should + # be projected onto the equality constraint hyperplane. + search_space_digest = SearchSpaceDigest( + feature_names=["x1", "x2", "x3"], + bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)], + ) + target_point = torch.tensor([1.0 / 3, 1.0 / 3, 1.0 / 3]) + acq = Acquisition( + surrogate=self.surrogate, + search_space_digest=search_space_digest, + torch_opt_config=dataclasses.replace( + self.torch_opt_config, + pruning_target_point=target_point, + ), + botorch_acqf_class=DummyAcquisitionFunction, + ) + mock_acqf = Mock() + mock_acqf._log = False + acq.acqf = mock_acqf + acq._instantiate_acquisition = Mock() + + # Candidate that satisfies x1 + x2 + x3 = 1. + candidates = torch.tensor([[0.5, 0.3, 0.2]]) + # Equality constraint: x1 + x2 + x3 = 1 + equality_constraints = [ + (torch.tensor([0, 1, 2]), torch.tensor([1.0, 1.0, 1.0]), 1.0) + ] + bounds = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) + + mock_evaluate = Mock( + side_effect=[ + torch.tensor([0.0]), # baseline af val + torch.tensor([1.0]), # dense af val + # After pruning dim 0 to 1/3 and projecting, the candidate + # still satisfies x1+x2+x3=1. Two pruning candidates + # (dim 1 and dim 2) survive projection. + torch.tensor([0.95, 0.90]), # pruned af vals + torch.tensor([0.93]), # second round pruned af val + ] + ) + acq.evaluate = mock_evaluate + + pruned_candidates, pruned_values = acq._prune_irrelevant_parameters( + candidates=candidates, + search_space_digest=search_space_digest, + equality_constraints=equality_constraints, + bounds=bounds, + ) + # Verify that pruning occurred and the result satisfies the constraint. + self.assertEqual(pruned_candidates.shape[-1], 3) + for i in range(pruned_candidates.shape[0]): + self.assertAlmostEqual( + pruned_candidates[i].sum().item(), + 1.0, + places=4, + ) + + def test_prune_irrelevant_parameters_fixed_features_pinned_in_projection( + self, + ) -> None: + # When constraints are active and `fixed_features` is provided, the + # SLSQP projection must pin the fixed dims so they cannot be silently + # adjusted to satisfy the constraint. + search_space_digest = SearchSpaceDigest( + feature_names=["x1", "x2", "x3"], + bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)], + ) + target_point = torch.tensor([1.0 / 3, 1.0 / 3, 1.0 / 3]) + acq = Acquisition( + surrogate=self.surrogate, + search_space_digest=search_space_digest, + torch_opt_config=dataclasses.replace( + self.torch_opt_config, + pruning_target_point=target_point, + ), + botorch_acqf_class=DummyAcquisitionFunction, + ) + mock_acqf = Mock() + mock_acqf._log = False + acq.acqf = mock_acqf + acq._instantiate_acquisition = Mock() + + # Candidate that satisfies x1 + x2 + x3 = 1 with x1 fixed at 0.6. + candidates = torch.tensor([[0.6, 0.3, 0.1]]) + # Equality constraint: x1 + x2 + x3 = 1 + equality_constraints = [ + (torch.tensor([0, 1, 2]), torch.tensor([1.0, 1.0, 1.0]), 1.0) + ] + bounds = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) + # Fix x1 to its current value. Pruning dim 1 (x2 -> 1/3) breaks the + # constraint; without pinning x1 in the projection, SLSQP could move + # x1 to recover feasibility, silently overwriting the fixed value. + fixed_features = {0: 0.6} + + mock_evaluate = Mock( + side_effect=[ + torch.tensor([0.0]), # baseline af val + torch.tensor([1.0]), # dense af val + # Only dim 1 and dim 2 are eligible (dim 0 is fixed). Both + # pruning attempts should yield projected candidates that + # keep x1 == 0.6 exactly. + torch.tensor([0.95, 0.90]), # pruned af vals + torch.tensor([0.93]), # second-round pruned af val + ] + ) + acq.evaluate = mock_evaluate + + pruned_candidates, _ = acq._prune_irrelevant_parameters( + candidates=candidates, + search_space_digest=search_space_digest, + equality_constraints=equality_constraints, + bounds=bounds, + fixed_features=fixed_features, + ) + # The fixed feature must be preserved exactly through projection, + # and the constraint must still be satisfied. + self.assertEqual(pruned_candidates.shape[-1], 3) + self.assertAlmostEqual(pruned_candidates[0, 0].item(), 0.6, places=6) + self.assertAlmostEqual(pruned_candidates[0].sum().item(), 1.0, places=4) + def test_prune_irrelevant_parameters_with_task_and_fidelity_features(self) -> None: # Test pruning with both task and fidelity features that should be excluded # from pruning From 286fd3f1c6704048f21b32c20d5fc274a8008733 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Fri, 8 May 2026 09:32:10 -0700 Subject: [PATCH 2/2] Enforce equality constraints in random generator rejection sampling (#5182) Summary: Random generators (Sobol, etc.) were not respecting equality constraints during candidate generation. Two fixes: 1. When equality constraints are present, skip rejection sampling entirely and go straight to polytope sampling. Unconstrained random samples have probability zero of satisfying continuous equality constraints, so rejection sampling would always exhaust max_draws. 2. Add `equality_constraints` parameter to `rejection_sample` and `check_param_constraints` so that post-rounding feasibility checks also validate equality constraints (important when the polytope sampler fallback uses rejection_sample for deduplication). Reviewed By: esantorella Differential Revision: D100256485 --- ax/generators/random/base.py | 22 +++++++++++++---- ax/generators/utils.py | 48 +++++++++++++++++++++++++++++++++--- 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/ax/generators/random/base.py b/ax/generators/random/base.py index cc7d29a2e86..48d3560d256 100644 --- a/ax/generators/random/base.py +++ b/ax/generators/random/base.py @@ -154,6 +154,13 @@ def gen( max_draws = model_gen_options.get("max_rs_draws", DEFAULT_MAX_RS_DRAWS) max_draws = int(assert_is_instance_of_tuple(max_draws, (int, float))) try: + # With equality constraints, unconstrained sampling has probability + # zero of producing feasible points, so skip straight to polytope + # sampling. + if equality_constraints is not None: + raise SearchSpaceExhausted( + "Equality constraints require polytope sampling." + ) # Always rejection sample, but this only rejects if there are # constraints or actual duplicates and deduplicate is specified. # If rejection sampling fails, fall back to polytope sampling. @@ -184,11 +191,15 @@ def gen( num_generated = ( len(generated_points) if generated_points is not None else 0 ) - interior_point = ( # A feasible point of shape `d x 1`. - torch.from_numpy(generated_points[-1].reshape((-1, 1))).double() - if generated_points is not None - else None - ) + # Use a previously generated point as the interior point + # hint, but only if it's likely feasible. When equality + # constraints are present, previous points (generated + # without those constraints) won't satisfy them. + interior_point: torch.Tensor | None = None + if generated_points is not None and equality_constraints is None: + interior_point = torch.from_numpy( + generated_points[-1].reshape((-1, 1)) + ).double() kwargs = {"n_burnin": 100, "n_thinning": 20} kwargs.update(self.polytope_sampler_kwargs) polytope_sampler: HitAndRunPolytopeSampler = HitAndRunPolytopeSampler( @@ -229,6 +240,7 @@ def gen_polytope_sampler( fixed_features=fixed_features, rounding_func=rounding_func, existing_points=generated_points, + equality_constraints=equality_constraints, ) else: raise e diff --git a/ax/generators/utils.py b/ax/generators/utils.py index 897a1507063..f23ea25a3fc 100644 --- a/ax/generators/utils.py +++ b/ax/generators/utils.py @@ -68,6 +68,7 @@ def rejection_sample( fixed_features: dict[int, float] | None = None, rounding_func: Callable[[npt.NDArray], npt.NDArray] | None = None, existing_points: npt.NDArray | None = None, + equality_constraints: tuple[npt.NDArray, npt.NDArray] | None = None, ) -> tuple[npt.NDArray, int]: """Rejection sample in parameter space. @@ -96,6 +97,9 @@ def rejection_sample( existing_points: A set of previously generated points to use for deduplication. These should be provided in the parameter space model operates in. + equality_constraints: A tuple of (A, b). For k equality constraints + on d-dimensional x, A is (k x d) and b is (k x 1) such that + A x = b. Returns: 2-element tuple containing the generated points and the number of @@ -124,9 +128,26 @@ def rejection_sample( )[0] # Check parameter constraints, always in raw transformed space. + has_constraints = ( + linear_constraints is not None or equality_constraints is not None + ) if linear_constraints is not None: all_constraints_satisfied, _ = check_param_constraints( - linear_constraints=linear_constraints, point=point + linear_constraints=linear_constraints, + point=point, + equality_constraints=equality_constraints, + ) + elif equality_constraints is not None: + # No inequality constraints but have equality constraints. + # Use a dummy (0, d) inequality matrix so check_param_constraints works. + dummy_ineq = ( + np.zeros((0, len(point))), + np.zeros((0, 1)), + ) + all_constraints_satisfied, _ = check_param_constraints( + linear_constraints=dummy_ineq, + point=point, + equality_constraints=equality_constraints, ) else: all_constraints_satisfied = True @@ -140,9 +161,15 @@ def rejection_sample( # Re-check constraints after rounding for discrete parameters # (e.g. numerical choice parameters) because rounding can push values # in a direction that violates sum constraints. - if linear_constraints is not None: + if has_constraints: + ineq = linear_constraints or ( + np.zeros((0, len(point))), + np.zeros((0, 1)), + ) all_constraints_satisfied, _ = check_param_constraints( - linear_constraints=linear_constraints, point=point + linear_constraints=ineq, + point=point, + equality_constraints=equality_constraints, ) if not all_constraints_satisfied: attempted_draws += 1 @@ -228,6 +255,7 @@ def add_fixed_features( def check_param_constraints( linear_constraints: tuple[npt.NDArray, npt.NDArray], point: npt.NDArray, + equality_constraints: tuple[npt.NDArray, npt.NDArray] | None = None, ) -> tuple[bool, npt.NDArray]: """Check if a point satisfies parameter constraints. @@ -236,6 +264,9 @@ def check_param_constraints( d-dimensional x, A is (k x d) and b is (k x 1) such that A x <= b. point: A candidate point in d-dimensional space, as a (1 x d) matrix. + equality_constraints: A tuple of (A, b). For k equality constraints on + d-dimensional x, A is (k x d) and b is (k x 1) such that + A x = b. Returns: 2-element tuple containing @@ -246,6 +277,17 @@ def check_param_constraints( constraints_satisfied = ( linear_constraints[0] @ np.expand_dims(point, axis=1) <= linear_constraints[1] ) + if equality_constraints is not None: + eq_satisfied = ( + np.abs( + equality_constraints[0] @ np.expand_dims(point, axis=1) + - equality_constraints[1] + ) + <= 1e-8 + ) + constraints_satisfied = np.concatenate( + [constraints_satisfied, eq_satisfied], axis=0 + ) if np.all(constraints_satisfied): return True, np.array([]) else: