Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 191 additions & 10 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def __init__(
self._trial_type_to_runner: dict[str | None, Runner | None] = {
default_trial_type: runner
}
# Maps each trial type to the set of metric names relevant to that type.
# This is the complement of _trial_type_to_runner and is used for
# multi-type experiments where different metrics apply to different
# trial types.
self._trial_type_to_metric_names: dict[str, set[str]] = {}
# Used to keep track of whether any trials on the experiment
# specify a TTL. Since trials need to be checked for their TTL's
# expiration often, having this attribute helps avoid unnecessary
Expand Down Expand Up @@ -197,6 +202,10 @@ def __init__(
# a naming collision occurs.
for m in [*(tracking_metrics or []), *(metrics or [])]:
self._metrics[m.name] = m
if self._default_trial_type is not None:
self._trial_type_to_metric_names.setdefault(
self._default_trial_type, set()
).add(m.name)

# call setters defined below
self.status_quo = status_quo
Expand Down Expand Up @@ -585,6 +594,12 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
"but not found on experiment. Add it first with add_metric()."
)
self._optimization_config = optimization_config
resolved_trial_type = self._resolve_trial_type(None)
if resolved_trial_type is not None:
for metric_name in optimization_config.metric_names:
self._trial_type_to_metric_names.setdefault(
resolved_trial_type, set()
).add(metric_name)

@property
def is_moo_problem(self) -> bool:
Expand Down Expand Up @@ -837,7 +852,41 @@ def get_metric(self, name: str) -> Metric:
)
return self._metrics[name]

def add_metric(self, metric: Metric) -> Self:
def _resolve_trial_type(self, trial_type: str | None) -> str | None:
"""Resolve an explicit or default trial type and validate it.

Returns ``trial_type`` if explicitly provided (after validating via
``supports_trial_type``), falls back to ``_default_trial_type`` when
available, and raises ``ValueError`` if this experiment uses trial types
(``_trial_type_to_metric_names`` is non-empty) but none could be
resolved.

Args:
trial_type: The explicitly provided trial type, or ``None``.

Returns:
The resolved trial type, which may be ``None`` for single-type
experiments.

Raises:
ValueError: If ``trial_type`` is provided but not supported, or if
no trial type could be resolved for a multi-type experiment.
"""
if trial_type is not None:
if not self.supports_trial_type(trial_type):
raise ValueError(f"`{trial_type}` is not a supported trial type.")
return trial_type
if self._default_trial_type is not None:
return self._default_trial_type
if self._trial_type_to_metric_names:
raise ValueError(
"This experiment has trial-type-aware metrics but no "
"`trial_type` was specified and no `default_trial_type` is set. "
"Please specify a `trial_type`."
)
return None

def add_metric(self, metric: Metric, trial_type: str | None = None) -> Self:
"""Add a new metric to the experiment.

Metrics that are not referenced by the experiment's optimization config
Expand All @@ -846,54 +895,98 @@ def add_metric(self, metric: Metric) -> Self:

Args:
metric: Metric to be added.
trial_type: If provided, associates the metric with this trial type.
When ``None`` and a ``default_trial_type`` is set, defaults to
the default trial type.

Raises:
ValueError: If the metric already exists, the trial type is not
supported, or trial types are in use but none could be resolved.
"""
if metric.name in self._metrics:
raise ValueError(
f"Metric `{metric.name}` already defined on experiment. "
"Use `update_metric` to update an existing metric definition."
)
trial_type = self._resolve_trial_type(trial_type)
if trial_type is not None:
self._trial_type_to_metric_names.setdefault(trial_type, set()).add(
metric.name
)
self._metrics[metric.name] = metric
return self

def add_tracking_metric(self, metric: Metric) -> Self:
def add_tracking_metric(
self,
metric: Metric,
trial_type: str | None = None,
canonical_name: str | None = None,
) -> Self:
"""*Deprecated.* Use ``add_metric`` instead."""
warnings.warn(
"add_tracking_metric is deprecated. Use add_metric instead.",
DeprecationWarning,
stacklevel=2,
)
return self.add_metric(metric)
return self.add_metric(metric, trial_type=trial_type)

def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment:
def add_tracking_metrics(
self,
metrics: list[Metric],
metrics_to_trial_types: dict[str, str] | None = None,
canonical_names: dict[str, str] | None = None,
) -> Experiment:
"""*Deprecated.* Use ``add_metric`` instead."""
warnings.warn(
"add_tracking_metrics is deprecated. Use add_metric instead.",
DeprecationWarning,
stacklevel=2,
)
metrics_to_trial_types = metrics_to_trial_types or {}
for metric in metrics:
self.add_metric(metric)
canonical_name = (canonical_names or {}).get(metric.name)
self.add_tracking_metric(
metric=metric,
trial_type=metrics_to_trial_types.get(metric.name),
canonical_name=canonical_name,
)
return self

