Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ax/analysis/healthcheck/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ax.analysis.healthcheck.regression_analysis import RegressionAnalysis
from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis
from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis

__all__ = [
"create_healthcheck_analysis_card",
Expand All @@ -39,4 +40,5 @@
"ComplexityRatingAnalysis",
"PredictableMetricsAnalysis",
"BaselineImprovementAnalysis",
"TransferLearningAnalysis",
]
164 changes: 164 additions & 0 deletions ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from unittest.mock import patch

from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis
from ax.core.auxiliary import TransferLearningMetadata
from ax.core.experiment import Experiment
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
from ax.utils.common.testutils import TestCase


def _make_experiment(
param_names: list[str],
experiment_type: str | None = None,
) -> Experiment:
"""Create a simple experiment with the given parameter names."""
return Experiment(
search_space=SearchSpace(
parameters=[
RangeParameter(
name=name,
parameter_type=ParameterType.FLOAT,
lower=0.0,
upper=1.0,
)
for name in param_names
]
),
name="test_experiment",
experiment_type=experiment_type,
)


_MOCK_TARGET = "ax.storage.sqa_store.load.identify_transferable_experiments"


class TestTransferLearningAnalysis(TestCase):
def test_no_experiment_type_returns_pass(self) -> None:
"""When no experiment_type is set and no experiment_types provided,
return PASS."""
experiment = _make_experiment(["x1", "x2"], experiment_type=None)
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
self.assertTrue(card.is_passing())
self.assertIn("No experiment type set", card.subtitle)

@patch(_MOCK_TARGET, return_value={})
def test_no_candidates_returns_pass(self, mock_identify: object) -> None:
experiment = _make_experiment(["x1", "x2"], experiment_type="my_type")
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
self.assertTrue(card.is_passing())
self.assertTrue(card.df.empty)

@patch(_MOCK_TARGET)
def test_single_candidate_returns_warning(self, mock_identify: object) -> None:
experiment = _make_experiment(
["x1", "x2", "x3", "x4", "x5"], experiment_type="my_type"
)
mock_identify.return_value = { # pyre-ignore[16]
"source_exp": TransferLearningMetadata(
overlap_parameters=["x1", "x2", "x3", "x4"],
),
}
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
self.assertFalse(card.is_passing())
self.assertIn("source_exp", card.subtitle)
self.assertIn("80.0%", card.subtitle)
self.assertEqual(len(card.df), 1)
self.assertEqual(card.df.iloc[0]["Experiment"], "source_exp")
self.assertEqual(card.df.iloc[0]["Overlapping Parameters"], 4)
self.assertEqual(card.df.iloc[0]["Overlap (%)"], 80.0)

@patch(_MOCK_TARGET)
def test_multiple_candidates_preserves_order(self, mock_identify: object) -> None:
"""Results should preserve the order from identify_transferable_experiments
(sorted by overlap then recency)."""
experiment = _make_experiment(
["x1", "x2", "x3", "x4", "x5"], experiment_type="my_type"
)
# Mock returns already-sorted results (as identify_transferable_experiments
# now handles sorting by overlap desc, then recency desc).
mock_identify.return_value = { # pyre-ignore[16]
"exp_high": TransferLearningMetadata(
overlap_parameters=["x1", "x2", "x3", "x4"],
),
"exp_mid": TransferLearningMetadata(
overlap_parameters=["x1", "x2", "x3"],
),
"exp_low": TransferLearningMetadata(
overlap_parameters=["x1"],
),
}
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)

# Verify order is preserved from identify_transferable_experiments
self.assertEqual(card.df.iloc[0]["Experiment"], "exp_high")
self.assertEqual(card.df.iloc[0]["Overlapping Parameters"], 4)
self.assertEqual(card.df.iloc[1]["Experiment"], "exp_mid")
self.assertEqual(card.df.iloc[1]["Overlapping Parameters"], 3)
self.assertEqual(card.df.iloc[2]["Experiment"], "exp_low")
self.assertEqual(card.df.iloc[2]["Overlapping Parameters"], 1)

# All experiments listed in subtitle
self.assertIn("exp_high", card.subtitle)
self.assertIn("exp_mid", card.subtitle)
self.assertIn("exp_low", card.subtitle)
self.assertIn("We found **3 eligible source experiment(s)**", card.subtitle)

@patch(_MOCK_TARGET)
def test_percentage_calculation(self, mock_identify: object) -> None:
experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type")
mock_identify.return_value = { # pyre-ignore[16]
"exp_a": TransferLearningMetadata(
overlap_parameters=["x1"],
),
}
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.df.iloc[0]["Overlap (%)"], 33.3)

@patch(_MOCK_TARGET)
def test_parameters_listed_alphabetically(self, mock_identify: object) -> None:
experiment = _make_experiment(
["alpha", "beta", "gamma", "delta"], experiment_type="my_type"
)
mock_identify.return_value = { # pyre-ignore[16]
"exp_a": TransferLearningMetadata(
overlap_parameters=["gamma", "alpha", "delta"],
),
}
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.df.iloc[0]["Parameters"], "alpha, delta, gamma")

def test_requires_experiment(self) -> None:
analysis = TransferLearningAnalysis()
with self.assertRaises(UserInputError):
analysis.compute(experiment=None)

