Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TParameterization,
TParamValue,
)
from ax.core.utils import compute_metric_availability, MetricAvailability
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
from ax.early_stopping.utils import estimate_early_stopping_savings
from ax.exceptions.constants import CHOLESKY_ERROR_ANNOTATION
Expand Down Expand Up @@ -1407,9 +1408,14 @@ def fit_model(self) -> None:
raise DataRequiredError(
"At least one trial must be completed with data to fit a model."
)
if self.experiment.lookup_data(trial_indices=completed_trial_indices).df.empty:
availability = compute_metric_availability(
experiment=self.experiment,
trial_indices=completed_trial_indices,
)
if not any(v == MetricAvailability.COMPLETE for v in availability.values()):
raise DataRequiredError(
"At least one completed trial must have data attached to fit a model."
"At least one completed trial must have data for all required "
"metrics to fit a model."
)
self.generation_strategy.fit(experiment=self.experiment)

Expand Down
22 changes: 22 additions & 0 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2581,6 +2581,28 @@ def test_get_model_predictions_no_next_trial_no_completed_trial(self) -> None:
):
ax_client.get_model_predictions()

def test_fit_model_partial_metric_data(self) -> None:
"""Test that fit_model raises when completed trials only have data for
a subset of required metrics."""
ax_client = _set_up_client_for_get_model_predictions_no_next_trial()
# Attach a trial and complete it with data for only one of the two
# required metrics (test_metric1 is the objective, test_metric2 is the
# constraint). We bypass complete_trial() because it marks the trial as
# failed when required metrics are missing. Instead, we attach data and
# mark completed directly, simulating the case where the data check at
# completion time is skipped (e.g., data is attached asynchronously).
trial: TParameterization = {"x1": 0.1, "x2": 0.1}
_parameters, trial_index = ax_client.attach_trial(trial)
ax_trial = ax_client.get_trial(trial_index)
ax_trial.update_trial_data(raw_data={"test_metric1": (1.0, 0.0)})
ax_trial.mark_completed()

with self.assertRaisesRegex(
DataRequiredError,
"At least one completed trial must have data for all required metrics",
):
ax_client.fit_model()

def test_get_model_predictions_no_next_trial_filtered(self) -> None:
ax_client = _set_up_client_for_get_model_predictions_no_next_trial()
_attach_completed_trials(ax_client)
Expand Down
Loading