diff --git a/botorch/utils/gp_sampling.py b/botorch/utils/gp_sampling.py index df4c61fb82..90eb46e094 100644 --- a/botorch/utils/gp_sampling.py +++ b/botorch/utils/gp_sampling.py @@ -7,12 +7,20 @@ from __future__ import annotations from copy import deepcopy -from typing import Optional +from math import pi +from typing import List, Optional import torch +from botorch.models.converter import batched_to_model_list +from botorch.models.deterministic import GenericDeterministicModel from botorch.models.model import Model +from botorch.models.model_list_gp_regression import ModelListGP +from botorch.models.multitask import MultiTaskGP from botorch.utils.sampling import manual_seed +from gpytorch.kernels import Kernel, RBFKernel, MaternKernel, ScaleKernel +from gpytorch.utils.cholesky import psd_safe_cholesky from torch import Tensor +from torch.distributions import MultivariateNormal from torch.nn import Module @@ -93,3 +101,187 @@ def forward(self, X: Tensor) -> Tensor: self.register_buffer("_seed", seed) self.register_buffer("_base_samples", base_samples) return self.Ys[..., -(X.size(-2)) :, :] + + +class RandomFourierFeatures(Module): + """A class that represents Random Fourier Features.""" + + def __init__(self, kernel: Kernel, input_dim: int, num_rff_features: int) -> None: + r"""Initialize RandomFourierFeatures. + + Args: + kernel: the GP kernel + input_dim: the input dimension to the GP kernel + num_rff_features: the number of fourier features + """ + if not isinstance(kernel, ScaleKernel): + base_kernel = kernel + outputscale = torch.tensor( + 1.0, + dtype=base_kernel.lengthscale.dtype, + device=base_kernel.lengthscale.device, + ) + else: + base_kernel = kernel.base_kernel + outputscale = kernel.outputscale.detach().clone() + if not isinstance(base_kernel, (MaternKernel, RBFKernel)): + raise NotImplementedError("Only Matern and RBF kernels are supported.") + elif len(base_kernel.batch_shape) > 0: + raise NotImplementedError("Batched kernels are not supported.") + super().__init__() + self.register_buffer("outputscale", outputscale) + + self.register_buffer("lengthscale", base_kernel.lengthscale.detach().clone()) + self.register_buffer( + "weights", + self._get_weights( + base_kernel=base_kernel, + input_dim=input_dim, + num_rff_features=num_rff_features, + ), + ) + # initialize uniformly in [0, 2 * pi] + self.register_buffer( + "bias", + 2 + * pi + * torch.rand( + num_rff_features, + dtype=base_kernel.lengthscale.dtype, + device=base_kernel.lengthscale.device, + ), + ) + + def _get_weights( + self, base_kernel: Kernel, input_dim: int, num_rff_features: int + ) -> Tensor: + r"""Sample weights for RFF. + + Args: + kernel: the GP base kernel + input_dim: the input dimension to the GP kernel + num_rff_features: the number of fourier features + + Returns: + A `input_dim x num_rff_features`-dim tensor of weights + """ + weights = torch.randn( + input_dim, + num_rff_features, + dtype=base_kernel.lengthscale.dtype, + device=base_kernel.lengthscale.device, + ) + if isinstance(base_kernel, MaternKernel): + gamma_dist = torch.distributions.Gamma(base_kernel.nu, base_kernel.nu) + gamma_samples = gamma_dist.sample(torch.Size([1, num_rff_features])).to( + weights + ) + weights = torch.rsqrt(gamma_samples) * weights + return weights + + def forward(self, X: Tensor) -> Tensor: + r"""Get fourier basis features for the provided inputs.""" + X_scaled = torch.div(X, self.lengthscale) + outputs = torch.cos(X_scaled @ self.weights + self.bias) + return ( + torch.sqrt(torch.tensor(2.0) * self.outputscale / self.weights.shape[-1]) + * outputs + ) + + +def get_deterministic_model( + weights: List[Tensor], bases: List[RandomFourierFeatures] +) -> GenericDeterministicModel: + """Get a deterministic model using the provided weights and bases for each output. + + Args: + weights: a list of weights with `m` elements + bases: a list of RandomFourierFeatures with `m` elements. + + Returns: + A deterministic model. + """ + + def evaluate_gp_sample(X): + return torch.stack([basis(X) @ w for w, basis in zip(weights, bases)], dim=-1) + + return GenericDeterministicModel(f=evaluate_gp_sample, num_outputs=len(weights)) + + +def get_weights_posterior(X: Tensor, y: Tensor, sigma_sq: float) -> MultivariateNormal: + r"""Sample bayesian linear regression weights. + + Args: + X: a `n x num_rff_features`-dim tensor of inputs + y: a `n`-dim tensor of outputs + sigma_sq: the noise variance + + Returns: + The posterior distribution over the weights. + """ + with torch.no_grad(): + A = X.T @ X + sigma_sq * torch.eye(X.shape[-1], dtype=X.dtype, device=X.device) + # mean is given by: m = S @ x.T @ y, where S = A_inv + # compute inverse of A using solves + # covariance is A_inv * sigma + L_A = psd_safe_cholesky(A) + # solve L_A @ u = I + Iw = torch.eye(L_A.shape[0], dtype=X.dtype, device=X.device) + u = torch.triangular_solve(Iw, L_A, upper=False).solution + # solve L_A^T @ S = u + A_inv = torch.triangular_solve(u, L_A.T).solution + m = A_inv @ X.T @ y + L = psd_safe_cholesky(A_inv * sigma_sq) + return MultivariateNormal(loc=m, scale_tril=L) + + +def get_gp_samples( + model: Model, num_outputs: int, n_samples: int, num_rff_features: int = 500 +) -> List[GenericDeterministicModel]: + r"""Sample functions from GP posterior using RFF. + + Args: + model: the model + num_outputs: the number of outputs + n_samples: the number of sampled functions to draw + num_rff_features: the number of random fourier features + + Returns: + A list of sampled functions. + """ + if num_outputs > 1: + if not isinstance(model, ModelListGP): + models = batched_to_model_list(model).models + else: + models = [model] + if isinstance(models[0], MultiTaskGP): + raise NotImplementedError + + weights = [] + bases = [] + for m in range(num_outputs): + train_X = models[m].train_inputs[0] + # get random fourier features + basis = RandomFourierFeatures( + kernel=models[m].covar_module, + input_dim=train_X.shape[-1], + num_rff_features=num_rff_features, + ) + bases.append(basis) + phi_X = basis(train_X) + # sample weights from bayesian linear model + mvn = get_weights_posterior( + X=phi_X, + y=models[m].train_targets, + sigma_sq=models[m].likelihood.noise.mean().item(), + ) + weights.append(mvn.sample(torch.Size([n_samples]))) + # construct a determinisitic, multi-output model for each sample + models = [ + get_deterministic_model( + weights=[weights[m][i] for m in range(num_outputs)], + bases=bases, + ) + for i in range(n_samples) + ] + return models diff --git a/test/utils/test_gp_sampling.py b/test/utils/test_gp_sampling.py index 9acefd0d9b..126a193988 100644 --- a/test/utils/test_gp_sampling.py +++ b/test/utils/test_gp_sampling.py @@ -4,11 +4,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from itertools import product +from math import pi +from unittest import mock import torch +from botorch.models.deterministic import DeterministicModel from botorch.models.gp_regression import SingleTaskGP -from botorch.utils.gp_sampling import GPDraw +from botorch.models.multitask import MultiTaskGP +from botorch.utils.gp_sampling import ( + GPDraw, + RandomFourierFeatures, + get_deterministic_model, + get_weights_posterior, + get_gp_samples, +) from botorch.utils.testing import BotorchTestCase +from gpytorch.kernels import RBFKernel, MaternKernel, ScaleKernel, PeriodicKernel +from torch.distributions import MultivariateNormal def _get_model(device, dtype, multi_output=False): @@ -93,14 +106,14 @@ def _get_model(device, dtype, multi_output=False): model = SingleTaskGP(train_X, train_Y) model.load_state_dict(state_dict) model.to(device=device, dtype=dtype) - return model + return model, train_X, train_Y class TestGPDraw(BotorchTestCase): def test_gp_draw_single_output(self): for dtype in (torch.float, torch.double): tkwargs = {"device": self.device, "dtype": dtype} - model = _get_model(**tkwargs) + model, _, _ = _get_model(**tkwargs) mean = model.mean_module.constant.detach().clone() gp = GPDraw(model) # test initialization @@ -125,18 +138,21 @@ def test_gp_draw_single_output(self): torch.equal(initial_base_samples, new_base_samples[..., :1, :]) ) # evaluate in batch mode (need a new model for this!) - gp = GPDraw(_get_model(**tkwargs)) + model, _, _ = _get_model(**tkwargs) + gp = GPDraw(model) with torch.no_grad(): Y_batch = gp(torch.rand(2, 1, 1, **tkwargs)) self.assertEqual(Y_batch.shape, torch.Size([2, 1, 1])) # test random seed test_X = torch.rand(1, 1, **tkwargs) - gp_a = GPDraw(model=_get_model(**tkwargs), seed=0) + model, _, _ = _get_model(**tkwargs) + gp_a = GPDraw(model=model, seed=0) self.assertEqual(int(gp_a._seed), 0) with torch.no_grad(): Ya = gp_a(test_X) self.assertEqual(int(gp_a._seed), 1) - gp_b = GPDraw(model=_get_model(**tkwargs), seed=0) + model, _, _ = _get_model(**tkwargs) + gp_b = GPDraw(model=model, seed=0) with torch.no_grad(): Yb = gp_b(test_X) self.assertAlmostEqual(Ya, Yb) @@ -144,7 +160,7 @@ def test_gp_draw_single_output(self): def test_gp_draw_multi_output(self): for dtype in (torch.float, torch.double): tkwargs = {"device": self.device, "dtype": dtype} - model = _get_model(**tkwargs, multi_output=True) + model, _, _ = _get_model(**tkwargs, multi_output=True) mean = model.mean_module.constant.detach().clone() gp = GPDraw(model) # test initialization @@ -168,7 +184,185 @@ def test_gp_draw_multi_output(self): torch.equal(initial_base_samples, new_base_samples[..., :1, :]) ) # evaluate in batch mode (need a new model for this!) - gp = GPDraw(_get_model(**tkwargs, multi_output=True)) + model = model, _, _ = _get_model(**tkwargs, multi_output=True) + gp = GPDraw(model) with torch.no_grad(): Y_batch = gp(torch.rand(2, 1, 1, **tkwargs)) self.assertEqual(Y_batch.shape, torch.Size([2, 1, 2])) + + +class TestRandomFourierFeatures(BotorchTestCase): + def test_random_fourier_features(self): + # test kernel that is not Scale, RBF, or Matern + with self.assertRaises(NotImplementedError): + RandomFourierFeatures( + kernel=PeriodicKernel(), + input_dim=2, + num_rff_features=3, + ) + + # test batched kernel + with self.assertRaises(NotImplementedError): + RandomFourierFeatures( + kernel=RBFKernel(batch_shape=torch.Size([2])), + input_dim=2, + num_rff_features=3, + ) + tkwargs = {"device": self.device} + for dtype in (torch.float, torch.double): + tkwargs["dtype"] = dtype + # test init + # test ScaleKernel + base_kernel = RBFKernel(ard_num_dims=2) + kernel = ScaleKernel(base_kernel).to(**tkwargs) + rff = RandomFourierFeatures( + kernel=kernel, + input_dim=2, + num_rff_features=3, + ) + self.assertTrue(torch.equal(rff.outputscale, kernel.outputscale)) + # check that rff makes a copy + self.assertFalse(rff.outputscale is kernel.outputscale) + self.assertTrue(torch.equal(rff.lengthscale, base_kernel.lengthscale)) + # check that rff makes a copy + self.assertFalse(rff.lengthscale is kernel.lengthscale) + + # test not ScaleKernel + rff = RandomFourierFeatures( + kernel=base_kernel, + input_dim=2, + num_rff_features=3, + ) + self.assertTrue(torch.equal(rff.outputscale, torch.tensor(1, **tkwargs))) + self.assertTrue(torch.equal(rff.lengthscale, base_kernel.lengthscale)) + # check that rff makes a copy + self.assertFalse(rff.lengthscale is kernel.lengthscale) + self.assertEqual(rff.weights.shape, torch.Size([2, 3])) + self.assertEqual(rff.bias.shape, torch.Size([3])) + self.assertTrue(((rff.bias <= 2 * pi) & (rff.bias >= 0.0)).all()) + + # test forward + rff = RandomFourierFeatures( + kernel=kernel, + input_dim=2, + num_rff_features=3, + ) + for batch_shape in (torch.Size([]), torch.Size([3])): + X = torch.rand(*batch_shape, 1, 2, **tkwargs) + Y = rff(X) + self.assertTrue(Y.shape, torch.Size([*batch_shape, 1, 1])) + expected_Y = torch.sqrt(2 * rff.outputscale / rff.weights.shape[-1]) * ( + torch.cos(X / base_kernel.lengthscale @ rff.weights + rff.bias) + ) + self.assertTrue(torch.equal(Y, expected_Y)) + + # test get_weights + with mock.patch("torch.randn", wraps=torch.randn) as mock_randn: + rff._get_weights( + base_kernel=base_kernel, input_dim=2, num_rff_features=3 + ) + mock_randn.assert_called_once_with( + 2, + 3, + dtype=base_kernel.lengthscale.dtype, + device=base_kernel.lengthscale.device, + ) + # test get_weights with Matern kernel + with mock.patch("torch.randn", wraps=torch.randn) as mock_randn, mock.patch( + "torch.distributions.Gamma", wraps=torch.distributions.Gamma + ) as mock_gamma: + base_kernel = MaternKernel(ard_num_dims=2).to(**tkwargs) + rff._get_weights( + base_kernel=base_kernel, input_dim=2, num_rff_features=3 + ) + mock_randn.assert_called_once_with( + 2, + 3, + dtype=base_kernel.lengthscale.dtype, + device=base_kernel.lengthscale.device, + ) + mock_gamma.assert_called_once_with( + base_kernel.nu, + base_kernel.nu, + ) + + def test_get_deterministic_model(self): + tkwargs = {"device": self.device} + for dtype, m in product((torch.float, torch.double), (1, 2)): + tkwargs["dtype"] = dtype + weights = [] + bases = [] + for i in range(m): + num_rff = 2 * (i + 2) + weights.append(torch.rand(num_rff, **tkwargs)) + kernel = ScaleKernel(RBFKernel(ard_num_dims=2)).to(**tkwargs) + kernel.outputscale = 0.3 + torch.rand(1, **tkwargs).view( + kernel.outputscale.shape + ) + kernel.base_kernel.lengthscale = 0.3 + torch.rand(2, **tkwargs).view( + kernel.base_kernel.lengthscale.shape + ) + bases.append( + RandomFourierFeatures( + kernel=kernel, + input_dim=2, + num_rff_features=num_rff, + ) + ) + + model = get_deterministic_model(weights=weights, bases=bases) + self.assertIsInstance(model, DeterministicModel) + self.assertEqual(model.num_outputs, m) + for batch_shape in (torch.Size([]), torch.Size([3])): + X = torch.rand(*batch_shape, 1, 2, **tkwargs) + Y = model(X) + expected_Y = torch.stack( + [basis(X) @ w for w, basis in zip(weights, bases)], dim=-1 + ) + self.assertTrue(torch.equal(Y, expected_Y)) + self.assertEqual(Y.shape, torch.Size([*batch_shape, 1, m])) + + def test_get_weights_posterior(self): + tkwargs = {"device": self.device} + sigma = 0.01 + for dtype in (torch.float, torch.double): + tkwargs["dtype"] = dtype + X = torch.rand(40, 2, **tkwargs) + w = torch.rand(2, **tkwargs) + Y_true = X @ w + Y = Y_true + sigma * torch.randn_like(Y_true) + posterior = get_weights_posterior(X=X, y=Y, sigma_sq=sigma ** 2) + self.assertIsInstance(posterior, MultivariateNormal) + self.assertTrue(torch.allclose(w, posterior.mean, atol=1e-1)) + w_samp = posterior.sample() + self.assertEqual(w_samp.shape, w.shape) + + def test_get_gp_samples(self): + # test multi-task model + X = torch.stack([torch.rand(3), torch.tensor([1.0, 0.0, 1.0])], dim=-1) + Y = torch.rand(3, 1) + with self.assertRaises(NotImplementedError): + gp_samples = get_gp_samples( + model=MultiTaskGP(X, Y, task_feature=1), + num_outputs=1, + n_samples=20, + num_rff_features=500, + ) + tkwargs = {"device": self.device} + for dtype, m in product((torch.float, torch.double), (1, 2)): + tkwargs["dtype"] = dtype + model, X, Y = _get_model(**tkwargs, multi_output=m == 2) + gp_samples = get_gp_samples( + model=model, + num_outputs=m, + n_samples=20, + num_rff_features=500, + ) + self.assertEqual(len(gp_samples), 20) + self.assertIsInstance(gp_samples[0], DeterministicModel) + Y_hat_rff = torch.stack( + [gp_sample(X) for gp_sample in gp_samples], dim=0 + ).mean(dim=0) + with torch.no_grad(): + Y_hat = model.posterior(X).mean + self.assertTrue(torch.allclose(Y_hat_rff, Y_hat, atol=2e-1))