From b38fb43d7596d3daf41db964ae42fe1d9e377839 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Fri, 19 Feb 2021 17:35:44 -0800 Subject: [PATCH 1/3] Add ability to subset outcome transforms Addresses #708 --- botorch/models/gpytorch.py | 7 +++ botorch/models/transforms/outcome.py | 69 ++++++++++++++++++++++++++ test/models/transforms/test_outcome.py | 8 +-- 3 files changed, 81 insertions(+), 3 deletions(-) diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index 4d7ce22e8b..80366af8ce 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -456,6 +456,13 @@ def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel: mod_name = full_name.split(".")[:-1] mod_batch_shape(new_model, mod_name, m if m > 1 else 0) + # subset outcome transform if present + try: + subset_octf = new_model.outcome_transform.subset_output(idcs=idcs) + new_model.outcome_transform = subset_octf + except AttributeError: + pass + return new_model diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 676f5ddf78..22a83c1aa7 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -44,6 +44,23 @@ def forward( """ pass # pragma: no cover + def subset_output(self, idcs: List[int]) -> OutcomeTransform: + r"""Subset the transform along the output dimension. + + This functionality is used tpo properly treat outcome transfomrations + in the `subset_model` functionality. + + Args: + idcs: The output indices to subset the transform to. + + Returns: + The current outcome transform, subset to the specified output indices. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not implement the " + "`subset_output` method" + ) + def untransform( self, Y: Tensor, Yvar: Optional[Tensor] = None ) -> Tuple[Tensor, Optional[Tensor]]: @@ -112,6 +129,17 @@ def forward( Y, Yvar = tf.forward(Y, Yvar) return Y, Yvar + def subset_output(self, idcs: List[int]) -> OutcomeTransform: + r"""Subset the transform along the output dimension. + + Args: + idcs: The output indices to subset the transform to. + + Returns: + The current outcome transform, subset to the specified output indices. + """ + return self.__class__(*(tf.subset_output(idcs=idcs) for tf in self.values())) + def untransform( self, Y: Tensor, Yvar: Optional[Tensor] = None ) -> Tuple[Tensor, Optional[Tensor]]: @@ -221,6 +249,33 @@ def forward( Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None return Y_tf, Yvar_tf + def subset_output(self, idcs: List[int]) -> OutcomeTransform: + r"""Subset the transform along the output dimension. + + Args: + idcs: The output indices to subset the transform to. + + Returns: + The current outcome transform, subset to the specified output indices. + """ + nlzd_idcs = normalize_indices(idcs, d=self._m) + new_m = len(nlzd_idcs) + if new_m > self._m: + raise RuntimeError( + "Trying to subset a transform have more outputs than " + " the original transform." + ) + new_tf = self.__class__( + m=new_m, + outputs=None if self._outputs is None else self._outputs[nlzd_idcs], + batch_shape=self._batch_shape, + min_stdv=self._min_stdv, + ) + new_tf.means = self.means[..., nlzd_idcs] + new_tf.stdvs = self.stdvs[..., nlzd_idcs] + new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs] + return new_tf + def untransform( self, Y: Tensor, Yvar: Optional[Tensor] = None ) -> Tuple[Tensor, Optional[Tensor]]: @@ -321,6 +376,20 @@ def __init__(self, outputs: Optional[List[int]] = None) -> None: super().__init__() self._outputs = outputs + def subset_output(self, idcs: List[int]) -> OutcomeTransform: + r"""Subset the transform along the output dimension. + + Args: + idcs: The output indices to subset the transform to. + + Returns: + The current outcome transform, subset to the specified output indices. + """ + nlzd_idcs = normalize_indices(idcs, d=self._m) + return self.__class__( + outputs=None if self._outputs is None else self._outputs[nlzd_idcs], + ) + def forward( self, Y: Tensor, Yvar: Optional[Tensor] = None ) -> Tuple[Tensor, Optional[Tensor]]: diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index e87c8f6763..7ec493ddec 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -50,6 +50,8 @@ def test_abstract_base_outcome_transform(self): with self.assertRaises(TypeError): OutcomeTransform() oct = NotSoAbstractOutcomeTransform() + with self.assertRaises(NotImplementedError): + oct.subset_output(None) with self.assertRaises(NotImplementedError): oct.untransform(None, None) with self.assertRaises(NotImplementedError): @@ -164,7 +166,7 @@ def test_standardize(self): tf_big.untransform_posterior(posterior2) self.assertTrue("Incompatible output dimensions" in str(e)) - # test subset outcomes + # test transforming a subset of outcomes for batch_shape, dtype in itertools.product(batch_shapes, dtypes): m = 2 @@ -278,7 +280,7 @@ def test_log(self): samples2 = p_utf.rsample(sample_shape=torch.Size([4, 2])) self.assertEqual(samples2.shape, torch.Size([4, 2]) + shape) - # test subset outcomes + # test transforming a subset of outcomes for batch_shape, dtype in itertools.product(batch_shapes, dtypes): m = 2 @@ -377,7 +379,7 @@ def test_chained_outcome_transform(self): samples2 = p_utf.rsample(sample_shape=torch.Size([4, 2])) self.assertEqual(samples2.shape, torch.Size([4, 2]) + shape) - # test subset outcomes + # test transforming a subset of outcomes for batch_shape, dtype in itertools.product(batch_shapes, dtypes): m = 2 From d7083f18bcd6192b49d1d018c0bd9563f6e2b495 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Fri, 19 Feb 2021 18:35:52 -0800 Subject: [PATCH 2/3] Unit tests --- botorch/models/transforms/outcome.py | 31 ++++++++++++----- test/models/test_gp_regression.py | 7 ++-- test/models/test_gpytorch.py | 4 ++- test/models/transforms/test_outcome.py | 46 ++++++++++++++++++++++++-- 4 files changed, 74 insertions(+), 14 deletions(-) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 22a83c1aa7..69f978d9d8 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -138,7 +138,9 @@ def subset_output(self, idcs: List[int]) -> OutcomeTransform: Returns: The current outcome transform, subset to the specified output indices. """ - return self.__class__(*(tf.subset_output(idcs=idcs) for tf in self.values())) + return self.__class__( + **{name: tf.subset_output(idcs=idcs) for name, tf in self.items()} + ) def untransform( self, Y: Tensor, Yvar: Optional[Tensor] = None @@ -258,22 +260,27 @@ def subset_output(self, idcs: List[int]) -> OutcomeTransform: Returns: The current outcome transform, subset to the specified output indices. """ - nlzd_idcs = normalize_indices(idcs, d=self._m) - new_m = len(nlzd_idcs) + new_m = len(idcs) if new_m > self._m: raise RuntimeError( "Trying to subset a transform have more outputs than " " the original transform." ) + nlzd_idcs = normalize_indices(idcs, d=self._m) + new_outputs = None + if self._outputs is not None: + new_outputs = [i for i in self._outputs if i in nlzd_idcs] new_tf = self.__class__( m=new_m, - outputs=None if self._outputs is None else self._outputs[nlzd_idcs], + outputs=new_outputs, batch_shape=self._batch_shape, min_stdv=self._min_stdv, ) new_tf.means = self.means[..., nlzd_idcs] new_tf.stdvs = self.stdvs[..., nlzd_idcs] new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs] + if not self.training: + new_tf.eval() return new_tf def untransform( @@ -385,10 +392,18 @@ def subset_output(self, idcs: List[int]) -> OutcomeTransform: Returns: The current outcome transform, subset to the specified output indices. """ - nlzd_idcs = normalize_indices(idcs, d=self._m) - return self.__class__( - outputs=None if self._outputs is None else self._outputs[nlzd_idcs], - ) + new_outputs = None + if self._outputs is not None: + if min(self._outputs + idcs) < 0: + raise NotImplementedError( + f"Negative indexing not supported for {self.__class__.__name__} " + "when subsetting outputs and only transforming some outputs." + ) + new_outputs = [i for i in self._outputs if i in idcs] + new_tf = self.__class__(outputs=new_outputs) + if not self.training: + new_tf.eval() + return new_tf def forward( self, Y: Tensor, Yvar: Optional[Tensor] = None diff --git a/test/models/test_gp_regression.py b/test/models/test_gp_regression.py index 72d445ca8a..073156f229 100644 --- a/test/models/test_gp_regression.py +++ b/test/models/test_gp_regression.py @@ -278,12 +278,13 @@ def test_fantasize(self): self.assertIsInstance(fm, model.__class__) def test_subset_model(self): - for batch_shape, dtype in itertools.product( - (torch.Size(), torch.Size([2])), (torch.float, torch.double) + for batch_shape, dtype, use_octf in itertools.product( + (torch.Size(), torch.Size([2])), (torch.float, torch.double), (True, False) ): tkwargs = {"device": self.device, "dtype": dtype} + octf = Standardize(m=2, batch_shape=batch_shape) if use_octf else None model, model_kwargs = self._get_model_and_data( - batch_shape=batch_shape, m=2, **tkwargs + batch_shape=batch_shape, m=2, outcome_transform=octf, **tkwargs ) subset_model = model.subset_output([0]) X = torch.rand(torch.Size(batch_shape + torch.Size([3, 1])), **tkwargs) diff --git a/test/models/test_gpytorch.py b/test/models/test_gpytorch.py index 11766e368f..e51362aa14 100644 --- a/test/models/test_gpytorch.py +++ b/test/models/test_gpytorch.py @@ -50,7 +50,7 @@ def forward(self, x): class SimpleBatchedMultiOutputGPyTorchModel(BatchedMultiOutputGPyTorchModel, ExactGP): - def __init__(self, train_X, train_Y): + def __init__(self, train_X, train_Y, outcome_transform=None): self._validate_tensor_args(train_X, train_Y) self._set_dimensions(train_X=train_X, train_Y=train_Y) train_X, train_Y, _ = self._transform_tensor_args(X=train_X, Y=train_Y) @@ -61,6 +61,8 @@ def __init__(self, train_X, train_Y): RBFKernel(batch_shape=self._aug_batch_shape), batch_shape=self._aug_batch_shape, ) + if outcome_transform is not None: + self.outcome_transform = outcome_transform self.to(train_X) def forward(self, x): diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index 7ec493ddec..5e42ed1e01 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -94,6 +94,14 @@ def test_standardize(self): torch.allclose(Y_utf, Y) self.assertIsNone(Yvar_utf) + # subset_output + tf_subset = tf.subset_output(idcs=[0]) + Y_tf_subset, Yvar_tf_subset = tf_subset(Y[..., [0]]) + self.assertTrue(torch.equal(Y_tf[..., [0]], Y_tf_subset)) + self.assertIsNone(Yvar_tf_subset) + with self.assertRaises(RuntimeError): + tf.subset_output(idcs=[0, 1, 2]) + # with observation noise tf = Standardize(m=m, batch_shape=batch_shape) Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype) @@ -162,9 +170,8 @@ def test_standardize(self): # test error on incompatible output dimension tf_big = Standardize(m=4).eval() - with self.assertRaises(RuntimeError) as e: + with self.assertRaises(RuntimeError): tf_big.untransform_posterior(posterior2) - self.assertTrue("Incompatible output dimensions" in str(e)) # test transforming a subset of outcomes for batch_shape, dtype in itertools.product(batch_shapes, dtypes): @@ -194,6 +201,14 @@ def test_standardize(self): torch.allclose(Y_utf, Y) self.assertIsNone(Yvar_utf) + # subset_output + tf_subset = tf.subset_output(idcs=[0]) + Y_tf_subset, Yvar_tf_subset = tf_subset(Y[..., [0]]) + self.assertTrue(torch.equal(Y_tf[..., [0]], Y_tf_subset)) + self.assertIsNone(Yvar_tf_subset) + with self.assertRaises(RuntimeError): + tf.subset_output(idcs=[0, 1, 2]) + # with observation noise tf = Standardize(m=m, outputs=outputs, batch_shape=batch_shape) Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype) @@ -244,6 +259,12 @@ def test_log(self): torch.allclose(Y_utf, Y) self.assertIsNone(Yvar_utf) + # subset_output + tf_subset = tf.subset_output(idcs=[0]) + Y_tf_subset, Yvar_tf_subset = tf_subset(Y[..., [0]]) + self.assertTrue(torch.equal(Y_tf[..., [0]], Y_tf_subset)) + self.assertIsNone(Yvar_tf_subset) + # test error if observation noise present tf = Log() Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype) @@ -305,6 +326,10 @@ def test_log(self): torch.allclose(Y_utf, Y) self.assertIsNone(Yvar_utf) + # subset_output + with self.assertRaises(NotImplementedError): + tf_subset = tf.subset_output(idcs=[0]) + # with observation noise tf = Log(outputs=outputs) Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype) @@ -318,6 +343,15 @@ def test_log(self): with self.assertRaises(NotImplementedError): tf.untransform_posterior(None) + # test subset_output with positive on subset of outcomes (pos. index) + tf = Log(outputs=[0]) + Y_tf, Yvar_tf = tf(Y, None) + tf_subset = tf.subset_output(idcs=[0]) + Y_tf_subset, Yvar_tf_subset = tf_subset(Y[..., [0]], None) + self.assertTrue(torch.equal(Y_tf_subset, Y_tf[..., [0]])) + self.assertIsNone(Yvar_tf_subset) + + def test_chained_outcome_transform(self): ms = (1, 2) @@ -352,6 +386,14 @@ def test_chained_outcome_transform(self): torch.allclose(Y_utf, Y) self.assertIsNone(Yvar_utf) + # subset_output + tf_subset = tf.subset_output(idcs=[0]) + Y_tf_subset, Yvar_tf_subset = tf_subset(Y[..., [0]]) + self.assertTrue(torch.equal(Y_tf[..., [0]], Y_tf_subset)) + self.assertIsNone(Yvar_tf_subset) + with self.assertRaises(RuntimeError): + tf.subset_output(idcs=[0, 1, 2]) + # test error if observation noise present Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype) Yvar = 1e-8 + torch.rand( From 9c5e832dc235f6412a07e8823ae37b0d94e9368a Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Sat, 20 Feb 2021 15:25:09 -0800 Subject: [PATCH 3/3] Fix black --- test/models/transforms/test_outcome.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index 5e42ed1e01..9eac29df0c 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -351,7 +351,6 @@ def test_log(self): self.assertTrue(torch.equal(Y_tf_subset, Y_tf[..., [0]])) self.assertIsNone(Yvar_tf_subset) - def test_chained_outcome_transform(self): ms = (1, 2)