Skip to content

Commit

Permalink
Remove everything deprecated in closures refactor (pytorch#1995)
Browse files Browse the repository at this point in the history
Summary:
## Motivation

Removes everything deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Differential Revision: https://internalfb.com/D48738275

fbshipit-source-id: 4cb19467d42d782c4abe95810e48428c193bef99
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 12, 2023
1 parent 6d330eb commit 21afeb1
Show file tree
Hide file tree
Showing 8 changed files with 6 additions and 1,046 deletions.
7 changes: 1 addition & 6 deletions botorch/__init__.py
Expand Up @@ -16,11 +16,7 @@
test_functions,
)
from botorch.cross_validation import batch_cross_validation
from botorch.fit import (
fit_fully_bayesian_model_nuts,
fit_gpytorch_mll,
fit_gpytorch_model,
)
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
from botorch.generation.gen import (
gen_candidates_scipy,
gen_candidates_torch,
Expand Down Expand Up @@ -56,7 +52,6 @@
"exceptions",
"fit_fully_bayesian_model_nuts",
"fit_gpytorch_mll",
"fit_gpytorch_model",
"gen_candidates_scipy",
"gen_candidates_torch",
"get_best_candidates",
Expand Down
66 changes: 2 additions & 64 deletions botorch/fit.py
Expand Up @@ -9,11 +9,10 @@
from __future__ import annotations

import logging
from contextlib import nullcontext
from functools import partial
from itertools import filterfalse
from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Type, Union
from warnings import catch_warnings, simplefilter, warn, warn_explicit, WarningMessage
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage

from botorch.exceptions.errors import ModelFittingError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
Expand All @@ -33,7 +32,6 @@
from botorch.utils.context_managers import (
module_rollback_ctx,
parameter_rollback_ctx,
requires_grad_ctx,
TensorCheckpoint,
)
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
Expand Down Expand Up @@ -113,66 +111,6 @@ def fit_gpytorch_mll(
)


def fit_gpytorch_model(
mll: MarginalLogLikelihood,
optimizer: Optional[Callable] = None,
optimizer_kwargs: Optional[dict] = None,
exclude: Optional[Iterable[str]] = None,
max_retries: Optional[int] = None,
**kwargs: Any,
) -> MarginalLogLikelihood:
r"""Convenience method for fitting GPyTorch models using legacy API. For more
details, see `fit_gpytorch_mll`.
Args:
mll: A GPyTorch MarginalLogLikelihood instance.
optimizer: User specified optimization algorithm. When `optimizer is None`,
this keyword argument is omitted when calling the dispatcher from inside
`fit_gpytorch_mll`.
optimizer_kwargs: Keyword arguments passed to `optimizer`.
exclude: Legacy argument for specifying parameters `x` that should be held fixed
during optimization. Internally, used to temporarily set `x.requires_grad`
to False.
max_retries: Legacy name for `max_attempts`. When `max_retries is None`,
this keyword argument is omitted when calling `fit_gpytorch_mll`.
"""
warn(
"`fit_gpytorch_model` is marked for deprecation, consider using "
"`fit_gpytorch_mll` instead.",
DeprecationWarning,
)
if max_retries is not None:
kwargs["max_attempts"] = max_retries

optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
for key in ("bounds", "options"):
if key not in kwargs:
continue

val = kwargs.pop(key)
if key in optimizer_kwargs and val is not optimizer_kwargs[key]:
raise SyntaxError(f"keyword argument repeated: {key}")

optimizer_kwargs[key] = val

with (
nullcontext()
if exclude is None
else requires_grad_ctx(mll, assignments={name: False for name in exclude})
):
try:
mll = fit_gpytorch_mll(
mll,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
**kwargs,
)
except ModelFittingError as err:
warn(str(err), RuntimeWarning)

return mll


@FitGPyTorchMLL.register(MarginalLogLikelihood, object, object)
def _fit_fallback(
mll: MarginalLogLikelihood,
Expand Down
3 changes: 0 additions & 3 deletions botorch/optim/__init__.py
Expand Up @@ -23,7 +23,6 @@
LogLinearHomotopySchedule,
)
from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg
from botorch.optim.numpy_converter import module_to_array, set_params_with_array
from botorch.optim.optimize import (
gen_batch_initial_conditions,
optimize_acqf,
Expand Down Expand Up @@ -51,9 +50,7 @@
"optimize_acqf_discrete_local_search",
"optimize_acqf_mixed",
"optimize_acqf_homotopy",
"module_to_array",
"scipy_minimize",
"set_params_with_array",
"torch_minimize",
"ExpMAStoppingCriterion",
"FixedHomotopySchedule",
Expand Down
228 changes: 2 additions & 226 deletions botorch/optim/fit.py
Expand Up @@ -9,21 +9,7 @@
from __future__ import annotations

from functools import partial
from itertools import filterfalse
from time import monotonic
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Pattern,
Sequence,
Set,
Tuple,
Union,
)
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Union
from warnings import warn

from botorch.exceptions.warnings import OptimizationWarning
Expand All @@ -34,25 +20,11 @@
scipy_minimize,
torch_minimize,
)
from botorch.optim.numpy_converter import (
_scipy_objective_and_grad,
module_to_array,
set_params_with_array,
)
from botorch.optim.stopping import ExpMAStoppingCriterion
from botorch.optim.utils import (
_filter_kwargs,
_get_extra_mll_args,
get_name_filter,
get_parameters_and_bounds,
TorchAttr,
)
from botorch.optim.utils.model_utils import get_parameters
from botorch.optim.utils import get_parameters_and_bounds, TorchAttr
from botorch.utils.types import DEFAULT
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.settings import fast_computations
from numpy import ndarray
from scipy.optimize import Bounds, minimize
from torch import Tensor
from torch.nn import Module
from torch.optim.adam import Adam
Expand Down Expand Up @@ -200,199 +172,3 @@ def fit_gpytorch_mll_torch(
callback=callback,
timeout_sec=timeout_sec,
)


