From f8fabe0b273b6ca02ce8fa4d80b973579ece9a14 Mon Sep 17 00:00:00 2001 From: James Wilson Date: Fri, 11 Nov 2022 14:01:14 -0800 Subject: [PATCH] Botorch closures (#1191) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/1191 X-link: https://github.com/pytorch/botorch/pull/1439 This diff acts as follow-up to the recent model fitting refactor. The previous update focused on the high-level logic used to determine which fitting routines to use for which MLLs. This diff refactors the internal machinery used to evaluate forward-backward passes (producing losses and gradients, respectively) during optimization. The solution we have opted for is to abstract away the evaluation process by relying on closures. In most cases, these closures are automatically constructed by composing simpler, multiply-dispatched base functions. Reviewed By: Balandat Differential Revision: D39101211 fbshipit-source-id: c2058a387fd74058073cfe73c9404d2df2f9b55a --- ax/models/torch/alebo.py | 15 ++++++--------- ax/utils/testing/mock.py | 3 +-- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/ax/models/torch/alebo.py b/ax/models/torch/alebo.py index 12c615a743..e49cedeea7 100644 --- a/ax/models/torch/alebo.py +++ b/ax/models/torch/alebo.py @@ -30,11 +30,10 @@ from botorch.models.gp_regression import FixedNoiseGP from botorch.models.gpytorch import GPyTorchModel from botorch.models.model_list_gp_regression import ModelListGP -from botorch.optim.fit import fit_gpytorch_scipy +from botorch.optim.fit import fit_gpytorch_mll_scipy from botorch.optim.initializers import initialize_q_batch_nonneg -from botorch.optim.numpy_converter import module_to_array +from botorch.optim.numpy_converter import _scipy_objective_and_grad, module_to_array from botorch.optim.optimize import optimize_acqf -from botorch.optim.utils import _scipy_objective_and_grad from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.utils.datasets import SupervisedDataset from gpytorch.distributions.multivariate_normal import MultivariateNormal @@ -283,12 +282,10 @@ def get_map_model( m.load_state_dict(init_state_dict) mll = ExactMarginalLogLikelihood(m.likelihood, m) mll.train() - mll, info_dict = fit_gpytorch_scipy(mll, track_iterations=False, method="tnc") - logger.debug(info_dict) - # pyre-fixme[58]: `<` is not supported for operand types - # `Union[List[botorch.optim.fit.OptimizationIteration], float]` and `float`. - if info_dict["fopt"] < f_best: - f_best = float(info_dict["fopt"]) # pyre-ignore + result = fit_gpytorch_mll_scipy(mll, method="tnc") + logger.debug(result) + if result.fval < f_best: + f_best = float(result.fval) sd_best = m.state_dict() # Set the final value m = ALEBOGP(B=B, train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar) diff --git a/ax/utils/testing/mock.py b/ax/utils/testing/mock.py index 455d505906..a3aee87298 100644 --- a/ax/utils/testing/mock.py +++ b/ax/utils/testing/mock.py @@ -29,7 +29,6 @@ def one_iteration_minimize(*args, **kwargs): kwargs["options"] = {} kwargs["options"]["maxiter"] = 1 - return minimize(*args, **kwargs) # pyre-fixme[3]: Return type must be annotated. @@ -58,7 +57,7 @@ def minimal_gen_os_ics(*args, **kwargs): mock_fit = es.enter_context( mock.patch( - "botorch.optim.fit.minimize", + "botorch.optim.core.minimize", wraps=one_iteration_minimize, ) )