diff --git a/ax/analysis/plotly/tests/test_top_surfaces.py b/ax/analysis/plotly/tests/test_top_surfaces.py index 58140d6e0c2..81da374dcd5 100644 --- a/ax/analysis/plotly/tests/test_top_surfaces.py +++ b/ax/analysis/plotly/tests/test_top_surfaces.py @@ -24,6 +24,69 @@ def test_validate_applicable_state(self) -> None: none_throws(TopSurfacesAnalysis().validate_applicable_state()), ) + client = Client() + client.configure_experiment( + name="foo", + parameters=[ + RangeParameterConfig( + name="x1", + parameter_type="float", + bounds=(0, 1), + ), + RangeParameterConfig( + name="x2", + parameter_type="float", + bounds=(0, 1), + ), + ], + ) + client.configure_optimization(objective="bar") + + for _ in range(1): + for trial_index, parameterization in client.get_next_trials( + max_trials=1 + ).items(): + client.complete_trial( + trial_index=trial_index, + raw_data={ + "bar": assert_is_instance(parameterization["x1"], float) + - 2 * assert_is_instance(parameterization["x2"], float) + }, + ) + + self.assertIn( + "Ax has not yet reached a GenerationNode", + none_throws( + TopSurfacesAnalysis( + metric_name="bar", order="first" + ).validate_applicable_state( + client._experiment, client._generation_strategy + ) + ), + ) + for _ in range(5): + for trial_index, parameterization in client.get_next_trials( + max_trials=1 + ).items(): + client.complete_trial( + trial_index=trial_index, + raw_data={ + "bar": assert_is_instance(parameterization["x1"], float) + - 2 * assert_is_instance(parameterization["x2"], float) + }, + ) + + self.assertIn( + "no data for metrics {'baz'}", + none_throws( + TopSurfacesAnalysis( + metric_name="baz", order="first" + ).validate_applicable_state( + client._experiment, client._generation_strategy + ) + ), + ) + @mock_botorch_optimize def test_compute(self) -> None: client = Client() diff --git a/ax/analysis/plotly/top_surfaces.py b/ax/analysis/plotly/top_surfaces.py index 7aa678172cd..ccbd23fce8a 100644 --- a/ax/analysis/plotly/top_surfaces.py +++ b/ax/analysis/plotly/top_surfaces.py @@ -8,6 +8,7 @@ from typing import final, Literal from ax.adapter.base import Adapter +from ax.adapter.torch import TorchAdapter from ax.analysis.analysis import Analysis from ax.analysis.analysis_card import ( AnalysisCard, @@ -15,6 +16,7 @@ AnalysisCardGroup, ErrorAnalysisCard, ) +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard from ax.analysis.plotly.sensitivity import SensitivityAnalysisPlot from ax.analysis.plotly.surface.contour import ( CONTOUR_CARDGROUP_SUBTITLE, @@ -27,8 +29,13 @@ SlicePlot, ) from ax.analysis.plotly.utils import select_metric -from ax.analysis.utils import validate_experiment +from ax.analysis.utils import ( + extract_relevant_adapter, + validate_experiment, + validate_experiment_has_trials, +) from ax.core.experiment import Experiment +from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy from pyre_extensions import assert_is_instance, none_throws, override @@ -68,14 +75,46 @@ def validate_applicable_state( adapter: Adapter | None = None, ) -> str | None: """ - TopSurfacesAnalysis requires an experiment with trials and data. + TopSurfacesAnalysis requires an experiment with trials and data as well as + a TorchAdapter. """ - if self.metric_name is None: - return validate_experiment( + if ( + experiment_invalid_reason := validate_experiment( experiment=experiment, require_trials=True, require_data=True, ) + ) is not None: + return experiment_invalid_reason + + metric_name = ( + self.metric_name + if self.metric_name is not None + else select_metric(experiment=none_throws(experiment)) + ) + + if ( + experiment_invalid_reason := validate_experiment_has_trials( + experiment=none_throws(experiment), + required_metric_names=[metric_name], + # Any trial indices and statuses will do since we use all data here + trial_indices=None, + trial_statuses=None, + ) + ) is not None: + return experiment_invalid_reason + + try: + relevant_adapter = extract_relevant_adapter( + experiment=experiment, + generation_strategy=generation_strategy, + adapter=adapter, + ) + + if not isinstance(relevant_adapter, TorchAdapter): + return f"TorchAdapter is required, found {type(relevant_adapter)}." + except UserInputError as e: + return e.message @override def compute( @@ -93,15 +132,27 @@ def compute( # Process the sensitivity analysis card to find the top K surfaces which # consist exclusively of tunable parameters (i.e. no fixed parameters, task # parameters, or OneHot parameters). - sensitivity_analysis_card = SensitivityAnalysisPlot( + maybe_sensitivity_analysis_card = SensitivityAnalysisPlot( metric_name=metric_name, order=self.order, top_k=self.top_k, - ).compute( + ).compute_result( experiment=experiment, generation_strategy=generation_strategy, adapter=adapter, ) + + if maybe_sensitivity_analysis_card.is_err(): + err = none_throws(maybe_sensitivity_analysis_card.err) + raise err.exception or RuntimeError( + "Failed to compute SensitivityAnalysisPlot" + f"({metric_name=}, {self.order=}, {self.top_k=})" + ) + + sensitivity_analysis_card = assert_is_instance( + maybe_sensitivity_analysis_card.ok, PlotlyAnalysisCard + ) + children: list[AnalysisCardBase] = [sensitivity_analysis_card] sensitivity_df = sensitivity_analysis_card.df.copy()