Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,14 @@ def _input_transform_argparse_learned_feature_imputation(
torch.ones(d, dtype=dtype, device=torch_device),
]
)
# The target task is at position 0 (target_dataset is prepended above), so
# at posterior time — when X arrives without a task column — LFI applies
# the target task's imputation pattern.
kwargs: dict[str, Any] = {
"feature_indices": feature_indices,
"d": d,
"task_feature_index": task_feature_index,
"target_task": 0,
"bounds": bounds,
"device": torch_device,
"dtype": dtype,
Expand Down
26 changes: 23 additions & 3 deletions ax/generators/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from logging import Logger
from typing import Any, cast

Expand Down Expand Up @@ -67,6 +67,7 @@
from botorch.models.transforms.input import (
ChainedInputTransform,
InputTransform,
LearnedFeatureImputation,
Normalize,
)
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
Expand Down Expand Up @@ -1253,6 +1254,22 @@ def _submodel_input_constructor_mtgp(
) -> dict[str, Any]:
if len(dataset.outcome_names) > 1:
raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.")
# If LearnedFeatureImputation is in the model config, tell construct_inputs
# to map heterogeneous per-task features to the full joint feature space.
# This must happen before the base call so construct_inputs can handle
# heterogeneous MultiTaskDatasets without raising.
uses_lfi = isinstance(model_config.input_transform_classes, list) and any(
issubclass(cls, LearnedFeatureImputation)
for cls in model_config.input_transform_classes
)
if uses_lfi and "map_heterogeneous_to_full" not in model_config.model_options:
model_config = replace(
model_config,
model_options={
**model_config.model_options,
"map_heterogeneous_to_full": True,
},
)
formatted_model_inputs = _submodel_input_constructor_base(
botorch_model_class=botorch_model_class,
model_config=model_config,
Expand All @@ -1266,9 +1283,12 @@ def _submodel_input_constructor_mtgp(
# specify output tasks so that model.num_outputs = 1
# since the model only models a single outcome
if formatted_model_inputs.get("output_tasks") is None:
# SSD doesn't use -1, so we need to normalize here
# SSD doesn't use -1, so we need to normalize here. Use the SSD's bound
# length since target_values is keyed by SSD column index — for
# heterogeneous MultiTaskDatasets this differs from the per-task
# dataset's feature_names length.
task_feature = none_throws(
normalize_indices(indices=[task_feature], d=len(dataset.feature_names))
normalize_indices(indices=[task_feature], d=len(search_space_digest.bounds))
)[0]
if (search_space_digest.target_values is not None) and (
target_value := search_space_digest.target_values.get(task_feature)
Expand Down
77 changes: 37 additions & 40 deletions ax/generators/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,32 @@ def use_model_list(
return True


def _ensure_input_transform(
model_config: ModelConfig,
transform_cls: type[InputTransform],
position: int | None = None,
) -> None:
"""Ensure ``transform_cls`` is in ``model_config.input_transform_classes``.

If the user hasn't specified any transforms (``DEFAULT``), initialise the
list with ``[transform_cls]``. Otherwise append (or insert at ``position``)
only when the class isn't already present. Mutates ``model_config``
in-place.
"""
itc = model_config.input_transform_classes
if isinstance(itc, list):
if transform_cls not in itc:
if position is not None:
itc.insert(position, transform_cls)
else:
itc.append(transform_cls)
else:
model_config.input_transform_classes = [transform_cls]
ito = model_config.input_transform_options or {}
ito.setdefault(transform_cls.__name__, {})
model_config.input_transform_options = ito


def copy_model_config_with_default_values(
model_config: ModelConfig,
dataset: SupervisedDataset,
Expand All @@ -235,43 +261,15 @@ def copy_model_config_with_default_values(
specified_model_class=model_config_copy.botorch_model_class,
)

# Handle heterogeneous multi-task datasets.
# Handle heterogeneous multi-task datasets: ensure Normalize is present
# and add LearnedFeatureImputation for models that don't handle
# heterogeneity natively.
if isinstance(dataset, MultiTaskDataset) and dataset.has_heterogeneous_features:
if model_config_copy.botorch_model_class is HeterogeneousMTGP:
# HeterogeneousMTGP handles heterogeneity natively; just ensure
# Normalize is present (bounds are set later by the TL adapter).
itc = model_config_copy.input_transform_classes
if isinstance(itc, list):
if Normalize not in itc:
itc.insert(0, Normalize)
ito = model_config_copy.input_transform_options or {}
ito.setdefault("Normalize", {"bounds": None})
model_config_copy.input_transform_options = ito
else:
model_config_copy.input_transform_classes = [Normalize]
ito = model_config_copy.input_transform_options or {}
ito.setdefault("Normalize", {"bounds": None})
model_config_copy.input_transform_options = ito
else:
# Other models need Normalize + LFI to pad features via
# map_heterogeneous_to_full.
itc = model_config_copy.input_transform_classes
if isinstance(itc, list):
if Normalize not in itc:
itc.insert(0, Normalize)
ito = model_config_copy.input_transform_options or {}
ito.setdefault("Normalize", {"bounds": None})
model_config_copy.input_transform_options = ito
if LearnedFeatureImputation not in itc:
itc.append(LearnedFeatureImputation)
else:
model_config_copy.input_transform_classes = [
Normalize,
LearnedFeatureImputation,
]
ito = model_config_copy.input_transform_options or {}
ito.setdefault("Normalize", {"bounds": None})
model_config_copy.input_transform_options = ito
_ensure_input_transform(model_config_copy, Normalize, position=0)
if model_config_copy.botorch_model_class is not None and not issubclass(
model_config_copy.botorch_model_class, HeterogeneousMTGP
):
_ensure_input_transform(model_config_copy, LearnedFeatureImputation)

if model_config_copy.mll_class is None:
model_config_copy.mll_class = (
Expand Down Expand Up @@ -321,16 +319,15 @@ def choose_model_class(
)

# Check for heterogeneous multi-task datasets. If a model class was
# explicitly specified, respect it; otherwise default to HeterogeneousMTGP.
# explicitly specified, respect it; otherwise default to MultiTaskGP
# (LearnedFeatureImputation handles missing features).
if (
search_space_digest.task_features
and isinstance(dataset, MultiTaskDataset)
and dataset.has_heterogeneous_features
):
model_class = (
specified_model_class
if specified_model_class is not None
else HeterogeneousMTGP
specified_model_class if specified_model_class is not None else MultiTaskGP
)
logger.debug(f"Chose BoTorch model class: {model_class}.")
return model_class
Expand Down
76 changes: 75 additions & 1 deletion ax/generators/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_construct_specified_input_transforms,
_extract_model_kwargs,
_make_botorch_input_transform,
_submodel_input_constructor_mtgp,
submodel_input_constructor,
Surrogate,
SurrogateSpec,
Expand Down Expand Up @@ -59,7 +60,12 @@
from botorch.models.model import Model, ModelList # noqa: F401 -- used in Mocks.
from botorch.models.multitask import MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
from botorch.models.transforms.input import ChainedInputTransform, Log10, Normalize
from botorch.models.transforms.input import (
ChainedInputTransform,
LearnedFeatureImputation,
Log10,
Normalize,
)
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
from botorch.utils.evaluation import compute_in_sample_model_fit_metric
Expand Down Expand Up @@ -277,6 +283,74 @@ def test__make_botorch_input_transform(self) -> None:
self.assertEqual(transform.indices.tolist(), [0])
self.assertEqual(transform.bounds.tolist(), [[1.0], [5.0]])

def test_submodel_input_constructor_mtgp_map_heterogeneous(self) -> None:
"""_submodel_input_constructor_mtgp passes map_heterogeneous_to_full
to construct_inputs when LFI is configured, enabling zero-padded
heterogeneous datasets to be used with MultiTaskGP."""
ds_target = SupervisedDataset(
X=torch.tensor([[1.0, 0.0], [2.0, 0.0]]),
Y=torch.tensor([[1.0], [2.0]]),
feature_names=["x0", "task"],
outcome_names=["y_task_0"],
)
ds_source = SupervisedDataset(
X=torch.tensor([[3.0, 4.0, 1.0], [5.0, 6.0, 1.0]]),
Y=torch.tensor([[3.0], [4.0]]),
feature_names=["x0", "x1", "task"],
outcome_names=["y_task_1"],
)
mt_dataset = MultiTaskDataset(
datasets=[ds_target, ds_source],
target_outcome_name="y_task_0",
task_feature_index=-1,
)
self.assertTrue(mt_dataset.has_heterogeneous_features)
ssd = SearchSpaceDigest(
feature_names=["x0", "x1", "task"],
bounds=[(0.0, 5.0), (0.0, 6.0), (0.0, 1.0)],
task_features=[2],
target_values={2: 0.0},
)
surrogate = Surrogate(
surrogate_spec=SurrogateSpec(
model_configs=[ModelConfig(botorch_model_class=MultiTaskGP)]
)
)

with self.subTest("with LFI — construct_inputs succeeds"):
config_with_lfi = ModelConfig(
botorch_model_class=MultiTaskGP,
input_transform_classes=[Normalize, LearnedFeatureImputation],
)
result = _submodel_input_constructor_mtgp(
botorch_model_class=MultiTaskGP,
model_config=config_with_lfi,
dataset=mt_dataset,
search_space_digest=ssd,
surrogate=surrogate,
)
self.assertEqual(result["train_X"].shape[-1], 3)

with self.subTest("without LFI — construct_inputs raises"):
from botorch.exceptions.errors import (
UnsupportedError as BotorchUnsupportedError,
)

config_no_lfi = ModelConfig(
botorch_model_class=MultiTaskGP,
input_transform_classes=[Normalize],
)
with self.assertRaisesRegex(
BotorchUnsupportedError, "heterogeneous feature sets"
):
_submodel_input_constructor_mtgp(
botorch_model_class=MultiTaskGP,
model_config=config_no_lfi,
dataset=mt_dataset,
search_space_digest=ssd,
surrogate=surrogate,
)


class SurrogateTest(TestCase):
def setUp(self, cuda: bool = False) -> None:
Expand Down
72 changes: 62 additions & 10 deletions ax/generators/torch/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import LearnedFeatureImputation, Normalize
from botorch.models.transforms.input import LearnedFeatureImputation, Normalize, Warp
from botorch.posteriors.ensemble import EnsemblePosterior
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
from botorch.utils.types import DEFAULT
Expand Down Expand Up @@ -186,9 +186,9 @@ def test_choose_model_class_heterogeneous_task_features(self) -> None:
mt_dataset = self._get_heterogeneous_mt_dataset()
ssd = dataclasses.replace(self.search_space_digest, task_features=[-1])

# Default: HeterogeneousMTGP.
# Default: MultiTaskGP (LearnedFeatureImputation handles missing features).
self.assertEqual(
HeterogeneousMTGP,
MultiTaskGP,
choose_model_class(dataset=mt_dataset, search_space_digest=ssd),
)

Expand Down Expand Up @@ -233,19 +233,23 @@ def test_copy_model_config_heterogeneous_mtgp(self) -> None:
mt_dataset = self._get_heterogeneous_mt_dataset()
ssd = dataclasses.replace(self.search_space_digest, task_features=[-1])

# Default (no model class specified) -> HeterogeneousMTGP.
# LFI is NOT injected; input_transform_classes stays DEFAULT.
# Default (no model class specified) -> MultiTaskGP.
# LFI is injected for MultiTaskGP with heterogeneous data.
updated_config = copy_model_config_with_default_values(
model_config=ModelConfig(),
dataset=mt_dataset,
search_space_digest=ssd,
)
self.assertEqual(updated_config.botorch_model_class, HeterogeneousMTGP)
self.assertEqual(updated_config.input_transform_classes, [Normalize])
self.assertEqual(updated_config.botorch_model_class, MultiTaskGP)
self.assertEqual(
none_throws(updated_config.input_transform_options),
{"Normalize": {"bounds": None}},
updated_config.input_transform_classes,
[Normalize, LearnedFeatureImputation],
)
# LFI is present in transform classes but absent from options; its
# argparse computes kwargs from the dataset at construction time.
ito = none_throws(updated_config.input_transform_options)
self.assertEqual(ito, {"Normalize": {}})
self.assertNotIn("LearnedFeatureImputation", ito)

# Explicit HeterogeneousMTGP behaves the same.
updated_config = copy_model_config_with_default_values(
Expand All @@ -257,7 +261,7 @@ def test_copy_model_config_heterogeneous_mtgp(self) -> None:
self.assertEqual(updated_config.input_transform_classes, [Normalize])
self.assertEqual(
none_throws(updated_config.input_transform_options),
{"Normalize": {"bounds": None}},
{"Normalize": {}},
)

def test_copy_model_config_mtgp_with_lfi_injection(self) -> None:
Expand Down Expand Up @@ -302,6 +306,54 @@ def test_copy_model_config_does_not_add_normalize_for_other_models(self) -> None
self.assertEqual(updated_config.input_transform_classes, DEFAULT)
self.assertEqual(updated_config.input_transform_options, {})

def test_copy_model_config_adds_imputation_for_heterogeneous(self) -> None:
mt_dataset = self._get_heterogeneous_mt_dataset()
ssd = dataclasses.replace(self.search_space_digest, task_features=[-1])

with self.subTest("no_input_transform_classes"):
model_config = ModelConfig(botorch_model_class=MultiTaskGP)
updated_config = copy_model_config_with_default_values(
model_config=model_config,
dataset=mt_dataset,
search_space_digest=ssd,
)
self.assertEqual(updated_config.botorch_model_class, MultiTaskGP)
self.assertEqual(
updated_config.input_transform_classes,
[Normalize, LearnedFeatureImputation],
)

with self.subTest("existing_transform_classes"):
model_config = ModelConfig(
botorch_model_class=MultiTaskGP,
input_transform_classes=[Warp],
input_transform_options={"Warp": {}},
)
updated_config = copy_model_config_with_default_values(
model_config=model_config,
dataset=mt_dataset,
search_space_digest=ssd,
)
self.assertEqual(
updated_config.input_transform_classes,
[Normalize, Warp, LearnedFeatureImputation],
)

with self.subTest("imputation_already_present"):
model_config = ModelConfig(
botorch_model_class=MultiTaskGP,
input_transform_classes=[Normalize, LearnedFeatureImputation],
)
updated_config = copy_model_config_with_default_values(
model_config=model_config,
dataset=mt_dataset,
search_space_digest=ssd,
)
self.assertEqual(
updated_config.input_transform_classes,
[Normalize, LearnedFeatureImputation],
)

def test_choose_model_class_discrete_features(self) -> None:
# With discrete features, use MixedSingleTaskyGP.
self.assertEqual(
Expand Down
Loading