diff --git a/botorch/utils/containers.py b/botorch/utils/containers.py index c464ee3f99..fbd73694d7 100644 --- a/botorch/utils/containers.py +++ b/botorch/utils/containers.py @@ -8,14 +8,93 @@ Containers to standardize inputs into models and acquisition functions. """ -from typing import NamedTuple, Optional +from dataclasses import dataclass +from typing import List, Optional +import torch +from botorch.exceptions.errors import UnsupportedError from torch import Tensor -class TrainingData(NamedTuple): - r"""Standardized struct of model training data for a single outcome.""" +@dataclass +class TrainingData: + r"""Standardized container of model training data for models. - X: Tensor - Y: Tensor - Yvar: Optional[Tensor] = None + Properties: + Xs: A list of tensors, each of shape `batch_shape x n_i x d`, + where `n_i` is the number of training inputs for the i-th model. + Ys: A list of tensors, each of shape `batch_shape x n_i x 1`, + where `n_i` is the number of training observations for the i-th + (single-output) model. + Yvars: A list of tensors, each of shape `batch_shape x n_i x 1`, + where `n_i` is the number of training observations of the + observation noise for the i-th (single-output) model. + If `None`, the observation noise level is unobserved. + """ + + Xs: List[Tensor] # `batch_shape x n_i x 1` + Ys: List[Tensor] # `batch_shape x n_i x 1` + Yvars: Optional[List[Tensor]] = None # `batch_shape x n_i x 1` + + def __post_init__(self): + self._is_block_design = all(torch.equal(X, self.Xs[0]) for X in self.Xs[1:]) + + @classmethod + def from_block_design(cls, X: Tensor, Y: Tensor, Yvar: Optional[Tensor] = None): + r"""Construct a TrainingData object from a block design description. + + Args: + X: A `batch_shape x n x d` tensor of training points (shared across + all outcomes). + Y: A `batch_shape x n x m` tensor of training observations. + Yvar: A `batch_shape x n x m` tensor of training noise variance + observations, or `None`. + + Returns: + The `TrainingData` object (with `is_block_design=True`). + """ + return cls( + Xs=[X for _ in range(Y.shape[-1])], + Ys=list(torch.split(Y, 1, dim=-1)), + Yvars=None if Yvar is None else list(torch.split(Yvar, 1, dim=-1)), + ) + + @property + def is_block_design(self) -> bool: + r"""Indicates whether training data is a "block design". + + Block designs are designs in which all outcomes are observed + at the same training inputs. + """ + return self._is_block_design + + @property + def X(self) -> Tensor: + r"""The training inputs (block-design only). + + This raises an `UnsupportedError` in the non-block-design case. + """ + if not self.is_block_design: + raise UnsupportedError + return self.Xs[0] + + @property + def Y(self) -> Tensor: + r"""The training observations (block-design only). + + This raises an `UnsupportedError` in the non-block-design case. + """ + if not self.is_block_design: + raise UnsupportedError + return torch.cat(self.Ys, dim=-1) + + @property + def Yvar(self) -> Optional[List[Tensor]]: + r"""The training observations's noise variance (block-design only). + + This raises an `UnsupportedError` in the non-block-design case. + """ + if self.Yvars is not None: + if not self.is_block_design: + raise UnsupportedError + return torch.cat(self.Yvars, dim=-1) diff --git a/test/models/test_gp_regression.py b/test/models/test_gp_regression.py index ccd9c8ff09..7503f992d3 100644 --- a/test/models/test_gp_regression.py +++ b/test/models/test_gp_regression.py @@ -305,7 +305,7 @@ def test_construct_inputs(self): model, model_kwargs = self._get_model_and_data( batch_shape=batch_shape, m=2, **tkwargs ) - training_data = TrainingData( + training_data = TrainingData.from_block_design( X=model_kwargs["train_X"], Y=model_kwargs["train_Y"] ) data_dict = model.construct_inputs(training_data) @@ -352,7 +352,7 @@ def test_construct_inputs(self): model, model_kwargs = self._get_model_and_data( batch_shape=batch_shape, m=2, **tkwargs ) - training_data = TrainingData( + training_data = TrainingData.from_block_design( X=model_kwargs["train_X"], Y=model_kwargs["train_Y"], Yvar=model_kwargs["train_Yvar"], @@ -365,7 +365,7 @@ def test_construct_inputs(self): torch.equal(data_dict["train_Yvar"], model_kwargs["train_Yvar"]) ) # if Yvars is missing, then raise error - training_data = TrainingData( + training_data = TrainingData.from_block_design( X=model_kwargs["train_X"], Y=model_kwargs["train_Y"] ) with self.assertRaises(ValueError): diff --git a/test/models/test_gp_regression_fidelity.py b/test/models/test_gp_regression_fidelity.py index c6a0eb46c9..c485769e10 100644 --- a/test/models/test_gp_regression_fidelity.py +++ b/test/models/test_gp_regression_fidelity.py @@ -378,7 +378,7 @@ def test_construct_inputs(self): **tkwargs, ) # len(Xs) == len(Ys) == 1 - training_data = TrainingData( + training_data = TrainingData.from_block_design( X=model_kwargs["train_X"], Y=model_kwargs["train_Y"], Yvar=torch.full_like(model_kwargs["train_Y"], 0.01), @@ -483,14 +483,14 @@ def test_construct_inputs(self): lin_truncated=lin_trunc, **tkwargs, ) - training_data = TrainingData( + training_data = TrainingData.from_block_design( X=model_kwargs["train_X"], Y=model_kwargs["train_Y"] ) # missing Yvars with self.assertRaises(ValueError): model.construct_inputs(training_data, fidelity_features=[1]) # len(Xs) == len(Ys) == 1 - training_data = TrainingData( + training_data = TrainingData.from_block_design( X=model_kwargs["train_X"], Y=model_kwargs["train_Y"], Yvar=torch.full_like(model_kwargs["train_Y"], 0.01), diff --git a/test/models/test_gp_regression_mixed.py b/test/models/test_gp_regression_mixed.py index 15fc008995..beea808746 100644 --- a/test/models/test_gp_regression_mixed.py +++ b/test/models/test_gp_regression_mixed.py @@ -270,7 +270,7 @@ def test_construct_inputs(self): batch_shape=batch_shape, m=m, d=d, **tkwargs ) cat_dims = list(range(ncat)) - training_data = TrainingData(X=train_X, Y=train_Y) + training_data = TrainingData.from_block_design(X=train_X, Y=train_Y) kwarg_dict = MixedSingleTaskGP.construct_inputs( training_data, categorical_features=cat_dims ) diff --git a/test/models/test_multitask.py b/test/models/test_multitask.py index f904ea893a..dafb87940c 100644 --- a/test/models/test_multitask.py +++ b/test/models/test_multitask.py @@ -396,7 +396,7 @@ def test_MultiTaskGP_construct_inputs(self): for dtype in (torch.float, torch.double): tkwargs = {"device": self.device, "dtype": dtype} model, train_X, train_Y = _get_model_and_training_data(**tkwargs) - training_data = TrainingData(X=train_X, Y=train_Y) + training_data = TrainingData.from_block_design(X=train_X, Y=train_Y) # Test that task features are required. with self.assertRaisesRegex(ValueError, "`task_features` required"): model.construct_inputs(training_data) @@ -444,11 +444,13 @@ def test_FixedNoiseMultiTaskGP_construct_inputs(self): train_Y, train_Yvar, ) = _get_fixed_noise_model_and_training_data(**tkwargs) - td_no_Yvar = TrainingData(X=train_X, Y=train_Y) + td_no_Yvar = TrainingData.from_block_design(X=train_X, Y=train_Y) # Test that Yvar is required. with self.assertRaisesRegex(ValueError, "Yvar required"): model.construct_inputs(td_no_Yvar) - training_data = TrainingData(X=train_X, Y=train_Y, Yvar=train_Yvar) + training_data = TrainingData.from_block_design( + X=train_X, Y=train_Y, Yvar=train_Yvar + ) # Test that task features are required. with self.assertRaisesRegex(ValueError, "`task_features` required"): model.construct_inputs(training_data) diff --git a/test/utils/test_containers.py b/test/utils/test_containers.py index e92d34005c..7fa14ea3a3 100644 --- a/test/utils/test_containers.py +++ b/test/utils/test_containers.py @@ -5,22 +5,98 @@ # LICENSE file in the root directory of this source tree. import torch +from botorch.exceptions.errors import UnsupportedError from botorch.utils.containers import TrainingData from botorch.utils.testing import BotorchTestCase -class TestConstructContainers(BotorchTestCase): +class TestContainers(BotorchTestCase): def test_TrainingData(self): - X = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]]) - Y = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]]) - Yvar = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]]) - training_data = TrainingData(X, Y) + # block design, without variance observations + X_bd = torch.rand(2, 4, 3) + Y_bd = torch.rand(2, 4, 2) + training_data = TrainingData.from_block_design(X_bd, Y_bd) + self.assertTrue(training_data.is_block_design) + self.assertTrue(torch.equal(training_data.X, X_bd)) + self.assertTrue(torch.equal(training_data.Y, Y_bd)) + self.assertIsNone(training_data.Yvar) + self.assertTrue(torch.equal(Xi, X_bd) for Xi in training_data.Xs) + self.assertTrue(torch.equal(training_data.Ys[0], Y_bd[..., :1])) + self.assertTrue(torch.equal(training_data.Ys[1], Y_bd[..., 1:])) + self.assertIsNone(training_data.Yvars) + + # block design, with variance observations + Yvar_bd = torch.rand(2, 4, 2) + training_data = TrainingData.from_block_design(X_bd, Y_bd, Yvar_bd) + self.assertTrue(training_data.is_block_design) + self.assertTrue(torch.equal(training_data.X, X_bd)) + self.assertTrue(torch.equal(training_data.Y, Y_bd)) + self.assertTrue(torch.equal(training_data.Yvar, Yvar_bd)) + self.assertTrue(torch.equal(Xi, X_bd) for Xi in training_data.Xs) + self.assertTrue(torch.equal(training_data.Ys[0], Y_bd[..., :1])) + self.assertTrue(torch.equal(training_data.Ys[1], Y_bd[..., 1:])) + self.assertTrue(torch.equal(training_data.Yvars[0], Yvar_bd[..., :1])) + self.assertTrue(torch.equal(training_data.Yvars[1], Yvar_bd[..., 1:])) + + # non-block design, without variance observations + Xs = [torch.rand(2, 4, 3), torch.rand(2, 3, 3)] + Ys = [torch.rand(2, 4, 2), torch.rand(2, 3, 2)] + training_data = TrainingData(Xs, Ys) + self.assertFalse(training_data.is_block_design) + self.assertTrue(torch.equal(training_data.Xs[0], Xs[0])) + self.assertTrue(torch.equal(training_data.Xs[1], Xs[1])) + self.assertTrue(torch.equal(training_data.Ys[0], Ys[0])) + self.assertTrue(torch.equal(training_data.Ys[1], Ys[1])) + self.assertIsNone(training_data.Yvars) + with self.assertRaises(UnsupportedError): + training_data.X + with self.assertRaises(UnsupportedError): + training_data.Y + self.assertIsNone(training_data.Yvar) + + # non-block design, with variance observations + Yvars = [torch.rand(2, 4, 2), torch.rand(2, 3, 2)] + training_data = TrainingData(Xs, Ys, Yvars) + self.assertFalse(training_data.is_block_design) + self.assertTrue(torch.equal(training_data.Xs[0], Xs[0])) + self.assertTrue(torch.equal(training_data.Xs[1], Xs[1])) + self.assertTrue(torch.equal(training_data.Ys[0], Ys[0])) + self.assertTrue(torch.equal(training_data.Ys[1], Ys[1])) + self.assertTrue(torch.equal(training_data.Yvars[0], Yvars[0])) + self.assertTrue(torch.equal(training_data.Yvars[1], Yvars[1])) + with self.assertRaises(UnsupportedError): + training_data.X + with self.assertRaises(UnsupportedError): + training_data.Y + with self.assertRaises(UnsupportedError): + training_data.Yvar + + # implicit block design, without variance observations + X = torch.rand(2, 4, 3) + Xs = [X] * 2 + Ys = [torch.rand(2, 4, 2), torch.rand(2, 4, 2)] + training_data = TrainingData(Xs, Ys) + self.assertTrue(training_data.is_block_design) self.assertTrue(torch.equal(training_data.X, X)) - self.assertTrue(torch.equal(training_data.Y, Y)) - self.assertEqual(training_data.Yvar, None) + self.assertTrue(torch.equal(training_data.Y, torch.cat(Ys, dim=-1))) + self.assertIsNone(training_data.Yvar) + self.assertTrue(torch.equal(training_data.Xs[0], X)) + self.assertTrue(torch.equal(training_data.Xs[1], X)) + self.assertTrue(torch.equal(training_data.Ys[0], Ys[0])) + self.assertTrue(torch.equal(training_data.Ys[1], Ys[1])) + self.assertIsNone(training_data.Yvars) - training_data = TrainingData(X, Y, Yvar) + # implicit block design, with variance observations + Yvars = [torch.rand(2, 4, 2), torch.rand(2, 4, 2)] + training_data = TrainingData(Xs, Ys, Yvars) + self.assertTrue(training_data.is_block_design) self.assertTrue(torch.equal(training_data.X, X)) - self.assertTrue(torch.equal(training_data.Y, Y)) - self.assertTrue(torch.equal(training_data.Yvar, Yvar)) + self.assertTrue(torch.equal(training_data.Y, torch.cat(Ys, dim=-1))) + self.assertTrue(torch.equal(training_data.Yvar, torch.cat(Yvars, dim=-1))) + self.assertTrue(torch.equal(training_data.Xs[0], X)) + self.assertTrue(torch.equal(training_data.Xs[1], X)) + self.assertTrue(torch.equal(training_data.Ys[0], Ys[0])) + self.assertTrue(torch.equal(training_data.Ys[1], Ys[1])) + self.assertTrue(torch.equal(training_data.Yvars[0], Yvars[0])) + self.assertTrue(torch.equal(training_data.Yvars[1], Yvars[1]))