@patch(_MOCK_TARGET, return_value={})
def test_experiment_name_passed_to_identify(self, mock_identify: object) -> None:
"""Verify that experiment.name is forwarded to
identify_transferable_experiments so it can filter the target out."""
experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type")
analysis = TransferLearningAnalysis()
analysis.compute(experiment=experiment)
mock_identify.assert_called_once() # pyre-ignore[16]
call_kwargs = mock_identify.call_args.kwargs # pyre-ignore[16]
self.assertEqual(call_kwargs["experiment_name"], "test_experiment")
152 changes: 152 additions & 0 deletions ax/analysis/healthcheck/transfer_learning_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

import json
from typing import final, TYPE_CHECKING

import markdown as md
import pandas as pd
from ax.adapter.base import Adapter
from ax.analysis.analysis import Analysis
from ax.analysis.healthcheck.healthcheck_analysis import (
create_healthcheck_analysis_card,
HealthcheckAnalysisCard,
HealthcheckStatus,
)
from ax.core.experiment import Experiment
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from pyre_extensions import override

if TYPE_CHECKING:
from ax.storage.sqa_store.sqa_config import SQAConfig


class TransferLearningAnalysisCard(HealthcheckAnalysisCard):
"""HealthcheckAnalysisCard with markdown-aware rendering for notebooks."""

def _body_html(self, depth: int) -> str:
parts = [md.markdown(self.subtitle)]
if not self.df.empty:
parts.append(self.df.to_html(index=False))
return f"<div class='content'>{''.join(parts)}</div>"


@final
class TransferLearningAnalysis(Analysis):
def __init__(
self,
experiment_types: list[str] | None = None,
overlap_threshold: float = 0.50,
max_num_exps: int = 10,
config: SQAConfig | None = None,
) -> None:
self.experiment_types = experiment_types
self.overlap_threshold = overlap_threshold
self.max_num_exps = max_num_exps
self.config = config

@override
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> HealthcheckAnalysisCard:
if experiment is None:
raise UserInputError(
"TransferLearningAnalysis requires a non-null experiment to compute "
"overlap percentages. Please provide an experiment."
)

# Determine experiment types to query for.
experiment_types = self.experiment_types
if experiment_types is None:
if experiment.experiment_type is None:
return create_healthcheck_analysis_card(
name=self.__class__.__name__,
title="Transfer Learning Eligibility",
subtitle=(
"No experiment type set on this experiment. "
"Cannot search for transferable experiments."
),
df=pd.DataFrame(),
status=HealthcheckStatus.PASS,
)
experiment_types = [experiment.experiment_type]

# Lazy import to avoid circular dependency (sqa_store depends on
# healthcheck_analysis).
from ax.storage.sqa_store.load import identify_transferable_experiments

transferable_experiments = identify_transferable_experiments(
search_space=experiment.search_space,
experiment_types=experiment_types,
overlap_threshold=self.overlap_threshold,
max_num_exps=self.max_num_exps,
config=self.config,
experiment_name=experiment.name,
)

if not transferable_experiments:
return create_healthcheck_analysis_card(
name=self.__class__.__name__,
title="Transfer Learning Eligibility",
subtitle="No eligible source experiments found for transfer learning.",
df=pd.DataFrame(),
status=HealthcheckStatus.PASS,
)

total_parameters = len(experiment.search_space.parameters)

rows = []
for exp_name, metadata in transferable_experiments.items():
overlap_count = len(metadata.overlap_parameters)
overlap_pct = (
(overlap_count / total_parameters * 100)
if total_parameters > 0
else 0.0
)
rows.append(
{
"Experiment": exp_name,
"Overlapping Parameters": overlap_count,
"Overlap (%)": round(overlap_pct, 1),
"Parameters": ", ".join(sorted(metadata.overlap_parameters)),
}
)

df = pd.DataFrame(rows)

n = len(rows)
exp_lines = "\n".join(
f"- **{r['Experiment']}** ({r['Overlap (%)']:.1f}% parameter overlap)"
for r in rows
)
subtitle = (
"Transfer learning can improve optimization by leveraging data "
"from similar past experiments. We found "
f"**{n} eligible source experiment(s)** "
"for transfer learning:\n\n"
f"{exp_lines}\n\n"
"Caution: Only use source experiments that are closely related "
"to your current experiment. "
"Using data from unrelated experiments can lead to negative "
"transfer, which may hurt "
"optimization performance. Review the overlapping parameters "
"before enabling transfer learning."
)

return TransferLearningAnalysisCard(
name=self.__class__.__name__,
title="Transfer Learning Eligibility",
subtitle=subtitle,
df=df,
blob=json.dumps({"status": HealthcheckStatus.WARNING}),
)
4 changes: 4 additions & 0 deletions ax/analysis/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ax.analysis.healthcheck.predictable_metrics import PredictableMetricsAnalysis
from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis
from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis
from ax.analysis.insights import InsightsAnalysis
from ax.analysis.results import ResultsAnalysis
from ax.analysis.trials import AllTrialsAnalysis
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
options: OrchestratorOptions | None = None,
tier_metadata: dict[str, Any] | None = None,
model_fit_threshold: float | None = None,
sqa_config: Any = None,
) -> None:
super().__init__()
self.can_generate = can_generate
Expand All @@ -124,6 +126,7 @@ def __init__(
self.options = options
self.tier_metadata = tier_metadata
self.model_fit_threshold = model_fit_threshold
self.sqa_config = sqa_config

@override
def validate_applicable_state(
Expand Down Expand Up @@ -229,6 +232,7 @@ def compute(
if not has_batch_trials
else None,
BaselineImprovementAnalysis() if not has_batch_trials else None,
TransferLearningAnalysis(config=self.sqa_config),
*[
SearchSpaceAnalysis(trial_index=trial.index)
for trial in candidate_trials
Expand Down
Loading