diff --git a/botorch/fit.py b/botorch/fit.py index 2c1de057bc..2671054adb 100644 --- a/botorch/fit.py +++ b/botorch/fit.py @@ -23,6 +23,7 @@ from botorch.optim.utils import sample_all_priors from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood +from gpytorch.utils.errors import NotPSDError FAILED_CONVERSION_MSG = ( @@ -80,7 +81,6 @@ def fit_gpytorch_model( tf = mll.model.outcome_transform mll.model.outcome_transform = None model_list = batched_to_model_list(mll.model) - model_ = model_list_to_batched(model_list) mll_ = SumMarginalLogLikelihood(model_list.likelihood, model_list) fit_gpytorch_model( mll=mll_, @@ -121,12 +121,26 @@ def fit_gpytorch_model( if retry > 0: # use normal initial conditions on first try mll.model.load_state_dict(original_state_dict) sample_all_priors(mll.model) - mll, _ = optimizer(mll, track_iterations=False, **kwargs) + try: + mll, _ = optimizer(mll, track_iterations=False, **kwargs) + except NotPSDError: + retry += 1 + logging.log( + logging.DEBUG, + f"Fitting failed on try {retry} due to a NotPSDError.", + ) + continue has_optwarning = False for w in ws: + # Do not count reaching `maxiter` as an optimization failure. + if "ITERATIONS REACHED LIMIT" in str(w.message): + logging.log( + logging.DEBUG, + "Fitting ended early due to reaching the iteration limit.", + ) + continue has_optwarning |= issubclass(w.category, OptimizationWarning) warnings.warn(w.message, w.category) - # TODO: this counts hitting `maxiter` as an optimization failure! if not has_optwarning: mll.eval() return mll diff --git a/test/test_fit.py b/test/test_fit.py index 8ac6161bee..bbe4a57361 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -97,9 +97,7 @@ def test_fit_gpytorch_model(self, optimizer=fit_gpytorch_scipy): self.assertTrue( any(issubclass(w.category, OptimizationWarning)) for w in ws ) - self.assertEqual( - sum(1 for w in ws if MAX_RETRY_MSG in str(w.message)), 1 - ) + self.assertFalse(any(MAX_RETRY_MSG in str(w.message) for w in ws)) model = mll.model # Make sure all of the parameters changed self.assertGreater(model.likelihood.raw_noise.abs().item(), 1e-3) @@ -111,21 +109,13 @@ def test_fit_gpytorch_model(self, optimizer=fit_gpytorch_scipy): # test overriding the default bounds with user supplied bounds mll = self._getModel(double=double) - with warnings.catch_warnings(record=True) as ws, settings.debug(True): - mll = fit_gpytorch_model( - mll, - optimizer=optimizer, - options=options, - max_retries=1, - bounds={"likelihood.noise_covar.raw_noise": (1e-1, None)}, - ) - if optimizer == fit_gpytorch_scipy: - self.assertTrue( - any(issubclass(w.category, OptimizationWarning)) for w in ws - ) - self.assertEqual( - sum(1 for w in ws if MAX_RETRY_MSG in str(w.message)), 1 - ) + mll = fit_gpytorch_model( + mll, + optimizer=optimizer, + options=options, + max_retries=1, + bounds={"likelihood.noise_covar.raw_noise": (1e-1, None)}, + ) model = mll.model self.assertGreaterEqual(model.likelihood.raw_noise.abs().item(), 1e-1) @@ -175,14 +165,9 @@ def test_fit_gpytorch_model(self, optimizer=fit_gpytorch_scipy): ) ), ) - with warnings.catch_warnings(record=True) as ws, settings.debug(True): - mll = fit_gpytorch_model( - mll, optimizer=optimizer, options=options, max_retries=1 - ) - if optimizer == fit_gpytorch_scipy: - self.assertEqual( - sum(1 for w in ws if MAX_RETRY_MSG in str(w.message)), 1 - ) + mll = fit_gpytorch_model( + mll, optimizer=optimizer, options=options, max_retries=1 + ) self.assertTrue(mll.dummy_param.grad is None) # test excluding a parameter @@ -193,17 +178,9 @@ def test_fit_gpytorch_model(self, optimizer=fit_gpytorch_scipy): "model.mean_module.constant", "likelihood.noise_covar.raw_noise", ] - with warnings.catch_warnings(record=True) as ws, settings.debug(True): - mll = fit_gpytorch_model( - mll, optimizer=optimizer, options=options, max_retries=1 - ) - if optimizer == fit_gpytorch_scipy: - self.assertTrue( - any(issubclass(w.category, OptimizationWarning)) for w in ws - ) - self.assertEqual( - sum(1 for w in ws if MAX_RETRY_MSG in str(w.message)), 1 - ) + mll = fit_gpytorch_model( + mll, optimizer=optimizer, options=options, max_retries=1 + ) model = mll.model # Make excluded params did not change self.assertEqual( @@ -221,21 +198,13 @@ def test_fit_gpytorch_model(self, optimizer=fit_gpytorch_scipy): # test non-default setting for approximate MLL computation is_scipy = optimizer == fit_gpytorch_scipy mll = self._getModel(double=double) - with warnings.catch_warnings(record=True) as ws, settings.debug(True): - mll = fit_gpytorch_model( - mll, - optimizer=optimizer, - options=options, - max_retries=1, - approx_mll=is_scipy, - ) - if is_scipy: - self.assertTrue( - any(issubclass(w.category, OptimizationWarning)) for w in ws - ) - self.assertEqual( - sum(1 for w in ws if MAX_RETRY_MSG in str(w.message)), 1 - ) + mll = fit_gpytorch_model( + mll, + optimizer=optimizer, + options=options, + max_retries=1, + approx_mll=is_scipy, + ) model = mll.model # Make sure all of the parameters changed self.assertGreater(model.likelihood.raw_noise.abs().item(), 1e-3) @@ -268,8 +237,9 @@ def test_fit_gpytorch_model_singular(self): ) mll = ExactMarginalLogLikelihood(gp.likelihood, gp) mll.to(device=self.device, dtype=dtype) - with self.assertRaises(NotPSDError): + with self.assertLogs(level="DEBUG") as logs: fit_gpytorch_model(mll, options=options, max_retries=2) + self.assertTrue(any("NotPSDError" in log for log in logs.output)) # ensure we can handle NaNErrors in the optimizer with mock.patch.object(SingleTaskGP, "__call__", side_effect=NanError): gp = SingleTaskGP(X_train, Y_train, likelihood=test_likelihood) @@ -278,6 +248,32 @@ def test_fit_gpytorch_model_singular(self): fit_gpytorch_model( mll, options={"disp": False, "maxiter": 1}, max_retries=1 ) + # ensure we catch NotPSDErrors + with mock.patch.object(SingleTaskGP, "__call__", side_effect=NotPSDError): + mll = self._getModel() + with self.assertLogs(level="DEBUG") as logs: + fit_gpytorch_model(mll, max_retries=2) + for retry in [1, 2]: + self.assertTrue( + any( + f"Fitting failed on try {retry} due to a NotPSDError." + in log + for log in logs.output + ) + ) + + # Failure due to optimization warning + + def optimize_w_warning(mll, **kwargs): + warnings.warn("Dummy warning.", OptimizationWarning) + return mll, None + + mll = self._getModel() + with self.assertLogs(level="DEBUG") as logs, settings.debug(True): + fit_gpytorch_model(mll, optimizer=optimize_w_warning, max_retries=2) + self.assertTrue( + any("Fitting failed on try 1." in log for log in logs.output) + ) def test_fit_gpytorch_model_torch(self): self.test_fit_gpytorch_model(optimizer=fit_gpytorch_torch) @@ -317,3 +313,12 @@ def test_fit_gpytorch_model_sequential(self): for w in ws ) ) + + def test_fit_w_maxiter(self): + options = {"maxiter": 1} + with warnings.catch_warnings(record=True) as ws, settings.debug(True): + mll = self._getModel() + fit_gpytorch_model(mll, options=options, max_retries=3) + mll = self._getBatchedModel() + fit_gpytorch_model(mll, options=options, max_retries=3) + self.assertFalse(any("ITERATIONS REACHED LIMIT" in str(w.message) for w in ws))