def fit_gpytorch_scipy(
mll: MarginalLogLikelihood,
bounds: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]] = None,
method: str = "L-BFGS-B",
options: Optional[Dict[str, Any]] = None,
track_iterations: bool = False,
approx_mll: bool = False,
scipy_objective: TScipyObjective = _scipy_objective_and_grad,
module_to_array_func: TModToArray = module_to_array,
module_from_array_func: TArrayToMod = set_params_with_array,
**kwargs: Any,
) -> Tuple[MarginalLogLikelihood, Dict[str, Union[float, List[OptimizationResult]]]]:
r"""Legacy method for scipy-based fitting of gpytorch models.
The model and likelihood in mll must already be in train mode. This method requires
that the model has `train_inputs` and `train_targets`.
Args:
mll: MarginalLogLikelihood to be maximized.
bounds: A dictionary mapping parameter names to tuples of lower and upper
bounds.
method: Solver type, passed along to scipy.optimize.minimize.
options: Dictionary of solver options, passed along to scipy.optimize.minimize.
approx_mll: If True, use gpytorch's approximate MLL computation. This is
disabled by default since the stochasticity is an issue for
determistic optimizers). Enabling this is only recommended when
working with large training data sets (n>2000).
Returns:
2-element tuple containing
- MarginalLogLikelihood with parameters optimized in-place.
- Dictionary with the following key/values:
"fopt": Best mll value.
"wall_time": Wall time of fitting.
"iterations": List of OptimizationResult objects with information on each
iteration. If track_iterations is False, will be empty.
"OptimizeResult": The result returned by `scipy.optim.minimize`.
"""
warn(
"`fit_gpytorch_scipy` is marked for deprecation, consider using "
"`scipy_minimize` or its model fitting helper `fit_gpytorch_mll_scipy`.",
DeprecationWarning,
)
start_time = monotonic()
iterations: List[OptimizationResult] = []

options = {} if options is None else options.copy()
exclude: Iterator[Union[Pattern, str]] = options.pop("exclude", None)
if exclude:
exclude, _ = zip( # get the qualified names of excluded parameters
*filterfalse(get_name_filter(exclude), mll.named_parameters())
)

x0, property_dict, bounds = module_to_array_func(
module=mll, exclude=exclude, bounds=bounds
)
if bounds is not None:
bounds = Bounds(lb=bounds[0], ub=bounds[1], keep_feasible=True)