def update_metric(self, metric: Metric) -> Self:
def update_metric(self, metric: Metric, trial_type: str | None = None) -> Self:
"""Redefine a metric that already exists on the experiment.

Args:
metric: New metric definition.
trial_type: If provided, reassociates the metric with this trial
type. When ``None``, keeps the metric's existing trial type.
"""
if metric.name not in self._metrics:
raise ValueError(f"Metric `{metric.name}` doesn't exist on experiment.")
if trial_type is not None:
trial_type = self._resolve_trial_type(trial_type)
# Remove from any existing trial type set
for names in self._trial_type_to_metric_names.values():
names.discard(metric.name)
# Add to new trial type set
self._trial_type_to_metric_names.setdefault(trial_type, set()).add(
metric.name
)
self._metrics[metric.name] = metric
return self

def update_tracking_metric(self, metric: Metric) -> Experiment:
def update_tracking_metric(
self,
metric: Metric,
trial_type: str | None = None,
canonical_name: str | None = None,
) -> Experiment:
"""*Deprecated.* Use ``update_metric`` instead."""
warnings.warn(
"update_tracking_metric is deprecated. Use update_metric instead.",
DeprecationWarning,
stacklevel=2,
)
return self.update_metric(metric)
return self.update_metric(metric, trial_type=trial_type)

def remove_metric(self, metric_name: str) -> Self:
"""Remove a metric from the experiment.
Expand All @@ -914,6 +1007,9 @@ def remove_metric(self, metric_name: str) -> Self:
f"Metric `{metric_name}` is referenced by the optimization config "
"and cannot be removed. Update the optimization config first."
)
# Clean up _trial_type_to_metric_names
for names in self._trial_type_to_metric_names.values():
names.discard(metric_name)
del self._metrics[metric_name]
return self

Expand Down Expand Up @@ -1928,6 +2024,85 @@ def default_trial_type(self) -> str | None:
"""
return self._default_trial_type

@property
def trial_type_to_metric_names(self) -> dict[str, set[str]]:
"""Map from trial type to the set of metric names relevant to that
type.

Returns a shallow copy of the internal mapping.
"""
return dict(self._trial_type_to_metric_names)

@property
def metric_to_trial_type(self) -> dict[str, str]:
"""Map each metric name to its associated trial type.

Computed from ``_trial_type_to_metric_names``. When a
``default_trial_type`` is set and an ``optimization_config`` exists,
optimization config metrics are pinned to the default trial type.
"""
result: dict[str, str] = {}
for trial_type, metric_names in self._trial_type_to_metric_names.items():
for name in metric_names:
result[name] = trial_type
opt_config = self._optimization_config
default_trial_type = self._default_trial_type
if default_trial_type is not None and opt_config is not None:
for metric_name in opt_config.metric_names:
result[metric_name] = default_trial_type
return result

def metrics_for_trial_type(self, trial_type: str) -> list[Metric]:
"""Return the metrics associated with a given trial type.

Args:
trial_type: The trial type to look up metrics for.

Raises:
ValueError: If the trial type is not supported.
"""
if not self.supports_trial_type(trial_type):
raise ValueError(f"Trial type `{trial_type}` is not supported.")
valid_names = self._trial_type_to_metric_names.get(trial_type, set())
return [self._metrics[name] for name in valid_names if name in self._metrics]

@property
def default_trials(self) -> set[int]:
"""Return the indices for trials of the default type."""
return {
idx
for idx, trial in self.trials.items()
if trial.trial_type == self.default_trial_type
}

def add_trial_type(self, trial_type: str, runner: Runner | None = None) -> Self:
"""Add a new trial type to be supported by this experiment.

Args:
trial_type: The new trial type to be added.
runner: The default runner for trials of this type.
"""
if self.supports_trial_type(trial_type):
raise ValueError(f"Experiment already contains trial_type `{trial_type}`")

if runner is not None:
self._trial_type_to_runner[trial_type] = runner

return self

def update_runner(self, trial_type: str, runner: Runner) -> Self:
"""Update the default runner for an existing trial type.

Args:
trial_type: The trial type whose runner should be updated.
runner: The new runner for trials of this type.
"""
if not self.supports_trial_type(trial_type):
raise ValueError(f"Experiment does not contain trial_type `{trial_type}`")
self._trial_type_to_runner[trial_type] = runner
self._runner = runner
return self

def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
"""The default runner to use for a given trial type.

Expand All @@ -1942,14 +2117,20 @@ def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
def supports_trial_type(self, trial_type: str | None) -> bool:
"""Whether this experiment allows trials of the given type.

The base experiment class only supports None. For experiments
with multiple trial types, use the MultiTypeExperiment class.
For experiments with a ``default_trial_type`` (multi-type experiments),
only trial types registered in ``_trial_type_to_runner`` are supported.
For single-type experiments, ``None`` is always supported, along with
``SHORT_RUN`` and ``LONG_RUN`` for backward compatibility with
generation strategies that use those trial types.
"""
if self._default_trial_type is not None:
return trial_type in self._trial_type_to_runner
return (
trial_type is None
or trial_type == Keys.SHORT_RUN
or trial_type == Keys.LONG_RUN
or trial_type == Keys.LILO_LABELING
or trial_type in self._trial_type_to_runner
)

def attach_trial(
Expand Down
Loading
Loading