From 38d8da75f54c6da8ee92aa71abddf4da85722cc1 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Thu, 15 Dec 2022 11:40:41 -0800 Subject: [PATCH] Raise an error if `Standardize` outcome transform's `untransform_posterior` is used without first calling the transform on outcomes (#1569) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1569 Say someone does the following bad thing: > tf = Standardize(m=1) > tf.untransform_posterior(posterior) Old behavior: - means and standard deviations are initialized in `Standarize.__init__` with a tensor of zeros, with 'device' not set - With a posterior on the CPU, the posterior would be nonsensically untransformed with means and standard deviations of zero - With a posterior on the GPU, this would cause an error about tensors on different devices, e.g. https://www.internalfb.com/diff/D42019721?dst_version_fbid=1618282175279712&selected_signal=dGVzdF9pZDo1NjI5NTAwMjcwNTY2NTk%3D&selected_signal_verification_phase=1 New behavior: - means and standard deviations are initialized as None - An informative error is raised Reviewed By: saitcakmak, Balandat Differential Revision: D42039100 fbshipit-source-id: c479a326ad615a42cf0553ecc675aaf10f320925 --- botorch/models/transforms/outcome.py | 31 ++++++++++++++----- test/models/transforms/test_outcome.py | 41 ++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 10 deletions(-) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 404b09cb99..063aa04ca7 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -223,9 +223,9 @@ def __init__( standardization (if lower, only de-mean the data). """ super().__init__() - self.register_buffer("means", torch.zeros(*batch_shape, 1, m)) - self.register_buffer("stdvs", torch.zeros(*batch_shape, 1, m)) - self.register_buffer("_stdvs_sq", torch.zeros(*batch_shape, 1, m)) + self.register_buffer("means", None) + self.register_buffer("stdvs", None) + self.register_buffer("_stdvs_sq", None) self._outputs = normalize_indices(outputs, d=m) self._m = m self._batch_shape = batch_shape @@ -296,9 +296,10 @@ def subset_output(self, idcs: List[int]) -> OutcomeTransform: batch_shape=self._batch_shape, min_stdv=self._min_stdv, ) - new_tf.means = self.means[..., nlzd_idcs] - new_tf.stdvs = self.stdvs[..., nlzd_idcs] - new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs] + if self.means is not None: + new_tf.means = self.means[..., nlzd_idcs] + new_tf.stdvs = self.stdvs[..., nlzd_idcs] + new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs] if not self.training: new_tf.eval() return new_tf @@ -319,6 +320,13 @@ def untransform( - The un-standardized outcome observations. - The un-standardized observation noise (if applicable). """ + if self.means is None: + raise RuntimeError( + "`Standardize` transforms must be called on outcome data " + "(e.g. `transform(Y)`) before calling `untransform`, since " + "means and standard deviations need to be computed." + ) + Y_utf = self.means + self.stdvs * Y Yvar_utf = self._stdvs_sq * Yvar if Yvar is not None else None return Y_utf, Yvar_utf @@ -338,13 +346,20 @@ def untransform_posterior(self, posterior: Posterior) -> Posterior: "Standardize does not yet support output selection for " "untransform_posterior" ) + if self.means is None: + raise RuntimeError( + "`Standardize` transforms must be called on outcome data " + "(e.g. `transform(Y)`) before calling `untransform_posterior`, since " + "means and standard deviations need to be computed." + ) is_mtgp_posterior = False if type(posterior) is GPyTorchPosterior: is_mtgp_posterior = posterior._is_mt if not self._m == posterior._extended_shape()[-1] and not is_mtgp_posterior: raise RuntimeError( - "Incompatible output dimensions encountered for transform " - f"{self._m} and posterior {posterior._extended_shape()[-1]}." + "Incompatible output dimensions encountered. Transform has output " + f"dimension {self._m} and posterior has " + f"{posterior._extended_shape()[-1]}." ) if type(posterior) is not GPyTorchPosterior: diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index 1818c55837..6cf843483b 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -63,6 +63,36 @@ def test_abstract_base_outcome_transform(self): with self.assertRaises(NotImplementedError): oct.untransform_posterior(None) + def test_standardize_raises_when_mean_not_set(self) -> None: + posterior = _get_test_posterior( + shape=torch.Size([1, 1]), device=self.device, dtype=torch.float64 + ) + for transform in [ + Standardize(m=1), + ChainedOutcomeTransform( + chained=ChainedOutcomeTransform(stand=Standardize(m=1)) + ), + ]: + with self.assertRaises( + RuntimeError, + msg="`Standardize` transforms must be called on outcome data " + "(e.g. `transform(Y)`) before calling `untransform_posterior`, since " + "means and standard deviations need to be computed.", + ): + transform.untransform_posterior(posterior) + + new_tf = transform.subset_output([0]) + assert isinstance(new_tf, type(transform)) + + y = torch.arange(3, device=self.device, dtype=torch.float64) + with self.assertRaises( + RuntimeError, + msg="`Standardize` transforms must be called on outcome data " + "(e.g. `transform(Y)`) before calling `untransform`, since " + "means and standard deviations need to be computed.", + ): + transform.untransform(y) + def test_standardize(self): # test error on incompatible dim tf = Standardize(m=1) @@ -208,8 +238,15 @@ def test_standardize(self): # test error on incompatible output dimension # TODO: add a unit test for MTGP posterior once #840 goes in - tf_big = Standardize(m=4).eval() - with self.assertRaises(RuntimeError): + tf_big = Standardize(m=4) + Y = torch.arange(4, device=self.device, dtype=dtype).reshape((1, 4)) + tf_big(Y) + with self.assertRaises( + RuntimeError, + msg="Incompatible output dimensions encountered. Transform has output " + f"dimension {tf._m} and posterior has " + f"{posterior._extended_shape()[-1]}.", + ): tf_big.untransform_posterior(posterior2) # test transforming a subset of outcomes