diff --git a/botorch/utils/datasets.py b/botorch/utils/datasets.py index fd67f388b9..6260d1a3ca 100644 --- a/botorch/utils/datasets.py +++ b/botorch/utils/datasets.py @@ -9,15 +9,13 @@ from __future__ import annotations import warnings -from typing import Any, Iterable, List, Optional, TypeVar, Union +from typing import Any, Dict, List, Optional, Union import torch +from botorch.exceptions.errors import InputDataError, UnsupportedError from botorch.utils.containers import BotorchContainer, SliceContainer from torch import long, ones, Tensor -T = TypeVar("T") -MaybeIterable = Union[T, Iterable[T]] - class SupervisedDataset: r"""Base class for datasets consisting of labelled pairs `(X, Y)` @@ -273,3 +271,207 @@ def _validate(self) -> None: # Same as: torch.where(y_diff == 0, y_incr + 1, 1) y_incr = y_incr - y_diff + 1 + + +class MultiTaskDataset(SupervisedDataset): + """This is a multi-task dataset that is constructed from the datasets of + individual tasks. It offers functionality to combine parts of individual + datasets to construct the inputs necessary for the `MultiTaskGP` models. + """ + + def __init__( + self, + datasets: List[SupervisedDataset], + target_outcome_name: str, + task_feature_index: Optional[int] = None, + ): + """Construct a `MultiTaskDataset`. + + Args: + datasets: A list of the datasets of individual tasks. Each dataset + is expected to contain data for only one outcome. + target_outcome_name: Name of the target outcome to be modeled. + task_feature_index: If the task feature is included in the Xs of the + individual datasets, this should be used to specify its index. + If omitted, the task feature will be appended while concatenating Xs. + If given, we sanity-check that the names of the task features + match between all datasets. + """ + self.datasets: Dict[str, SupervisedDataset] = { + ds.outcome_names[0]: ds for ds in datasets + } + self.target_outcome_name = target_outcome_name + self.task_feature_index = task_feature_index + self._validate_datasets(datasets=datasets) + self.feature_names = self.datasets[target_outcome_name].feature_names + self.outcome_names = [target_outcome_name] + + @classmethod + def from_joint_dataset( + cls, + dataset: SupervisedDataset, + task_feature_index: int, + target_task_value: int, + outcome_names_per_task: Optional[Dict[int, str]] = None, + ) -> MultiTaskDataset: + r"""Construct a `MultiTaskDataset` from a joint dataset that includes the + data for all tasks with the task feature index. + + This will break down the joint dataset into individual datasets by the value + of the task feature. Each resulting dataset will have its outcome name set + based on `outcome_names_per_task`, with the missing values defaulting to + `task_` (except for the target task, which will retain the + original outcome name from the dataset). + + Args: + dataset: The joint dataset. + task_feature_index: The column index of the task feature in `dataset.X`. + target_task_value: The value of the task feature for the target task + in the dataset. The data for the target task is filtered according to + `dataset.X[task_feature_index] == target_task_value`. + outcome_names_per_task: Optional dictionary mapping task feature values + to the outcome names for each task. If not provided, the auxiliary + tasks will be named `task_` and the target task will + retain the outcome name from the dataset. + + Returns: + A `MultiTaskDataset` instance. + """ + if len(dataset.outcome_names) > 1: + raise UnsupportedError( + "Dataset containing more than one outcome is not supported. " + f"Got {dataset.outcome_names=}." + ) + outcome_names_per_task = outcome_names_per_task or {} + # Split datasets by task feature. + datasets = [] + all_task_features = dataset.X[:, task_feature_index] + for task_value in all_task_features.unique().long().tolist(): + default_name = ( + dataset.outcome_names[0] + if task_value == target_task_value + else f"task_{task_value}" + ) + outcome_name = outcome_names_per_task.get(task_value, default_name) + filter_mask = all_task_features == task_value + new_dataset = SupervisedDataset( + X=dataset.X[filter_mask], + Y=dataset.Y[filter_mask], + Yvar=dataset.Yvar[filter_mask] if dataset.Yvar is not None else None, + feature_names=dataset.feature_names, + outcome_names=[outcome_name], + ) + datasets.append(new_dataset) + # Return the new + return cls( + datasets=datasets, + target_outcome_name=outcome_names_per_task.get( + target_task_value, dataset.outcome_names[0] + ), + task_feature_index=task_feature_index, + ) + + def _validate_datasets(self, datasets: List[SupervisedDataset]) -> None: + """Validates that: + * Each dataset models only one outcome; + * Each outcome is modeled by only one dataset; + * The target outcome is included in the datasets; + * The datasets do not model batched inputs; + * The task feature names of the datasets all match; + * Either all or none of the datasets specify Yvar. + """ + if any(len(ds.outcome_names) > 1 for ds in datasets): + raise UnsupportedError( + "Datasets containing more than one outcome are not supported." + ) + if len(self.datasets) != len(datasets): + raise UnsupportedError( + "Received multiple datasets for the same outcome. Each dataset " + "must contain data for a unique outcome. Got datasets with " + f"outcome names: {(ds.outcome_names for ds in datasets)}." + ) + if self.target_outcome_name not in self.datasets: + raise InputDataError( + "Target outcome is not present in the datasets. " + f"Got {self.target_outcome_name=} and datasets for " + f"outcomes {list(self.datasets.keys())}." + ) + if any(len(ds.X.shape) > 2 for ds in datasets): + raise UnsupportedError( + "Datasets modeling batched inputs are not supported." + ) + if self.task_feature_index is not None: + tf_names = [ds.feature_names[self.task_feature_index] for ds in datasets] + if any(name != tf_names[0] for name in tf_names[1:]): + raise InputDataError( + "Expected the names of the task features to match across all " + f"datasets. Got {tf_names}." + ) + all_Yvars = [ds.Yvar for ds in datasets] + is_none = [yvar is None for yvar in all_Yvars] + # Check that either all or None of the Yvars exist. + if not all(is_none) and any(is_none): + raise UnsupportedError( + "Expected either all or none of the datasets to have a Yvar. " + "Only subset of datasets define Yvar, which is unsupported." + ) + + @property + def X(self) -> Tensor: + """Appends task features, if needed, and concatenates the Xs of datasets to + produce the `train_X` expected by `MultiTaskGP` and subclasses. + + If appending the task features, 0 is reserved for the target task and the + remaining tasks are populated with 1, 2, ..., len(datasets) - 1. + """ + all_Xs = [] + next_task = 1 + for outcome, ds in self.datasets.items(): + if self.task_feature_index is None: + # Append the task feature index. + if outcome == self.target_outcome_name: + task_feature = 0 + else: + task_feature = next_task + next_task = next_task + 1 + all_Xs.append(torch.nn.functional.pad(ds.X, (0, 1), value=task_feature)) + else: + all_Xs.append(ds.X) + return torch.cat(all_Xs, dim=0) + + @property + def Y(self) -> Tensor: + """Concatenates Ys of the datasets.""" + return torch.cat([ds.Y for ds in self.datasets.values()], dim=0) + + @property + def Yvar(self) -> Optional[Tensor]: + """Concatenates Yvars of the datasets if they exist.""" + all_Yvars = [ds.Yvar for ds in self.datasets.values()] + return None if all_Yvars[0] is None else torch.cat(all_Yvars, dim=0) + + def get_dataset_without_task_feature(self, outcome_name: str) -> SupervisedDataset: + """A helper for extracting the child datasets with their task features removed. + + If the task feature index is `None`, the dataset will be returned as is. + + Args: + outcome_name: The outcome name for the dataset to extract. + + Returns: + The dataset without the task feature. + """ + dataset = self.datasets[outcome_name] + if self.task_feature_index is None: + return dataset + indices = list(range(len(self.feature_names))) + indices.pop(self.task_feature_index) + return SupervisedDataset( + X=dataset.X[..., indices], + Y=dataset.Y, + Yvar=dataset.Yvar, + feature_names=[ + fn for i, fn in enumerate(dataset.feature_names) if i in indices + ], + outcome_names=[outcome_name], + ) diff --git a/test/utils/test_datasets.py b/test/utils/test_datasets.py index 1575b2622b..7e9c419e8e 100644 --- a/test/utils/test_datasets.py +++ b/test/utils/test_datasets.py @@ -4,13 +4,42 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from typing import List, Optional + import torch +from botorch.exceptions.errors import InputDataError, UnsupportedError from botorch.utils.containers import DenseContainer, SliceContainer -from botorch.utils.datasets import FixedNoiseDataset, RankingDataset, SupervisedDataset +from botorch.utils.datasets import ( + FixedNoiseDataset, + MultiTaskDataset, + RankingDataset, + SupervisedDataset, +) from botorch.utils.testing import BotorchTestCase from torch import rand, randperm, Size, stack, Tensor, tensor +def make_dataset( + num_samples: int = 3, + d: int = 2, + m: int = 1, + has_yvar: bool = False, + feature_names: Optional[List[str]] = None, + outcome_names: Optional[List[str]] = None, + batch_shape: Optional[torch.Size] = None, +) -> SupervisedDataset: + feature_names = feature_names or [f"x{i}" for i in range(d)] + outcome_names = outcome_names or [f"y{i}" for i in range(m)] + batch_shape = batch_shape or torch.Size() + return SupervisedDataset( + X=rand(*batch_shape, num_samples, d), + Y=rand(*batch_shape, num_samples, m), + Yvar=rand(*batch_shape, num_samples, m) if has_yvar else None, + feature_names=feature_names, + outcome_names=outcome_names, + ) + + class TestDatasets(BotorchTestCase): def test_supervised(self): # Generate some data @@ -190,3 +219,119 @@ def test_ranking(self): feature_names=feature_names, outcome_names=outcome_names, ) + + def test_multi_task(self): + dataset_1 = make_dataset(outcome_names=["y"]) + dataset_2 = make_dataset(outcome_names=["z"]) + dataset_3 = make_dataset(has_yvar=True, outcome_names=["z"]) + dataset_4 = make_dataset(has_yvar=True, outcome_names=["y"]) + # Test validation. + with self.assertRaisesRegex( + UnsupportedError, "containing more than one outcome" + ): + MultiTaskDataset(datasets=[make_dataset(m=2)], target_outcome_name="y0") + with self.assertRaisesRegex( + UnsupportedError, "multiple datasets for the same outcome" + ): + MultiTaskDataset(datasets=[dataset_1, dataset_1], target_outcome_name="y") + with self.assertRaisesRegex(InputDataError, "Target outcome is not present"): + MultiTaskDataset(datasets=[dataset_1], target_outcome_name="z") + with self.assertRaisesRegex(UnsupportedError, "modeling batched inputs"): + MultiTaskDataset( + datasets=[make_dataset(batch_shape=torch.Size([2]))], + target_outcome_name="y0", + ) + with self.assertRaisesRegex(InputDataError, "names of the task features"): + MultiTaskDataset( + datasets=[ + dataset_1, + make_dataset(feature_names=["x1", "x3"], outcome_names=["z"]), + ], + target_outcome_name="z", + task_feature_index=1, + ) + with self.assertRaisesRegex( + UnsupportedError, "all or none of the datasets to have a Yvar." + ): + MultiTaskDataset(datasets=[dataset_1, dataset_3], target_outcome_name="z") + + # Test correct construction. + mt_dataset = MultiTaskDataset( + datasets=[dataset_1, dataset_2], + target_outcome_name="z", + ) + self.assertEqual(len(mt_dataset.datasets), 2) + self.assertIsNone(mt_dataset.task_feature_index) + self.assertIs(mt_dataset.datasets["y"], dataset_1) + self.assertIs(mt_dataset.datasets["z"], dataset_2) + self.assertIsNone(mt_dataset.Yvar) + expected_X = torch.cat( + [ + torch.cat([dataset_1.X, torch.ones(3, 1)], dim=-1), + torch.cat([dataset_2.X, torch.zeros(3, 1)], dim=-1), + ], + dim=0, + ) + expected_Y = torch.cat([ds.Y for ds in [dataset_1, dataset_2]], dim=0) + self.assertTrue(torch.equal(expected_X, mt_dataset.X)) + self.assertTrue(torch.equal(expected_Y, mt_dataset.Y)) + self.assertIs( + mt_dataset.get_dataset_without_task_feature(outcome_name="y"), dataset_1 + ) + + # Test with Yvar and target_feature_index. + mt_dataset = MultiTaskDataset( + datasets=[dataset_3, dataset_4], + target_outcome_name="z", + task_feature_index=1, + ) + self.assertEqual(mt_dataset.task_feature_index, 1) + expected_X_2 = torch.cat([dataset_3.X, dataset_4.X], dim=0) + expected_Yvar_2 = torch.cat([dataset_3.Yvar, dataset_4.Yvar], dim=0) + self.assertTrue(torch.equal(expected_X_2, mt_dataset.X)) + self.assertTrue(torch.equal(expected_Yvar_2, mt_dataset.Yvar)) + # Check that the task feature is removed correctly. + ds_3_no_task = mt_dataset.get_dataset_without_task_feature(outcome_name="z") + self.assertTrue(torch.equal(ds_3_no_task.X, dataset_3.X[:, :1])) + self.assertTrue(torch.equal(ds_3_no_task.Y, dataset_3.Y)) + self.assertTrue(torch.equal(ds_3_no_task.Yvar, dataset_3.Yvar)) + self.assertEqual(ds_3_no_task.feature_names, dataset_3.feature_names[:1]) + self.assertEqual(ds_3_no_task.outcome_names, dataset_3.outcome_names) + + # Test from_joint_dataset. + sort_idcs = [3, 4, 5, 0, 1, 2] # X & Y will get sorted based on task feature. + for outcome_names_per_task in [None, {0: "x", 1: "y"}]: + joint_dataset = SupervisedDataset( + X=expected_X, + Y=expected_Y, + feature_names=["x0", "x1", "task"], + outcome_names=["z"], + ) + mt_dataset = MultiTaskDataset.from_joint_dataset( + dataset=joint_dataset, + task_feature_index=-1, + target_task_value=0, + outcome_names_per_task=outcome_names_per_task, + ) + self.assertEqual(len(mt_dataset.datasets), 2) + if outcome_names_per_task is None: + self.assertEqual(list(mt_dataset.datasets.keys()), ["z", "task_1"]) + self.assertEqual(mt_dataset.target_outcome_name, "z") + else: + self.assertEqual(list(mt_dataset.datasets.keys()), ["x", "y"]) + self.assertEqual(mt_dataset.target_outcome_name, "x") + + self.assertTrue(torch.equal(mt_dataset.X, expected_X[sort_idcs])) + self.assertTrue(torch.equal(mt_dataset.Y, expected_Y[sort_idcs])) + self.assertTrue( + torch.equal( + mt_dataset.datasets[mt_dataset.target_outcome_name].Y, dataset_2.Y + ) + ) + self.assertIsNone(mt_dataset.Yvar) + with self.assertRaisesRegex(UnsupportedError, "more than one outcome"): + MultiTaskDataset.from_joint_dataset( + dataset=make_dataset(m=2), + task_feature_index=-1, + target_task_value=0, + )