Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,13 @@ def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:
mod_name = full_name.split(".")[:-1]
mod_batch_shape(new_model, mod_name, m if m > 1 else 0)

# subset outcome transform if present
try:
subset_octf = new_model.outcome_transform.subset_output(idcs=idcs)
new_model.outcome_transform = subset_octf
except AttributeError:
pass

return new_model


Expand Down
84 changes: 84 additions & 0 deletions botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ def forward(
"""
pass # pragma: no cover

def subset_output(self, idcs: List[int]) -> OutcomeTransform:
r"""Subset the transform along the output dimension.

This functionality is used tpo properly treat outcome transfomrations
in the `subset_model` functionality.

Args:
idcs: The output indices to subset the transform to.

Returns:
The current outcome transform, subset to the specified output indices.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement the "
"`subset_output` method"
)

def untransform(
self, Y: Tensor, Yvar: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down Expand Up @@ -112,6 +129,19 @@ def forward(
Y, Yvar = tf.forward(Y, Yvar)
return Y, Yvar

def subset_output(self, idcs: List[int]) -> OutcomeTransform:
r"""Subset the transform along the output dimension.

Args:
idcs: The output indices to subset the transform to.

Returns:
The current outcome transform, subset to the specified output indices.
"""
return self.__class__(
**{name: tf.subset_output(idcs=idcs) for name, tf in self.items()}
)

def untransform(
self, Y: Tensor, Yvar: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down Expand Up @@ -221,6 +251,38 @@ def forward(
Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None
return Y_tf, Yvar_tf

def subset_output(self, idcs: List[int]) -> OutcomeTransform:
r"""Subset the transform along the output dimension.

Args:
idcs: The output indices to subset the transform to.

Returns:
The current outcome transform, subset to the specified output indices.
"""
new_m = len(idcs)
if new_m > self._m:
raise RuntimeError(
"Trying to subset a transform have more outputs than "
" the original transform."
)
nlzd_idcs = normalize_indices(idcs, d=self._m)
new_outputs = None
if self._outputs is not None:
new_outputs = [i for i in self._outputs if i in nlzd_idcs]
new_tf = self.__class__(
m=new_m,
outputs=new_outputs,
batch_shape=self._batch_shape,
min_stdv=self._min_stdv,
)
new_tf.means = self.means[..., nlzd_idcs]
new_tf.stdvs = self.stdvs[..., nlzd_idcs]
new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs]
if not self.training:
new_tf.eval()
return new_tf

def untransform(
self, Y: Tensor, Yvar: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down Expand Up @@ -321,6 +383,28 @@ def __init__(self, outputs: Optional[List[int]] = None) -> None:
super().__init__()
self._outputs = outputs

def subset_output(self, idcs: List[int]) -> OutcomeTransform:
r"""Subset the transform along the output dimension.

Args:
idcs: The output indices to subset the transform to.

Returns:
The current outcome transform, subset to the specified output indices.
"""
new_outputs = None
if self._outputs is not None:
if min(self._outputs + idcs) < 0:
raise NotImplementedError(
f"Negative indexing not supported for {self.__class__.__name__} "
"when subsetting outputs and only transforming some outputs."
)
new_outputs = [i for i in self._outputs if i in idcs]
new_tf = self.__class__(outputs=new_outputs)
if not self.training:
new_tf.eval()
return new_tf

def forward(
self, Y: Tensor, Yvar: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down
7 changes: 4 additions & 3 deletions test/models/test_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,13 @@ def test_fantasize(self):
self.assertIsInstance(fm, model.__class__)

def test_subset_model(self):
for batch_shape, dtype in itertools.product(
(torch.Size(), torch.Size([2])), (torch.float, torch.double)
for batch_shape, dtype, use_octf in itertools.product(
(torch.Size(), torch.Size([2])), (torch.float, torch.double), (True, False)
):
tkwargs = {"device": self.device, "dtype": dtype}
octf = Standardize(m=2, batch_shape=batch_shape) if use_octf else None
model, model_kwargs = self._get_model_and_data(
batch_shape=batch_shape, m=2, **tkwargs
batch_shape=batch_shape, m=2, outcome_transform=octf, **tkwargs
)
subset_model = model.subset_output([0])
X = torch.rand(torch.Size(batch_shape + torch.Size([3, 1])), **tkwargs)
Expand Down
4 changes: 3 additions & 1 deletion test/models/test_gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(self, x):


class SimpleBatchedMultiOutputGPyTorchModel(BatchedMultiOutputGPyTorchModel, ExactGP):
def __init__(self, train_X, train_Y):
def __init__(self, train_X, train_Y, outcome_transform=None):
self._validate_tensor_args(train_X, train_Y)
self._set_dimensions(train_X=train_X, train_Y=train_Y)
train_X, train_Y, _ = self._transform_tensor_args(X=train_X, Y=train_Y)
Expand All @@ -61,6 +61,8 @@ def __init__(self, train_X, train_Y):
RBFKernel(batch_shape=self._aug_batch_shape),
batch_shape=self._aug_batch_shape,
)
if outcome_transform is not None:
self.outcome_transform = outcome_transform
self.to(train_X)

def forward(self, x):
Expand Down
53 changes: 48 additions & 5 deletions test/models/transforms/test_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def test_abstract_base_outcome_transform(self):
with self.assertRaises(TypeError):
OutcomeTransform()
oct = NotSoAbstractOutcomeTransform()
with self.assertRaises(NotImplementedError):
oct.subset_output(None)
with self.assertRaises(NotImplementedError):
oct.untransform(None, None)
with self.assertRaises(NotImplementedError):
Expand Down Expand Up @@ -92,6 +94,14 @@ def test_standardize(self):
torch.allclose(Y_utf, Y)
self.assertIsNone(Yvar_utf)

# subset_output
tf_subset = tf.subset_output(idcs=[0])
Y_tf_subset, Yvar_tf_subset = tf_subset(Y[..., [0]])
self.assertTrue(torch.equal(Y_tf[..., [0]], Y_tf_subset))
self.assertIsNone(Yvar_tf_subset)
with self.assertRaises(RuntimeError):
tf.subset_output(idcs=[0, 1, 2])

# with observation noise
tf = Standardize(m=m, batch_shape=batch_shape)
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
Expand Down Expand Up @@ -160,11 +170,10 @@ def test_standardize(self):

# test error on incompatible output dimension
tf_big = Standardize(m=4).eval()
with self.assertRaises(RuntimeError) as e:
with self.assertRaises(RuntimeError):
tf_big.untransform_posterior(posterior2)
self.assertTrue("Incompatible output dimensions" in str(e))

# test subset outcomes
# test transforming a subset of outcomes
for batch_shape, dtype in itertools.product(batch_shapes, dtypes):

m = 2
Expand Down Expand Up @@ -192,6 +201,14 @@ def test_standardize(self):
torch.allclose(Y_utf, Y)
self.assertIsNone(Yvar_utf)

# subset_output
tf_subset = tf.subset_output(idcs=[0])
Y_tf_subset, Yvar_tf_subset = tf_subset(Y[..., [0]])
self.assertTrue(torch.equal(Y_tf[..., [0]], Y_tf_subset))
self.assertIsNone(Yvar_tf_subset)
with self.assertRaises(RuntimeError):
tf.subset_output(idcs=[0, 1, 2])

# with observation noise
tf = Standardize(m=m, outputs=outputs, batch_shape=batch_shape)
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
Expand Down Expand Up @@ -242,6 +259,12 @@ def test_log(self):
torch.allclose(Y_utf, Y)
self.assertIsNone(Yvar_utf)

# subset_output
tf_subset = tf.subset_output(idcs=[0])
Y_tf_subset, Yvar_tf_subset = tf_subset(Y[..., [0]])
self.assertTrue(torch.equal(Y_tf[..., [0]], Y_tf_subset))
self.assertIsNone(Yvar_tf_subset)

# test error if observation noise present
tf = Log()
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
Expand Down Expand Up @@ -278,7 +301,7 @@ def test_log(self):
samples2 = p_utf.rsample(sample_shape=torch.Size([4, 2]))
self.assertEqual(samples2.shape, torch.Size([4, 2]) + shape)

# test subset outcomes
# test transforming a subset of outcomes
for batch_shape, dtype in itertools.product(batch_shapes, dtypes):

m = 2
Expand All @@ -303,6 +326,10 @@ def test_log(self):
torch.allclose(Y_utf, Y)
self.assertIsNone(Yvar_utf)

# subset_output
with self.assertRaises(NotImplementedError):
tf_subset = tf.subset_output(idcs=[0])

# with observation noise
tf = Log(outputs=outputs)
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
Expand All @@ -316,6 +343,14 @@ def test_log(self):
with self.assertRaises(NotImplementedError):
tf.untransform_posterior(None)

# test subset_output with positive on subset of outcomes (pos. index)
tf = Log(outputs=[0])
Y_tf, Yvar_tf = tf(Y, None)
tf_subset = tf.subset_output(idcs=[0])
Y_tf_subset, Yvar_tf_subset = tf_subset(Y[..., [0]], None)
self.assertTrue(torch.equal(Y_tf_subset, Y_tf[..., [0]]))
self.assertIsNone(Yvar_tf_subset)

def test_chained_outcome_transform(self):

ms = (1, 2)
Expand Down Expand Up @@ -350,6 +385,14 @@ def test_chained_outcome_transform(self):
torch.allclose(Y_utf, Y)
self.assertIsNone(Yvar_utf)

# subset_output
tf_subset = tf.subset_output(idcs=[0])
Y_tf_subset, Yvar_tf_subset = tf_subset(Y[..., [0]])
self.assertTrue(torch.equal(Y_tf[..., [0]], Y_tf_subset))
self.assertIsNone(Yvar_tf_subset)
with self.assertRaises(RuntimeError):
tf.subset_output(idcs=[0, 1, 2])

# test error if observation noise present
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
Yvar = 1e-8 + torch.rand(
Expand Down Expand Up @@ -377,7 +420,7 @@ def test_chained_outcome_transform(self):
samples2 = p_utf.rsample(sample_shape=torch.Size([4, 2]))
self.assertEqual(samples2.shape, torch.Size([4, 2]) + shape)

# test subset outcomes
# test transforming a subset of outcomes
for batch_shape, dtype in itertools.product(batch_shapes, dtypes):

m = 2
Expand Down