From 1bd3cc086082a1fb6bd81e491a67f5a555adedc5 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 30 Jun 2021 18:19:41 -0700 Subject: [PATCH] Make `variance` optional for `TransformedPosterior.mean` call (#855) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/855 Currently, `TransformedPosterior.mean` throws an error if the `_posterior` doesn't have a `variance` attribute or if the `variance` throws a `NotImplementedError`. This wraps `_posterior.variance` with a `try/except` block to support the use with posteriors that don't have a `variance`. Differential Revision: D29506115 fbshipit-source-id: 9c3188804bee7c303d870d853e95595d8784a64f --- botorch/posteriors/transformed.py | 6 +++++- test/posteriors/test_transformed.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/botorch/posteriors/transformed.py b/botorch/posteriors/transformed.py index 31fd12ef69..61f9ca18b3 100644 --- a/botorch/posteriors/transformed.py +++ b/botorch/posteriors/transformed.py @@ -68,7 +68,11 @@ def mean(self) -> Tensor: r"""The mean of the posterior as a `batch_shape x n x m`-dim Tensor.""" if self._mean_transform is None: raise NotImplementedError("No mean transform provided.") - return self._mean_transform(self._posterior.mean, self._posterior.variance) + try: + variance = self._posterior.variance + except (NotImplementedError, AttributeError): + variance = None + return self._mean_transform(self._posterior.mean, variance) @property def variance(self) -> Tensor: diff --git a/test/posteriors/test_transformed.py b/test/posteriors/test_transformed.py index eb7e067c05..74953331b3 100644 --- a/test/posteriors/test_transformed.py +++ b/test/posteriors/test_transformed.py @@ -69,3 +69,22 @@ def test_transformed_posterior(self): p_tf_2.mean with self.assertRaises(NotImplementedError): p_tf_2.variance + + # check that `mean` works even if posterior doesn't have `variance` + for error in (AttributeError, NotImplementedError): + + class DummyPosterior(object): + mean = torch.zeros(5) + + @property + def variance(self): + raise error + + post = DummyPosterior() + transformed_post = TransformedPosterior( + posterior=post, + sample_transform=None, + mean_transform=lambda x, _: x + 1, + ) + transformed_mean = transformed_post.mean + self.assertTrue(torch.allclose(transformed_mean, torch.ones(5)))