From db7d6b733b347e2060987ae2d3d5ba9571778960 Mon Sep 17 00:00:00 2001 From: Eric Lou Date: Wed, 8 Jul 2020 08:20:39 -0700 Subject: [PATCH] Model Input Standardization Using `TrainingData` (#477) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/477 Different GP models take different kwargs as inputs into their constructors. To standardize the inputs, we create a `TrainingData` dataclass in conjunction with a classmethod `construct_inputs()`. Reviewed By: Balandat Differential Revision: D22395030 fbshipit-source-id: 38e6283a2b86fdfc69060eaa579d5b1d152c7475 --- botorch/models/gp_regression.py | 19 ++++++++++++++- botorch/models/model.py | 10 +++++++- botorch/utils/containers.py | 21 +++++++++++++++++ sphinx/source/utils.rst | 5 ++++ test/models/test_gp_regression.py | 39 +++++++++++++++++++++++++++++++ test/models/test_model.py | 2 ++ test/utils/test_containers.py | 26 +++++++++++++++++++++ 7 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 botorch/utils/containers.py create mode 100644 test/utils/test_containers.py diff --git a/botorch/models/gp_regression.py b/botorch/models/gp_regression.py index 4d2adc7297..1981efc5e7 100644 --- a/botorch/models/gp_regression.py +++ b/botorch/models/gp_regression.py @@ -10,7 +10,7 @@ from __future__ import annotations -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch from botorch import settings @@ -18,6 +18,7 @@ from botorch.models.transforms.outcome import Log, OutcomeTransform from botorch.models.utils import validate_input_scaling from botorch.sampling.samplers import MCSampler +from botorch.utils.containers import TrainingData from gpytorch.constraints.constraints import GreaterThan from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels.matern_kernel import MaternKernel @@ -136,6 +137,11 @@ def forward(self, x: Tensor) -> MultivariateNormal: covar_x = self.covar_module(x) return MultivariateNormal(mean_x, covar_x) + @classmethod + def construct_inputs(cls, training_data: TrainingData) -> Dict[str, Any]: + r"""Standardize kwargs of the model constructor.""" + return {"train_X": training_data.Xs[0], "train_Y": training_data.Ys[-1]} + class FixedNoiseGP(BatchedMultiOutputGPyTorchModel, ExactGP): r"""A single-task exact GP model using fixed noise levels. @@ -276,6 +282,17 @@ def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel: new_model.likelihood.noise_covar.noise = new_noise return new_model + @classmethod + def construct_inputs(cls, training_data: TrainingData) -> Dict[str, Any]: + r"""Standardize kwargs of the model constructor.""" + if training_data.Yvars is None: + raise ValueError("Training data is missing Yvars member") + return { + "train_X": training_data.Xs[0], + "train_Y": training_data.Ys[-1], + "train_Yvar": training_data.Yvars[0], + } + class HeteroskedasticSingleTaskGP(SingleTaskGP): r"""A single-task exact GP model using a heteroskeastic noise model. diff --git a/botorch/models/model.py b/botorch/models/model.py index cfd0a7a3c5..9741dc8a19 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -11,11 +11,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from botorch import settings from botorch.posteriors import Posterior from botorch.sampling.samplers import MCSampler +from botorch.utils.containers import TrainingData from torch import Tensor from torch.nn import Module @@ -123,3 +124,10 @@ def fantasize( post_X = self.posterior(X, observation_noise=observation_noise) Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m return self.condition_on_observations(X=X, Y=Y_fantasized, **kwargs) + + @classmethod + def construct_inputs(cls, training_data: TrainingData) -> Dict[str, Any]: + r"""Standardize kwargs of the model constructor.""" + raise NotImplementedError( + f"`construct_inputs` not implemented for {cls.__name__}." + ) diff --git a/botorch/utils/containers.py b/botorch/utils/containers.py new file mode 100644 index 0000000000..28c7ce20ed --- /dev/null +++ b/botorch/utils/containers.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Containers to standardize inputs into models and acquisition functions. +""" + +from typing import List, NamedTuple, Optional + +from torch import Tensor + + +class TrainingData(NamedTuple): + r"""Standardized struct of model training data.""" + + Xs: List[Tensor] + Ys: List[Tensor] + Yvars: Optional[List[Tensor]] = None diff --git a/sphinx/source/utils.rst b/sphinx/source/utils.rst index 5be2f18dea..37038110d0 100644 --- a/sphinx/source/utils.rst +++ b/sphinx/source/utils.rst @@ -12,6 +12,11 @@ Constraints .. automodule:: botorch.utils.constraints :members: +Containers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.utils.containers + :members: + Objective ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.utils.objective diff --git a/test/models/test_gp_regression.py b/test/models/test_gp_regression.py index 8dfe40b0fc..2a3ced49b3 100644 --- a/test/models/test_gp_regression.py +++ b/test/models/test_gp_regression.py @@ -19,6 +19,7 @@ from botorch.models.utils import add_output_dim from botorch.posteriors import GPyTorchPosterior from botorch.sampling import SobolQMCNormalSampler +from botorch.utils.containers import TrainingData from botorch.utils.sampling import manual_seed from botorch.utils.testing import BotorchTestCase, _get_random_data from gpytorch.kernels import MaternKernel, ScaleKernel @@ -271,6 +272,22 @@ def test_subset_model(self): ) ) + def test_construct_inputs(self): + for batch_shape, dtype in itertools.product( + (torch.Size(), torch.Size([2])), (torch.float, torch.double) + ): + tkwargs = {"device": self.device, "dtype": dtype} + model, model_kwargs = self._get_model_and_data( + batch_shape=batch_shape, m=2, **tkwargs + ) + training_data = TrainingData( + Xs=model_kwargs["train_X"], + Ys=model_kwargs["train_Y"], + Yvars=torch.full_like(model_kwargs["train_Y"], 0.01), + ) + data_dict = model.construct_inputs(training_data) + self.assertTrue("train_Yvar" not in data_dict) + class TestFixedNoiseGP(TestSingleTaskGP): def _get_model_and_data(self, batch_shape, m, outcome_transform=None, **tkwargs): @@ -303,6 +320,28 @@ def test_fixed_noise_likelihood(self): ) ) + def test_construct_inputs(self): + for batch_shape, dtype in itertools.product( + (torch.Size(), torch.Size([2])), (torch.float, torch.double) + ): + tkwargs = {"device": self.device, "dtype": dtype} + model, model_kwargs = self._get_model_and_data( + batch_shape=batch_shape, m=2, **tkwargs + ) + training_data = TrainingData( + Xs=model_kwargs["train_X"], + Ys=model_kwargs["train_Y"], + Yvars=model_kwargs["train_Yvar"], + ) + data_dict = model.construct_inputs(training_data) + self.assertTrue("train_Yvar" in data_dict) + # if Yvars is missing, then raise error + training_data = TrainingData( + Xs=model_kwargs["train_X"], Ys=model_kwargs["train_Y"] + ) + with self.assertRaises(ValueError): + model.construct_inputs(training_data) + class TestHeteroskedasticSingleTaskGP(TestSingleTaskGP): def _get_model_and_data(self, batch_shape, m, outcome_transform=None, **tkwargs): diff --git a/test/models/test_model.py b/test/models/test_model.py index ce0366fdf3..a32bccdec4 100644 --- a/test/models/test_model.py +++ b/test/models/test_model.py @@ -26,3 +26,5 @@ def test_not_so_abstract_base_model(self): model.num_outputs with self.assertRaises(NotImplementedError): model.subset_output([0]) + with self.assertRaises(NotImplementedError): + model.construct_inputs(None) diff --git a/test/utils/test_containers.py b/test/utils/test_containers.py new file mode 100644 index 0000000000..ba527a2468 --- /dev/null +++ b/test/utils/test_containers.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.utils.containers import TrainingData +from botorch.utils.testing import BotorchTestCase + + +class TestConstructContainers(BotorchTestCase): + def test_TrainingData(self): + Xs = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]]) + Ys = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]]) + Yvars = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]]) + + training_data = TrainingData(Xs, Ys) + self.assertTrue(torch.equal(training_data.Xs, Xs)) + self.assertTrue(torch.equal(training_data.Ys, Ys)) + self.assertEqual(training_data.Yvars, None) + + training_data = TrainingData(Xs, Ys, Yvars) + self.assertTrue(torch.equal(training_data.Xs, Xs)) + self.assertTrue(torch.equal(training_data.Ys, Ys)) + self.assertTrue(torch.equal(training_data.Yvars, Yvars))