Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ def __init__(
standardization (if lower, only de-mean the data).
"""
super().__init__()
self.register_buffer("means", torch.zeros(*batch_shape, 1, m))
self.register_buffer("stdvs", torch.zeros(*batch_shape, 1, m))
self.register_buffer("_stdvs_sq", torch.zeros(*batch_shape, 1, m))
self.register_buffer("means", None)
self.register_buffer("stdvs", None)
self.register_buffer("_stdvs_sq", None)
self._outputs = normalize_indices(outputs, d=m)
self._m = m
self._batch_shape = batch_shape
Expand Down Expand Up @@ -296,9 +296,10 @@ def subset_output(self, idcs: List[int]) -> OutcomeTransform:
batch_shape=self._batch_shape,
min_stdv=self._min_stdv,
)
new_tf.means = self.means[..., nlzd_idcs]
new_tf.stdvs = self.stdvs[..., nlzd_idcs]
new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs]
if self.means is not None:
new_tf.means = self.means[..., nlzd_idcs]
new_tf.stdvs = self.stdvs[..., nlzd_idcs]
new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs]
if not self.training:
new_tf.eval()
return new_tf
Expand All @@ -319,6 +320,13 @@ def untransform(
- The un-standardized outcome observations.
- The un-standardized observation noise (if applicable).
"""
if self.means is None:
raise RuntimeError(
"`Standardize` transforms must be called on outcome data "
"(e.g. `transform(Y)`) before calling `untransform`, since "
"means and standard deviations need to be computed."
)

Y_utf = self.means + self.stdvs * Y
Yvar_utf = self._stdvs_sq * Yvar if Yvar is not None else None
return Y_utf, Yvar_utf
Expand All @@ -338,13 +346,20 @@ def untransform_posterior(self, posterior: Posterior) -> Posterior:
"Standardize does not yet support output selection for "
"untransform_posterior"
)
if self.means is None:
raise RuntimeError(
"`Standardize` transforms must be called on outcome data "
"(e.g. `transform(Y)`) before calling `untransform_posterior`, since "
"means and standard deviations need to be computed."
)
is_mtgp_posterior = False
if type(posterior) is GPyTorchPosterior:
is_mtgp_posterior = posterior._is_mt
if not self._m == posterior._extended_shape()[-1] and not is_mtgp_posterior:
raise RuntimeError(
"Incompatible output dimensions encountered for transform "
f"{self._m} and posterior {posterior._extended_shape()[-1]}."
"Incompatible output dimensions encountered. Transform has output "
f"dimension {self._m} and posterior has "
f"{posterior._extended_shape()[-1]}."
)

if type(posterior) is not GPyTorchPosterior:
Expand Down
41 changes: 39 additions & 2 deletions test/models/transforms/test_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,36 @@ def test_abstract_base_outcome_transform(self):
with self.assertRaises(NotImplementedError):
oct.untransform_posterior(None)

def test_standardize_raises_when_mean_not_set(self) -> None:
posterior = _get_test_posterior(
shape=torch.Size([1, 1]), device=self.device, dtype=torch.float64
)
for transform in [
Standardize(m=1),
ChainedOutcomeTransform(
chained=ChainedOutcomeTransform(stand=Standardize(m=1))
),
]:
with self.assertRaises(
RuntimeError,
msg="`Standardize` transforms must be called on outcome data "
"(e.g. `transform(Y)`) before calling `untransform_posterior`, since "
"means and standard deviations need to be computed.",
):
transform.untransform_posterior(posterior)

new_tf = transform.subset_output([0])
assert isinstance(new_tf, type(transform))

y = torch.arange(3, device=self.device, dtype=torch.float64)
with self.assertRaises(
RuntimeError,
msg="`Standardize` transforms must be called on outcome data "
"(e.g. `transform(Y)`) before calling `untransform`, since "
"means and standard deviations need to be computed.",
):
transform.untransform(y)

def test_standardize(self):
# test error on incompatible dim
tf = Standardize(m=1)
Expand Down Expand Up @@ -208,8 +238,15 @@ def test_standardize(self):

# test error on incompatible output dimension
# TODO: add a unit test for MTGP posterior once #840 goes in
tf_big = Standardize(m=4).eval()
with self.assertRaises(RuntimeError):
tf_big = Standardize(m=4)
Y = torch.arange(4, device=self.device, dtype=dtype).reshape((1, 4))
tf_big(Y)
with self.assertRaises(
RuntimeError,
msg="Incompatible output dimensions encountered. Transform has output "
f"dimension {tf._m} and posterior has "
f"{posterior._extended_shape()[-1]}.",
):
tf_big.untransform_posterior(posterior2)

# test transforming a subset of outcomes
Expand Down