From 649197f88ca0845fc8cf7ecdec4e2aca6193ada6 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Mon, 30 Mar 2026 14:21:21 -0700 Subject: [PATCH] Use compute_metric_availability in AxClient.fit_model Summary: The previous data check in `fit_model` only verified that the DataFrame was non-empty (`lookup_data().df.empty`), which would pass even when data existed for only a subset of required metrics. This could allow model fitting to proceed with incomplete data, leading to downstream errors. Replace the manual check with `compute_metric_availability()` from `ax.core.utils`, which inspects per-trial metric coverage against the optimization config's required metrics. `fit_model` now raises `DataRequiredError` unless at least one completed trial has data for **all** required metrics (`MetricAvailability.COMPLETE`). Reviewed By: saitcakmak Differential Revision: D98208718 --- ax/service/ax_client.py | 10 ++++++++-- ax/service/tests/test_ax_client.py | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 4a15b7f2ded..e34b4303507 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -40,6 +40,7 @@ TParameterization, TParamValue, ) +from ax.core.utils import compute_metric_availability, MetricAvailability from ax.early_stopping.strategies import BaseEarlyStoppingStrategy from ax.early_stopping.utils import estimate_early_stopping_savings from ax.exceptions.constants import CHOLESKY_ERROR_ANNOTATION @@ -1407,9 +1408,14 @@ def fit_model(self) -> None: raise DataRequiredError( "At least one trial must be completed with data to fit a model." ) - if self.experiment.lookup_data(trial_indices=completed_trial_indices).df.empty: + availability = compute_metric_availability( + experiment=self.experiment, + trial_indices=completed_trial_indices, + ) + if not any(v == MetricAvailability.COMPLETE for v in availability.values()): raise DataRequiredError( - "At least one completed trial must have data attached to fit a model." + "At least one completed trial must have data for all required " + "metrics to fit a model." ) self.generation_strategy.fit(experiment=self.experiment) diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 93fc0beb89e..507d8c6dd31 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -2581,6 +2581,28 @@ def test_get_model_predictions_no_next_trial_no_completed_trial(self) -> None: ): ax_client.get_model_predictions() + def test_fit_model_partial_metric_data(self) -> None: + """Test that fit_model raises when completed trials only have data for + a subset of required metrics.""" + ax_client = _set_up_client_for_get_model_predictions_no_next_trial() + # Attach a trial and complete it with data for only one of the two + # required metrics (test_metric1 is the objective, test_metric2 is the + # constraint). We bypass complete_trial() because it marks the trial as + # failed when required metrics are missing. Instead, we attach data and + # mark completed directly, simulating the case where the data check at + # completion time is skipped (e.g., data is attached asynchronously). + trial: TParameterization = {"x1": 0.1, "x2": 0.1} + _parameters, trial_index = ax_client.attach_trial(trial) + ax_trial = ax_client.get_trial(trial_index) + ax_trial.update_trial_data(raw_data={"test_metric1": (1.0, 0.0)}) + ax_trial.mark_completed() + + with self.assertRaisesRegex( + DataRequiredError, + "At least one completed trial must have data for all required metrics", + ): + ax_client.fit_model() + def test_get_model_predictions_no_next_trial_filtered(self) -> None: ax_client = _set_up_client_for_get_model_predictions_no_next_trial() _attach_completed_trials(ax_client)