Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions ax/analysis/plotly/tests/test_top_surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,69 @@ def test_validate_applicable_state(self) -> None:
none_throws(TopSurfacesAnalysis().validate_applicable_state()),
)

client = Client()
client.configure_experiment(
name="foo",
parameters=[
RangeParameterConfig(
name="x1",
parameter_type="float",
bounds=(0, 1),
),
RangeParameterConfig(
name="x2",
parameter_type="float",
bounds=(0, 1),
),
],
)
client.configure_optimization(objective="bar")

for _ in range(1):
for trial_index, parameterization in client.get_next_trials(
max_trials=1
).items():
client.complete_trial(
trial_index=trial_index,
raw_data={
"bar": assert_is_instance(parameterization["x1"], float)
- 2 * assert_is_instance(parameterization["x2"], float)
},
)

self.assertIn(
"Ax has not yet reached a GenerationNode",
none_throws(
TopSurfacesAnalysis(
metric_name="bar", order="first"
).validate_applicable_state(
client._experiment, client._generation_strategy
)
),
)
for _ in range(5):
for trial_index, parameterization in client.get_next_trials(
max_trials=1
).items():
client.complete_trial(
trial_index=trial_index,
raw_data={
"bar": assert_is_instance(parameterization["x1"], float)
- 2 * assert_is_instance(parameterization["x2"], float)
},
)

self.assertIn(
"no data for metrics {'baz'}",
none_throws(
TopSurfacesAnalysis(
metric_name="baz", order="first"
).validate_applicable_state(
client._experiment, client._generation_strategy
)
),
)

@mock_botorch_optimize
def test_compute(self) -> None:
client = Client()
Expand Down
63 changes: 57 additions & 6 deletions ax/analysis/plotly/top_surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from typing import final, Literal

from ax.adapter.base import Adapter
from ax.adapter.torch import TorchAdapter
from ax.analysis.analysis import Analysis
from ax.analysis.analysis_card import (
AnalysisCard,
AnalysisCardBase,
AnalysisCardGroup,
ErrorAnalysisCard,
)
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard
from ax.analysis.plotly.sensitivity import SensitivityAnalysisPlot
from ax.analysis.plotly.surface.contour import (
CONTOUR_CARDGROUP_SUBTITLE,
Expand All @@ -27,8 +29,13 @@
SlicePlot,
)
from ax.analysis.plotly.utils import select_metric
from ax.analysis.utils import validate_experiment
from ax.analysis.utils import (
extract_relevant_adapter,
validate_experiment,
validate_experiment_has_trials,
)
from ax.core.experiment import Experiment
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from pyre_extensions import assert_is_instance, none_throws, override

Expand Down Expand Up @@ -68,14 +75,46 @@ def validate_applicable_state(
adapter: Adapter | None = None,
) -> str | None:
"""
TopSurfacesAnalysis requires an experiment with trials and data.
TopSurfacesAnalysis requires an experiment with trials and data as well as
a TorchAdapter.
"""
if self.metric_name is None:
return validate_experiment(
if (
experiment_invalid_reason := validate_experiment(
experiment=experiment,
require_trials=True,
require_data=True,
)
) is not None:
return experiment_invalid_reason

metric_name = (
self.metric_name
if self.metric_name is not None
else select_metric(experiment=none_throws(experiment))
)

if (
experiment_invalid_reason := validate_experiment_has_trials(
experiment=none_throws(experiment),
required_metric_names=[metric_name],
# Any trial indices and statuses will do since we use all data here
trial_indices=None,
trial_statuses=None,
)
) is not None:
return experiment_invalid_reason

try:
relevant_adapter = extract_relevant_adapter(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)

if not isinstance(relevant_adapter, TorchAdapter):
return f"TorchAdapter is required, found {type(relevant_adapter)}."
except UserInputError as e:
return e.message

@override
def compute(
Expand All @@ -93,15 +132,27 @@ def compute(
# Process the sensitivity analysis card to find the top K surfaces which
# consist exclusively of tunable parameters (i.e. no fixed parameters, task
# parameters, or OneHot parameters).
sensitivity_analysis_card = SensitivityAnalysisPlot(
maybe_sensitivity_analysis_card = SensitivityAnalysisPlot(
metric_name=metric_name,
order=self.order,
top_k=self.top_k,
).compute(
).compute_result(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)

if maybe_sensitivity_analysis_card.is_err():
err = none_throws(maybe_sensitivity_analysis_card.err)
raise err.exception or RuntimeError(
"Failed to compute SensitivityAnalysisPlot"
f"({metric_name=}, {self.order=}, {self.top_k=})"
)

sensitivity_analysis_card = assert_is_instance(
maybe_sensitivity_analysis_card.ok, PlotlyAnalysisCard
)

children: list[AnalysisCardBase] = [sensitivity_analysis_card]

sensitivity_df = sensitivity_analysis_card.df.copy()
Expand Down