From 42d668522edceee3237f62bb5ab15fc43bb2a0e9 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Fri, 14 Nov 2025 11:04:55 -0800 Subject: [PATCH] Tighten TopSurfacesAnalysis.validate_applicable_state Summary: Add checks for experiment having trials, data, that the metric name is present on the data, and that the current adapter is the TorchAdapter since Sensitivity will fail if the adapter is anything but the TorchAdapter Differential Revision: D87089811 --- ax/analysis/plotly/tests/test_top_surfaces.py | 63 +++++++++++++++++++ ax/analysis/plotly/top_surfaces.py | 63 +++++++++++++++++-- 2 files changed, 120 insertions(+), 6 deletions(-) diff --git a/ax/analysis/plotly/tests/test_top_surfaces.py b/ax/analysis/plotly/tests/test_top_surfaces.py index 58140d6e0c2..81da374dcd5 100644 --- a/ax/analysis/plotly/tests/test_top_surfaces.py +++ b/ax/analysis/plotly/tests/test_top_surfaces.py @@ -24,6 +24,69 @@ def test_validate_applicable_state(self) -> None: none_throws(TopSurfacesAnalysis().validate_applicable_state()), ) + client = Client() + client.configure_experiment( + name="foo", + parameters=[ + RangeParameterConfig( + name="x1", + parameter_type="float", + bounds=(0, 1), + ), + RangeParameterConfig( + name="x2", + parameter_type="float", + bounds=(0, 1), + ), + ], + ) + client.configure_optimization(objective="bar") + + for _ in range(1): + for trial_index, parameterization in client.get_next_trials( + max_trials=1 + ).items(): + client.complete_trial( + trial_index=trial_index, + raw_data={ + "bar": assert_is_instance(parameterization["x1"], float) + - 2 * assert_is_instance(parameterization["x2"], float) + }, + ) + + self.assertIn( + "Ax has not yet reached a GenerationNode", + none_throws( + TopSurfacesAnalysis( + metric_name="bar", order="first" + ).validate_applicable_state( + client._experiment, client._generation_strategy + ) + ), + ) + for _ in range(5): + for trial_index, parameterization in client.get_next_trials( + max_trials=1 + ).items(): + client.complete_trial( + trial_index=trial_index, + raw_data={ + "bar": assert_is_instance(parameterization["x1"], float) + - 2 * assert_is_instance(parameterization["x2"], float) + }, + ) + + self.assertIn( + "no data for metrics {'baz'}", + none_throws( + TopSurfacesAnalysis( + metric_name="baz", order="first" + ).validate_applicable_state( + client._experiment, client._generation_strategy + ) + ), + ) + @mock_botorch_optimize def test_compute(self) -> None: client = Client() diff --git a/ax/analysis/plotly/top_surfaces.py b/ax/analysis/plotly/top_surfaces.py index 7aa678172cd..ccbd23fce8a 100644 --- a/ax/analysis/plotly/top_surfaces.py +++ b/ax/analysis/plotly/top_surfaces.py @@ -8,6 +8,7 @@ from typing import final, Literal from ax.adapter.base import Adapter +from ax.adapter.torch import TorchAdapter from ax.analysis.analysis import Analysis from ax.analysis.analysis_card import ( AnalysisCard, @@ -15,6 +16,7 @@ AnalysisCardGroup, ErrorAnalysisCard, ) +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard from ax.analysis.plotly.sensitivity import SensitivityAnalysisPlot from ax.analysis.plotly.surface.contour import ( CONTOUR_CARDGROUP_SUBTITLE, @@ -27,8 +29,13 @@ SlicePlot, ) from ax.analysis.plotly.utils import select_metric -from ax.analysis.utils import validate_experiment +from ax.analysis.utils import ( + extract_relevant_adapter, + validate_experiment, + validate_experiment_has_trials, +) from ax.core.experiment import Experiment +from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy from pyre_extensions import assert_is_instance, none_throws, override @@ -68,14 +75,46 @@ def validate_applicable_state( adapter: Adapter | None = None, ) -> str | None: """ - TopSurfacesAnalysis requires an experiment with trials and data. + TopSurfacesAnalysis requires an experiment with trials and data as well as + a TorchAdapter. """ - if self.metric_name is None: - return validate_experiment( + if ( + experiment_invalid_reason := validate_experiment( experiment=experiment, require_trials=True, require_data=True, ) + ) is not None: + return experiment_invalid_reason + + metric_name = ( + self.metric_name + if self.metric_name is not None + else select_metric(experiment=none_throws(experiment)) + ) + + if ( + experiment_invalid_reason := validate_experiment_has_trials( + experiment=none_throws(experiment), + required_metric_names=[metric_name], + # Any trial indices and statuses will do since we use all data here + trial_indices=None, + trial_statuses=None, + ) + ) is not None: + return experiment_invalid_reason + + try: + relevant_adapter = extract_relevant_adapter( + experiment=experiment, + generation_strategy=generation_strategy, + adapter=adapter, + ) + + if not isinstance(relevant_adapter, TorchAdapter): + return f"TorchAdapter is required, found {type(relevant_adapter)}." + except UserInputError as e: + return e.message @override def compute( @@ -93,15 +132,27 @@ def compute( # Process the sensitivity analysis card to find the top K surfaces which # consist exclusively of tunable parameters (i.e. no fixed parameters, task # parameters, or OneHot parameters). - sensitivity_analysis_card = SensitivityAnalysisPlot( + maybe_sensitivity_analysis_card = SensitivityAnalysisPlot( metric_name=metric_name, order=self.order, top_k=self.top_k, - ).compute( + ).compute_result( experiment=experiment, generation_strategy=generation_strategy, adapter=adapter, ) + + if maybe_sensitivity_analysis_card.is_err(): + err = none_throws(maybe_sensitivity_analysis_card.err) + raise err.exception or RuntimeError( + "Failed to compute SensitivityAnalysisPlot" + f"({metric_name=}, {self.order=}, {self.top_k=})" + ) + + sensitivity_analysis_card = assert_is_instance( + maybe_sensitivity_analysis_card.ok, PlotlyAnalysisCard + ) + children: list[AnalysisCardBase] = [sensitivity_analysis_card] sensitivity_df = sensitivity_analysis_card.df.copy()