Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 85 additions & 6 deletions botorch/utils/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,93 @@
Containers to standardize inputs into models and acquisition functions.
"""

from typing import NamedTuple, Optional
from dataclasses import dataclass
from typing import List, Optional

import torch
from botorch.exceptions.errors import UnsupportedError
from torch import Tensor


class TrainingData(NamedTuple):
r"""Standardized struct of model training data for a single outcome."""
@dataclass
class TrainingData:
r"""Standardized container of model training data for models.

X: Tensor
Y: Tensor
Yvar: Optional[Tensor] = None
Properties:
Xs: A list of tensors, each of shape `batch_shape x n_i x d`,
where `n_i` is the number of training inputs for the i-th model.
Ys: A list of tensors, each of shape `batch_shape x n_i x 1`,
where `n_i` is the number of training observations for the i-th
(single-output) model.
Yvars: A list of tensors, each of shape `batch_shape x n_i x 1`,
where `n_i` is the number of training observations of the
observation noise for the i-th (single-output) model.
If `None`, the observation noise level is unobserved.
"""

Xs: List[Tensor] # `batch_shape x n_i x 1`
Ys: List[Tensor] # `batch_shape x n_i x 1`
Yvars: Optional[List[Tensor]] = None # `batch_shape x n_i x 1`

def __post_init__(self):
self._is_block_design = all(torch.equal(X, self.Xs[0]) for X in self.Xs[1:])

@classmethod
def from_block_design(cls, X: Tensor, Y: Tensor, Yvar: Optional[Tensor] = None):
r"""Construct a TrainingData object from a block design description.

Args:
X: A `batch_shape x n x d` tensor of training points (shared across
all outcomes).
Y: A `batch_shape x n x m` tensor of training observations.
Yvar: A `batch_shape x n x m` tensor of training noise variance
observations, or `None`.

Returns:
The `TrainingData` object (with `is_block_design=True`).
"""
return cls(
Xs=[X for _ in range(Y.shape[-1])],
Ys=list(torch.split(Y, 1, dim=-1)),
Yvars=None if Yvar is None else list(torch.split(Yvar, 1, dim=-1)),
)

@property
def is_block_design(self) -> bool:
r"""Indicates whether training data is a "block design".

Block designs are designs in which all outcomes are observed
at the same training inputs.
"""
return self._is_block_design

@property
def X(self) -> Tensor:
r"""The training inputs (block-design only).

This raises an `UnsupportedError` in the non-block-design case.
"""
if not self.is_block_design:
raise UnsupportedError
return self.Xs[0]

@property
def Y(self) -> Tensor:
r"""The training observations (block-design only).

This raises an `UnsupportedError` in the non-block-design case.
"""
if not self.is_block_design:
raise UnsupportedError
return torch.cat(self.Ys, dim=-1)

@property
def Yvar(self) -> Optional[List[Tensor]]:
r"""The training observations's noise variance (block-design only).

