diff --git a/ax/modelbridge/tests/test_torch_modelbridge.py b/ax/modelbridge/tests/test_torch_modelbridge.py index 021a4217f2d..304e5fc5559 100644 --- a/ax/modelbridge/tests/test_torch_modelbridge.py +++ b/ax/modelbridge/tests/test_torch_modelbridge.py @@ -340,6 +340,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None: ma._model_space = get_branin_search_space() ma._optimization_config = None ma.outcomes = ["test_metric"] + ma._fit_out_of_design = False with self.assertRaisesRegex(ValueError, "optimization_config"): ma.evaluate_acquisition_function( diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index 046b0f43845..1c466b4999c 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -871,6 +871,7 @@ def _get_transformed_model_gen_args( opt_config_metrics=opt_config_metrics, is_moo=optimization_config.is_moo_problem, risk_measure=risk_measure, + fit_out_of_design=self._fit_out_of_design, ) return search_space_digest, torch_opt_config diff --git a/ax/models/tests/test_torch_utils.py b/ax/models/tests/test_torch_utils.py index cbf5cb86ec8..f5d3972c434 100644 --- a/ax/models/tests/test_torch_utils.py +++ b/ax/models/tests/test_torch_utils.py @@ -81,6 +81,28 @@ def _to_obs_set(X: torch.Tensor) -> Set[Tuple[float]]: expected = Xs[0] self.assertEqual(_to_obs_set(expected), _to_obs_set(not_none(X_observed))) + # Out of design observations are filtered out + Xs = [torch.tensor([[2.0, 3.0], [3.0, 4.0]])] + _, X_observed = _get_X_pending_and_observed( + Xs=Xs, + objective_weights=objective_weights, + bounds=bounds, + fixed_features=fixed_features, + fit_out_of_design=False, + ) + self.assertIsNone(X_observed) + + # Keep out of design observations + _, X_observed = _get_X_pending_and_observed( + Xs=Xs, + objective_weights=objective_weights, + bounds=bounds, + fixed_features=fixed_features, + fit_out_of_design=True, + ) + expected = Xs[0] + self.assertEqual(_to_obs_set(expected), _to_obs_set(not_none(X_observed))) + @patch( f"{get_botorch_objective_and_transform.__module__}.get_infeasible_cost", return_value=1.0, diff --git a/ax/models/torch/botorch.py b/ax/models/torch/botorch.py index 7ea4cf5ba62..89dab01a145 100644 --- a/ax/models/torch/botorch.py +++ b/ax/models/torch/botorch.py @@ -353,6 +353,7 @@ def gen( outcome_constraints=torch_opt_config.outcome_constraints, linear_constraints=torch_opt_config.linear_constraints, fixed_features=torch_opt_config.fixed_features, + fit_out_of_design=torch_opt_config.fit_out_of_design, ) model = self.model # subset model only to the outcomes we need for the optimization 357 diff --git a/ax/models/torch/utils.py b/ax/models/torch/utils.py index bc0099de319..414539aaa87 100644 --- a/ax/models/torch/utils.py +++ b/ax/models/torch/utils.py @@ -99,6 +99,7 @@ def _filter_X_observed( outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, linear_constraints: Optional[Tuple[Tensor, Tensor]] = None, fixed_features: Optional[Dict[int, float]] = None, + fit_out_of_design: bool = False, ) -> Optional[Tensor]: r"""Filter input points to those appearing in objective or constraints. @@ -115,6 +116,8 @@ def _filter_X_observed( A x <= b. (Not used by single task models) fixed_features: A map {feature_index: value} for features that should be fixed to a particular value during generation. + fit_out_of_design: If specified, all training data is returned. + Otherwise, only in design points are returned. Returns: Tensor: All points that are feasible and appear in the objective or @@ -126,13 +129,14 @@ def _filter_X_observed( objective_weights=objective_weights, outcome_constraints=outcome_constraints, ) - # Filter to those that satisfy constraints. - X_obs = filter_constraints_and_fixed_features( - X=X_obs, - bounds=bounds, - linear_constraints=linear_constraints, - fixed_features=fixed_features, - ) + if not fit_out_of_design: + # Filter to those that satisfy constraints. + X_obs = filter_constraints_and_fixed_features( + X=X_obs, + bounds=bounds, + linear_constraints=linear_constraints, + fixed_features=fixed_features, + ) if len(X_obs) > 0: return torch.as_tensor(X_obs) # please the linter @@ -145,6 +149,7 @@ def _get_X_pending_and_observed( outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, linear_constraints: Optional[Tuple[Tensor, Tensor]] = None, fixed_features: Optional[Dict[int, float]] = None, + fit_out_of_design: bool = False, ) -> Tuple[Optional[Tensor], Optional[Tensor]]: r"""Get pending and observed points. @@ -167,6 +172,8 @@ def _get_X_pending_and_observed( A x <= b. (Not used by single task models) fixed_features: A map {feature_index: value} for features that should be fixed to a particular value during generation. + fit_out_of_design: If specified, all training data is returned. + Otherwise, only in design points are returned. Returns: Tensor: Pending points that are feasible and appear in the objective or @@ -192,6 +199,7 @@ def _get_X_pending_and_observed( bounds=bounds, linear_constraints=linear_constraints, fixed_features=fixed_features, + fit_out_of_design=fit_out_of_design, ) if filtered_X_observed is not None and len(filtered_X_observed) > 0: return X_pending, filtered_X_observed @@ -201,6 +209,7 @@ def _get_X_pending_and_observed( objective_weights=objective_weights, bounds=bounds, outcome_constraints=outcome_constraints, + fit_out_of_design=fit_out_of_design, ) return X_pending, unfiltered_X_observed diff --git a/ax/models/torch_base.py b/ax/models/torch_base.py index 4991b4e04de..7822a62d920 100644 --- a/ax/models/torch_base.py +++ b/ax/models/torch_base.py @@ -88,6 +88,7 @@ class TorchOptConfig: opt_config_metrics: Optional[Dict[str, Metric]] = None is_moo: bool = False risk_measure: Optional[RiskMeasureMCObjective] = None + fit_out_of_design: bool = False @dataclass(frozen=True)