Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 124 additions & 25 deletions botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def untransform(self, X: Tensor) -> Tensor:
A `batch_shape x n x d`-dim tensor of un-transformed inputs.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement the `untransform` method"
f"{self.__class__.__name__} does not implement the `untransform` method."
)

def equals(self, other: InputTransform) -> bool:
Expand All @@ -104,7 +104,7 @@ def equals(self, other: InputTransform) -> bool:
pytorch. See https://github.com/pytorch/pytorch/issues/7733.

Args:
other: Another input transform
other: Another input transform.

Returns:
A boolean indicating if the other transform is equivalent.
Expand Down Expand Up @@ -141,7 +141,7 @@ def preprocess_transform(self, X: Tensor) -> Tensor:


class ChainedInputTransform(InputTransform, ModuleDict):
r"""An input transform representing the chaining of individual transforms"""
r"""An input transform representing the chaining of individual transforms."""

def __init__(self, **transforms: InputTransform) -> None:
r"""Chaining of input transforms.
Expand Down Expand Up @@ -204,7 +204,7 @@ def equals(self, other: InputTransform) -> bool:
r"""Check if another input transform is equivalent.

Args:
other: Another input transform
other: Another input transform.

Returns:
A boolean indicating if the other transform is equivalent.
Expand All @@ -219,7 +219,7 @@ def preprocess_transform(self, X: Tensor) -> Tensor:
The main use cases for this method are 1) to preprocess training data
before calling `set_train_data` and 2) preprocess `X_baseline` for noisy
acquisition functions so that `X_baseline` is "preprocessed" with the
same transformations as the cached training inputs
same transformations as the cached training inputs.

Args:
X: A `batch_shape x n x d`-dim tensor of inputs.
Expand Down Expand Up @@ -291,7 +291,7 @@ def equals(self, other: InputTransform) -> bool:
r"""Check if another input transform is equivalent.

Args:
other: Another input transform
other: Another input transform.

Returns:
A boolean indicating if the other transform is equivalent.
Expand Down Expand Up @@ -328,19 +328,19 @@ def __init__(
of shape `batch_shape x n x d`). If provided, perform individual
normalization per batch, otherwise uses a single normalization.
transform_on_train: A boolean indicating whether to apply the
transforms in train() mode. Default: True
transforms in train() mode. Default: True.
transform_on_eval: A boolean indicating whether to apply the
transform in eval() mode. Default: True
transform in eval() mode. Default: True.
transform_on_fantasize: A boolean indicating whether to apply the
transform when called from within a `fantasize` call. Default: True
transform when called from within a `fantasize` call. Default: True.
reverse: A boolean indicating whether the forward pass should untransform
the inputs.
"""
super().__init__()
if bounds is not None:
if bounds.size(-1) != d:
raise BotorchTensorDimensionError(
"Incompatible dimensions of provided bounds"
"Dimensions of provided `bounds` are incompatible with `d`!"
)
mins = bounds[..., 0:1, :]
ranges = bounds[..., 1:2, :] - mins
Expand Down Expand Up @@ -376,8 +376,8 @@ def _transform(self, X: Tensor) -> Tensor:
if self.learn_bounds and self.training:
if X.size(-1) != self.mins.size(-1):
raise BotorchTensorDimensionError(
f"Wrong input. dimension. Received {X.size(-1)}, "
f"expected {self.mins.size(-1)}"
f"Wrong input dimension. Received {X.size(-1)}, "
f"expected {self.mins.size(-1)}."
)
self.mins = X.min(dim=-2, keepdim=True)[0]
self.ranges = X.max(dim=-2, keepdim=True)[0] - self.mins
Expand All @@ -403,7 +403,7 @@ def equals(self, other: InputTransform) -> bool:
r"""Check if another input transform is equivalent.

Args:
other: Another input transform
other: Another input transform.

Returns:
A boolean indicating if the other transform is equivalent.
Expand Down Expand Up @@ -465,16 +465,16 @@ def __init__(
r"""Initialize transform.

