Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions botorch/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from .acquisition import AcquisitionFunction, OneShotAcquisitionFunction
from .active_learning import qNegIntegratedPosteriorVariance
from .analytic import (
AnalyticAcquisitionFunction,
ConstrainedExpectedImprovement,
Expand Down Expand Up @@ -55,6 +56,7 @@
"qMaxValueEntropy",
"qMultiFidelityMaxValueEntropy",
"qNoisyExpectedImprovement",
"qNegIntegratedPosteriorVariance",
"qProbabilityOfImprovement",
"qSimpleRegret",
"qUpperConfidenceBound",
Expand Down
122 changes: 122 additions & 0 deletions botorch/acquisition/active_learning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Active learning acquisition functions.

.. [Seo2014activedata]
S. Seo, M. Wallat, T. Graepel, and K. Obermayer. Gaussian process regression:
Active data selection and test point rejection. IJCNN 2000.

.. [Chen2014seqexpdesign]
X. Chen and Q. Zhou. Sequential experimental designs for stochastic kriging.
Winter Simulation Conference 2014.

.. [Binois2017repexp]
M. Binois, J. Huang, R. B. Gramacy, and M. Ludkovski. Replication or
exploration? Sequential design for stochastic simulation experiments.
ArXiv 2017.
"""

from typing import Optional

from botorch import settings
from torch import Tensor

from ..models.model import Model
from ..sampling.samplers import MCSampler, SobolQMCNormalSampler
from ..utils.transforms import concatenate_pending_points, t_batch_mode_transform
from .analytic import AnalyticAcquisitionFunction
from .objective import ScalarizedObjective


class qNegIntegratedPosteriorVariance(AnalyticAcquisitionFunction):
r"""Batch Integrated Negative Posterior Variance for Active Learning.

This acquisition function quantifies the (negative) integrated posterior variance
(excluding observation noise, computed using MC integration) of the model.
In that, it is a proxy for global model uncertainty, and thus purely focused on
"exploration", rather the "exploitation" of many of the classic Bayesian
Optimization acquisition functions.

See [Seo2014activedata]_, [Chen2014seqexpdesign]_, and [Binois2017repexp]_.
"""

def __init__(
self,
model: Model,
mc_points: Tensor,
sampler: Optional[MCSampler] = None,
objective: Optional[ScalarizedObjective] = None,
X_pending: Optional[Tensor] = None,
) -> None:
r"""q-Integrated Negative Posterior Variance.

Args:
model: A fitted model.
mc_points: A `batch_shape x N x d` tensor of points to use for
MC-integrating the posterior variance. Usually, these are qMC
samples on the whole design space, but biased sampling directly
allows weighted integration of the posterior variance.
sampler: The sampler used for drawing fantasy samples. In the basic setting
of a standard GP (default) this is a dummy, since the variance of the
model after conditioning does not actually depend on the sampled values.
objective: A ScalarizedObjective. Required for multi-output models.
X_pending: A `n' x d`-dim Tensor of `n'` design points that have
points that have been submitted for function evaluation but
have not yet been evaluated.
"""
super().__init__(model=model, objective=objective)
if sampler is None:
# If no sampler is provided, we use the following dummy sampler for the
# fantasize() method in forward. IMPORTANT: This assumes that the posterior
# variance does not depend on the samples y (only on x), which is true for
# standard GP models, but not in general (e.g. for other likelihoods or
# heteroskedastic GPs using a separate noise model fit on data).
sampler = SobolQMCNormalSampler(
num_samples=1, resample=False, collapse_batch_dims=True
)
self.sampler = sampler
self.X_pending = X_pending
self.register_buffer("mc_points", mc_points)

@concatenate_pending_points
@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
# Construct the fantasy model (we actually do not use the full model,
# this is just a convenient way of computing fast posterior covariances
fantasy_model = self.model.fantasize(
X=X, sampler=self.sampler, observation_noise=True
)

bdims = tuple(1 for _ in X.shape[:-2])
if self.model.num_outputs > 1:
# We use q=1 here b/c ScalarizedObjective currently does not fully exploit
# lazy tensor operations and thus may be slow / overly memory-hungry.
# TODO (T52818288): Properly use lazy tensors in scalarize_posterior
mc_points = self.mc_points.view(-1, *bdims, 1, X.size(-1))
else:
# While we only need marginal variances, we can evaluate for q>1
# b/c for GPyTorch models lazy evaluation can make this quite a bit
# faster than evaluting in t-batch mode with q-batch size of 1
mc_points = self.mc_points.view(*bdims, -1, X.size(-1))

# evaluate the posterior at the grid points
with settings.propagate_grads(True):
posterior = fantasy_model.posterior(mc_points)

# transform with the scalarized objective
if self.objective is not None:
posterior = self.objective(posterior)

neg_variance = posterior.variance.mul(-1.0)

if self.objective is None:
# if single-output, shape is 1 x batch_shape x num_grid_points x 1
return neg_variance.mean(dim=-2).squeeze(-1).squeeze(0)
else:
# if multi-output + obj, shape is num_grid_points x batch_shape x 1 x 1
return neg_variance.mean(dim=0).squeeze(-1).squeeze(-1)
5 changes: 5 additions & 0 deletions sphinx/source/acquisition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ Entropy-Based Acquisition Functions
.. automodule:: botorch.acquisition.max_value_entropy_search
:members:

Active Learning Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.active_learning
:members:


Objectives and Cost-Aware Utilities
-------------------------------------------
Expand Down
126 changes: 126 additions & 0 deletions test/acquisition/test_active_learning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#! /usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from unittest import mock

import torch
from botorch.acquisition.active_learning import qNegIntegratedPosteriorVariance
from botorch.acquisition.objective import IdentityMCObjective, ScalarizedObjective
from botorch.exceptions.errors import UnsupportedError
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
from gpytorch.distributions import MultitaskMultivariateNormal


class TestQNegIntegratedPosteriorVariance(BotorchTestCase):
def test_init(self):
mm = MockModel(MockPosterior(mean=torch.rand(2, 1)))
mc_points = torch.rand(2, 2)
qNIPV = qNegIntegratedPosteriorVariance(model=mm, mc_points=mc_points)
sampler = qNIPV.sampler
self.assertIsInstance(sampler, SobolQMCNormalSampler)
self.assertEqual(sampler.sample_shape, torch.Size([1]))
self.assertFalse(sampler.resample)
self.assertTrue(torch.equal(mc_points, qNIPV.mc_points))
self.assertIsNone(qNIPV.X_pending)
self.assertIsNone(qNIPV.objective)
sampler = IIDNormalSampler(num_samples=2, resample=True)
qNIPV = qNegIntegratedPosteriorVariance(
model=mm, mc_points=mc_points, sampler=sampler
)
self.assertIsInstance(qNIPV.sampler, IIDNormalSampler)
self.assertEqual(qNIPV.sampler.sample_shape, torch.Size([2]))

def test_q_neg_int_post_variance(self):
no = "botorch.utils.testing.MockModel.num_outputs"
for dtype in (torch.float, torch.double):
# basic test
mean = torch.zeros(4, 1, device=self.device, dtype=dtype)
variance = torch.rand(4, 1, device=self.device, dtype=dtype)
mc_points = torch.rand(10, 1, device=self.device, dtype=dtype)
mfm = MockModel(MockPosterior(mean=mean, variance=variance))
with mock.patch.object(MockModel, "fantasize", return_value=mfm):
with mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs:
mock_num_outputs.return_value = 1
# TODO: Make this work with arbitrary models
mm = MockModel(None)
qNIPV = qNegIntegratedPosteriorVariance(
model=mm, mc_points=mc_points
)
X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy
val = qNIPV(X)
self.assertTrue(torch.allclose(val, -variance.mean(), atol=1e-4))
# batched model
mean = torch.zeros(2, 4, 1, device=self.device, dtype=dtype)
variance = torch.rand(2, 4, 1, device=self.device, dtype=dtype)
mc_points = torch.rand(2, 10, 1, device=self.device, dtype=dtype)
mfm = MockModel(MockPosterior(mean=mean, variance=variance))
with mock.patch.object(MockModel, "fantasize", return_value=mfm):
with mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs:
mock_num_outputs.return_value = 1
# TODO: Make this work with arbitrary models
mm = MockModel(None)
qNIPV = qNegIntegratedPosteriorVariance(
model=mm, mc_points=mc_points
)
# TODO: Allow broadcasting for batch evaluation
X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy
val = qNIPV(X)
val_exp = -variance.mean(dim=-2).squeeze(-1)
self.assertTrue(torch.allclose(val, val_exp, atol=1e-4))
# multi-output model
mean = torch.zeros(4, 2, device=self.device, dtype=dtype)
variance = torch.rand(4, 2, device=self.device, dtype=dtype)
cov = torch.diag_embed(variance.view(-1))
f_posterior = GPyTorchPosterior(MultitaskMultivariateNormal(mean, cov))
mc_points = torch.rand(10, 1, device=self.device, dtype=dtype)
mfm = MockModel(f_posterior)
with mock.patch.object(MockModel, "fantasize", return_value=mfm):
with mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs:
mock_num_outputs.return_value = 2
mm = MockModel(None)

# check error if objective is not ScalarizedObjective
with self.assertRaises(UnsupportedError):
qNegIntegratedPosteriorVariance(
model=mm,
mc_points=mc_points,
objective=IdentityMCObjective(),
)

weights = torch.tensor([0.5, 0.5], device=self.device, dtype=dtype)
qNIPV = qNegIntegratedPosteriorVariance(
model=mm,
mc_points=mc_points,
objective=ScalarizedObjective(weights=weights),
)
X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy
val = qNIPV(X)
self.assertTrue(
torch.allclose(val, -0.5 * variance.mean(), atol=1e-4)
)
# batched multi-output model
mean = torch.zeros(4, 3, 1, 2, device=self.device, dtype=dtype)
variance = torch.rand(4, 3, 1, 2, device=self.device, dtype=dtype)
cov = torch.diag_embed(variance.view(4, 3, -1))
f_posterior = GPyTorchPosterior(MultitaskMultivariateNormal(mean, cov))
mc_points = torch.rand(4, 1, device=self.device, dtype=dtype)
mfm = MockModel(f_posterior)
with mock.patch.object(MockModel, "fantasize", return_value=mfm):
with mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs:
mock_num_outputs.return_value = 2
mm = MockModel(None)
weights = torch.tensor([0.5, 0.5], device=self.device, dtype=dtype)
qNIPV = qNegIntegratedPosteriorVariance(
model=mm,
mc_points=mc_points,
objective=ScalarizedObjective(weights=weights),
)
X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy
val = qNIPV(X)
val_exp = -0.5 * variance.mean(dim=0).view(3, -1).mean(dim=-1)
self.assertTrue(torch.allclose(val, val_exp, atol=1e-4))