From ba14496b80a26e64bcf553f2faf0790914ca5c27 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Wed, 26 Feb 2020 14:01:46 -0800 Subject: [PATCH] qNegIntegratedPosteriorVariance acquisition function (#377) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/377 Moves `qNegIntegratedPosteriorVariance` to OSS. ToDo: - tutorial notebook - full test coverage (see https://codecov.io/gh/pytorch/botorch/compare/7038aa7f0cf48d5c5f9fde62e3f8386a96a52079...491e15c61830c3ebf614f341a935d0f044e4a3af/diff) Reviewed By: ItsMrLin Differential Revision: D17572731 fbshipit-source-id: 956821a198f298ffe0e75b3b38eecbfdd80a7b9a --- botorch/acquisition/__init__.py | 2 + botorch/acquisition/active_learning.py | 122 ++++++++++++++++++++++ sphinx/source/acquisition.rst | 5 + test/acquisition/test_active_learning.py | 126 +++++++++++++++++++++++ 4 files changed, 255 insertions(+) create mode 100644 botorch/acquisition/active_learning.py create mode 100644 test/acquisition/test_active_learning.py diff --git a/botorch/acquisition/__init__.py b/botorch/acquisition/__init__.py index e3918cd697..9a66c1cc20 100644 --- a/botorch/acquisition/__init__.py +++ b/botorch/acquisition/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from .acquisition import AcquisitionFunction, OneShotAcquisitionFunction +from .active_learning import qNegIntegratedPosteriorVariance from .analytic import ( AnalyticAcquisitionFunction, ConstrainedExpectedImprovement, @@ -55,6 +56,7 @@ "qMaxValueEntropy", "qMultiFidelityMaxValueEntropy", "qNoisyExpectedImprovement", + "qNegIntegratedPosteriorVariance", "qProbabilityOfImprovement", "qSimpleRegret", "qUpperConfidenceBound", diff --git a/botorch/acquisition/active_learning.py b/botorch/acquisition/active_learning.py new file mode 100644 index 0000000000..0f94607ee6 --- /dev/null +++ b/botorch/acquisition/active_learning.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Active learning acquisition functions. + +.. [Seo2014activedata] + S. Seo, M. Wallat, T. Graepel, and K. Obermayer. Gaussian process regression: + Active data selection and test point rejection. IJCNN 2000. + +.. [Chen2014seqexpdesign] + X. Chen and Q. Zhou. Sequential experimental designs for stochastic kriging. + Winter Simulation Conference 2014. + +.. [Binois2017repexp] + M. Binois, J. Huang, R. B. Gramacy, and M. Ludkovski. Replication or + exploration? Sequential design for stochastic simulation experiments. + ArXiv 2017. +""" + +from typing import Optional + +from botorch import settings +from torch import Tensor + +from ..models.model import Model +from ..sampling.samplers import MCSampler, SobolQMCNormalSampler +from ..utils.transforms import concatenate_pending_points, t_batch_mode_transform +from .analytic import AnalyticAcquisitionFunction +from .objective import ScalarizedObjective + + +class qNegIntegratedPosteriorVariance(AnalyticAcquisitionFunction): + r"""Batch Integrated Negative Posterior Variance for Active Learning. + + This acquisition function quantifies the (negative) integrated posterior variance + (excluding observation noise, computed using MC integration) of the model. + In that, it is a proxy for global model uncertainty, and thus purely focused on + "exploration", rather the "exploitation" of many of the classic Bayesian + Optimization acquisition functions. + + See [Seo2014activedata]_, [Chen2014seqexpdesign]_, and [Binois2017repexp]_. + """ + + def __init__( + self, + model: Model, + mc_points: Tensor, + sampler: Optional[MCSampler] = None, + objective: Optional[ScalarizedObjective] = None, + X_pending: Optional[Tensor] = None, + ) -> None: + r"""q-Integrated Negative Posterior Variance. + + Args: + model: A fitted model. + mc_points: A `batch_shape x N x d` tensor of points to use for + MC-integrating the posterior variance. Usually, these are qMC + samples on the whole design space, but biased sampling directly + allows weighted integration of the posterior variance. + sampler: The sampler used for drawing fantasy samples. In the basic setting + of a standard GP (default) this is a dummy, since the variance of the + model after conditioning does not actually depend on the sampled values. + objective: A ScalarizedObjective. Required for multi-output models. + X_pending: A `n' x d`-dim Tensor of `n'` design points that have + points that have been submitted for function evaluation but + have not yet been evaluated. + """ + super().__init__(model=model, objective=objective) + if sampler is None: + # If no sampler is provided, we use the following dummy sampler for the + # fantasize() method in forward. IMPORTANT: This assumes that the posterior + # variance does not depend on the samples y (only on x), which is true for + # standard GP models, but not in general (e.g. for other likelihoods or + # heteroskedastic GPs using a separate noise model fit on data). + sampler = SobolQMCNormalSampler( + num_samples=1, resample=False, collapse_batch_dims=True + ) + self.sampler = sampler + self.X_pending = X_pending + self.register_buffer("mc_points", mc_points) + + @concatenate_pending_points + @t_batch_mode_transform() + def forward(self, X: Tensor) -> Tensor: + # Construct the fantasy model (we actually do not use the full model, + # this is just a convenient way of computing fast posterior covariances + fantasy_model = self.model.fantasize( + X=X, sampler=self.sampler, observation_noise=True + ) + + bdims = tuple(1 for _ in X.shape[:-2]) + if self.model.num_outputs > 1: + # We use q=1 here b/c ScalarizedObjective currently does not fully exploit + # lazy tensor operations and thus may be slow / overly memory-hungry. + # TODO (T52818288): Properly use lazy tensors in scalarize_posterior + mc_points = self.mc_points.view(-1, *bdims, 1, X.size(-1)) + else: + # While we only need marginal variances, we can evaluate for q>1 + # b/c for GPyTorch models lazy evaluation can make this quite a bit + # faster than evaluting in t-batch mode with q-batch size of 1 + mc_points = self.mc_points.view(*bdims, -1, X.size(-1)) + + # evaluate the posterior at the grid points + with settings.propagate_grads(True): + posterior = fantasy_model.posterior(mc_points) + + # transform with the scalarized objective + if self.objective is not None: + posterior = self.objective(posterior) + + neg_variance = posterior.variance.mul(-1.0) + + if self.objective is None: + # if single-output, shape is 1 x batch_shape x num_grid_points x 1 + return neg_variance.mean(dim=-2).squeeze(-1).squeeze(0) + else: + # if multi-output + obj, shape is num_grid_points x batch_shape x 1 x 1 + return neg_variance.mean(dim=0).squeeze(-1).squeeze(-1) diff --git a/sphinx/source/acquisition.rst b/sphinx/source/acquisition.rst index 8c66606354..d93d5e574b 100644 --- a/sphinx/source/acquisition.rst +++ b/sphinx/source/acquisition.rst @@ -53,6 +53,11 @@ Entropy-Based Acquisition Functions .. automodule:: botorch.acquisition.max_value_entropy_search :members: +Active Learning Acquisition Functions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.acquisition.active_learning + :members: + Objectives and Cost-Aware Utilities ------------------------------------------- diff --git a/test/acquisition/test_active_learning.py b/test/acquisition/test_active_learning.py new file mode 100644 index 0000000000..649df27ada --- /dev/null +++ b/test/acquisition/test_active_learning.py @@ -0,0 +1,126 @@ +#! /usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from unittest import mock + +import torch +from botorch.acquisition.active_learning import qNegIntegratedPosteriorVariance +from botorch.acquisition.objective import IdentityMCObjective, ScalarizedObjective +from botorch.exceptions.errors import UnsupportedError +from botorch.posteriors.gpytorch import GPyTorchPosterior +from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler +from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior +from gpytorch.distributions import MultitaskMultivariateNormal + + +class TestQNegIntegratedPosteriorVariance(BotorchTestCase): + def test_init(self): + mm = MockModel(MockPosterior(mean=torch.rand(2, 1))) + mc_points = torch.rand(2, 2) + qNIPV = qNegIntegratedPosteriorVariance(model=mm, mc_points=mc_points) + sampler = qNIPV.sampler + self.assertIsInstance(sampler, SobolQMCNormalSampler) + self.assertEqual(sampler.sample_shape, torch.Size([1])) + self.assertFalse(sampler.resample) + self.assertTrue(torch.equal(mc_points, qNIPV.mc_points)) + self.assertIsNone(qNIPV.X_pending) + self.assertIsNone(qNIPV.objective) + sampler = IIDNormalSampler(num_samples=2, resample=True) + qNIPV = qNegIntegratedPosteriorVariance( + model=mm, mc_points=mc_points, sampler=sampler + ) + self.assertIsInstance(qNIPV.sampler, IIDNormalSampler) + self.assertEqual(qNIPV.sampler.sample_shape, torch.Size([2])) + + def test_q_neg_int_post_variance(self): + no = "botorch.utils.testing.MockModel.num_outputs" + for dtype in (torch.float, torch.double): + # basic test + mean = torch.zeros(4, 1, device=self.device, dtype=dtype) + variance = torch.rand(4, 1, device=self.device, dtype=dtype) + mc_points = torch.rand(10, 1, device=self.device, dtype=dtype) + mfm = MockModel(MockPosterior(mean=mean, variance=variance)) + with mock.patch.object(MockModel, "fantasize", return_value=mfm): + with mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs: + mock_num_outputs.return_value = 1 + # TODO: Make this work with arbitrary models + mm = MockModel(None) + qNIPV = qNegIntegratedPosteriorVariance( + model=mm, mc_points=mc_points + ) + X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy + val = qNIPV(X) + self.assertTrue(torch.allclose(val, -variance.mean(), atol=1e-4)) + # batched model + mean = torch.zeros(2, 4, 1, device=self.device, dtype=dtype) + variance = torch.rand(2, 4, 1, device=self.device, dtype=dtype) + mc_points = torch.rand(2, 10, 1, device=self.device, dtype=dtype) + mfm = MockModel(MockPosterior(mean=mean, variance=variance)) + with mock.patch.object(MockModel, "fantasize", return_value=mfm): + with mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs: + mock_num_outputs.return_value = 1 + # TODO: Make this work with arbitrary models + mm = MockModel(None) + qNIPV = qNegIntegratedPosteriorVariance( + model=mm, mc_points=mc_points + ) + # TODO: Allow broadcasting for batch evaluation + X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy + val = qNIPV(X) + val_exp = -variance.mean(dim=-2).squeeze(-1) + self.assertTrue(torch.allclose(val, val_exp, atol=1e-4)) + # multi-output model + mean = torch.zeros(4, 2, device=self.device, dtype=dtype) + variance = torch.rand(4, 2, device=self.device, dtype=dtype) + cov = torch.diag_embed(variance.view(-1)) + f_posterior = GPyTorchPosterior(MultitaskMultivariateNormal(mean, cov)) + mc_points = torch.rand(10, 1, device=self.device, dtype=dtype) + mfm = MockModel(f_posterior) + with mock.patch.object(MockModel, "fantasize", return_value=mfm): + with mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs: + mock_num_outputs.return_value = 2 + mm = MockModel(None) + + # check error if objective is not ScalarizedObjective + with self.assertRaises(UnsupportedError): + qNegIntegratedPosteriorVariance( + model=mm, + mc_points=mc_points, + objective=IdentityMCObjective(), + ) + + weights = torch.tensor([0.5, 0.5], device=self.device, dtype=dtype) + qNIPV = qNegIntegratedPosteriorVariance( + model=mm, + mc_points=mc_points, + objective=ScalarizedObjective(weights=weights), + ) + X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy + val = qNIPV(X) + self.assertTrue( + torch.allclose(val, -0.5 * variance.mean(), atol=1e-4) + ) + # batched multi-output model + mean = torch.zeros(4, 3, 1, 2, device=self.device, dtype=dtype) + variance = torch.rand(4, 3, 1, 2, device=self.device, dtype=dtype) + cov = torch.diag_embed(variance.view(4, 3, -1)) + f_posterior = GPyTorchPosterior(MultitaskMultivariateNormal(mean, cov)) + mc_points = torch.rand(4, 1, device=self.device, dtype=dtype) + mfm = MockModel(f_posterior) + with mock.patch.object(MockModel, "fantasize", return_value=mfm): + with mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs: + mock_num_outputs.return_value = 2 + mm = MockModel(None) + weights = torch.tensor([0.5, 0.5], device=self.device, dtype=dtype) + qNIPV = qNegIntegratedPosteriorVariance( + model=mm, + mc_points=mc_points, + objective=ScalarizedObjective(weights=weights), + ) + X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy + val = qNIPV(X) + val_exp = -0.5 * variance.mean(dim=0).view(3, -1).mean(dim=-1) + self.assertTrue(torch.allclose(val, val_exp, atol=1e-4))