From 4c56e7cc677550262c07aebf2c57756ed89fa016 Mon Sep 17 00:00:00 2001 From: Michael Shvartsman Date: Tue, 29 Sep 2020 17:46:17 -0700 Subject: [PATCH] Remove noise term in PairwiseGP and add ScaleKernel by default (#571) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/571 In a probit model, the noise is not identifiable separately from a scale on latent function. In the current implementation, this scale is implemented via a noise term, which has three disadvantages: 1. A user coming from anywhere else in gpytorch/botorch wouldn't necessarily know not to use ScaleKernel here. 2. When running prediction / computing posteriors, by default we don't include the noise (consistent with behavior elsewhere), so our prediction of f(x) is off by an arbitrary constant factor. 3. When computing posteriors with `noise=True` the noise is added to the covariance, but this doesn't scale the function value either. This means if I use a batched single-outcome acquisition function to do pairwise acquisition (which is reasonable), I'm not doing acquisition on the same model I'm using for interpolation (since now I have noise on individual items rather than on the comparison). Explicitly pairwise MC acquisition functions that take draws of pairs and do something with them should still be correct here, I think. This PR changes the default so that we use ScaleKernel and remove the noise term. I've added docs to this effect in a few places, but this is a breaking change API-wise. Reviewed By: Balandat Differential Revision: D23854124 fbshipit-source-id: da802d9df05ac1cb873b4f7e4fe6ff9bfa67da8a --- botorch/models/pairwise_gp.py | 136 +++++++++++--------------------- test/models/test_pairwise_gp.py | 61 +++----------- 2 files changed, 55 insertions(+), 142 deletions(-) diff --git a/botorch/models/pairwise_gp.py b/botorch/models/pairwise_gp.py index ce15dc0a79..ce53d3ded4 100644 --- a/botorch/models/pairwise_gp.py +++ b/botorch/models/pairwise_gp.py @@ -31,12 +31,11 @@ from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.posteriors.posterior import Posterior from gpytorch import settings -from gpytorch.constraints import GreaterThan, Positive +from gpytorch.constraints import Positive from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels.rbf_kernel import RBFKernel -from gpytorch.lazy.added_diag_lazy_tensor import AddedDiagLazyTensor +from gpytorch.kernels.scale_kernel import ScaleKernel from gpytorch.lazy.lazy_tensor import LazyTensor -from gpytorch.likelihoods.noise_models import HomoskedasticNoise from gpytorch.means.constant_mean import ConstantMean from gpytorch.mlls import MarginalLogLikelihood from gpytorch.models.gp import GP @@ -53,6 +52,12 @@ class PairwiseGP(Model, GP): Implementation is based on [Chu2005preference]_. Also see [Brochu2010tutorial]_ for additional reference. + + Note that in [Chu2005preference]_ the likelihood of a pairwise comparison + is :math:`\left(\frac{f(x_1) - f(x_2)}{\sqrt{2}\sigma}\right)`, i.e. a scale is + used in the denominator. To maintain consistency with usage of kernels + elsewhere in botorch, we instead do not include :math:`\sigma` in the code + (implicitly setting it to 1) and use ScaleKernel to scale the function. """ def __init__( @@ -60,14 +65,11 @@ def __init__( datapoints: Tensor, comparisons: Tensor, covar_module: Optional[Module] = None, - noise_module: Optional[HomoskedasticNoise] = None, **kwargs, ) -> None: super().__init__() - r"""A probit-likelihood GP with Laplace approximation model. - - A probit-likelihood GP with Laplace approximation model that learns via - pairwise comparison data. By default it uses a scaled-RBF kernel. + r"""A probit-likelihood GP with Laplace approximation model that learns via + pairwise comparison data. By default it uses a scaled RBF kernel. Args: datapoints: A `batch_shape x n x d` tensor of training features. @@ -75,7 +77,6 @@ def __init__( comparisons[i] is a noisy indicator suggesting the utility value of comparisons[i, 0]-th is greater than comparisons[i, 1]-th. covar_module: Covariance module - noise_module: Noise module """ # Compatibility variables with fit_gpytorch_*: Dummy likelihood @@ -132,26 +133,24 @@ def __init__( param.requires_grad = False # set covariance module - if noise_module is None: - noise_module = HomoskedasticNoise( - noise_prior=SmoothedBoxPrior(-5, 5, 0.5, transform=torch.log), - noise_constraint=GreaterThan(1e-4), # if None, 1e-4 by default - batch_shape=self._input_batch_shape, - ) - self.noise_module = noise_module - - # set covariance module + # the default outputscale here is only a rule of thumb, meant to keep + # estimates away from scale value that would make Phi(f(x)) saturate + # at 0 or 1 if covar_module is None: ls_prior = GammaPrior(1.2, 0.5) ls_prior_mode = (ls_prior.concentration - 1) / ls_prior.rate - covar_module = RBFKernel( - batch_shape=self._input_batch_shape, - ard_num_dims=self.dim, - lengthscale_prior=ls_prior, - lengthscale_constraint=Positive( - transform=None, initial_value=ls_prior_mode + covar_module = ScaleKernel( + RBFKernel( + batch_shape=self._input_batch_shape, + ard_num_dims=self.dim, + lengthscale_prior=ls_prior, + lengthscale_constraint=Positive( + transform=None, initial_value=ls_prior_mode + ), ), + outputscale_prior=SmoothedBoxPrior(a=1, b=4), ) + self.covar_module = covar_module self._x0 = None # will store temporary results for warm-starting @@ -191,14 +190,6 @@ def __deepcopy__(self, memo) -> PairwiseGP: self.__deepcopy__ = dcp return new_model - @property - def std_noise(self) -> Tensor: - return self.noise_module.noise - - @std_noise.setter - def std_noise(self, value: Tensor) -> None: - self.noise_module.initialize(noise=value) - @property def num_outputs(self) -> int: r"""The number of outputs of the model.""" @@ -212,16 +203,10 @@ def _has_no_data(self): or self.comparisons is None ) - def _calc_covar( - self, X1: Tensor, X2: Tensor, observation_noise: bool = False - ) -> Union[Tensor, LazyTensor]: + def _calc_covar(self, X1: Tensor, X2: Tensor) -> Union[Tensor, LazyTensor]: r"""Calculate the covariance matrix given two sets of datapoints""" X1, X2 = X1, X2 covar = self.covar_module(X1, X2) - if observation_noise: - noise_shape = self._input_batch_shape + self.covar.shape[-1:] - noise = self.noise_module(shape=noise_shape) - covar = AddedDiagLazyTensor(covar, noise) return covar.evaluate() def _batch_chol_inv(self, mat_chol: Tensor) -> Tensor: @@ -302,7 +287,7 @@ def _add_jitter(self, X: Tensor) -> Tensor: return X def _calc_z( - self, utility: Tensor, D: Tensor, std_noise: Tensor + self, utility: Tensor, D: Tensor ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: r"""Calculate z score. @@ -311,7 +296,6 @@ def _calc_z( Args: utility: A Tensor of shape `batch_size x n`, the utility at MAP point D: as in self.D - std_noise: comparison noise (as a Tensor since it's a pytorch param) Returns: z: z score calculated as in [Chu2005preference]_. @@ -320,7 +304,7 @@ def _calc_z( hazard: hazard function defined as pdf(z)/cdf(z) """ - scaled_util = (utility / (math.sqrt(2) * std_noise)).unsqueeze(-1).to(D) + scaled_util = (utility / math.sqrt(2)).unsqueeze(-1).to(D) z = (D @ scaled_util).squeeze(-1) std_norm = torch.distributions.normal.Normal( torch.zeros(1, dtype=z.dtype, device=z.device), @@ -336,9 +320,7 @@ def _calc_z( hazard = torch.exp(z_logpdf - z_logcdf) return z, z_logpdf, z_logcdf, hazard - def _grad_likelihood_f_sum( - self, utility: Tensor, D: Tensor, std_noise: Tensor - ) -> Tensor: + def _grad_likelihood_f_sum(self, utility: Tensor, D: Tensor) -> Tensor: r"""Compute the sum over of grad. of negative Log-LH wrt utility f. Original grad should be of dimension m x n, as in (6) from [Chu2005preference]_. Sum over the first dimension and return a tensor of shape n @@ -347,20 +329,17 @@ def _grad_likelihood_f_sum( Args: utility: A Tensor of shape `batch_size x n` D: A Tensor of shape `batch_size x m x n` as in self.D - std_noise: A Tensor of shape `batch_size x 1`, as in self.std_noise Returns: The sum over the first dimension of grad. of negative Log-LH wrt utility f """ - _, _, _, h = self._calc_z(utility, D, std_noise) - h_factor = (h / (math.sqrt(2) * std_noise)).unsqueeze(-2) + _, _, _, h = self._calc_z(utility, D) + h_factor = (h / math.sqrt(2)).unsqueeze(-2) grad = (h_factor @ (-D)).squeeze(-2) return grad - def _hess_likelihood_f_sum( - self, utility: Tensor, D: Tensor, DT: Tensor, std_noise: Tensor - ) -> Tensor: + def _hess_likelihood_f_sum(self, utility: Tensor, D: Tensor, DT: Tensor) -> Tensor: r"""Compute the sum over of hessian of neg. Log-LH wrt utility f. Original hess should be of dimension m x n x n, as in (7) from @@ -371,13 +350,12 @@ def _hess_likelihood_f_sum( utility: A Tensor of shape `batch_size x n` D: A Tensor of shape `batch_size x m x n` as in self.D DT: Transpose of D. A Tensor of shape `batch_size x n x m` as in self.DT - std_noise: A Tensor of shape `batch_size x 1`, as in self.std_noise Returns: The sum over the first dimension of hess. of negative Log-LH wrt utility f """ - z, _, _, h = self._calc_z(utility, D, std_noise) - mul_factor = h * (h + z) / (2 * (std_noise ** 2)) + z, _, _, h = self._calc_z(utility, D) + mul_factor = h * (h + z) / 2 weighted_DT = DT * mul_factor.unsqueeze(-2).expand(*DT.size()) hess = weighted_DT @ D @@ -389,7 +367,6 @@ def _grad_posterior_f( datapoints: Tensor, D: Tensor, DT: Tensor, - std_noise: Tensor, covar_chol: Tensor, covar_inv: Tensor, ret_np: bool = False, @@ -405,7 +382,6 @@ def _grad_posterior_f( datapoints: A Tensor of shape `batch_size x n x d` as in self.datapoints D: A Tensor of shape `batch_size x m x n` as in self.D DT: Transpose of D. A Tensor of shape `batch_size x n x m` as in self.DT - std_noise: A Tensor of shape `batch_size x 1`, as in self.std_noise covar_chol: A Tensor of shape `batch_size x n x n`, as in self.covar_chol covar_inv: A Tensor of shape `batch_size x n x n`, as in self.covar_inv ret_np: return a numpy array if true, otherwise a Tensor @@ -416,7 +392,7 @@ def _grad_posterior_f( utility = torch.tensor(utility, dtype=self.datapoints.dtype) prior_mean = prior_mean.cpu() - b = self._grad_likelihood_f_sum(utility, D, std_noise) + b = self._grad_likelihood_f_sum(utility, D) # g_ = covar_inv x (utility - pred_prior) p = (utility - prior_mean).unsqueeze(-1).to(covar_chol) @@ -434,7 +410,6 @@ def _hess_posterior_f( datapoints: Tensor, D: Tensor, DT: Tensor, - std_noise: Tensor, covar_chol: Tensor, covar_inv: Tensor, ret_np: bool = False, @@ -450,7 +425,6 @@ def _hess_posterior_f( datapoints: A Tensor of shape `batch_size x n x d` as in self.datapoints D: A Tensor of shape `batch_size x m x n` as in self.D DT: Transpose of D. A Tensor of shape `batch_size x n x m` as in self.DT - std_noise: A Tensor of shape `batch_size x 1`, as in self.std_noise covar_chol: A Tensor of shape `batch_size x n x n`, as in self.covar_chol covar_inv: A Tensor of shape `batch_size x n x n`, as in self.covar_inv ret_np: return a numpy array if true, otherwise a Tensor @@ -458,7 +432,7 @@ def _hess_posterior_f( if ret_np: utility = torch.tensor(utility, dtype=self.datapoints.dtype) - hl = self._hess_likelihood_f_sum(utility, D, DT, std_noise) + hl = self._hess_likelihood_f_sum(utility, D, DT) hess = hl + covar_inv return hess.numpy() if ret_np else hess @@ -470,7 +444,7 @@ def _posterior_f(self, utility: Union[Tensor, np.ndarray]) -> Tensor: Args: utility: A Tensor of shape `batch_size x n` """ - _, _, z_logcdf, _ = self._calc_z(utility, self.D, self.std_noise) + _, _, z_logcdf, _ = self._calc_z(utility, self.D) loss1 = -(torch.sum(z_logcdf, dim=-1)) inv_prod = torch.cholesky_solve(utility.unsqueeze(-1), self.covar_chol) loss2 = 0.5 * (utility.unsqueeze(-2) @ inv_prod).squeeze(-1).squeeze(-1) @@ -536,23 +510,11 @@ def _update(self, **kwargs) -> None: dp_v = self.datapoints.view(-1, self.n, self.dim).cpu() D_v = self.D.view(-1, self.m, self.n).cpu() DT_v = self.DT.view(-1, self.n, self.m).cpu() - # Use `expand` here since we need to expand std_noise along - # the batch shape dimensions if we start off as non-batch model, - # but later conditioned on batched new data - sn_v = self.std_noise.expand(*init_x0_size[:-1], 1).reshape(-1).cpu() ch_v = self.covar_chol.view(-1, self.n, self.n).cpu() ci_v = self.covar_inv.view(-1, self.n, self.n).cpu() x = np.empty(x0.shape) for i in range(x0.shape[0]): - fsolve_args = ( - dp_v[i], - D_v[i], - DT_v[i], - sn_v[i], - ch_v[i], - ci_v[i], - True, - ) + fsolve_args = (dp_v[i], D_v[i], DT_v[i], ch_v[i], ci_v[i], True) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) x[i] = optimize.fsolve( @@ -571,7 +533,6 @@ def _update(self, **kwargs) -> None: self.datapoints.cpu(), self.D.cpu(), self.DT.cpu(), - self.std_noise.cpu(), self.covar_chol.cpu(), self.covar_inv.cpu(), True, @@ -597,9 +558,7 @@ def _update(self, **kwargs) -> None: # when calling forward() in order to obtain correct gradients # self.likelihood_hess is updated here is for the rare case where we # do not want to call forward() - self.likelihood_hess = self._hess_likelihood_f_sum( - f, self.D, self.DT, self.std_noise - ) + self.likelihood_hess = self._hess_likelihood_f_sum(f, self.D, self.DT) # Lazy update hlcov_eye, which is used in calculating posterior during training self.pred_cov_fac_need_update = True @@ -655,11 +614,10 @@ def _util_newton_updates(self, x0, max_iter=1, xtol=None) -> Tensor: finishing `max_iter` updates """ xtol = float("-Inf") if xtol is None else xtol - dp, D, DT, sn, ch, ci = ( + dp, D, DT, ch, ci = ( self.datapoints, self.D, self.DT, - self.std_noise, self.covar_chol, self.covar_inv, ) @@ -669,7 +627,7 @@ def _util_newton_updates(self, x0, max_iter=1, xtol=None) -> Tensor: x = x0 eye = None while i < max_iter and diff > xtol: - hl = self._hess_likelihood_f_sum(x, D, DT, sn) + hl = self._hess_likelihood_f_sum(x, D, DT) cov_hl = covar @ hl if eye is None: eye = torch.eye( @@ -678,7 +636,7 @@ def _util_newton_updates(self, x0, max_iter=1, xtol=None) -> Tensor: device=self.datapoints.device, ).expand(cov_hl.shape) cov_hl = cov_hl + eye # add 1 to cov_hl - g = self._grad_posterior_f(x, dp, D, DT, sn, ch, ci) + g = self._grad_posterior_f(x, dp, D, DT, ch, ci) cov_g = covar @ g.unsqueeze(-1) x_update = torch.solve(cov_g, cov_hl).solution.squeeze(-1) x_next = x - x_update @@ -789,7 +747,7 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal: self.utility = self._util_newton_updates(self.utility, max_iter=1) hl = self.likelihood_hess = self._hess_likelihood_f_sum( - self.utility, self.D, self.DT, self.std_noise + self.utility, self.D, self.DT ) covar = self.covar # Apply matrix inversion lemma on eq. in page 27 of [Brochu2010tutorial]_ @@ -883,11 +841,12 @@ def posterior( X: A `batch_shape x q x d`-dim Tensor, where `d` is the dimension of the feature space and `q` is the number of points considered jointly. output_indices: As defined in parent Model class, not used for this model. - observation_noise: If True, add observation noise to the posterior. + observation_noise: Ignored (since noise is not identifiable from scale + in probit models). Returns: A `Posterior` object, representing joint - distributions over `q` points. Includes observation noise if specified. + distributions over `q` points. """ self.eval() # make sure model is in eval mode @@ -899,10 +858,6 @@ def posterior( post = self(X) - if observation_noise: - noise_module = self.noise_module(shape=post.mean.shape).evaluate() - post = MultivariateNormal(post.mean, post.covariance_matrix + noise_module) - return GPyTorchPosterior(post) def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model: @@ -918,8 +873,7 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode Returns: A (deepcopied) `Model` object of the same type, representing the - original model conditioned on the new observations `(X, Y)` (and - possibly noise observations passed in via kwargs). + original model conditioned on the new observations `(X, Y)`. """ new_model = deepcopy(self) if self._has_no_data(): diff --git a/test/models/test_pairwise_gp.py b/test/models/test_pairwise_gp.py index 1f8f630066..bcc40b4374 100644 --- a/test/models/test_pairwise_gp.py +++ b/test/models/test_pairwise_gp.py @@ -14,11 +14,10 @@ from botorch.posteriors import GPyTorchPosterior from botorch.sampling.pairwise_samplers import PairwiseSobolQMCNormalSampler from botorch.utils.testing import BotorchTestCase -from gpytorch.kernels import RBFKernel +from gpytorch.kernels import RBFKernel, ScaleKernel from gpytorch.kernels.linear_kernel import LinearKernel -from gpytorch.likelihoods.noise_models import HomoskedasticNoise from gpytorch.means import ConstantMean -from gpytorch.priors import GammaPrior +from gpytorch.priors import GammaPrior, SmoothedBoxPrior class TestPairwiseGP(BotorchTestCase): @@ -82,43 +81,25 @@ def test_pairwise_gp(self): # test init self.assertIsInstance(model.mean_module, ConstantMean) - self.assertIsInstance(model.covar_module, RBFKernel) - self.assertIsInstance(model.covar_module.lengthscale_prior, GammaPrior) + self.assertIsInstance(model.covar_module, ScaleKernel) + self.assertIsInstance(model.covar_module.base_kernel, RBFKernel) + self.assertIsInstance( + model.covar_module.base_kernel.lengthscale_prior, GammaPrior + ) + self.assertIsInstance( + model.covar_module.outputscale_prior, SmoothedBoxPrior + ) self.assertEqual(model.num_outputs, 1) - # test custom noise prior - custom_noise_prior = GammaPrior(concentration=2.0, rate=1.0) - custom_noise_module = HomoskedasticNoise(noise_prior=custom_noise_prior) - custom_m = PairwiseGP(**model_kwargs, noise_module=custom_noise_module) - self.assertEqual( - custom_m.noise_module.noise_prior.concentration, torch.tensor(2.0) - ) - self.assertEqual(custom_m.noise_module.noise_prior.rate, torch.tensor(1.0)) # test custom models custom_m = PairwiseGP(**model_kwargs, covar_module=LinearKernel()) self.assertIsInstance(custom_m.covar_module, LinearKernel) - # std_noise setter - custom_m.std_noise = 123 - self.assertTrue(torch.all(custom_m.std_noise == 123)) # prior prediction prior_m = PairwiseGP(None, None) prior_m.eval() post = prior_m.posterior(train_X) self.assertIsInstance(post, GPyTorchPosterior) - # test methods that are not commonly or explicitly used - # _calc_covar with observation noise - no_noise_cov = model._calc_covar(train_X, train_X, observation_noise=False) - noise_cov = model._calc_covar(train_X, train_X, observation_noise=True) - diag_diff = (noise_cov - no_noise_cov).diagonal(dim1=-2, dim2=-1) - self.assertTrue( - torch.allclose( - diag_diff, - model.std_noise.expand(diag_diff.shape), - rtol=1e-4, - atol=1e-5, - ) - ) # test trying adding jitter pd_mat = torch.eye(2, 2) with warnings.catch_warnings(): @@ -155,18 +136,6 @@ def test_pairwise_gp(self): posterior = model.posterior(X) self.assertIsInstance(posterior, GPyTorchPosterior) - # test adding observation noise - posterior_pred = model.posterior(X, observation_noise=True) - self.assertIsInstance(posterior_pred, GPyTorchPosterior) - self.assertEqual(posterior_pred.mean.shape, expected_shape) - self.assertEqual(posterior_pred.variance.shape, expected_shape) - pvar = posterior_pred.variance - reshaped_noise = model.std_noise.unsqueeze(-2).expand( - posterior.variance.shape - ) - pvar_exp = posterior.variance + reshaped_noise - self.assertTrue(torch.allclose(pvar, pvar_exp, rtol=1e-4, atol=1e-5)) - # test batch evaluation X = torch.rand(2, *batch_shape, 3, X_dim, **tkwargs) expected_shape = torch.Size([2]) + batch_shape + torch.Size([3, 1]) @@ -174,16 +143,6 @@ def test_pairwise_gp(self): posterior = model.posterior(X) self.assertIsInstance(posterior, GPyTorchPosterior) self.assertEqual(posterior.mean.shape, expected_shape) - # test adding observation noise in batch mode - posterior_pred = model.posterior(X, observation_noise=True) - self.assertIsInstance(posterior_pred, GPyTorchPosterior) - self.assertEqual(posterior_pred.mean.shape, expected_shape) - pvar = posterior_pred.variance - reshaped_noise = model.std_noise.unsqueeze(-2).expand( - posterior.variance.shape - ) - pvar_exp = posterior.variance + reshaped_noise - self.assertTrue(torch.allclose(pvar, pvar_exp, rtol=1e-4, atol=1e-5)) def test_condition_on_observations(self): for batch_shape, dtype in itertools.product(