This raises an `UnsupportedError` in the non-block-design case.
"""
if self.Yvars is not None:
if not self.is_block_design:
raise UnsupportedError
return torch.cat(self.Yvars, dim=-1)
6 changes: 3 additions & 3 deletions test/models/test_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_construct_inputs(self):
model, model_kwargs = self._get_model_and_data(
batch_shape=batch_shape, m=2, **tkwargs
)
training_data = TrainingData(
training_data = TrainingData.from_block_design(
X=model_kwargs["train_X"], Y=model_kwargs["train_Y"]
)
data_dict = model.construct_inputs(training_data)
Expand Down Expand Up @@ -352,7 +352,7 @@ def test_construct_inputs(self):
model, model_kwargs = self._get_model_and_data(
batch_shape=batch_shape, m=2, **tkwargs
)
training_data = TrainingData(
training_data = TrainingData.from_block_design(
X=model_kwargs["train_X"],
Y=model_kwargs["train_Y"],
Yvar=model_kwargs["train_Yvar"],
Expand All @@ -365,7 +365,7 @@ def test_construct_inputs(self):
torch.equal(data_dict["train_Yvar"], model_kwargs["train_Yvar"])
)
# if Yvars is missing, then raise error
training_data = TrainingData(
training_data = TrainingData.from_block_design(
X=model_kwargs["train_X"], Y=model_kwargs["train_Y"]
)
with self.assertRaises(ValueError):
Expand Down
6 changes: 3 additions & 3 deletions test/models/test_gp_regression_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def test_construct_inputs(self):
**tkwargs,
)
# len(Xs) == len(Ys) == 1
training_data = TrainingData(
training_data = TrainingData.from_block_design(
X=model_kwargs["train_X"],
Y=model_kwargs["train_Y"],
Yvar=torch.full_like(model_kwargs["train_Y"], 0.01),
Expand Down Expand Up @@ -483,14 +483,14 @@ def test_construct_inputs(self):
lin_truncated=lin_trunc,
**tkwargs,
)
training_data = TrainingData(
training_data = TrainingData.from_block_design(
X=model_kwargs["train_X"], Y=model_kwargs["train_Y"]
)
# missing Yvars
with self.assertRaises(ValueError):
model.construct_inputs(training_data, fidelity_features=[1])
# len(Xs) == len(Ys) == 1
training_data = TrainingData(
training_data = TrainingData.from_block_design(
X=model_kwargs["train_X"],
Y=model_kwargs["train_Y"],
Yvar=torch.full_like(model_kwargs["train_Y"], 0.01),
Expand Down
2 changes: 1 addition & 1 deletion test/models/test_gp_regression_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def test_construct_inputs(self):
batch_shape=batch_shape, m=m, d=d, **tkwargs
)
cat_dims = list(range(ncat))
training_data = TrainingData(X=train_X, Y=train_Y)
training_data = TrainingData.from_block_design(X=train_X, Y=train_Y)
kwarg_dict = MixedSingleTaskGP.construct_inputs(
training_data, categorical_features=cat_dims
)
Expand Down
8 changes: 5 additions & 3 deletions test/models/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def test_MultiTaskGP_construct_inputs(self):
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}
model, train_X, train_Y = _get_model_and_training_data(**tkwargs)
training_data = TrainingData(X=train_X, Y=train_Y)
training_data = TrainingData.from_block_design(X=train_X, Y=train_Y)
# Test that task features are required.
with self.assertRaisesRegex(ValueError, "`task_features` required"):
model.construct_inputs(training_data)
Expand Down Expand Up @@ -444,11 +444,13 @@ def test_FixedNoiseMultiTaskGP_construct_inputs(self):
train_Y,
train_Yvar,
) = _get_fixed_noise_model_and_training_data(**tkwargs)
td_no_Yvar = TrainingData(X=train_X, Y=train_Y)
td_no_Yvar = TrainingData.from_block_design(X=train_X, Y=train_Y)
# Test that Yvar is required.
with self.assertRaisesRegex(ValueError, "Yvar required"):
model.construct_inputs(td_no_Yvar)
training_data = TrainingData(X=train_X, Y=train_Y, Yvar=train_Yvar)
training_data = TrainingData.from_block_design(
X=train_X, Y=train_Y, Yvar=train_Yvar
)
# Test that task features are required.
with self.assertRaisesRegex(ValueError, "`task_features` required"):
model.construct_inputs(training_data)
Expand Down
96 changes: 86 additions & 10 deletions test/utils/test_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,98 @@
# LICENSE file in the root directory of this source tree.

import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.utils.containers import TrainingData
from botorch.utils.testing import BotorchTestCase


class TestConstructContainers(BotorchTestCase):
class TestContainers(BotorchTestCase):
def test_TrainingData(self):
X = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]])
Y = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]])
Yvar = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]])

training_data = TrainingData(X, Y)
# block design, without variance observations
X_bd = torch.rand(2, 4, 3)
Y_bd = torch.rand(2, 4, 2)
training_data = TrainingData.from_block_design(X_bd, Y_bd)
self.assertTrue(training_data.is_block_design)
self.assertTrue(torch.equal(training_data.X, X_bd))
self.assertTrue(torch.equal(training_data.Y, Y_bd))
self.assertIsNone(training_data.Yvar)
self.assertTrue(torch.equal(Xi, X_bd) for Xi in training_data.Xs)
self.assertTrue(torch.equal(training_data.Ys[0], Y_bd[..., :1]))
self.assertTrue(torch.equal(training_data.Ys[1], Y_bd[..., 1:]))
self.assertIsNone(training_data.Yvars)

# block design, with variance observations
Yvar_bd = torch.rand(2, 4, 2)
training_data = TrainingData.from_block_design(X_bd, Y_bd, Yvar_bd)
self.assertTrue(training_data.is_block_design)
self.assertTrue(torch.equal(training_data.X, X_bd))
self.assertTrue(torch.equal(training_data.Y, Y_bd))
self.assertTrue(torch.equal(training_data.Yvar, Yvar_bd))
self.assertTrue(torch.equal(Xi, X_bd) for Xi in training_data.Xs)
self.assertTrue(torch.equal(training_data.Ys[0], Y_bd[..., :1]))
self.assertTrue(torch.equal(training_data.Ys[1], Y_bd[..., 1:]))
self.assertTrue(torch.equal(training_data.Yvars[0], Yvar_bd[..., :1]))
self.assertTrue(torch.equal(training_data.Yvars[1], Yvar_bd[..., 1:]))

# non-block design, without variance observations
Xs = [torch.rand(2, 4, 3), torch.rand(2, 3, 3)]
Ys = [torch.rand(2, 4, 2), torch.rand(2, 3, 2)]
training_data = TrainingData(Xs, Ys)
self.assertFalse(training_data.is_block_design)
self.assertTrue(torch.equal(training_data.Xs[0], Xs[0]))
self.assertTrue(torch.equal(training_data.Xs[1], Xs[1]))
self.assertTrue(torch.equal(training_data.Ys[0], Ys[0]))
self.assertTrue(torch.equal(training_data.Ys[1], Ys[1]))
self.assertIsNone(training_data.Yvars)
with self.assertRaises(UnsupportedError):
training_data.X
with self.assertRaises(UnsupportedError):
training_data.Y
self.assertIsNone(training_data.Yvar)

# non-block design, with variance observations
Yvars = [torch.rand(2, 4, 2), torch.rand(2, 3, 2)]
training_data = TrainingData(Xs, Ys, Yvars)
self.assertFalse(training_data.is_block_design)
self.assertTrue(torch.equal(training_data.Xs[0], Xs[0]))
self.assertTrue(torch.equal(training_data.Xs[1], Xs[1]))
self.assertTrue(torch.equal(training_data.Ys[0], Ys[0]))
self.assertTrue(torch.equal(training_data.Ys[1], Ys[1]))
self.assertTrue(torch.equal(training_data.Yvars[0], Yvars[0]))
self.assertTrue(torch.equal(training_data.Yvars[1], Yvars[1]))
with self.assertRaises(UnsupportedError):
training_data.X
with self.assertRaises(UnsupportedError):
training_data.Y
with self.assertRaises(UnsupportedError):
training_data.Yvar

# implicit block design, without variance observations
X = torch.rand(2, 4, 3)
Xs = [X] * 2
Ys = [torch.rand(2, 4, 2), torch.rand(2, 4, 2)]
training_data = TrainingData(Xs, Ys)
self.assertTrue(training_data.is_block_design)
self.assertTrue(torch.equal(training_data.X, X))
self.assertTrue(torch.equal(training_data.Y, Y))
self.assertEqual(training_data.Yvar, None)
self.assertTrue(torch.equal(training_data.Y, torch.cat(Ys, dim=-1)))
self.assertIsNone(training_data.Yvar)
self.assertTrue(torch.equal(training_data.Xs[0], X))
self.assertTrue(torch.equal(training_data.Xs[1], X))
self.assertTrue(torch.equal(training_data.Ys[0], Ys[0]))
self.assertTrue(torch.equal(training_data.Ys[1], Ys[1]))
self.assertIsNone(training_data.Yvars)

training_data = TrainingData(X, Y, Yvar)
# implicit block design, with variance observations
Yvars = [torch.rand(2, 4, 2), torch.rand(2, 4, 2)]
training_data = TrainingData(Xs, Ys, Yvars)
self.assertTrue(training_data.is_block_design)
self.assertTrue(torch.equal(training_data.X, X))
self.assertTrue(torch.equal(training_data.Y, Y))
self.assertTrue(torch.equal(training_data.Yvar, Yvar))
self.assertTrue(torch.equal(training_data.Y, torch.cat(Ys, dim=-1)))
self.assertTrue(torch.equal(training_data.Yvar, torch.cat(Yvars, dim=-1)))
self.assertTrue(torch.equal(training_data.Xs[0], X))
self.assertTrue(torch.equal(training_data.Xs[1], X))
self.assertTrue(torch.equal(training_data.Ys[0], Ys[0]))
self.assertTrue(torch.equal(training_data.Ys[1], Ys[1]))
self.assertTrue(torch.equal(training_data.Yvars[0], Yvars[0]))
self.assertTrue(torch.equal(training_data.Yvars[1], Yvars[1]))