Args:
indices: The indices of the integer inputs
indices: The indices of the integer inputs.
transform_on_train: A boolean indicating whether to apply the
transforms in train() mode. Default: True
transforms in train() mode. Default: True.
transform_on_eval: A boolean indicating whether to apply the
transform in eval() mode. Default: True
transform in eval() mode. Default: True.
transform_on_fantasize: A boolean indicating whether to apply the
transform when called from within a `fantasize` call. Default: True
transform when called from within a `fantasize` call. Default: True.
approximate: A boolean indicating whether approximate or exact
rounding should be used. Default: approximate
tau: The temperature parameter for approximate rounding
rounding should be used. Default: approximate.
tau: The temperature parameter for approximate rounding.
"""
super().__init__()
self.transform_on_train = transform_on_train
Expand Down Expand Up @@ -506,7 +506,7 @@ def equals(self, other: InputTransform) -> bool:
r"""Check if another input transform is equivalent.

Args:
other: Another input transform
other: Another input transform.

Returns:
A boolean indicating if the other transform is equivalent.
Expand All @@ -532,13 +532,13 @@ def __init__(
r"""Initialize transform.

Args:
indices: The indices of the inputs to log transform
indices: The indices of the inputs to log transform.
transform_on_train: A boolean indicating whether to apply the
transforms in train() mode. Default: True
transforms in train() mode. Default: True.
transform_on_eval: A boolean indicating whether to apply the
transform in eval() mode. Default: True
transform in eval() mode. Default: True.
transform_on_fantasize: A boolean indicating whether to apply the
transform when called from within a `fantasize` call. Default: True
transform when called from within a `fantasize` call. Default: True.
reverse: A boolean indicating whether the forward pass should untransform
the inputs.
"""
Expand Down Expand Up @@ -615,7 +615,7 @@ def __init__(
transform_on_eval: A boolean indicating whether to apply the
transform in eval() mode. Default: True.
transform_on_fantasize: A boolean indicating whether to apply the
transform when called from within a `fantasize` call. Default: True
transform when called from within a `fantasize` call. Default: True.
reverse: A boolean indicating whether the forward pass should untransform
the inputs.
eps: A small value used to clip values to be in the interval (0, 1).
Expand Down Expand Up @@ -728,3 +728,102 @@ def _untransform(self, X: Tensor) -> Tensor:
(k.icdf(X_tf[..., self.indices]) - self._X_min) / self._X_range
).clamp(0.0, 1.0)
return X_tf


class AppendFeatures(InputTransform, Module):
r"""A transform that appends the input with a given set of features.

As an example, this can be used with `RiskMeasureMCObjective` to optimize risk
measures as described in [Cakmak2020risk]_. A tutorial notebook implementing the
rhoKG acqusition function introduced in [Cakmak2020risk]_ can be found at
https://botorch.org/tutorials/risk_averse_bo_with_environmental_variables.

The steps for using this to obtain samples of a risk measure are as follows:

- Train a model on `(x, w)` inputs and the corresponding observations;

- Pass in an instance of `AppendFeatures` with the `feature_set` denoting the
samples of `W` as the `input_transform` to the trained model;

- Call `posterior(...).rsample(...)` on the model with `x` inputs only to
get the joint posterior samples over `(x, w)`s, where the `w`s come
from the `feature_set`;

- Pass these posterior samples through the `RiskMeasureMCObjective` of choice to
get the samples of the risk measure.

Note: The samples of the risk measure obtained this way are in general biased
since the `feature_set` does not fully represent the distribution of the
environmental variable.

Example:
>>> # We consider 1D `x` and 1D `w`, with `W` having a
>>> # uniform distribution over [0, 1]
>>> model = SingleTaskGP(
... train_X=torch.rand(10, 2),
... train_Y=torch.randn(10, 1),
... input_transform=AppendFeatures(feature_set=torch.rand(10, 1))
... )
>>> mll = ExactMarginalLogLikelihood(model.likelihood, model)
>>> fit_gpytorch_model(mll)
>>> test_x = torch.rand(3, 1)
>>> # `posterior_samples` is a `10 x 30 x 1`-dim tensor
>>> posterior_samples = model.posterior(test_x).rsamples(torch.size([10]))
>>> risk_measure = VaR(alpha=0.8, n_w=10)
>>> # `risk_measure_samples` is a `10 x 3`-dim tensor of samples of the
>>> # risk measure VaR
>>> risk_measure_samples = risk_measure(posterior_samples)
"""

def __init__(
self,
feature_set: Tensor,
transform_on_train: bool = False,
transform_on_eval: bool = True,
transform_on_fantasize: bool = False,
) -> None:
r"""Append `feature_set` to each input.

Args:
feature_set: An `n_f x d_f`-dim tensor denoting the features to be
appended to the inputs.
transform_on_train: A boolean indicating whether to apply the
transforms in train() mode. Default: False.
transform_on_eval: A boolean indicating whether to apply the
transform in eval() mode. Default: True.
transform_on_fantasize: A boolean indicating whether to apply the
transform when called from within a `fantasize` call. Default: False.
"""
super().__init__()
if feature_set.dim() != 2:
raise ValueError("`feature_set` must be an `n_f x d_f`-dim tensor!")
self.feature_set = feature_set
self.transform_on_train = transform_on_train
self.transform_on_eval = transform_on_eval
self.transform_on_fantasize = transform_on_fantasize

def transform(self, X: Tensor) -> Tensor:
r"""Transform the inputs by appending `feature_set` to each input.

For each `1 x d`-dim element in the input tensor, this will produce
an `n_f x (d + d_f)`-dim tensor with `feature_set` appended as the last `d_f`
dimensions. For a generic `batch_shape x q x d`-dim `X`, this translates to a
`batch_shape x (q * n_f) x (d + d_f)`-dim output, where the values corresponding
to `X[..., i, :]` are found in `output[..., i * n_f: (i + 1) * n_f, :]`.

Note: Adding the `feature_set` on the `q-batch` dimension is necessary to avoid
introducing additional bias by evaluating the inputs on independent GP
sample paths.

Args:
X: A `batch_shape x q x d`-dim tensor of inputs.

Returns:
A `batch_shape x (q * n_f) x (d + d_f)`-dim tensor of appended inputs.
"""
expanded_X = X.unsqueeze(dim=-2).expand(
*X.shape[:-1], self.feature_set.shape[0], -1
)
expanded_features = self.feature_set.expand(*expanded_X.shape[:-1], -1)
appended_X = torch.cat([expanded_X, expanded_features], dim=-1)
return appended_X.view(*X.shape[:-2], -1, appended_X.shape[-1])
35 changes: 35 additions & 0 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.models.transforms.input import (
AppendFeatures,
ChainedInputTransform,
InputTransform,
Warp,
Expand Down Expand Up @@ -568,3 +569,37 @@ def test_warp_transform(self):
self.assertTrue((warp_tf.concentration0 == 2.0).all())
warp_tf._set_concentration(i=1, value=3.0)
self.assertTrue((warp_tf.concentration1 == 3.0).all())


class TestAppendFeatures(BotorchTestCase):
def test_append_features(self):
with self.assertRaises(ValueError):
AppendFeatures(torch.ones(1))
with self.assertRaises(ValueError):
AppendFeatures(torch.ones(3, 4, 2))

for dtype in (torch.float, torch.double):
feature_set = (
torch.linspace(0, 1, 6).view(3, 2).to(device=self.device, dtype=dtype)
)
transform = AppendFeatures(feature_set=feature_set)
X = torch.rand(4, 5, 3, device=self.device, dtype=dtype)
# in train - no transform
transform.train()
transformed_X = transform(X)
self.assertTrue(torch.equal(X, transformed_X))
# in eval - yes transform
transform.eval()
transformed_X = transform(X)
self.assertFalse(torch.equal(X, transformed_X))
self.assertEqual(transformed_X.shape, torch.Size([4, 15, 5]))
self.assertTrue(
torch.equal(transformed_X[..., :3], X.repeat_interleave(3, dim=-2))
)
self.assertTrue(
torch.equal(transformed_X[..., 3:], feature_set.repeat(4, 5, 1))
)
# in fantasize - no transform
with fantasize():
transformed_X = transform(X)
self.assertTrue(torch.equal(X, transformed_X))