diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 404b09cb99..063aa04ca7 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -223,9 +223,9 @@ def __init__( standardization (if lower, only de-mean the data). """ super().__init__() - self.register_buffer("means", torch.zeros(*batch_shape, 1, m)) - self.register_buffer("stdvs", torch.zeros(*batch_shape, 1, m)) - self.register_buffer("_stdvs_sq", torch.zeros(*batch_shape, 1, m)) + self.register_buffer("means", None) + self.register_buffer("stdvs", None) + self.register_buffer("_stdvs_sq", None) self._outputs = normalize_indices(outputs, d=m) self._m = m self._batch_shape = batch_shape @@ -296,9 +296,10 @@ def subset_output(self, idcs: List[int]) -> OutcomeTransform: batch_shape=self._batch_shape, min_stdv=self._min_stdv, ) - new_tf.means = self.means[..., nlzd_idcs] - new_tf.stdvs = self.stdvs[..., nlzd_idcs] - new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs] + if self.means is not None: + new_tf.means = self.means[..., nlzd_idcs] + new_tf.stdvs = self.stdvs[..., nlzd_idcs] + new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs] if not self.training: new_tf.eval() return new_tf @@ -319,6 +320,13 @@ def untransform( - The un-standardized outcome observations. - The un-standardized observation noise (if applicable). """ + if self.means is None: + raise RuntimeError( + "`Standardize` transforms must be called on outcome data " + "(e.g. `transform(Y)`) before calling `untransform`, since " + "means and standard deviations need to be computed." + ) + Y_utf = self.means + self.stdvs * Y Yvar_utf = self._stdvs_sq * Yvar if Yvar is not None else None return Y_utf, Yvar_utf @@ -338,13 +346,20 @@ def untransform_posterior(self, posterior: Posterior) -> Posterior: "Standardize does not yet support output selection for " "untransform_posterior" ) + if self.means is None: + raise RuntimeError( + "`Standardize` transforms must be called on outcome data " + "(e.g. `transform(Y)`) before calling `untransform_posterior`, since " + "means and standard deviations need to be computed." + ) is_mtgp_posterior = False if type(posterior) is GPyTorchPosterior: is_mtgp_posterior = posterior._is_mt if not self._m == posterior._extended_shape()[-1] and not is_mtgp_posterior: raise RuntimeError( - "Incompatible output dimensions encountered for transform " - f"{self._m} and posterior {posterior._extended_shape()[-1]}." + "Incompatible output dimensions encountered. Transform has output " + f"dimension {self._m} and posterior has " + f"{posterior._extended_shape()[-1]}." ) if type(posterior) is not GPyTorchPosterior: diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index 1818c55837..6cf843483b 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -63,6 +63,36 @@ def test_abstract_base_outcome_transform(self): with self.assertRaises(NotImplementedError): oct.untransform_posterior(None) + def test_standardize_raises_when_mean_not_set(self) -> None: + posterior = _get_test_posterior( + shape=torch.Size([1, 1]), device=self.device, dtype=torch.float64 + ) + for transform in [ + Standardize(m=1), + ChainedOutcomeTransform( + chained=ChainedOutcomeTransform(stand=Standardize(m=1)) + ), + ]: + with self.assertRaises( + RuntimeError, + msg="`Standardize` transforms must be called on outcome data " + "(e.g. `transform(Y)`) before calling `untransform_posterior`, since " + "means and standard deviations need to be computed.", + ): + transform.untransform_posterior(posterior) + + new_tf = transform.subset_output([0]) + assert isinstance(new_tf, type(transform)) + + y = torch.arange(3, device=self.device, dtype=torch.float64) + with self.assertRaises( + RuntimeError, + msg="`Standardize` transforms must be called on outcome data " + "(e.g. `transform(Y)`) before calling `untransform`, since " + "means and standard deviations need to be computed.", + ): + transform.untransform(y) + def test_standardize(self): # test error on incompatible dim tf = Standardize(m=1) @@ -208,8 +238,15 @@ def test_standardize(self): # test error on incompatible output dimension # TODO: add a unit test for MTGP posterior once #840 goes in - tf_big = Standardize(m=4).eval() - with self.assertRaises(RuntimeError): + tf_big = Standardize(m=4) + Y = torch.arange(4, device=self.device, dtype=dtype).reshape((1, 4)) + tf_big(Y) + with self.assertRaises( + RuntimeError, + msg="Incompatible output dimensions encountered. Transform has output " + f"dimension {tf._m} and posterior has " + f"{posterior._extended_shape()[-1]}.", + ): tf_big.untransform_posterior(posterior2) # test transforming a subset of outcomes