diff --git a/ax/api/client.py b/ax/api/client.py index ee2d98882bb..b8653637de2 100644 --- a/ax/api/client.py +++ b/ax/api/client.py @@ -571,13 +571,19 @@ def attach_data( progression: int | None = None, ) -> None: """ - Attach data without indicating the trial is complete. Missing metrics are, + Attach data without indicating the trial is complete. Missing metrics are allowed, and unexpected metric values will be added to the Experiment as tracking metrics. Saves to database on completion if ``storage_config`` is present. """ + # Auto-register any metrics present in raw_data but not yet on the + # experiment as tracking metrics, matching the docstring contract. + extra_metrics = set(raw_data.keys()) - set(self._experiment.metrics.keys()) + if extra_metrics: + self.configure_tracking_metrics(metric_names=list(extra_metrics)) + # If no progression is provided assume the data is not timeseries-like and # set step=NaN data_with_progression = [ diff --git a/ax/api/tests/test_client.py b/ax/api/tests/test_client.py index 9d0fd707e22..cc3dceb882d 100644 --- a/ax/api/tests/test_client.py +++ b/ax/api/tests/test_client.py @@ -41,7 +41,7 @@ from ax.core.trial_status import TrialStatus from ax.core.utils import compute_metric_availability, MetricAvailability from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy -from ax.exceptions.core import UnsupportedError, UserInputError +from ax.exceptions.core import UnsupportedError from ax.storage.sqa_store.db import init_test_engine_and_session_factory from ax.storage.sqa_store.with_db_settings_base import ( _save_generation_strategy_to_db_if_possible, @@ -656,21 +656,14 @@ def test_attach_data(self) -> None: ) # With extra metrics - # Try and attach data for a metric that doesn't exist - with self.assertRaisesRegex( - UserInputError, - "Unable to find the metric signature for one or more metrics.", - ): - client.attach_data( - trial_index=trial_index, - raw_data={"foo": 1.0, "bar": 2.0}, - ) - - client.configure_metrics(metrics=[DummyMetric(name="bar")]) + # Extraneous metrics should be auto-registered as tracking metrics + self.assertNotIn("bar", client._experiment.metrics) client.attach_data( trial_index=trial_index, raw_data={"foo": 1.0, "bar": 2.0}, ) + self.assertIn("bar", client._experiment.metrics) + self.assertIn("bar", [m.name for m in client._experiment.tracking_metrics]) self.assertEqual( client._experiment.trials[trial_index].status, TrialStatus.RUNNING,