Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion ax/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,13 +571,19 @@ def attach_data(
progression: int | None = None,
) -> None:
"""
Attach data without indicating the trial is complete. Missing metrics are,
Attach data without indicating the trial is complete. Missing metrics are
allowed, and unexpected metric values will be added to the Experiment as
tracking metrics.

Saves to database on completion if ``storage_config`` is present.
"""

# Auto-register any metrics present in raw_data but not yet on the
# experiment as tracking metrics, matching the docstring contract.
extra_metrics = set(raw_data.keys()) - set(self._experiment.metrics.keys())
if extra_metrics:
self.configure_tracking_metrics(metric_names=list(extra_metrics))

# If no progression is provided assume the data is not timeseries-like and
# set step=NaN
data_with_progression = [
Expand Down
17 changes: 5 additions & 12 deletions ax/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from ax.core.trial_status import TrialStatus
from ax.core.utils import compute_metric_availability, MetricAvailability
from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.exceptions.core import UnsupportedError
from ax.storage.sqa_store.db import init_test_engine_and_session_factory
from ax.storage.sqa_store.with_db_settings_base import (
_save_generation_strategy_to_db_if_possible,
Expand Down Expand Up @@ -656,21 +656,14 @@ def test_attach_data(self) -> None:
)

# With extra metrics
# Try and attach data for a metric that doesn't exist
with self.assertRaisesRegex(
UserInputError,
"Unable to find the metric signature for one or more metrics.",
):
client.attach_data(
trial_index=trial_index,
raw_data={"foo": 1.0, "bar": 2.0},
)

client.configure_metrics(metrics=[DummyMetric(name="bar")])
# Extraneous metrics should be auto-registered as tracking metrics
self.assertNotIn("bar", client._experiment.metrics)
client.attach_data(
trial_index=trial_index,
raw_data={"foo": 1.0, "bar": 2.0},
)
self.assertIn("bar", client._experiment.metrics)
self.assertIn("bar", [m.name for m in client._experiment.tracking_metrics])
self.assertEqual(
client._experiment.trials[trial_index].status,
TrialStatus.RUNNING,
Expand Down
Loading