Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions botorch/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from botorch.optim.utils import sample_all_priors
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from gpytorch.utils.errors import NotPSDError


FAILED_CONVERSION_MSG = (
Expand Down Expand Up @@ -80,7 +81,6 @@ def fit_gpytorch_model(
tf = mll.model.outcome_transform
mll.model.outcome_transform = None
model_list = batched_to_model_list(mll.model)
model_ = model_list_to_batched(model_list)
mll_ = SumMarginalLogLikelihood(model_list.likelihood, model_list)
fit_gpytorch_model(
mll=mll_,
Expand Down Expand Up @@ -121,12 +121,26 @@ def fit_gpytorch_model(
if retry > 0: # use normal initial conditions on first try
mll.model.load_state_dict(original_state_dict)
sample_all_priors(mll.model)
mll, _ = optimizer(mll, track_iterations=False, **kwargs)
try:
mll, _ = optimizer(mll, track_iterations=False, **kwargs)
except NotPSDError:
retry += 1
logging.log(
logging.DEBUG,
f"Fitting failed on try {retry} due to a NotPSDError.",
)
continue
has_optwarning = False
for w in ws:
# Do not count reaching `maxiter` as an optimization failure.
if "ITERATIONS REACHED LIMIT" in str(w.message):
logging.log(
logging.DEBUG,
"Fitting ended early due to reaching the iteration limit.",
)
continue
has_optwarning |= issubclass(w.category, OptimizationWarning)
warnings.warn(w.message, w.category)
# TODO: this counts hitting `maxiter` as an optimization failure!
if not has_optwarning:
mll.eval()
return mll
Expand Down
111 changes: 58 additions & 53 deletions test/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def test_fit_gpytorch_model(self, optimizer=fit_gpytorch_scipy):
self.assertTrue(
any(issubclass(w.category, OptimizationWarning)) for w in ws
)
self.assertEqual(
sum(1 for w in ws if MAX_RETRY_MSG in str(w.message)), 1
)
self.assertFalse(any(MAX_RETRY_MSG in str(w.message) for w in ws))
model = mll.model
# Make sure all of the parameters changed
self.assertGreater(model.likelihood.raw_noise.abs().item(), 1e-3)
Expand All @@ -111,21 +109,13 @@ def test_fit_gpytorch_model(self, optimizer=fit_gpytorch_scipy):

# test overriding the default bounds with user supplied bounds
mll = self._getModel(double=double)
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
mll = fit_gpytorch_model(
mll,
optimizer=optimizer,
options=options,
max_retries=1,
bounds={"likelihood.noise_covar.raw_noise": (1e-1, None)},
)
if optimizer == fit_gpytorch_scipy:
self.assertTrue(
any(issubclass(w.category, OptimizationWarning)) for w in ws
)
self.assertEqual(
sum(1 for w in ws if MAX_RETRY_MSG in str(w.message)), 1
)
mll = fit_gpytorch_model(
mll,
optimizer=optimizer,
options=options,
max_retries=1,
bounds={"likelihood.noise_covar.raw_noise": (1e-1, None)},
)

model = mll.model
self.assertGreaterEqual(model.likelihood.raw_noise.abs().item(), 1e-1)
Expand Down Expand Up @@ -175,14 +165,9 @@ def test_fit_gpytorch_model(self, optimizer=fit_gpytorch_scipy):
)
),
)
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
mll = fit_gpytorch_model(
mll, optimizer=optimizer, options=options, max_retries=1
)
if optimizer == fit_gpytorch_scipy:
self.assertEqual(
sum(1 for w in ws if MAX_RETRY_MSG in str(w.message)), 1
)
mll = fit_gpytorch_model(
mll, optimizer=optimizer, options=options, max_retries=1
)
self.assertTrue(mll.dummy_param.grad is None)

# test excluding a parameter
Expand All @@ -193,17 +178,9 @@ def test_fit_gpytorch_model(self, optimizer=fit_gpytorch_scipy):
"model.mean_module.constant",
"likelihood.noise_covar.raw_noise",
]
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
mll = fit_gpytorch_model(
mll, optimizer=optimizer, options=options, max_retries=1
)
if optimizer == fit_gpytorch_scipy:
self.assertTrue(
any(issubclass(w.category, OptimizationWarning)) for w in ws
)
self.assertEqual(
sum(1 for w in ws if MAX_RETRY_MSG in str(w.message)), 1
)
mll = fit_gpytorch_model(
mll, optimizer=optimizer, options=options, max_retries=1
)
model = mll.model
# Make excluded params did not change
self.assertEqual(
Expand All @@ -221,21 +198,13 @@ def test_fit_gpytorch_model(self, optimizer=fit_gpytorch_scipy):
# test non-default setting for approximate MLL computation
is_scipy = optimizer == fit_gpytorch_scipy
mll = self._getModel(double=double)
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
mll = fit_gpytorch_model(
mll,
optimizer=optimizer,
options=options,
max_retries=1,
approx_mll=is_scipy,
)
if is_scipy:
self.assertTrue(
any(issubclass(w.category, OptimizationWarning)) for w in ws
)
self.assertEqual(
sum(1 for w in ws if MAX_RETRY_MSG in str(w.message)), 1
)
mll = fit_gpytorch_model(
mll,
optimizer=optimizer,
options=options,
max_retries=1,
approx_mll=is_scipy,
)
model = mll.model
# Make sure all of the parameters changed
self.assertGreater(model.likelihood.raw_noise.abs().item(), 1e-3)
Expand Down Expand Up @@ -268,8 +237,9 @@ def test_fit_gpytorch_model_singular(self):
)
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
mll.to(device=self.device, dtype=dtype)
with self.assertRaises(NotPSDError):
with self.assertLogs(level="DEBUG") as logs:
fit_gpytorch_model(mll, options=options, max_retries=2)
self.assertTrue(any("NotPSDError" in log for log in logs.output))
# ensure we can handle NaNErrors in the optimizer
with mock.patch.object(SingleTaskGP, "__call__", side_effect=NanError):
gp = SingleTaskGP(X_train, Y_train, likelihood=test_likelihood)
Expand All @@ -278,6 +248,32 @@ def test_fit_gpytorch_model_singular(self):
fit_gpytorch_model(
mll, options={"disp": False, "maxiter": 1}, max_retries=1
)
# ensure we catch NotPSDErrors
with mock.patch.object(SingleTaskGP, "__call__", side_effect=NotPSDError):
mll = self._getModel()
with self.assertLogs(level="DEBUG") as logs:
fit_gpytorch_model(mll, max_retries=2)
for retry in [1, 2]:
self.assertTrue(
any(
f"Fitting failed on try {retry} due to a NotPSDError."
in log
for log in logs.output
)
)

# Failure due to optimization warning

def optimize_w_warning(mll, **kwargs):
warnings.warn("Dummy warning.", OptimizationWarning)
return mll, None

mll = self._getModel()
with self.assertLogs(level="DEBUG") as logs, settings.debug(True):
fit_gpytorch_model(mll, optimizer=optimize_w_warning, max_retries=2)
self.assertTrue(
any("Fitting failed on try 1." in log for log in logs.output)
)

def test_fit_gpytorch_model_torch(self):
self.test_fit_gpytorch_model(optimizer=fit_gpytorch_torch)
Expand Down Expand Up @@ -317,3 +313,12 @@ def test_fit_gpytorch_model_sequential(self):
for w in ws
)
)

def test_fit_w_maxiter(self):
options = {"maxiter": 1}
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
mll = self._getModel()
fit_gpytorch_model(mll, options=options, max_retries=3)
mll = self._getBatchedModel()
fit_gpytorch_model(mll, options=options, max_retries=3)
self.assertFalse(any("ITERATIONS REACHED LIMIT" in str(w.message) for w in ws))