Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 4 additions & 89 deletions botorch/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,34 @@
from warnings import catch_warnings, simplefilter, warn, warn_explicit, WarningMessage

from botorch.exceptions.errors import ModelFittingError, UnsupportedError
from botorch.exceptions.warnings import BotorchWarning, OptimizationWarning
from botorch.exceptions.warnings import OptimizationWarning
from botorch.models.approximate_gp import ApproximateGPyTorchModel
from botorch.models.converter import batched_to_model_list, model_list_to_batched
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.optim.closures import get_loss_closure_with_grads
from botorch.optim.core import _LBFGSB_MAXITER_MAXFUN_REGEX
from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch
from botorch.optim.utils import (
_warning_handler_template,
allclose_mll,
get_parameters,
sample_all_priors,
)
from botorch.settings import debug
from botorch.utils.context_managers import (
del_attribute_ctx,
module_rollback_ctx,
parameter_rollback_ctx,
requires_grad_ctx,
TensorCheckpoint,
)
from botorch.utils.dispatcher import (
Dispatcher,
MDNotImplementedError,
type_bypassing_encoder,
)
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from linear_operator.utils.errors import NotPSDError
from pyro.infer.mcmc import MCMC, NUTS
from torch import device, mean, Tensor
from torch import device, Tensor
from torch.nn import Parameter
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -299,83 +291,6 @@ def _fit_list(
return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll


@FitGPyTorchMLL.register(
(MarginalLogLikelihood, _ApproximateMarginalLogLikelihood),
object,
BatchedMultiOutputGPyTorchModel,
)
def _fit_multioutput_independent(
mll: MarginalLogLikelihood,
_: Type[Likelihood],
__: Type[BatchedMultiOutputGPyTorchModel],
*,
closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None,
sequential: bool = True,
**kwargs: Any,
) -> MarginalLogLikelihood:
r"""Fitting routine for multioutput Gaussian processes.

Args:
closure: Forward-backward closure for obtaining objective values and gradients.
Responsible for setting parameters' `grad` attributes. If no closure is
provided, one will be obtained by calling `get_loss_closure_with_grads`.
sequential: Boolean specifying whether or not to an attempt should be made to
fit the model as a collection of independent GPs. Only relevant for
certain types of GPs with independent outputs, see `batched_to_model_list`.
**kwargs: Passed to the next method unaltered.

Returns:
The `mll` instance. If fitting succeeded, then `mll` will be in evaluation mode,
i.e. `mll.training == False`. Otherwise, `mll` will be in training mode.
"""
if ( # incompatible models
not sequential
or closure is not None
or mll.model.num_outputs == 1
or mll.likelihood is not getattr(mll.model, "likelihood", None)
):
raise MDNotImplementedError # defer to generic

# TODO: Unpacking of OutcomeTransforms not yet supported. Targets are often
# pre-transformed in __init__, so try fitting with outcome_transform hidden
mll.train()
with del_attribute_ctx(mll.model, "outcome_transform"):
try:
# Attempt to unpack batched model into a list of independent submodels
unpacked_model = batched_to_model_list(mll.model)
unpacked_mll = SumMarginalLogLikelihood( # avg. over MLLs internally
unpacked_model.likelihood, unpacked_model
)
if not allclose_mll(a=mll, b=unpacked_mll, transform_a=mean):
raise RuntimeError( # validate model unpacking
"Training loss of unpacked model differs from that of the original."
)

# Fit submodels independently
unpacked_mll = fit_gpytorch_mll(unpacked_mll, **kwargs)

# Repackage submodels and copy over state_dict
repacked_model = model_list_to_batched(unpacked_mll.model.train())
repacked_mll = type(mll)(repacked_model.likelihood, repacked_model)
with module_rollback_ctx(mll, device=device("cpu")) as ckpt:
mll.load_state_dict(repacked_mll.state_dict())
if not allclose_mll(a=mll, b=repacked_mll):
raise RuntimeError( # validate model repacking
"Training loss of repacked model differs from that of the "
"original."
)
ckpt.clear() # do not rollback when exiting
return mll.eval() # DONE!

except (AttributeError, RuntimeError, UnsupportedError) as err:
msg = f"Failed to independently fit submodels with exception: {err}"
warn(
f"{msg.rstrip('.')}. Deferring to generic dispatch...",
BotorchWarning,
)
raise MDNotImplementedError


@FitGPyTorchMLL.register(_ApproximateMarginalLogLikelihood, object, object)
def _fit_fallback_approximate(
mll: _ApproximateMarginalLogLikelihood,
Expand All @@ -385,7 +300,7 @@ def _fit_fallback_approximate(
closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None,
data_loader: Optional[DataLoader] = None,
optimizer: Optional[Callable] = None,
full_batch_limit: int = 1024, # TODO: To be determined.
full_batch_limit: int = 1024,
**kwargs: Any,
) -> _ApproximateMarginalLogLikelihood:
r"""Fallback method for fitting approximate Gaussian processes.
Expand Down
2 changes: 0 additions & 2 deletions botorch/optim/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)
from botorch.optim.utils.model_utils import (
_get_extra_mll_args,
allclose_mll,
get_data_loader,
get_name_filter,
get_parameters,
Expand All @@ -38,7 +37,6 @@
"_get_extra_mll_args",
"_handle_numerical_errors",
"_warning_handler_template",
"allclose_mll",
"as_ndarray",
"columnwise_clamp",
"DEFAULT",
Expand Down
44 changes: 0 additions & 44 deletions botorch/optim/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,47 +214,3 @@ def sample_all_priors(model: GPyTorchModel, max_retries: int = 100) -> None:
)
else:
raise e


def allclose_mll(
a: MarginalLogLikelihood,
b: MarginalLogLikelihood,
transform_a: Optional[Callable[[Tensor], Tensor]] = None,
transform_b: Optional[Callable[[Tensor], Tensor]] = None,
rtol: float = 1e-05,
atol: float = 1e-08,
) -> bool:
r"""Convenience method for testing whether the log likelihoods produced by different
MarginalLogLikelihood instances, when evaluated on their respective models' training
sets, are allclose.

Args:
a: A MarginalLogLikelihood instance.
b: A second MarginalLogLikelihood instance.
transform_a: Optional callable used to post-transform log likelihoods under `a`.
transform_b: Optional callable used to post-transform log likelihoods under `b`.
rtol: Relative tolerance.
atol: Absolute tolerance.

Returns:
Boolean result of the allclose test.
"""
warn("`allclose_mll` is marked for deprecation.", DeprecationWarning)

values_a = a(
a.model(*a.model.train_inputs),
a.model.train_targets,
*_get_extra_mll_args(a),
)
if transform_a:
values_a = transform_a(values_a)

values_b = b(
b.model(*b.model.train_inputs),
b.model.train_targets,
*_get_extra_mll_args(b),
)
if transform_b:
values_b = transform_b(values_b)

return values_a.allclose(values_b, rtol=rtol, atol=atol)
5 changes: 5 additions & 0 deletions test/models/test_gp_regression_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from botorch.exceptions.warnings import OptimizationWarning
from botorch.fit import fit_gpytorch_mll
from botorch.models.converter import batched_to_model_list
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from botorch.models.kernels.categorical import CategoricalKernel
from botorch.models.transforms import Normalize
Expand Down Expand Up @@ -156,6 +157,10 @@ def test_gp(self):
pvar_exp = _get_pvar_expected(posterior, model, X, m)
self.assertTrue(torch.allclose(pvar, pvar_exp, rtol=1e-4, atol=1e-5))

# test that model converter throws an exception
with self.assertRaisesRegex(NotImplementedError, "not supported"):
batched_to_model_list(model)

def test_condition_on_observations(self):
d = 3
for batch_shape, m, ncat, dtype in itertools.product(
Expand Down
31 changes: 0 additions & 31 deletions test/optim/utils/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@

from __future__ import annotations

import math
import re
import warnings
from copy import deepcopy
from itertools import product
from string import ascii_lowercase
from unittest.mock import MagicMock, patch

Expand All @@ -19,7 +17,6 @@
from botorch.models import ModelListGP, SingleTaskGP
from botorch.optim.utils import (
_get_extra_mll_args,
allclose_mll,
get_data_loader,
get_name_filter,
get_parameters,
Expand Down Expand Up @@ -253,31 +250,3 @@ def test_sample_all_priors(self):
original_state_dict = dict(deepcopy(mll.model.state_dict()))
with self.assertRaises(RuntimeError):
sample_all_priors(model)


class TestAllcloseMLL(BotorchTestCase):
def setUp(self):
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
train_Y = torch.sin((2 * math.pi) * train_X)
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)

self.mlls = []
for nu in (1.5, 2.5):
model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
model.covar_module.base_kernel.nu = nu
self.mlls.append(ExactMarginalLogLikelihood(model.likelihood, model))

def test_allclose_mll(self):
self.assertTrue(allclose_mll(a=self.mlls[0], b=self.mlls[0]))
for transform_a, transform_b in product(
*(2 * [(None, lambda vals: torch.zeros_like(vals))])
):
out = allclose_mll(
a=self.mlls[0],
b=self.mlls[1],
transform_a=transform_a,
transform_b=transform_b,
)
self.assertEqual(out, transform_a is not None and transform_b is not None)
Loading