Skip to content

Commit

Permalink
Botorch closures (#1191)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1191

X-link: pytorch/botorch#1439

This diff acts as follow-up to the recent model fitting refactor. The previous update focused on the high-level logic used to determine which fitting routines to use for which MLLs. This diff refactors the internal machinery used to evaluate forward-backward passes (producing losses and gradients, respectively) during optimization.

The solution we have opted for is to abstract away the evaluation process by relying on closures. In most cases, these closures are automatically constructed by composing simpler, multiply-dispatched base functions.

Reviewed By: Balandat

Differential Revision: D39101211

fbshipit-source-id: c2058a387fd74058073cfe73c9404d2df2f9b55a
  • Loading branch information
James Wilson authored and facebook-github-bot committed Nov 11, 2022
1 parent a67302f commit f8fabe0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
15 changes: 6 additions & 9 deletions ax/models/torch/alebo.py
Expand Up @@ -30,11 +30,10 @@
from botorch.models.gp_regression import FixedNoiseGP
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.optim.fit import fit_gpytorch_scipy
from botorch.optim.fit import fit_gpytorch_mll_scipy
from botorch.optim.initializers import initialize_q_batch_nonneg
from botorch.optim.numpy_converter import module_to_array
from botorch.optim.numpy_converter import _scipy_objective_and_grad, module_to_array
from botorch.optim.optimize import optimize_acqf
from botorch.optim.utils import _scipy_objective_and_grad
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.utils.datasets import SupervisedDataset
from gpytorch.distributions.multivariate_normal import MultivariateNormal
Expand Down Expand Up @@ -283,12 +282,10 @@ def get_map_model(
m.load_state_dict(init_state_dict)
mll = ExactMarginalLogLikelihood(m.likelihood, m)
mll.train()
mll, info_dict = fit_gpytorch_scipy(mll, track_iterations=False, method="tnc")
logger.debug(info_dict)
# pyre-fixme[58]: `<` is not supported for operand types
# `Union[List[botorch.optim.fit.OptimizationIteration], float]` and `float`.
if info_dict["fopt"] < f_best:
f_best = float(info_dict["fopt"]) # pyre-ignore
result = fit_gpytorch_mll_scipy(mll, method="tnc")
logger.debug(result)
if result.fval < f_best:
f_best = float(result.fval)
sd_best = m.state_dict()
# Set the final value
m = ALEBOGP(B=B, train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar)
Expand Down
3 changes: 1 addition & 2 deletions ax/utils/testing/mock.py
Expand Up @@ -29,7 +29,6 @@ def one_iteration_minimize(*args, **kwargs):
kwargs["options"] = {}

kwargs["options"]["maxiter"] = 1

return minimize(*args, **kwargs)

# pyre-fixme[3]: Return type must be annotated.
Expand Down Expand Up @@ -58,7 +57,7 @@ def minimal_gen_os_ics(*args, **kwargs):

mock_fit = es.enter_context(
mock.patch(
"botorch.optim.fit.minimize",
"botorch.optim.core.minimize",
wraps=one_iteration_minimize,
)
)
Expand Down

0 comments on commit f8fabe0

Please sign in to comment.