Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions ax/analysis/insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
from pyre_extensions import none_throws, override


# When the number of parameters exceeds this threshold, use total-order
# sensitivity analysis instead of second-order. Second-order computes pairwise
# interaction effects (O(p^2)) which becomes expensive for high-dimensional
# search spaces, while total-order captures each parameter's overall importance
# including all interactions at O(p) cost.
_MAX_NUM_PARAMS_FOR_SECOND_ORDER: int = 25

INSIGHTS_CARDGROUP_TITLE = "Insights Analysis"

INSIGHTS_CARDGROUP_SUBTITLE = (
Expand All @@ -33,6 +40,26 @@
)


def _choose_sensitivity_order(
num_params: int,
) -> Literal["first", "second", "total"]:
"""Choose the sensitivity analysis order based on parameter count.

- 1 parameter: first-order (second-order requires >= 2 for interaction
effects).
- Many parameters (> threshold): total-order to avoid the O(p^2) cost of
second-order pairwise interactions.
- Otherwise: second-order to surface pairwise interactions for contour
plots.
"""
if num_params == 1:
return "first"
elif num_params > _MAX_NUM_PARAMS_FOR_SECOND_ORDER:
return "total"
else:
return "second"


@final
class InsightsAnalysis(Analysis):
"""
Expand Down Expand Up @@ -116,18 +143,16 @@ def compute(
# For non-bandit experiments, for each objective and constraint, compute a
# sensitivity analysis and plot the top 3 surfaces.
else:
# Default to second-order sensitivity analysis, but fall back to first-order
# if there is only one parameter (second-order requires at least 2
# parameters for interaction effects).
order: Literal["first", "second"] = (
"first" if len(experiment.search_space.parameters) == 1 else "second"
num_params = len(experiment.search_space.parameters)
sensitivity_order: Literal["first", "second", "total"] = (
_choose_sensitivity_order(num_params=num_params)
)
top_surfaces_groups = [
TopSurfacesAnalysis(
metric_name=metric_name,
top_k=3,
relativize=relativize,
order=order,
order=sensitivity_order,
).compute_or_error_card(
experiment=experiment,
generation_strategy=generation_strategy,
Expand Down
54 changes: 53 additions & 1 deletion ax/analysis/tests/test_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@


from datetime import datetime
from unittest.mock import patch

import pandas as pd
from ax.adapter.base import Adapter
from ax.adapter.registry import Generators
from ax.analysis.insights import InsightsAnalysis
from ax.analysis.insights import _MAX_NUM_PARAMS_FOR_SECOND_ORDER, InsightsAnalysis
from ax.analysis.overview import OverviewAnalysis
from ax.analysis.plotly.arm_effects import ArmEffectsPlot
from ax.analysis.plotly.scatter import ScatterPlot
from ax.analysis.plotly.top_surfaces import TopSurfacesAnalysis
from ax.analysis.results import ResultsAnalysis
from ax.api.client import Client
from ax.api.configs import RangeParameterConfig
Expand Down Expand Up @@ -448,3 +450,53 @@ def test_insights_analysis_single_parameter(self) -> None:
# Check that none of the cards are error cards
for card in all_cards:
self.assertNotIsInstance(card, ErrorAnalysisCard)

@mock_botorch_optimize
def test_insights_analysis_many_parameters_uses_total_order(self) -> None:
"""Test that InsightsAnalysis uses total-order sensitivity for
high-dimensional experiments (> _MAX_NUM_PARAMS_FOR_SECOND_ORDER
parameters) to avoid the O(p^2) cost of second-order analysis.
"""
num_params = _MAX_NUM_PARAMS_FOR_SECOND_ORDER + 1
client = Client()
client.configure_experiment(
name="many_params",
parameters=[
RangeParameterConfig(
name=f"x{i}",
bounds=(0.0, 1.0),
parameter_type="float",
)
for i in range(num_params)
],
)
client.configure_optimization(objective="objective_metric")

for _ in range(num_params + 2):
for trial_index, parameters in client.get_next_trials(max_trials=1).items():
client.complete_trial(
trial_index=trial_index,
raw_data={
"objective_metric": sum(float(v) for v in parameters.values()),
},
)

# Patch TopSurfacesAnalysis.__init__ to capture the order argument
original_init: object = TopSurfacesAnalysis.__init__
captured_orders: list[object] = []

def patched_init(self_inner: TopSurfacesAnalysis, **kwargs: object) -> None:
captured_orders.append(kwargs.get("order", "second"))
# pyre-ignore[29]: `object` is not callable
original_init(self_inner, **kwargs)

with patch.object(TopSurfacesAnalysis, "__init__", patched_init):
InsightsAnalysis().compute(
experiment=client._experiment,
generation_strategy=client._generation_strategy,
)

# Should use total-order for many parameters
self.assertGreater(len(captured_orders), 0)
for order in captured_orders:
self.assertEqual(order, "total")
Loading