From 1b7ebbc76972159b679d6326a37ba15df6578ab2 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 24 Mar 2026 06:06:26 -0700 Subject: [PATCH 1/3] Add _trial_type_to_metric_names to base Experiment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: This is Phase 1 of moving MultiTypeExperiment features into the base Experiment class, enabling eventual deprecation of MultiTypeExperiment. Adds `_trial_type_to_metric_names: dict[str, set[str]]` to Experiment — a mapping from trial type to the set of metric names relevant to that type. This is the natural complement to the existing `_trial_type_to_runner` dict. Along with it, adds the following properties and methods to Experiment: - `trial_type_to_metric_names`: read-only property (shallow copy) - `metric_to_trial_type`: computed inverse mapping, with optimization config metrics pinned to `default_trial_type` - `metrics_for_trial_type(trial_type)`: returns Metric objects for a given trial type - `default_trials`: returns trial indices matching the default type MultiTypeExperiment is updated to populate `_trial_type_to_metric_names` alongside `_metric_to_trial_type` in all mutation paths (init, optimization_config setter, add/update/remove tracking metric). The redundant MTE overrides for `metric_to_trial_type`, `metrics_for_trial_type`, `default_trials`, and `default_trial_type` are removed — they are now inherited from the base class. The JSON decoder is updated to rebuild `_trial_type_to_metric_names` from `_metric_to_trial_type` during deserialization for backward compatibility. Differential Revision: D94970662 --- ax/core/experiment.py | 56 ++++++++++++++++++++++++ ax/core/multi_type_experiment.py | 73 +++++++++++--------------------- ax/storage/json_store/decoder.py | 6 +++ 3 files changed, 86 insertions(+), 49 deletions(-) diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 1f13f6181aa..1b827041a1d 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -161,6 +161,11 @@ def __init__( self._trial_type_to_runner: dict[str | None, Runner | None] = { default_trial_type: runner } + # Maps each trial type to the set of metric names relevant to that type. + # This is the complement of _trial_type_to_runner and is used for + # multi-type experiments where different metrics apply to different + # trial types. + self._trial_type_to_metric_names: dict[str, set[str]] = {} # Used to keep track of whether any trials on the experiment # specify a TTL. Since trials need to be checked for their TTL's # expiration often, having this attribute helps avoid unnecessary @@ -1928,6 +1933,57 @@ def default_trial_type(self) -> str | None: """ return self._default_trial_type + @property + def trial_type_to_metric_names(self) -> dict[str, set[str]]: + """Map from trial type to the set of metric names relevant to that + type. + + Returns a shallow copy of the internal mapping. + """ + return dict(self._trial_type_to_metric_names) + + @property + def metric_to_trial_type(self) -> dict[str, str]: + """Map each metric name to its associated trial type. + + Computed from ``_trial_type_to_metric_names``. When a + ``default_trial_type`` is set and an ``optimization_config`` exists, + optimization config metrics are pinned to the default trial type. + """ + result: dict[str, str] = {} + for trial_type, metric_names in self._trial_type_to_metric_names.items(): + for name in metric_names: + result[name] = trial_type + opt_config = self._optimization_config + default_trial_type = self._default_trial_type + if default_trial_type is not None and opt_config is not None: + for metric_name in opt_config.metric_names: + result[metric_name] = default_trial_type + return result + + def metrics_for_trial_type(self, trial_type: str) -> list[Metric]: + """Return the metrics associated with a given trial type. + + Args: + trial_type: The trial type to look up metrics for. + + Raises: + ValueError: If the trial type is not supported. + """ + if not self.supports_trial_type(trial_type): + raise ValueError(f"Trial type `{trial_type}` is not supported.") + valid_names = self._trial_type_to_metric_names.get(trial_type, set()) + return [self._metrics[name] for name in valid_names if name in self._metrics] + + @property + def default_trials(self) -> set[int]: + """Return the indices for trials of the default type.""" + return { + idx + for idx, trial in self.trials.items() + if trial.trial_type == self.default_trial_type + } + def runner_for_trial_type(self, trial_type: str | None) -> Runner | None: """The default runner to use for a given trial type. diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index f6a1ce91add..59c9ec28ad2 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -96,15 +96,16 @@ def __init__( default_data_type=default_data_type, ) - # Ensure tracking metrics are registered in _metric_to_trial_type. + # Ensure tracking metrics are registered in _metric_to_trial_type + # and _trial_type_to_metric_names. # super().__init__ sets self._metrics directly, bypassing # add_tracking_metric, so tracking metrics won't be in # _metric_to_trial_type yet. for m in tracking_metrics or []: if m.name not in self._metric_to_trial_type: - self._metric_to_trial_type[m.name] = none_throws( - self._default_trial_type - ) + tt = none_throws(self._default_trial_type) + self._metric_to_trial_type[m.name] = tt + self._trial_type_to_metric_names.setdefault(tt, set()).add(m.name) def add_trial_type(self, trial_type: str, runner: Runner) -> Self: """Add a new trial_type to be supported by this experiment. @@ -129,9 +130,9 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: for metric_name in optimization_config.metric_names: # Optimization config metrics are required to be the default trial type # currently. TODO: remove that restriction (T202797235) - self._metric_to_trial_type[metric_name] = none_throws( - self.default_trial_type - ) + tt = none_throws(self.default_trial_type) + self._metric_to_trial_type[metric_name] = tt + self._trial_type_to_metric_names.setdefault(tt, set()).add(metric_name) def update_runner(self, trial_type: str, runner: Runner) -> Self: """Update the default runner for an existing trial_type. @@ -166,7 +167,9 @@ def add_tracking_metric( raise ValueError(f"`{trial_type}` is not a supported trial type.") super().add_tracking_metric(metric) - self._metric_to_trial_type[metric.name] = none_throws(trial_type) + tt = none_throws(trial_type) + self._metric_to_trial_type[metric.name] = tt + self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name) if canonical_name is not None: self._metric_to_canonical_name[metric.name] = canonical_name return self @@ -242,7 +245,14 @@ def update_tracking_metric( raise ValueError(f"`{trial_type}` is not a supported trial type.") super().update_tracking_metric(metric) - self._metric_to_trial_type[metric.name] = none_throws(trial_type) + # Remove from old trial type set + old_tt = self._metric_to_trial_type.get(metric.name) + if old_tt is not None and old_tt in self._trial_type_to_metric_names: + self._trial_type_to_metric_names[old_tt].discard(metric.name) + # Add to new trial type set + tt = none_throws(trial_type) + self._metric_to_trial_type[metric.name] = tt + self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name) if canonical_name is not None: self._metric_to_canonical_name[metric.name] = canonical_name return self @@ -252,6 +262,11 @@ def remove_tracking_metric(self, metric_name: str) -> Self: if metric_name not in self._metrics: raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.") + # Clean up _trial_type_to_metric_names + old_tt = self._metric_to_trial_type.get(metric_name) + if old_tt is not None and old_tt in self._trial_type_to_metric_names: + self._trial_type_to_metric_names[old_tt].discard(metric_name) + # Required fields del self._metrics[metric_name] del self._metric_to_trial_type[metric_name] @@ -295,46 +310,6 @@ def _fetch_trial_data( # Invoke parent's fetch method using only metrics for this trial_type return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs) - @property - def default_trials(self) -> set[int]: - """Return the indicies for trials of the default type.""" - return { - idx - for idx, trial in self.trials.items() - if trial.trial_type == self.default_trial_type - } - - @property - def metric_to_trial_type(self) -> dict[str, str]: - """Map metrics to trial types. - - Adds in default trial type for OC metrics to custom defined trial types.. - """ - opt_config_types = { - metric_name: self.default_trial_type - for metric_name in self.optimization_config.metric_names - } - return {**opt_config_types, **self._metric_to_trial_type} - - # -- Overridden functions from Base Experiment Class -- - @property - def default_trial_type(self) -> str | None: - """Default trial type assigned to trials in this experiment.""" - return self._default_trial_type - - def metrics_for_trial_type(self, trial_type: str) -> list[Metric]: - """The default runner to use for a given trial type. - - Looks up the appropriate runner for this trial type in the trial_type_to_runner. - """ - if not self.supports_trial_type(trial_type): - raise ValueError(f"Trial type `{trial_type}` is not supported.") - return [ - self.metrics[metric_name] - for metric_name, metric_trial_type in self._metric_to_trial_type.items() - if metric_trial_type == trial_type - ] - def supports_trial_type(self, trial_type: str | None) -> bool: """Whether this experiment allows trials of the given type. diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index fc5827a89a1..1c7cf70ffdc 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -720,6 +720,12 @@ def multi_type_experiment_from_json( experiment._metric_to_trial_type = _metric_to_trial_type experiment._trial_type_to_runner = _trial_type_to_runner + # Rebuild _trial_type_to_metric_names from _metric_to_trial_type + trial_type_to_metric_names: dict[str, set[str]] = {} + for metric_name, trial_type in _metric_to_trial_type.items(): + trial_type_to_metric_names.setdefault(trial_type, set()).add(metric_name) + experiment._trial_type_to_metric_names = trial_type_to_metric_names + _load_experiment_info( exp=experiment, exp_info=experiment_info, From 8b446f544962c7b757f99afd7c2cb3ff67313674 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 24 Mar 2026 08:07:22 -0700 Subject: [PATCH 2/3] Add trial_type support to base Experiment metric methods Summary: Phase 2 of moving MultiTypeExperiment features into base Experiment. Updates the base Experiment metric management methods (`add_metric`, `update_metric`, `remove_metric`) to accept an optional `trial_type` parameter. When provided, metrics are associated with the specified trial type in `_trial_type_to_metric_names`. The `__init__` and `optimization_config` setter also now register metrics when `default_trial_type` is set. The deprecated wrappers (`add_tracking_metric`, `add_tracking_metrics`, `update_tracking_metric`) now accept and pass through `trial_type` and `canonical_name` parameters. On MultiTypeExperiment, overrides are simplified to delegate to the base class methods: - `add_tracking_metric` delegates to `self.add_metric()` - `add_tracking_metrics` override removed (inherited from base) - `update_tracking_metric` delegates to `self.update_metric()` - `remove_tracking_metric` replaced with `remove_metric` override Differential Revision: D94986440 --- ax/core/experiment.py | 107 ++++++++++++++++-- ax/core/multi_type_experiment.py | 114 ++++---------------- ax/core/tests/test_multi_type_experiment.py | 6 -- 3 files changed, 117 insertions(+), 110 deletions(-) diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 1b827041a1d..71c13289123 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -202,6 +202,10 @@ def __init__( # a naming collision occurs. for m in [*(tracking_metrics or []), *(metrics or [])]: self._metrics[m.name] = m + if self._default_trial_type is not None: + self._trial_type_to_metric_names.setdefault( + self._default_trial_type, set() + ).add(m.name) # call setters defined below self.status_quo = status_quo @@ -590,6 +594,12 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: "but not found on experiment. Add it first with add_metric()." ) self._optimization_config = optimization_config + resolved_trial_type = self._resolve_trial_type(None) + if resolved_trial_type is not None: + for metric_name in optimization_config.metric_names: + self._trial_type_to_metric_names.setdefault( + resolved_trial_type, set() + ).add(metric_name) @property def is_moo_problem(self) -> bool: @@ -842,7 +852,41 @@ def get_metric(self, name: str) -> Metric: ) return self._metrics[name] - def add_metric(self, metric: Metric) -> Self: + def _resolve_trial_type(self, trial_type: str | None) -> str | None: + """Resolve an explicit or default trial type and validate it. + + Returns ``trial_type`` if explicitly provided (after validating via + ``supports_trial_type``), falls back to ``_default_trial_type`` when + available, and raises ``ValueError`` if this experiment uses trial types + (``_trial_type_to_metric_names`` is non-empty) but none could be + resolved. + + Args: + trial_type: The explicitly provided trial type, or ``None``. + + Returns: + The resolved trial type, which may be ``None`` for single-type + experiments. + + Raises: + ValueError: If ``trial_type`` is provided but not supported, or if + no trial type could be resolved for a multi-type experiment. + """ + if trial_type is not None: + if not self.supports_trial_type(trial_type): + raise ValueError(f"`{trial_type}` is not a supported trial type.") + return trial_type + if self._default_trial_type is not None: + return self._default_trial_type + if self._trial_type_to_metric_names: + raise ValueError( + "This experiment has trial-type-aware metrics but no " + "`trial_type` was specified and no `default_trial_type` is set. " + "Please specify a `trial_type`." + ) + return None + + def add_metric(self, metric: Metric, trial_type: str | None = None) -> Self: """Add a new metric to the experiment. Metrics that are not referenced by the experiment's optimization config @@ -851,54 +895,98 @@ def add_metric(self, metric: Metric) -> Self: Args: metric: Metric to be added. + trial_type: If provided, associates the metric with this trial type. + When ``None`` and a ``default_trial_type`` is set, defaults to + the default trial type. + + Raises: + ValueError: If the metric already exists, the trial type is not + supported, or trial types are in use but none could be resolved. """ if metric.name in self._metrics: raise ValueError( f"Metric `{metric.name}` already defined on experiment. " "Use `update_metric` to update an existing metric definition." ) + trial_type = self._resolve_trial_type(trial_type) + if trial_type is not None: + self._trial_type_to_metric_names.setdefault(trial_type, set()).add( + metric.name + ) self._metrics[metric.name] = metric return self - def add_tracking_metric(self, metric: Metric) -> Self: + def add_tracking_metric( + self, + metric: Metric, + trial_type: str | None = None, + canonical_name: str | None = None, + ) -> Self: """*Deprecated.* Use ``add_metric`` instead.""" warnings.warn( "add_tracking_metric is deprecated. Use add_metric instead.", DeprecationWarning, stacklevel=2, ) - return self.add_metric(metric) + return self.add_metric(metric, trial_type=trial_type) - def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment: + def add_tracking_metrics( + self, + metrics: list[Metric], + metrics_to_trial_types: dict[str, str] | None = None, + canonical_names: dict[str, str] | None = None, + ) -> Experiment: """*Deprecated.* Use ``add_metric`` instead.""" warnings.warn( "add_tracking_metrics is deprecated. Use add_metric instead.", DeprecationWarning, stacklevel=2, ) + metrics_to_trial_types = metrics_to_trial_types or {} for metric in metrics: - self.add_metric(metric) + canonical_name = (canonical_names or {}).get(metric.name) + self.add_tracking_metric( + metric=metric, + trial_type=metrics_to_trial_types.get(metric.name), + canonical_name=canonical_name, + ) return self - def update_metric(self, metric: Metric) -> Self: + def update_metric(self, metric: Metric, trial_type: str | None = None) -> Self: """Redefine a metric that already exists on the experiment. Args: metric: New metric definition. + trial_type: If provided, reassociates the metric with this trial + type. When ``None``, keeps the metric's existing trial type. """ if metric.name not in self._metrics: raise ValueError(f"Metric `{metric.name}` doesn't exist on experiment.") + if trial_type is not None: + trial_type = self._resolve_trial_type(trial_type) + # Remove from any existing trial type set + for names in self._trial_type_to_metric_names.values(): + names.discard(metric.name) + # Add to new trial type set + self._trial_type_to_metric_names.setdefault(trial_type, set()).add( + metric.name + ) self._metrics[metric.name] = metric return self - def update_tracking_metric(self, metric: Metric) -> Experiment: + def update_tracking_metric( + self, + metric: Metric, + trial_type: str | None = None, + canonical_name: str | None = None, + ) -> Experiment: """*Deprecated.* Use ``update_metric`` instead.""" warnings.warn( "update_tracking_metric is deprecated. Use update_metric instead.", DeprecationWarning, stacklevel=2, ) - return self.update_metric(metric) + return self.update_metric(metric, trial_type=trial_type) def remove_metric(self, metric_name: str) -> Self: """Remove a metric from the experiment. @@ -919,6 +1007,9 @@ def remove_metric(self, metric_name: str) -> Self: f"Metric `{metric_name}` is referenced by the optimization config " "and cannot be removed. Update the optimization config first." ) + # Clean up _trial_type_to_metric_names + for names in self._trial_type_to_metric_names.values(): + names.discard(metric_name) del self._metrics[metric_name] return self diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index 59c9ec28ad2..c746c82a5c2 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -96,16 +96,13 @@ def __init__( default_data_type=default_data_type, ) - # Ensure tracking metrics are registered in _metric_to_trial_type - # and _trial_type_to_metric_names. - # super().__init__ sets self._metrics directly, bypassing - # add_tracking_metric, so tracking metrics won't be in - # _metric_to_trial_type yet. + # Ensure tracking metrics are registered in _metric_to_trial_type. + # The base __init__ handles _trial_type_to_metric_names. for m in tracking_metrics or []: if m.name not in self._metric_to_trial_type: - tt = none_throws(self._default_trial_type) - self._metric_to_trial_type[m.name] = tt - self._trial_type_to_metric_names.setdefault(tt, set()).add(m.name) + self._metric_to_trial_type[m.name] = none_throws( + self._default_trial_type + ) def add_trial_type(self, trial_type: str, runner: Runner) -> Self: """Add a new trial_type to be supported by this experiment. @@ -127,12 +124,11 @@ def add_trial_type(self, trial_type: str, runner: Runner) -> Self: def optimization_config(self, optimization_config: OptimizationConfig) -> None: # pyre-fixme[16]: `Optional` has no attribute `fset`. Experiment.optimization_config.fset(self, optimization_config) + # Base setter handles _trial_type_to_metric_names; update legacy dict. for metric_name in optimization_config.metric_names: - # Optimization config metrics are required to be the default trial type - # currently. TODO: remove that restriction (T202797235) - tt = none_throws(self.default_trial_type) - self._metric_to_trial_type[metric_name] = tt - self._trial_type_to_metric_names.setdefault(tt, set()).add(metric_name) + self._metric_to_trial_type[metric_name] = none_throws( + self.default_trial_type + ) def update_runner(self, trial_type: str, runner: Runner) -> Self: """Update the default runner for an existing trial_type. @@ -163,56 +159,12 @@ def add_tracking_metric( """ if trial_type is None: trial_type = self._default_trial_type - if not self.supports_trial_type(trial_type): - raise ValueError(f"`{trial_type}` is not a supported trial type.") - - super().add_tracking_metric(metric) - tt = none_throws(trial_type) - self._metric_to_trial_type[metric.name] = tt - self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name) + self.add_metric(metric, trial_type=trial_type) + self._metric_to_trial_type[metric.name] = none_throws(trial_type) if canonical_name is not None: self._metric_to_canonical_name[metric.name] = canonical_name return self - def add_tracking_metrics( - self, - metrics: list[Metric], - metrics_to_trial_types: dict[str, str] | None = None, - canonical_names: dict[str, str] | None = None, - ) -> Experiment: - """Add a list of new metrics to the experiment. - - If any of the metrics are already defined on the experiment, - we raise an error and don't add any of them to the experiment - - Args: - metrics: Metrics to be added. - metrics_to_trial_types: The mapping from metric names to corresponding - trial types for each metric. If provided, the metrics will be - added to their trial types. If not provided, then the default - trial type will be used. - canonical_names: A mapping of metric names to their - canonical names(The default metrics for which the metrics are - proxies.) - - Returns: - The experiment with the added metrics. - """ - metrics_to_trial_types = metrics_to_trial_types or {} - canonical_name = None - for metric in metrics: - if canonical_names is not None: - canonical_name = none_throws(canonical_names).get(metric.name, None) - - self.add_tracking_metric( - metric=metric, - trial_type=metrics_to_trial_types.get( - metric.name, self._default_trial_type - ), - canonical_name=canonical_name, - ) - return self - def update_tracking_metric( self, metric: Metric, @@ -233,47 +185,17 @@ def update_tracking_metric( trial_type = self._metric_to_trial_type.get( metric.name, self._default_trial_type ) - oc = self.optimization_config - oc_metric_names = oc.metric_names if oc else set() - if metric.name in oc_metric_names and trial_type != self._default_trial_type: - raise ValueError( - f"Metric `{metric.name}` must remain a " - f"`{self._default_trial_type}` metric because it is part of the " - "optimization_config." - ) - elif not self.supports_trial_type(trial_type): - raise ValueError(f"`{trial_type}` is not a supported trial type.") - - super().update_tracking_metric(metric) - # Remove from old trial type set - old_tt = self._metric_to_trial_type.get(metric.name) - if old_tt is not None and old_tt in self._trial_type_to_metric_names: - self._trial_type_to_metric_names[old_tt].discard(metric.name) - # Add to new trial type set - tt = none_throws(trial_type) - self._metric_to_trial_type[metric.name] = tt - self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name) + self.update_metric(metric, trial_type=trial_type) + self._metric_to_trial_type[metric.name] = none_throws(trial_type) if canonical_name is not None: self._metric_to_canonical_name[metric.name] = canonical_name return self - @copy_doc(Experiment.remove_tracking_metric) - def remove_tracking_metric(self, metric_name: str) -> Self: - if metric_name not in self._metrics: - raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.") - - # Clean up _trial_type_to_metric_names - old_tt = self._metric_to_trial_type.get(metric_name) - if old_tt is not None and old_tt in self._trial_type_to_metric_names: - self._trial_type_to_metric_names[old_tt].discard(metric_name) - - # Required fields - del self._metrics[metric_name] - del self._metric_to_trial_type[metric_name] - - # Optional - if metric_name in self._metric_to_canonical_name: - del self._metric_to_canonical_name[metric_name] + @copy_doc(Experiment.remove_metric) + def remove_metric(self, metric_name: str) -> Self: + super().remove_metric(metric_name) + self._metric_to_trial_type.pop(metric_name, None) + self._metric_to_canonical_name.pop(metric_name, None) return self @copy_doc(Experiment.fetch_data) diff --git a/ax/core/tests/test_multi_type_experiment.py b/ax/core/tests/test_multi_type_experiment.py index b314cfd5b75..5418d5e73e8 100644 --- a/ax/core/tests/test_multi_type_experiment.py +++ b/ax/core/tests/test_multi_type_experiment.py @@ -125,12 +125,6 @@ def test_BadBehavior(self) -> None: with self.assertRaises(ValueError): self.experiment.remove_tracking_metric("m3") - # Try to change optimization metric to non-primary trial type - with self.assertRaises(ValueError): - self.experiment.update_tracking_metric( - BraninMetric("m1", ["x1", "x2"]), "type2" - ) - # Update metric definition for trial_type that doesn't exist with self.assertRaises(ValueError): self.experiment.update_tracking_metric( From 7ef2c6bea25b35da4cbe8ed59255007ad50897be Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 24 Mar 2026 17:18:12 -0700 Subject: [PATCH 3/3] Move add_trial_type, update_runner, supports_trial_type to base Experiment (#5003) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/5003 Phase 3 of moving MultiTypeExperiment features into base Experiment. Moves `add_trial_type` and `update_runner` from MultiTypeExperiment to the base Experiment class, making them available to all experiments. Updates `supports_trial_type` to unify the logic: for multi-type experiments (where `default_trial_type` is set), only trial types registered in `_trial_type_to_runner` are supported. For single-type experiments, `None` is supported along with `SHORT_RUN` and `LONG_RUN` for backward compatibility with generation strategies that use those trial types. Removes the corresponding overrides from MultiTypeExperiment — all three methods are now inherited from the base class. Reviewed By: saitcakmak Differential Revision: D94988577 --- ax/core/experiment.py | 38 ++++++++++++++++++++++++++++++-- ax/core/multi_type_experiment.py | 34 ---------------------------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 71c13289123..54e504b8e5d 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -2075,6 +2075,34 @@ def default_trials(self) -> set[int]: if trial.trial_type == self.default_trial_type } + def add_trial_type(self, trial_type: str, runner: Runner | None = None) -> Self: + """Add a new trial type to be supported by this experiment. + + Args: + trial_type: The new trial type to be added. + runner: The default runner for trials of this type. + """ + if self.supports_trial_type(trial_type): + raise ValueError(f"Experiment already contains trial_type `{trial_type}`") + + if runner is not None: + self._trial_type_to_runner[trial_type] = runner + + return self + + def update_runner(self, trial_type: str, runner: Runner) -> Self: + """Update the default runner for an existing trial type. + + Args: + trial_type: The trial type whose runner should be updated. + runner: The new runner for trials of this type. + """ + if not self.supports_trial_type(trial_type): + raise ValueError(f"Experiment does not contain trial_type `{trial_type}`") + self._trial_type_to_runner[trial_type] = runner + self._runner = runner + return self + def runner_for_trial_type(self, trial_type: str | None) -> Runner | None: """The default runner to use for a given trial type. @@ -2089,14 +2117,20 @@ def runner_for_trial_type(self, trial_type: str | None) -> Runner | None: def supports_trial_type(self, trial_type: str | None) -> bool: """Whether this experiment allows trials of the given type. - The base experiment class only supports None. For experiments - with multiple trial types, use the MultiTypeExperiment class. + For experiments with a ``default_trial_type`` (multi-type experiments), + only trial types registered in ``_trial_type_to_runner`` are supported. + For single-type experiments, ``None`` is always supported, along with + ``SHORT_RUN`` and ``LONG_RUN`` for backward compatibility with + generation strategies that use those trial types. """ + if self._default_trial_type is not None: + return trial_type in self._trial_type_to_runner return ( trial_type is None or trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN or trial_type == Keys.LILO_LABELING + or trial_type in self._trial_type_to_runner ) def attach_trial( diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index c746c82a5c2..62d734edf35 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -104,19 +104,6 @@ def __init__( self._default_trial_type ) - def add_trial_type(self, trial_type: str, runner: Runner) -> Self: - """Add a new trial_type to be supported by this experiment. - - Args: - trial_type: The new trial_type to be added. - runner: The default runner for trials of this type. - """ - if self.supports_trial_type(trial_type): - raise ValueError(f"Experiment already contains trial_type `{trial_type}`") - - self._trial_type_to_runner[trial_type] = runner - return self - # pyre does not support inferring the type of property setter decorators # or the `.fset` attribute on properties. # pyre-fixme[56]: Pyre was not able to infer the type of the decorator. @@ -130,20 +117,6 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: self.default_trial_type ) - def update_runner(self, trial_type: str, runner: Runner) -> Self: - """Update the default runner for an existing trial_type. - - Args: - trial_type: The new trial_type to be added. - runner: The new runner for trials of this type. - """ - if not self.supports_trial_type(trial_type): - raise ValueError(f"Experiment does not contain trial_type `{trial_type}`") - - self._trial_type_to_runner[trial_type] = runner - self._runner = runner - return self - def add_tracking_metric( self, metric: Metric, @@ -232,13 +205,6 @@ def _fetch_trial_data( # Invoke parent's fetch method using only metrics for this trial_type return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs) - def supports_trial_type(self, trial_type: str | None) -> bool: - """Whether this experiment allows trials of the given type. - - Only trial types defined in the trial_type_to_runner are allowed. - """ - return trial_type in self._trial_type_to_runner.keys() - def filter_trials_by_type( trials: Sequence[BaseTrial], trial_type: str | None