diff --git a/ax/models/torch/alebo.py b/ax/models/torch/alebo.py index 12c615a743..e49cedeea7 100644 --- a/ax/models/torch/alebo.py +++ b/ax/models/torch/alebo.py @@ -30,11 +30,10 @@ from botorch.models.gp_regression import FixedNoiseGP from botorch.models.gpytorch import GPyTorchModel from botorch.models.model_list_gp_regression import ModelListGP -from botorch.optim.fit import fit_gpytorch_scipy +from botorch.optim.fit import fit_gpytorch_mll_scipy from botorch.optim.initializers import initialize_q_batch_nonneg -from botorch.optim.numpy_converter import module_to_array +from botorch.optim.numpy_converter import _scipy_objective_and_grad, module_to_array from botorch.optim.optimize import optimize_acqf -from botorch.optim.utils import _scipy_objective_and_grad from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.utils.datasets import SupervisedDataset from gpytorch.distributions.multivariate_normal import MultivariateNormal @@ -283,12 +282,10 @@ def get_map_model( m.load_state_dict(init_state_dict) mll = ExactMarginalLogLikelihood(m.likelihood, m) mll.train() - mll, info_dict = fit_gpytorch_scipy(mll, track_iterations=False, method="tnc") - logger.debug(info_dict) - # pyre-fixme[58]: `<` is not supported for operand types - # `Union[List[botorch.optim.fit.OptimizationIteration], float]` and `float`. - if info_dict["fopt"] < f_best: - f_best = float(info_dict["fopt"]) # pyre-ignore + result = fit_gpytorch_mll_scipy(mll, method="tnc") + logger.debug(result) + if result.fval < f_best: + f_best = float(result.fval) sd_best = m.state_dict() # Set the final value m = ALEBOGP(B=B, train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar) diff --git a/ax/utils/testing/mock.py b/ax/utils/testing/mock.py index 455d505906..a3aee87298 100644 --- a/ax/utils/testing/mock.py +++ b/ax/utils/testing/mock.py @@ -29,7 +29,6 @@ def one_iteration_minimize(*args, **kwargs): kwargs["options"] = {} kwargs["options"]["maxiter"] = 1 - return minimize(*args, **kwargs) # pyre-fixme[3]: Return type must be annotated. @@ -58,7 +57,7 @@ def minimal_gen_os_ics(*args, **kwargs): mock_fit = es.enter_context( mock.patch( - "botorch.optim.fit.minimize", + "botorch.optim.core.minimize", wraps=one_iteration_minimize, ) )