From 4330fab6550d782ce12357572dbe8addd4830233 Mon Sep 17 00:00:00 2001 From: James Wilson Date: Thu, 15 Dec 2022 10:40:28 -0800 Subject: [PATCH] Removing `_fit_multioutput_independent` and `allclose_mll`. (#1570) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1570 The original bug that motivated fitting certain types of multioutput models independently has been resolved, so `_fit_multioutput_independent` is no longer needed. Reviewed By: Balandat Differential Revision: D42071480 fbshipit-source-id: f1f6d3f2212344243ba5ad4c3bb8d9e5f7007aa6 --- botorch/fit.py | 93 +-------- botorch/optim/utils/__init__.py | 2 - botorch/optim/utils/model_utils.py | 44 ---- test/models/test_gp_regression_mixed.py | 5 + test/optim/utils/test_model_utils.py | 31 --- test/test_fit.py | 254 +++--------------------- 6 files changed, 36 insertions(+), 393 deletions(-) diff --git a/botorch/fit.py b/botorch/fit.py index bcbd75e060..1e67c89c86 100644 --- a/botorch/fit.py +++ b/botorch/fit.py @@ -16,42 +16,34 @@ from warnings import catch_warnings, simplefilter, warn, warn_explicit, WarningMessage from botorch.exceptions.errors import ModelFittingError, UnsupportedError -from botorch.exceptions.warnings import BotorchWarning, OptimizationWarning +from botorch.exceptions.warnings import OptimizationWarning from botorch.models.approximate_gp import ApproximateGPyTorchModel -from botorch.models.converter import batched_to_model_list, model_list_to_batched from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP -from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel from botorch.models.model_list_gp_regression import ModelListGP from botorch.optim.closures import get_loss_closure_with_grads from botorch.optim.core import _LBFGSB_MAXITER_MAXFUN_REGEX from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch from botorch.optim.utils import ( _warning_handler_template, - allclose_mll, get_parameters, sample_all_priors, ) from botorch.settings import debug from botorch.utils.context_managers import ( - del_attribute_ctx, module_rollback_ctx, parameter_rollback_ctx, requires_grad_ctx, TensorCheckpoint, ) -from botorch.utils.dispatcher import ( - Dispatcher, - MDNotImplementedError, - type_bypassing_encoder, -) +from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder from gpytorch.likelihoods import Likelihood from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood from linear_operator.utils.errors import NotPSDError from pyro.infer.mcmc import MCMC, NUTS -from torch import device, mean, Tensor +from torch import device, Tensor from torch.nn import Parameter from torch.utils.data import DataLoader @@ -299,83 +291,6 @@ def _fit_list( return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll -@FitGPyTorchMLL.register( - (MarginalLogLikelihood, _ApproximateMarginalLogLikelihood), - object, - BatchedMultiOutputGPyTorchModel, -) -def _fit_multioutput_independent( - mll: MarginalLogLikelihood, - _: Type[Likelihood], - __: Type[BatchedMultiOutputGPyTorchModel], - *, - closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None, - sequential: bool = True, - **kwargs: Any, -) -> MarginalLogLikelihood: - r"""Fitting routine for multioutput Gaussian processes. - - Args: - closure: Forward-backward closure for obtaining objective values and gradients. - Responsible for setting parameters' `grad` attributes. If no closure is - provided, one will be obtained by calling `get_loss_closure_with_grads`. - sequential: Boolean specifying whether or not to an attempt should be made to - fit the model as a collection of independent GPs. Only relevant for - certain types of GPs with independent outputs, see `batched_to_model_list`. - **kwargs: Passed to the next method unaltered. - - Returns: - The `mll` instance. If fitting succeeded, then `mll` will be in evaluation mode, - i.e. `mll.training == False`. Otherwise, `mll` will be in training mode. - """ - if ( # incompatible models - not sequential - or closure is not None - or mll.model.num_outputs == 1 - or mll.likelihood is not getattr(mll.model, "likelihood", None) - ): - raise MDNotImplementedError # defer to generic - - # TODO: Unpacking of OutcomeTransforms not yet supported. Targets are often - # pre-transformed in __init__, so try fitting with outcome_transform hidden - mll.train() - with del_attribute_ctx(mll.model, "outcome_transform"): - try: - # Attempt to unpack batched model into a list of independent submodels - unpacked_model = batched_to_model_list(mll.model) - unpacked_mll = SumMarginalLogLikelihood( # avg. over MLLs internally - unpacked_model.likelihood, unpacked_model - ) - if not allclose_mll(a=mll, b=unpacked_mll, transform_a=mean): - raise RuntimeError( # validate model unpacking - "Training loss of unpacked model differs from that of the original." - ) - - # Fit submodels independently - unpacked_mll = fit_gpytorch_mll(unpacked_mll, **kwargs) - - # Repackage submodels and copy over state_dict - repacked_model = model_list_to_batched(unpacked_mll.model.train()) - repacked_mll = type(mll)(repacked_model.likelihood, repacked_model) - with module_rollback_ctx(mll, device=device("cpu")) as ckpt: - mll.load_state_dict(repacked_mll.state_dict()) - if not allclose_mll(a=mll, b=repacked_mll): - raise RuntimeError( # validate model repacking - "Training loss of repacked model differs from that of the " - "original." - ) - ckpt.clear() # do not rollback when exiting - return mll.eval() # DONE! - - except (AttributeError, RuntimeError, UnsupportedError) as err: - msg = f"Failed to independently fit submodels with exception: {err}" - warn( - f"{msg.rstrip('.')}. Deferring to generic dispatch...", - BotorchWarning, - ) - raise MDNotImplementedError - - @FitGPyTorchMLL.register(_ApproximateMarginalLogLikelihood, object, object) def _fit_fallback_approximate( mll: _ApproximateMarginalLogLikelihood, @@ -385,7 +300,7 @@ def _fit_fallback_approximate( closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None, data_loader: Optional[DataLoader] = None, optimizer: Optional[Callable] = None, - full_batch_limit: int = 1024, # TODO: To be determined. + full_batch_limit: int = 1024, **kwargs: Any, ) -> _ApproximateMarginalLogLikelihood: r"""Fallback method for fitting approximate Gaussian processes. diff --git a/botorch/optim/utils/__init__.py b/botorch/optim/utils/__init__.py index ddd1e9d72f..24fe6b5c27 100644 --- a/botorch/optim/utils/__init__.py +++ b/botorch/optim/utils/__init__.py @@ -18,7 +18,6 @@ ) from botorch.optim.utils.model_utils import ( _get_extra_mll_args, - allclose_mll, get_data_loader, get_name_filter, get_parameters, @@ -38,7 +37,6 @@ "_get_extra_mll_args", "_handle_numerical_errors", "_warning_handler_template", - "allclose_mll", "as_ndarray", "columnwise_clamp", "DEFAULT", diff --git a/botorch/optim/utils/model_utils.py b/botorch/optim/utils/model_utils.py index 70a17b6b1b..e596685767 100644 --- a/botorch/optim/utils/model_utils.py +++ b/botorch/optim/utils/model_utils.py @@ -214,47 +214,3 @@ def sample_all_priors(model: GPyTorchModel, max_retries: int = 100) -> None: ) else: raise e - - -def allclose_mll( - a: MarginalLogLikelihood, - b: MarginalLogLikelihood, - transform_a: Optional[Callable[[Tensor], Tensor]] = None, - transform_b: Optional[Callable[[Tensor], Tensor]] = None, - rtol: float = 1e-05, - atol: float = 1e-08, -) -> bool: - r"""Convenience method for testing whether the log likelihoods produced by different - MarginalLogLikelihood instances, when evaluated on their respective models' training - sets, are allclose. - - Args: - a: A MarginalLogLikelihood instance. - b: A second MarginalLogLikelihood instance. - transform_a: Optional callable used to post-transform log likelihoods under `a`. - transform_b: Optional callable used to post-transform log likelihoods under `b`. - rtol: Relative tolerance. - atol: Absolute tolerance. - - Returns: - Boolean result of the allclose test. - """ - warn("`allclose_mll` is marked for deprecation.", DeprecationWarning) - - values_a = a( - a.model(*a.model.train_inputs), - a.model.train_targets, - *_get_extra_mll_args(a), - ) - if transform_a: - values_a = transform_a(values_a) - - values_b = b( - b.model(*b.model.train_inputs), - b.model.train_targets, - *_get_extra_mll_args(b), - ) - if transform_b: - values_b = transform_b(values_b) - - return values_a.allclose(values_b, rtol=rtol, atol=atol) diff --git a/test/models/test_gp_regression_mixed.py b/test/models/test_gp_regression_mixed.py index 2044fdf331..00410af6ff 100644 --- a/test/models/test_gp_regression_mixed.py +++ b/test/models/test_gp_regression_mixed.py @@ -11,6 +11,7 @@ import torch from botorch.exceptions.warnings import OptimizationWarning from botorch.fit import fit_gpytorch_mll +from botorch.models.converter import batched_to_model_list from botorch.models.gp_regression_mixed import MixedSingleTaskGP from botorch.models.kernels.categorical import CategoricalKernel from botorch.models.transforms import Normalize @@ -156,6 +157,10 @@ def test_gp(self): pvar_exp = _get_pvar_expected(posterior, model, X, m) self.assertTrue(torch.allclose(pvar, pvar_exp, rtol=1e-4, atol=1e-5)) + # test that model converter throws an exception + with self.assertRaisesRegex(NotImplementedError, "not supported"): + batched_to_model_list(model) + def test_condition_on_observations(self): d = 3 for batch_shape, m, ncat, dtype in itertools.product( diff --git a/test/optim/utils/test_model_utils.py b/test/optim/utils/test_model_utils.py index 87d88ebadf..3e56207794 100644 --- a/test/optim/utils/test_model_utils.py +++ b/test/optim/utils/test_model_utils.py @@ -6,11 +6,9 @@ from __future__ import annotations -import math import re import warnings from copy import deepcopy -from itertools import product from string import ascii_lowercase from unittest.mock import MagicMock, patch @@ -19,7 +17,6 @@ from botorch.models import ModelListGP, SingleTaskGP from botorch.optim.utils import ( _get_extra_mll_args, - allclose_mll, get_data_loader, get_name_filter, get_parameters, @@ -253,31 +250,3 @@ def test_sample_all_priors(self): original_state_dict = dict(deepcopy(mll.model.state_dict())) with self.assertRaises(RuntimeError): sample_all_priors(model) - - -class TestAllcloseMLL(BotorchTestCase): - def setUp(self): - with torch.random.fork_rng(): - torch.manual_seed(0) - train_X = torch.linspace(0, 1, 10).unsqueeze(-1) - train_Y = torch.sin((2 * math.pi) * train_X) - train_Y = train_Y + 0.1 * torch.randn_like(train_Y) - - self.mlls = [] - for nu in (1.5, 2.5): - model = SingleTaskGP(train_X=train_X, train_Y=train_Y) - model.covar_module.base_kernel.nu = nu - self.mlls.append(ExactMarginalLogLikelihood(model.likelihood, model)) - - def test_allclose_mll(self): - self.assertTrue(allclose_mll(a=self.mlls[0], b=self.mlls[0])) - for transform_a, transform_b in product( - *(2 * [(None, lambda vals: torch.zeros_like(vals))]) - ): - out = allclose_mll( - a=self.mlls[0], - b=self.mlls[1], - transform_a=transform_a, - transform_b=transform_b, - ) - self.assertEqual(out, transform_a is not None and transform_b is not None) diff --git a/test/test_fit.py b/test/test_fit.py index e609f53b64..e7fe8742d3 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -5,9 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -import warnings from contextlib import nullcontext -from copy import deepcopy from itertools import filterfalse, product from typing import Callable, Iterable, Optional from unittest.mock import MagicMock, patch @@ -16,29 +14,20 @@ import torch from botorch import fit from botorch.exceptions.errors import ModelFittingError, UnsupportedError -from botorch.exceptions.warnings import BotorchWarning, OptimizationWarning -from botorch.fit import fit_gpytorch_mll -from botorch.models import ( - FixedNoiseGP, - HeteroskedasticSingleTaskGP, - SingleTaskGP, - SingleTaskVariationalGP, -) -from botorch.models.converter import batched_to_model_list +from botorch.exceptions.warnings import OptimizationWarning +from botorch.models import SingleTaskGP, SingleTaskVariationalGP from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize from botorch.optim.closures import get_loss_closure_with_grads from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch -from botorch.optim.utils import allclose_mll, get_data_loader +from botorch.optim.utils import get_data_loader from botorch.settings import debug from botorch.utils.context_managers import ( - del_attribute_ctx, module_rollback_ctx, requires_grad_ctx, TensorCheckpoint, ) -from botorch.utils.dispatcher import MDNotImplementedError from botorch.utils.testing import BotorchTestCase from gpytorch.kernels import MaternKernel from gpytorch.mlls import ExactMarginalLogLikelihood, VariationalELBO @@ -95,6 +84,30 @@ def setUp(self): ) self.mll = ExactMarginalLogLikelihood(model.likelihood, model) + def test_fit_gpytorch_mll(self): + # Test that `optimizer` is only passed when non-None + with patch.object(fit, "FitGPyTorchMLL") as mock_dispatcher: + fit.fit_gpytorch_mll(self.mll, optimizer=None) + mock_dispatcher.assert_called_once_with( + self.mll, + type(self.mll.likelihood), + type(self.mll.model), + closure=None, + closure_kwargs=None, + optimizer_kwargs=None, + ) + + fit.fit_gpytorch_mll(self.mll, optimizer="foo") + mock_dispatcher.assert_called_with( + self.mll, + type(self.mll.likelihood), + type(self.mll.model), + closure=None, + closure_kwargs=None, + optimizer="foo", + optimizer_kwargs=None, + ) + def test_fit_gyptorch_model(self): r"""Test support for legacy API""" @@ -460,216 +473,3 @@ def test_main(self): closure=self.closure, data_loader=self.data_loader, ) - - -class TestFitMultioutputIndependent(BotorchTestCase): - def setUp(self): - with torch.random.fork_rng(): - torch.manual_seed(0) - train_X = torch.linspace(0, 1, 10).unsqueeze(-1) - train_F = torch.sin(2 * math.pi * train_X) - - self.mlls = {} - self.checkpoints = {} - self.converted_mlls = {} - for model_type, output_dim in product( - [SingleTaskGP, FixedNoiseGP, HeteroskedasticSingleTaskGP], [1, 2] - ): - train_Y = train_F.repeat(1, output_dim) - train_Y = train_Y + 0.1 * torch.randn_like(train_Y) - model = model_type( - train_X=train_X, - train_Y=train_Y, - input_transform=Normalize(d=1), - outcome_transform=Standardize(m=output_dim), - **( - {} - if model_type is SingleTaskGP - else {"train_Yvar": torch.full_like(train_Y, 0.1)} - ), - ) - self.assertIsInstance(model.covar_module.base_kernel, MaternKernel) - model.covar_module.base_kernel.nu = 2.5 - - mll = ExactMarginalLogLikelihood(model.likelihood, model) - for dtype in (torch.float32, torch.float64): - key = model_type, output_dim - self.mlls[key] = mll.to(dtype=dtype).train() - self.checkpoints[key] = { - k: TensorCheckpoint( - values=v.detach().clone(), device=v.device, dtype=v.dtype - ) - for k, v in mll.state_dict().items() - } - if output_dim > 1: - with del_attribute_ctx(mll.model, "outcome_transform"): - _mll = self.converted_mlls[key] = deepcopy(mll) - _mll.model = deepcopy(mll.model) - _mll.model.covar_module.base_kernel.nu = 1.5 # break on purpose - - def test_main(self): - for case, mll in self.mlls.items(): - self._test_main(mll, self.checkpoints[case]) - - def test_unpack(self): - for case, mll in self.mlls.items(): - if case in self.converted_mlls: - self._test_unpack( - mll, self.checkpoints[case], self.converted_mlls[case] - ) - - def test_repack(self): - for case, mll in self.mlls.items(): - if case in self.converted_mlls: - self._test_repack( - mll, self.checkpoints[case], self.converted_mlls[case] - ) - - def test_exceptions(self): - for case, mll in self.mlls.items(): - if case in self.converted_mlls: - self._test_exceptions( - mll, self.checkpoints[case], self.converted_mlls[case] - ) - - def _test_main(self, mll, ckpt): - # Test that ineligible models error out approriately, then short-circuit - if mll.model.num_outputs == 1 or mll.likelihood is not getattr( - mll.model, "likelihood", None - ): - with self.assertRaises(MDNotImplementedError): - fit._fit_multioutput_independent(mll, None, None) - - return - - optimizer = MockOptimizer() - with module_rollback_ctx(mll, checkpoint=ckpt), debug( - True - ), warnings.catch_warnings(record=True) as ws: - warnings.simplefilter("always", BotorchWarning) - warnings.simplefilter("ignore", DeprecationWarning) - try: - fit._fit_multioutput_independent( - mll, - None, - None, - optimizer=optimizer, - warning_handler=lambda w: True, # mark all warnings as resolved - max_attempts=1, - ) - except Exception: - pass # exception handling tested separately - else: - self.assertEqual(0, len(ws)) - self.assertFalse(mll.training) - self.assertEqual(optimizer.call_count, mll.model.num_outputs) - self.assertTrue( - all( - v.equal(ckpt[k].values) != v.requires_grad - for k, v in mll.named_parameters() - ) - ) - - def _test_unpack(self, mll, ckpt, bad_mll): - # Test that model unpacking fails gracefully - optimizer = MockOptimizer() - converter = MagicMock(return_value=bad_mll.model) - with patch.multiple( - fit, - batched_to_model_list=converter, - SumMarginalLogLikelihood=MagicMock(return_value=bad_mll), - ): - with catch_warnings(record=True) as ws, debug(True): - with self.assertRaises(MDNotImplementedError): - fit._fit_multioutput_independent( - mll, None, None, optimizer=optimizer, max_attempts=1 - ) - - self.assertEqual(converter.call_count, 1) - self.assertEqual(optimizer.call_count, 0) # should fail beforehand - self.assertTrue( - all(v.equal(ckpt[k].values) for k, v in mll.state_dict().items()) - ) - self.assertTrue(any("unpacked model differs" in str(w.message) for w in ws)) - - def _test_repack(self, mll, ckpt, bad_mll): - # Test that model repacking fails gracefully - with patch.multiple( - fit, # skips unpacking + fitting, tests bad model repacking - allclose_mll=lambda a, b, **kwargs: allclose_mll(a, b), - batched_to_model_list=lambda model: model, - SumMarginalLogLikelihood=MagicMock(return_value=mll), - fit_gpytorch_mll=lambda mll, **kwargs: mll, - model_list_to_batched=MagicMock(return_value=bad_mll.model), - ): - with catch_warnings(record=True) as ws, debug(True): - with self.assertRaises(MDNotImplementedError): - fit._fit_multioutput_independent(mll, None, None, max_attempts=1) - - self.assertTrue( - all(v.equal(ckpt[k].values) for k, v in mll.state_dict().items()) - ) - self.assertTrue(any("repacked model differs" in str(w.message) for w in ws)) - - def _test_exceptions(self, mll, ckpt, bad_mll): - for exception in ( - AttributeError("test_attribute_error"), - RuntimeError("test_runtime_error"), - UnsupportedError("test_unsupported_error"), - ): - converter = MagicMock(return_value=bad_mll.model) - with catch_warnings(record=True) as ws, debug(True): - - def mock_fit_gpytorch_mll(*args, **kwargs): - raise exception - - try: - with patch.multiple( - fit, # skip unpacking, throw exception from fit_gpytorch_mll - allclose_mll=lambda a, b, **kwargs: True, - batched_to_model_list=converter, - model_list_to_batched=converter, # should not get called - fit_gpytorch_mll=mock_fit_gpytorch_mll, - SumMarginalLogLikelihood=type(mll), - module_rollback_ctx=lambda *args, **kwargs: nullcontext({}), - ): - fit._fit_multioutput_independent(mll, None, None) - except MDNotImplementedError: - pass - - self.assertEqual(converter.call_count, 1) - self.assertTrue(any(str(exception) in str(w.message) for w in ws)) - - -class TestFitOther(BotorchTestCase): - def helper_fit_with_converter(self, dtype) -> None: - # Check that sequential optimization using converter does not - # break input transforms. - tkwargs = {"device": self.device, "dtype": dtype} - # Set the seed to a number that doesn't generate numerical - # issues (no NaNs) - torch.manual_seed(0) - X = torch.rand(5, 2, **tkwargs) * 10 - Y = X**2 - intf = Normalize(2) - model = SingleTaskGP(X, Y, input_transform=intf) - mll = ExactMarginalLogLikelihood(model.likelihood, model) - with patch( - f"{fit_gpytorch_mll.__module__}.batched_to_model_list", - wraps=batched_to_model_list, - ) as wrapped_converter, warnings.catch_warnings(record=True) as ws: - warnings.simplefilter("always", BotorchWarning) - fit_gpytorch_mll(mll) - # Check that MLL repacking succeeded. - self.assertFalse( - any("Training loss of repacked model" in str(w.message) for w in ws) - ) - wrapped_converter.assert_called_once() - self.assertFalse(torch.allclose(intf.mins, torch.zeros(1, 2, **tkwargs))) - self.assertFalse(torch.allclose(intf.ranges, torch.ones(1, 2, **tkwargs))) - - def test_fit_with_converter_float32(self) -> None: - self.helper_fit_with_converter(torch.float) - - def test_fit_with_converter_float64(self) -> None: - self.helper_fit_with_converter(torch.double)