From 6c824ce3f353cc0227577062dbc5f0d1d8a9ef74 Mon Sep 17 00:00:00 2001 From: Eric Onofrey Date: Tue, 3 Mar 2026 15:40:49 -0800 Subject: [PATCH] TransferLearningAnalysis (#4918) Summary: Analysis card to show transferrable experiments with a default of 25% parameter overlap. Reviewed By: mpolson64 Differential Revision: D92926519 --- ax/analysis/healthcheck/__init__.py | 2 + .../tests/test_transfer_learning_analysis.py | 164 ++++++++++++++++++ .../healthcheck/transfer_learning_analysis.py | 152 ++++++++++++++++ ax/analysis/overview.py | 4 + ax/storage/sqa_store/load.py | 55 +++--- ax/storage/sqa_store/tests/test_sqa_store.py | 2 +- 6 files changed, 358 insertions(+), 21 deletions(-) create mode 100644 ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py create mode 100644 ax/analysis/healthcheck/transfer_learning_analysis.py diff --git a/ax/analysis/healthcheck/__init__.py b/ax/analysis/healthcheck/__init__.py index 87f819e7436..c51edcf57fd 100644 --- a/ax/analysis/healthcheck/__init__.py +++ b/ax/analysis/healthcheck/__init__.py @@ -24,6 +24,7 @@ from ax.analysis.healthcheck.regression_analysis import RegressionAnalysis from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates +from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis __all__ = [ "create_healthcheck_analysis_card", @@ -39,4 +40,5 @@ "ComplexityRatingAnalysis", "PredictableMetricsAnalysis", "BaselineImprovementAnalysis", + "TransferLearningAnalysis", ] diff --git a/ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py b/ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py new file mode 100644 index 00000000000..ec34fa048b0 --- /dev/null +++ b/ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from unittest.mock import patch + +from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus +from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis +from ax.core.auxiliary import TransferLearningMetadata +from ax.core.experiment import Experiment +from ax.core.parameter import ParameterType, RangeParameter +from ax.core.search_space import SearchSpace +from ax.exceptions.core import UserInputError +from ax.utils.common.testutils import TestCase + + +def _make_experiment( + param_names: list[str], + experiment_type: str | None = None, +) -> Experiment: + """Create a simple experiment with the given parameter names.""" + return Experiment( + search_space=SearchSpace( + parameters=[ + RangeParameter( + name=name, + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + ) + for name in param_names + ] + ), + name="test_experiment", + experiment_type=experiment_type, + ) + + +_MOCK_TARGET = "ax.storage.sqa_store.load.identify_transferable_experiments" + + +class TestTransferLearningAnalysis(TestCase): + def test_no_experiment_type_returns_pass(self) -> None: + """When no experiment_type is set and no experiment_types provided, + return PASS.""" + experiment = _make_experiment(["x1", "x2"], experiment_type=None) + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.get_status(), HealthcheckStatus.PASS) + self.assertTrue(card.is_passing()) + self.assertIn("No experiment type set", card.subtitle) + + @patch(_MOCK_TARGET, return_value={}) + def test_no_candidates_returns_pass(self, mock_identify: object) -> None: + experiment = _make_experiment(["x1", "x2"], experiment_type="my_type") + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.get_status(), HealthcheckStatus.PASS) + self.assertTrue(card.is_passing()) + self.assertTrue(card.df.empty) + + @patch(_MOCK_TARGET) + def test_single_candidate_returns_warning(self, mock_identify: object) -> None: + experiment = _make_experiment( + ["x1", "x2", "x3", "x4", "x5"], experiment_type="my_type" + ) + mock_identify.return_value = { # pyre-ignore[16] + "source_exp": TransferLearningMetadata( + overlap_parameters=["x1", "x2", "x3", "x4"], + ), + } + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) + self.assertFalse(card.is_passing()) + self.assertIn("source_exp", card.subtitle) + self.assertIn("80.0%", card.subtitle) + self.assertEqual(len(card.df), 1) + self.assertEqual(card.df.iloc[0]["Experiment"], "source_exp") + self.assertEqual(card.df.iloc[0]["Overlapping Parameters"], 4) + self.assertEqual(card.df.iloc[0]["Overlap (%)"], 80.0) + + @patch(_MOCK_TARGET) + def test_multiple_candidates_preserves_order(self, mock_identify: object) -> None: + """Results should preserve the order from identify_transferable_experiments + (sorted by overlap then recency).""" + experiment = _make_experiment( + ["x1", "x2", "x3", "x4", "x5"], experiment_type="my_type" + ) + # Mock returns already-sorted results (as identify_transferable_experiments + # now handles sorting by overlap desc, then recency desc). + mock_identify.return_value = { # pyre-ignore[16] + "exp_high": TransferLearningMetadata( + overlap_parameters=["x1", "x2", "x3", "x4"], + ), + "exp_mid": TransferLearningMetadata( + overlap_parameters=["x1", "x2", "x3"], + ), + "exp_low": TransferLearningMetadata( + overlap_parameters=["x1"], + ), + } + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) + + # Verify order is preserved from identify_transferable_experiments + self.assertEqual(card.df.iloc[0]["Experiment"], "exp_high") + self.assertEqual(card.df.iloc[0]["Overlapping Parameters"], 4) + self.assertEqual(card.df.iloc[1]["Experiment"], "exp_mid") + self.assertEqual(card.df.iloc[1]["Overlapping Parameters"], 3) + self.assertEqual(card.df.iloc[2]["Experiment"], "exp_low") + self.assertEqual(card.df.iloc[2]["Overlapping Parameters"], 1) + + # All experiments listed in subtitle + self.assertIn("exp_high", card.subtitle) + self.assertIn("exp_mid", card.subtitle) + self.assertIn("exp_low", card.subtitle) + self.assertIn("We found **3 eligible source experiment(s)**", card.subtitle) + + @patch(_MOCK_TARGET) + def test_percentage_calculation(self, mock_identify: object) -> None: + experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type") + mock_identify.return_value = { # pyre-ignore[16] + "exp_a": TransferLearningMetadata( + overlap_parameters=["x1"], + ), + } + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.df.iloc[0]["Overlap (%)"], 33.3) + + @patch(_MOCK_TARGET) + def test_parameters_listed_alphabetically(self, mock_identify: object) -> None: + experiment = _make_experiment( + ["alpha", "beta", "gamma", "delta"], experiment_type="my_type" + ) + mock_identify.return_value = { # pyre-ignore[16] + "exp_a": TransferLearningMetadata( + overlap_parameters=["gamma", "alpha", "delta"], + ), + } + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.df.iloc[0]["Parameters"], "alpha, delta, gamma") + + def test_requires_experiment(self) -> None: + analysis = TransferLearningAnalysis() + with self.assertRaises(UserInputError): + analysis.compute(experiment=None) + + @patch(_MOCK_TARGET, return_value={}) + def test_experiment_name_passed_to_identify(self, mock_identify: object) -> None: + """Verify that experiment.name is forwarded to + identify_transferable_experiments so it can filter the target out.""" + experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type") + analysis = TransferLearningAnalysis() + analysis.compute(experiment=experiment) + mock_identify.assert_called_once() # pyre-ignore[16] + call_kwargs = mock_identify.call_args.kwargs # pyre-ignore[16] + self.assertEqual(call_kwargs["experiment_name"], "test_experiment") diff --git a/ax/analysis/healthcheck/transfer_learning_analysis.py b/ax/analysis/healthcheck/transfer_learning_analysis.py new file mode 100644 index 00000000000..94e67f6bb26 --- /dev/null +++ b/ax/analysis/healthcheck/transfer_learning_analysis.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +import json +from typing import final, TYPE_CHECKING + +import markdown as md +import pandas as pd +from ax.adapter.base import Adapter +from ax.analysis.analysis import Analysis +from ax.analysis.healthcheck.healthcheck_analysis import ( + create_healthcheck_analysis_card, + HealthcheckAnalysisCard, + HealthcheckStatus, +) +from ax.core.experiment import Experiment +from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy +from pyre_extensions import override + +if TYPE_CHECKING: + from ax.storage.sqa_store.sqa_config import SQAConfig + + +class TransferLearningAnalysisCard(HealthcheckAnalysisCard): + """HealthcheckAnalysisCard with markdown-aware rendering for notebooks.""" + + def _body_html(self, depth: int) -> str: + parts = [md.markdown(self.subtitle)] + if not self.df.empty: + parts.append(self.df.to_html(index=False)) + return f"
{''.join(parts)}
" + + +@final +class TransferLearningAnalysis(Analysis): + def __init__( + self, + experiment_types: list[str] | None = None, + overlap_threshold: float = 0.50, + max_num_exps: int = 10, + config: SQAConfig | None = None, + ) -> None: + self.experiment_types = experiment_types + self.overlap_threshold = overlap_threshold + self.max_num_exps = max_num_exps + self.config = config + + @override + def compute( + self, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategy | None = None, + adapter: Adapter | None = None, + ) -> HealthcheckAnalysisCard: + if experiment is None: + raise UserInputError( + "TransferLearningAnalysis requires a non-null experiment to compute " + "overlap percentages. Please provide an experiment." + ) + + # Determine experiment types to query for. + experiment_types = self.experiment_types + if experiment_types is None: + if experiment.experiment_type is None: + return create_healthcheck_analysis_card( + name=self.__class__.__name__, + title="Transfer Learning Eligibility", + subtitle=( + "No experiment type set on this experiment. " + "Cannot search for transferable experiments." + ), + df=pd.DataFrame(), + status=HealthcheckStatus.PASS, + ) + experiment_types = [experiment.experiment_type] + + # Lazy import to avoid circular dependency (sqa_store depends on + # healthcheck_analysis). + from ax.storage.sqa_store.load import identify_transferable_experiments + + transferable_experiments = identify_transferable_experiments( + search_space=experiment.search_space, + experiment_types=experiment_types, + overlap_threshold=self.overlap_threshold, + max_num_exps=self.max_num_exps, + config=self.config, + experiment_name=experiment.name, + ) + + if not transferable_experiments: + return create_healthcheck_analysis_card( + name=self.__class__.__name__, + title="Transfer Learning Eligibility", + subtitle="No eligible source experiments found for transfer learning.", + df=pd.DataFrame(), + status=HealthcheckStatus.PASS, + ) + + total_parameters = len(experiment.search_space.parameters) + + rows = [] + for exp_name, metadata in transferable_experiments.items(): + overlap_count = len(metadata.overlap_parameters) + overlap_pct = ( + (overlap_count / total_parameters * 100) + if total_parameters > 0 + else 0.0 + ) + rows.append( + { + "Experiment": exp_name, + "Overlapping Parameters": overlap_count, + "Overlap (%)": round(overlap_pct, 1), + "Parameters": ", ".join(sorted(metadata.overlap_parameters)), + } + ) + + df = pd.DataFrame(rows) + + n = len(rows) + exp_lines = "\n".join( + f"- **{r['Experiment']}** ({r['Overlap (%)']:.1f}% parameter overlap)" + for r in rows + ) + subtitle = ( + "Transfer learning can improve optimization by leveraging data " + "from similar past experiments. We found " + f"**{n} eligible source experiment(s)** " + "for transfer learning:\n\n" + f"{exp_lines}\n\n" + "Caution: Only use source experiments that are closely related " + "to your current experiment. " + "Using data from unrelated experiments can lead to negative " + "transfer, which may hurt " + "optimization performance. Review the overlapping parameters " + "before enabling transfer learning." + ) + + return TransferLearningAnalysisCard( + name=self.__class__.__name__, + title="Transfer Learning Eligibility", + subtitle=subtitle, + df=df, + blob=json.dumps({"status": HealthcheckStatus.WARNING}), + ) diff --git a/ax/analysis/overview.py b/ax/analysis/overview.py index 75d87f2afc4..ecdcd53dee1 100644 --- a/ax/analysis/overview.py +++ b/ax/analysis/overview.py @@ -24,6 +24,7 @@ from ax.analysis.healthcheck.predictable_metrics import PredictableMetricsAnalysis from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates +from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis from ax.analysis.insights import InsightsAnalysis from ax.analysis.results import ResultsAnalysis from ax.analysis.trials import AllTrialsAnalysis @@ -114,6 +115,7 @@ def __init__( options: OrchestratorOptions | None = None, tier_metadata: dict[str, Any] | None = None, model_fit_threshold: float | None = None, + sqa_config: Any = None, ) -> None: super().__init__() self.can_generate = can_generate @@ -124,6 +126,7 @@ def __init__( self.options = options self.tier_metadata = tier_metadata self.model_fit_threshold = model_fit_threshold + self.sqa_config = sqa_config @override def validate_applicable_state( @@ -229,6 +232,7 @@ def compute( if not has_batch_trials else None, BaselineImprovementAnalysis() if not has_batch_trials else None, + TransferLearningAnalysis(config=self.sqa_config), *[ SearchSpaceAnalysis(trial_index=trial.index) for trial in candidate_trials diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index 96c3652982d..2c6d0f7d146 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -8,6 +8,7 @@ import logging from collections.abc import Mapping +from datetime import datetime from math import ceil from typing import Any, cast @@ -735,7 +736,7 @@ def _query_historical_experiments_given_parameters( parameter_names: list[str], experiment_types: list[str], config: SQAConfig | None = None, -) -> dict[str, SearchSpace | None]: +) -> dict[str, tuple[SearchSpace | None, datetime]]: r""" Find historical experiments of given types tuning any of the given parameter names. @@ -744,8 +745,9 @@ def _query_historical_experiments_given_parameters( parameter_names: List of parameter names. experiment_types: List of experiment types. - Returns: Dictionary mapping experiment names to their filtered SearchSpace objects - containing only the parameters that are also present in the target experiment. + Returns: Dictionary mapping experiment names to a tuple of their filtered + SearchSpace (containing only parameters also present in the target + experiment) and the experiment's creation time. """ from ax.storage.sqa_store.encoder import Encoder @@ -754,9 +756,9 @@ def _query_historical_experiments_given_parameters( encoder = Encoder(config=config) with session_scope() as session: - # Query both parameters and experiment names + # Query parameters, experiment names, and creation time parameters_query = ( - session.query(SQAParameter, SQAExperiment.name) + session.query(SQAParameter, SQAExperiment.name, SQAExperiment.time_created) .filter(SQAParameter.name.in_(parameter_names)) .join(SQAExperiment, SQAParameter.experiment_id == SQAExperiment.id) .filter( @@ -775,19 +777,24 @@ def _query_historical_experiments_given_parameters( query_results = parameters_query.all() - # Group parameters by experiment name + # Group parameters by experiment name, track creation time experiments_params: dict[str, list[SQAParameter]] = {} - for sqa_param, exp_name in query_results: + experiments_time_created: dict[str, datetime] = {} + for sqa_param, exp_name, time_created in query_results: if exp_name not in experiments_params: experiments_params[exp_name] = [] - experiments_params[exp_name].append(sqa_param) + experiments_time_created[exp_name] = time_created return { - exp_name: decoder.search_space_from_sqa( - parameters_sqa=parameters_sqa, - # Parameter constraints don't matter for search space compatibility - parameter_constraints_sqa=[], + exp_name: ( + decoder.search_space_from_sqa( + parameters_sqa=parameters_sqa, + # Parameter constraints don't matter for search space + # compatibility + parameter_constraints_sqa=[], + ), + experiments_time_created[exp_name], ) for exp_name, parameters_sqa in experiments_params.items() } @@ -799,11 +806,15 @@ def identify_transferable_experiments( overlap_threshold: float = 0.0, max_num_exps: int = 10, config: SQAConfig | None = None, + experiment_name: str | None = None, ) -> Mapping[str, TransferLearningMetadata]: r""" Find all transferable historical experiments of given types having at least the given proportion of overlapping parameters with the provided search space. + Results are sorted by overlap proportion descending, then by recency + (most recently created first). + Args: search_space: Search space to compare with historical experiments. experiment_types: List of experiment types to search for. @@ -811,8 +822,11 @@ def identify_transferable_experiments( max_num_exps: Max number of transferable experiments to return with highest prop overlap. config: SQAConfig to use for the query. Defaults to None (use default config). + experiment_name: If provided, exclude this experiment from results (used + to filter out the target experiment itself). - Returns: A dictionary mapping experiment names to overlapping parameter names + Returns: A dictionary mapping experiment names to overlapping parameter names, + ordered by overlap then recency. """ experiments_search_spaces = _query_historical_experiments_given_parameters( @@ -825,7 +839,9 @@ def identify_transferable_experiments( # Calculate overlap for each experiment results = [] - for exp_name, exp_search_space in experiments_search_spaces.items(): + for exp_name, (exp_search_space, time_created) in experiments_search_spaces.items(): + if experiment_name is not None and exp_name == experiment_name: + continue if not exp_search_space: continue overlap_params = exp_search_space.get_overlapping_parameters(search_space) @@ -836,6 +852,7 @@ def identify_transferable_experiments( "experiment_name": exp_name, "overlap_params": overlap_params, "prop_overlap": prop_overlap, + "time_created": time_created, } ) @@ -843,7 +860,8 @@ def identify_transferable_experiments( df = pd.DataFrame(results) if df.empty: return {} - df = df.sort_values(by="prop_overlap", ascending=False) + # Sort by overlap descending, then by recency (most recent first) + df = df.sort_values(by=["prop_overlap", "time_created"], ascending=[False, False]) if max_num_exps is not None: df = df.head(max_num_exps) @@ -883,12 +901,9 @@ def load_candidate_source_auxiliary_experiments( search_space=target_experiment.search_space, experiment_types=[experiment_type], config=config, + experiment_name=target_experiment.name, ) - return { - experiment: metadata - for experiment, metadata in transferable_experiments.items() - if experiment != target_experiment.name - } + return transferable_experiments case _: raise NotImplementedError( "Loading candidate source auxiliary experiments for purpose " diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 7a992eaf68d..b3690d58b46 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -3197,7 +3197,7 @@ def test_query_historical_experiments_given_parameters(self) -> None: # Assert: Should find the experiment with the matching parameters self.assertIn(experiment.name, result) - returned_ss = result[experiment.name] + returned_ss, _time_created = result[experiment.name] self.assertIsNotNone(returned_ss) # The returned search space should contain w and x self.assertIn("w", none_throws(returned_ss).parameters)