From 80d68fd06a2b7e9266916284114904b0164abdbb Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 16 Jun 2021 11:42:50 -0700 Subject: [PATCH] Implements AppendFeatures (#820) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/820 Implements AppendFeatures for built-in risk measure support in BoTorch. AppendFeatures: - Appends a given set of features to the input for joint posterior evaluation. - For a `batch x q x d`-dim input `X` and a `n_f x d_f`-dim `feature_set`, this results in a `batch x (q * n_f) x (d + d_f)`-dim tensor of appended inputs. Reviewed By: sdaulton Differential Revision: D28981162 fbshipit-source-id: f177b20e8162d7a43ae90c3206ede9215ee03dc7 --- botorch/models/transforms/input.py | 149 ++++++++++++++++++++++----- test/models/transforms/test_input.py | 35 +++++++ 2 files changed, 159 insertions(+), 25 deletions(-) diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index e2b0f28525..5e038638d3 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -92,7 +92,7 @@ def untransform(self, X: Tensor) -> Tensor: A `batch_shape x n x d`-dim tensor of un-transformed inputs. """ raise NotImplementedError( - f"{self.__class__.__name__} does not implement the `untransform` method" + f"{self.__class__.__name__} does not implement the `untransform` method." ) def equals(self, other: InputTransform) -> bool: @@ -104,7 +104,7 @@ def equals(self, other: InputTransform) -> bool: pytorch. See https://github.com/pytorch/pytorch/issues/7733. Args: - other: Another input transform + other: Another input transform. Returns: A boolean indicating if the other transform is equivalent. @@ -141,7 +141,7 @@ def preprocess_transform(self, X: Tensor) -> Tensor: class ChainedInputTransform(InputTransform, ModuleDict): - r"""An input transform representing the chaining of individual transforms""" + r"""An input transform representing the chaining of individual transforms.""" def __init__(self, **transforms: InputTransform) -> None: r"""Chaining of input transforms. @@ -204,7 +204,7 @@ def equals(self, other: InputTransform) -> bool: r"""Check if another input transform is equivalent. Args: - other: Another input transform + other: Another input transform. Returns: A boolean indicating if the other transform is equivalent. @@ -219,7 +219,7 @@ def preprocess_transform(self, X: Tensor) -> Tensor: The main use cases for this method are 1) to preprocess training data before calling `set_train_data` and 2) preprocess `X_baseline` for noisy acquisition functions so that `X_baseline` is "preprocessed" with the - same transformations as the cached training inputs + same transformations as the cached training inputs. Args: X: A `batch_shape x n x d`-dim tensor of inputs. @@ -291,7 +291,7 @@ def equals(self, other: InputTransform) -> bool: r"""Check if another input transform is equivalent. Args: - other: Another input transform + other: Another input transform. Returns: A boolean indicating if the other transform is equivalent. @@ -328,11 +328,11 @@ def __init__( of shape `batch_shape x n x d`). If provided, perform individual normalization per batch, otherwise uses a single normalization. transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: True + transforms in train() mode. Default: True. transform_on_eval: A boolean indicating whether to apply the - transform in eval() mode. Default: True + transform in eval() mode. Default: True. transform_on_fantasize: A boolean indicating whether to apply the - transform when called from within a `fantasize` call. Default: True + transform when called from within a `fantasize` call. Default: True. reverse: A boolean indicating whether the forward pass should untransform the inputs. """ @@ -340,7 +340,7 @@ def __init__( if bounds is not None: if bounds.size(-1) != d: raise BotorchTensorDimensionError( - "Incompatible dimensions of provided bounds" + "Dimensions of provided `bounds` are incompatible with `d`!" ) mins = bounds[..., 0:1, :] ranges = bounds[..., 1:2, :] - mins @@ -376,8 +376,8 @@ def _transform(self, X: Tensor) -> Tensor: if self.learn_bounds and self.training: if X.size(-1) != self.mins.size(-1): raise BotorchTensorDimensionError( - f"Wrong input. dimension. Received {X.size(-1)}, " - f"expected {self.mins.size(-1)}" + f"Wrong input dimension. Received {X.size(-1)}, " + f"expected {self.mins.size(-1)}." ) self.mins = X.min(dim=-2, keepdim=True)[0] self.ranges = X.max(dim=-2, keepdim=True)[0] - self.mins @@ -403,7 +403,7 @@ def equals(self, other: InputTransform) -> bool: r"""Check if another input transform is equivalent. Args: - other: Another input transform + other: Another input transform. Returns: A boolean indicating if the other transform is equivalent. @@ -465,16 +465,16 @@ def __init__( r"""Initialize transform. Args: - indices: The indices of the integer inputs + indices: The indices of the integer inputs. transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: True + transforms in train() mode. Default: True. transform_on_eval: A boolean indicating whether to apply the - transform in eval() mode. Default: True + transform in eval() mode. Default: True. transform_on_fantasize: A boolean indicating whether to apply the - transform when called from within a `fantasize` call. Default: True + transform when called from within a `fantasize` call. Default: True. approximate: A boolean indicating whether approximate or exact - rounding should be used. Default: approximate - tau: The temperature parameter for approximate rounding + rounding should be used. Default: approximate. + tau: The temperature parameter for approximate rounding. """ super().__init__() self.transform_on_train = transform_on_train @@ -506,7 +506,7 @@ def equals(self, other: InputTransform) -> bool: r"""Check if another input transform is equivalent. Args: - other: Another input transform + other: Another input transform. Returns: A boolean indicating if the other transform is equivalent. @@ -532,13 +532,13 @@ def __init__( r"""Initialize transform. Args: - indices: The indices of the inputs to log transform + indices: The indices of the inputs to log transform. transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: True + transforms in train() mode. Default: True. transform_on_eval: A boolean indicating whether to apply the - transform in eval() mode. Default: True + transform in eval() mode. Default: True. transform_on_fantasize: A boolean indicating whether to apply the - transform when called from within a `fantasize` call. Default: True + transform when called from within a `fantasize` call. Default: True. reverse: A boolean indicating whether the forward pass should untransform the inputs. """ @@ -615,7 +615,7 @@ def __init__( transform_on_eval: A boolean indicating whether to apply the transform in eval() mode. Default: True. transform_on_fantasize: A boolean indicating whether to apply the - transform when called from within a `fantasize` call. Default: True + transform when called from within a `fantasize` call. Default: True. reverse: A boolean indicating whether the forward pass should untransform the inputs. eps: A small value used to clip values to be in the interval (0, 1). @@ -728,3 +728,102 @@ def _untransform(self, X: Tensor) -> Tensor: (k.icdf(X_tf[..., self.indices]) - self._X_min) / self._X_range ).clamp(0.0, 1.0) return X_tf + + +class AppendFeatures(InputTransform, Module): + r"""A transform that appends the input with a given set of features. + + As an example, this can be used with `RiskMeasureMCObjective` to optimize risk + measures as described in [Cakmak2020risk]_. A tutorial notebook implementing the + rhoKG acqusition function introduced in [Cakmak2020risk]_ can be found at + https://botorch.org/tutorials/risk_averse_bo_with_environmental_variables. + + The steps for using this to obtain samples of a risk measure are as follows: + + - Train a model on `(x, w)` inputs and the corresponding observations; + + - Pass in an instance of `AppendFeatures` with the `feature_set` denoting the + samples of `W` as the `input_transform` to the trained model; + + - Call `posterior(...).rsample(...)` on the model with `x` inputs only to + get the joint posterior samples over `(x, w)`s, where the `w`s come + from the `feature_set`; + + - Pass these posterior samples through the `RiskMeasureMCObjective` of choice to + get the samples of the risk measure. + + Note: The samples of the risk measure obtained this way are in general biased + since the `feature_set` does not fully represent the distribution of the + environmental variable. + + Example: + >>> # We consider 1D `x` and 1D `w`, with `W` having a + >>> # uniform distribution over [0, 1] + >>> model = SingleTaskGP( + ... train_X=torch.rand(10, 2), + ... train_Y=torch.randn(10, 1), + ... input_transform=AppendFeatures(feature_set=torch.rand(10, 1)) + ... ) + >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) + >>> fit_gpytorch_model(mll) + >>> test_x = torch.rand(3, 1) + >>> # `posterior_samples` is a `10 x 30 x 1`-dim tensor + >>> posterior_samples = model.posterior(test_x).rsamples(torch.size([10])) + >>> risk_measure = VaR(alpha=0.8, n_w=10) + >>> # `risk_measure_samples` is a `10 x 3`-dim tensor of samples of the + >>> # risk measure VaR + >>> risk_measure_samples = risk_measure(posterior_samples) + """ + + def __init__( + self, + feature_set: Tensor, + transform_on_train: bool = False, + transform_on_eval: bool = True, + transform_on_fantasize: bool = False, + ) -> None: + r"""Append `feature_set` to each input. + + Args: + feature_set: An `n_f x d_f`-dim tensor denoting the features to be + appended to the inputs. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: False. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: False. + """ + super().__init__() + if feature_set.dim() != 2: + raise ValueError("`feature_set` must be an `n_f x d_f`-dim tensor!") + self.feature_set = feature_set + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + + def transform(self, X: Tensor) -> Tensor: + r"""Transform the inputs by appending `feature_set` to each input. + + For each `1 x d`-dim element in the input tensor, this will produce + an `n_f x (d + d_f)`-dim tensor with `feature_set` appended as the last `d_f` + dimensions. For a generic `batch_shape x q x d`-dim `X`, this translates to a + `batch_shape x (q * n_f) x (d + d_f)`-dim output, where the values corresponding + to `X[..., i, :]` are found in `output[..., i * n_f: (i + 1) * n_f, :]`. + + Note: Adding the `feature_set` on the `q-batch` dimension is necessary to avoid + introducing additional bias by evaluating the inputs on independent GP + sample paths. + + Args: + X: A `batch_shape x q x d`-dim tensor of inputs. + + Returns: + A `batch_shape x (q * n_f) x (d + d_f)`-dim tensor of appended inputs. + """ + expanded_X = X.unsqueeze(dim=-2).expand( + *X.shape[:-1], self.feature_set.shape[0], -1 + ) + expanded_features = self.feature_set.expand(*expanded_X.shape[:-1], -1) + appended_X = torch.cat([expanded_X, expanded_features], dim=-1) + return appended_X.view(*X.shape[:-2], -1, appended_X.shape[-1]) diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 2e3a3332c0..54e6bad331 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -10,6 +10,7 @@ import torch from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.models.transforms.input import ( + AppendFeatures, ChainedInputTransform, InputTransform, Warp, @@ -568,3 +569,37 @@ def test_warp_transform(self): self.assertTrue((warp_tf.concentration0 == 2.0).all()) warp_tf._set_concentration(i=1, value=3.0) self.assertTrue((warp_tf.concentration1 == 3.0).all()) + + +class TestAppendFeatures(BotorchTestCase): + def test_append_features(self): + with self.assertRaises(ValueError): + AppendFeatures(torch.ones(1)) + with self.assertRaises(ValueError): + AppendFeatures(torch.ones(3, 4, 2)) + + for dtype in (torch.float, torch.double): + feature_set = ( + torch.linspace(0, 1, 6).view(3, 2).to(device=self.device, dtype=dtype) + ) + transform = AppendFeatures(feature_set=feature_set) + X = torch.rand(4, 5, 3, device=self.device, dtype=dtype) + # in train - no transform + transform.train() + transformed_X = transform(X) + self.assertTrue(torch.equal(X, transformed_X)) + # in eval - yes transform + transform.eval() + transformed_X = transform(X) + self.assertFalse(torch.equal(X, transformed_X)) + self.assertEqual(transformed_X.shape, torch.Size([4, 15, 5])) + self.assertTrue( + torch.equal(transformed_X[..., :3], X.repeat_interleave(3, dim=-2)) + ) + self.assertTrue( + torch.equal(transformed_X[..., 3:], feature_set.repeat(4, 5, 1)) + ) + # in fantasize - no transform + with fantasize(): + transformed_X = transform(X) + self.assertTrue(torch.equal(X, transformed_X))