Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions botorch/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import warnings
from contextlib import contextmanager
from typing import Generator, Optional
from typing import Generator, Optional, Iterable

import torch
from botorch.exceptions.warnings import SamplingWarning
Expand Down Expand Up @@ -144,7 +144,11 @@ def construct_base_samples_from_posterior(


def draw_sobol_samples(
bounds: Tensor, n: int, q: int, seed: Optional[int] = None
bounds: Tensor,
n: int,
q: int,
batch_shape: Optional[Iterable[int], torch.Size] = None,
seed: Optional[int] = None,
) -> Tensor:
r"""Draw qMC samples from the box defined by bounds.

Expand All @@ -154,22 +158,30 @@ def draw_sobol_samples(
to lower and upper bounds, respectively.
n: The number of (q-batch) samples.
q: The size of each q-batch.
batch_shape: The batch shape of the samples. If given, returns samples
of shape `n x batch_shape x q x d`, where each batch is an
`n x q x d`-dim tensor of qMC samples.
seed: The seed used for initializing Owen scrambling. If None (default),
use a random seed.

Returns:
A `n x q x d`-dim tensor of qMC samples from the box defined by bounds.
A `n x batch_shape x q x d`-dim tensor of qMC samples from the box
defined by bounds.

Example:
>>> bounds = torch.stack([torch.zeros(3), torch.ones(3)])
>>> samples = draw_sobol_samples(bounds, 10, 2)
"""
batch_shape = batch_shape or torch.Size()
batch_size = int(torch.prod(torch.tensor(batch_shape)))
d = bounds.shape[-1]
lower = bounds[0]
rng = bounds[1] - bounds[0]
sobol_engine = SobolEngine(q * d, scramble=True, seed=seed)
samples_raw = sobol_engine.draw(n, dtype=lower.dtype).view(n, q, d)
samples_raw = samples_raw.to(device=lower.device)
samples_raw = sobol_engine.draw(batch_size * n, dtype=lower.dtype)
samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=lower.device)
if batch_shape != torch.Size():
samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1)
return lower + rng * samples_raw


Expand Down
17 changes: 13 additions & 4 deletions test/utils/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,22 @@ def test_construct_base_samples_from_posterior(self): # noqa: C901

class TestSampleUtils(BotorchTestCase):
def test_draw_sobol_samples(self):
for d, q, n, seed, dtype in itertools.product(
(1, 3), (1, 2), (2, 5), (None, 1234), (torch.float, torch.double)
batch_shapes = [None, [3, 5], torch.Size([2]), (5, 3, 2, 3), []]
for d, q, n, batch_shape, seed, dtype in itertools.product(
(1, 3),
(1, 2),
(2, 5),
batch_shapes,
(None, 1234),
(torch.float, torch.double),
):
tkwargs = {"device": self.device, "dtype": dtype}
bounds = torch.stack([torch.rand(d), 1 + torch.rand(d)]).to(**tkwargs)
samples = draw_sobol_samples(bounds=bounds, n=n, q=q, seed=seed)
self.assertEqual(samples.shape, torch.Size([n, q, d]))
samples = draw_sobol_samples(
bounds=bounds, n=n, q=q, batch_shape=batch_shape, seed=seed
)
batch_shape = batch_shape or torch.Size()
self.assertEqual(samples.shape, torch.Size([n, *batch_shape, q, d]))
self.assertTrue(torch.all(samples >= bounds[0]))
self.assertTrue(torch.all(samples <= bounds[1]))
self.assertEqual(samples.device.type, self.device.type)
Expand Down