def wrapper(x: ndarray) -> Tuple[float, ndarray]:
with fast_computations(log_prob=approx_mll):
return scipy_objective(x=x, mll=mll, property_dict=property_dict)

def store_iteration(xk):
iterations.append(
OptimizationResult(
step=len(iterations),
fval=float(wrapper(xk)[0]),
status=OptimizationStatus.RUNNING,
runtime=monotonic() - start_time,
)
)

result = minimize(
wrapper,
x0,
bounds=bounds,
method=method,
jac=True,
options=options,
callback=store_iteration if track_iterations else None,
)

info_dict = {
"fopt": float(result.fun),
"wall_time": monotonic() - start_time,
"iterations": iterations,
"OptimizeResult": result,
}
if not result.success:
try:
# Some result.message are bytes
msg = result.message.decode("ascii")
except AttributeError:
# Others are str
msg = result.message
warn(
f"Fitting failed with the optimizer reporting '{msg}'", OptimizationWarning
)

# Set to optimum
mll = module_from_array_func(mll, result.x, property_dict)
return mll, info_dict


def fit_gpytorch_torch(
mll: MarginalLogLikelihood,
bounds: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]] = None,
optimizer_cls: Optimizer = Adam,
options: Optional[Dict[str, Any]] = None,
track_iterations: bool = False,
approx_mll: bool = False,
) -> Tuple[MarginalLogLikelihood, Dict[str, Union[float, List[OptimizationResult]]]]:
r"""Legacy method for torch-based fitting of gpytorch models.
The model and likelihood in mll must already be in train mode.
Note: this method requires that the model has `train_inputs` and `train_targets`.
Args:
mll: MarginalLogLikelihood to be maximized.
bounds: An optional dictionary mapping parameter names to tuples
of lower and upper bounds. Bounds specified here take precedence
over bounds on the same parameters specified in the constraints
registered with the module.
optimizer_cls: Torch optimizer to use. Must not require a closure.
options: options for model fitting. Relevant options will be passed to
the `optimizer_cls`. Additionally, options can include: "disp"
to specify whether to display model fitting diagnostics and "maxiter"
to specify the maximum number of iterations.
Returns:
2-element tuple containing
- mll with parameters optimized in-place.
- Dictionary with the following key/values:
"fopt": Best mll value.
"wall_time": Wall time of fitting.
"iterations": List of OptimizationResult objects with information on each
iteration. If track_iterations is False, will be empty.
Example:
>>> gp = SingleTaskGP(train_X, train_Y)
>>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
>>> mll.train()
>>> fit_gpytorch_torch(mll)
>>> mll.eval()
"""
warn(
"`fit_gpytorch_torch` is marked for deprecation, consider using "
"`torch_minimize` or its model fitting helper `fit_gpytorch_mll_torch`.",
DeprecationWarning,
)
_options = {"maxiter": 100, "disp": True, "lr": 0.05}
_options.update(options or {})
exclude = _options.pop("exclude", None)
parameters = get_parameters(
mll,
requires_grad=True,
name_filter=None if exclude is None else get_name_filter(exclude),
)

optimizer = optimizer_cls(
params=list(parameters.values()), **_filter_kwargs(optimizer_cls, **_options)
)
iterations: List[OptimizationResult] = []
stopping_criterion = ExpMAStoppingCriterion(
**_filter_kwargs(ExpMAStoppingCriterion, **_options)
)

def closure() -> Tuple[Tensor, Tuple[Tensor, ...]]:
optimizer.zero_grad()
with fast_computations(log_prob=approx_mll):
out = mll.model(*mll.model.train_inputs)
loss = -mll(out, mll.model.train_targets, *_get_extra_mll_args(mll)).sum()
loss.backward()

return loss, tuple(param.grad for param in parameters.values())

def store_iteration(parameters: Dict[str, Tensor], result: OptimizationResult):
iterations.append(result)

result = fit_gpytorch_mll_torch(
mll=mll,
closure=closure,
bounds=bounds,
parameters=parameters,
optimizer=optimizer,
stopping_criterion=stopping_criterion,
callback=store_iteration if track_iterations else None,
)
return mll, {
"fopt": result.fval,
"wall_time": result.runtime,
"iterations": iterations,
}

0 comments on commit 21afeb1

Please sign in to comment.