Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 193 additions & 1 deletion botorch/utils/gp_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@
from __future__ import annotations

from copy import deepcopy
from typing import Optional
from math import pi
from typing import List, Optional

import torch
from botorch.models.converter import batched_to_model_list
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import MultiTaskGP
from botorch.utils.sampling import manual_seed
from gpytorch.kernels import Kernel, RBFKernel, MaternKernel, ScaleKernel
from gpytorch.utils.cholesky import psd_safe_cholesky
from torch import Tensor
from torch.distributions import MultivariateNormal
from torch.nn import Module


Expand Down Expand Up @@ -93,3 +101,187 @@ def forward(self, X: Tensor) -> Tensor:
self.register_buffer("_seed", seed)
self.register_buffer("_base_samples", base_samples)
return self.Ys[..., -(X.size(-2)) :, :]


class RandomFourierFeatures(Module):
"""A class that represents Random Fourier Features."""

def __init__(self, kernel: Kernel, input_dim: int, num_rff_features: int) -> None:
r"""Initialize RandomFourierFeatures.

Args:
kernel: the GP kernel
input_dim: the input dimension to the GP kernel
num_rff_features: the number of fourier features
"""
if not isinstance(kernel, ScaleKernel):
base_kernel = kernel
outputscale = torch.tensor(
1.0,
dtype=base_kernel.lengthscale.dtype,
device=base_kernel.lengthscale.device,
)
else:
base_kernel = kernel.base_kernel
outputscale = kernel.outputscale.detach().clone()
if not isinstance(base_kernel, (MaternKernel, RBFKernel)):
raise NotImplementedError("Only Matern and RBF kernels are supported.")
elif len(base_kernel.batch_shape) > 0:
raise NotImplementedError("Batched kernels are not supported.")
super().__init__()
self.register_buffer("outputscale", outputscale)

self.register_buffer("lengthscale", base_kernel.lengthscale.detach().clone())
self.register_buffer(
"weights",
self._get_weights(
base_kernel=base_kernel,
input_dim=input_dim,
num_rff_features=num_rff_features,
),
)
# initialize uniformly in [0, 2 * pi]
self.register_buffer(
"bias",
2
* pi
* torch.rand(
num_rff_features,
dtype=base_kernel.lengthscale.dtype,
device=base_kernel.lengthscale.device,
),
)

def _get_weights(
self, base_kernel: Kernel, input_dim: int, num_rff_features: int
) -> Tensor:
r"""Sample weights for RFF.

Args:
kernel: the GP base kernel
input_dim: the input dimension to the GP kernel
num_rff_features: the number of fourier features

Returns:
A `input_dim x num_rff_features`-dim tensor of weights
"""
weights = torch.randn(
input_dim,
num_rff_features,
dtype=base_kernel.lengthscale.dtype,
device=base_kernel.lengthscale.device,
)
Comment on lines +168 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought: Would there be any benefit to using draw_sobol_normal_samples here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting thought. probably? night be worth looking into.

if isinstance(base_kernel, MaternKernel):
gamma_dist = torch.distributions.Gamma(base_kernel.nu, base_kernel.nu)
gamma_samples = gamma_dist.sample(torch.Size([1, num_rff_features])).to(
weights
)
weights = torch.rsqrt(gamma_samples) * weights
return weights

def forward(self, X: Tensor) -> Tensor:
r"""Get fourier basis features for the provided inputs."""
X_scaled = torch.div(X, self.lengthscale)
outputs = torch.cos(X_scaled @ self.weights + self.bias)
return (
torch.sqrt(torch.tensor(2.0) * self.outputscale / self.weights.shape[-1])
* outputs
)


def get_deterministic_model(
weights: List[Tensor], bases: List[RandomFourierFeatures]
) -> GenericDeterministicModel:
"""Get a deterministic model using the provided weights and bases for each output.

Args:
weights: a list of weights with `m` elements
bases: a list of RandomFourierFeatures with `m` elements.

Returns:
A deterministic model.
"""

def evaluate_gp_sample(X):
return torch.stack([basis(X) @ w for w, basis in zip(weights, bases)], dim=-1)

return GenericDeterministicModel(f=evaluate_gp_sample, num_outputs=len(weights))


def get_weights_posterior(X: Tensor, y: Tensor, sigma_sq: float) -> MultivariateNormal:
r"""Sample bayesian linear regression weights.

Args:
X: a `n x num_rff_features`-dim tensor of inputs
y: a `n`-dim tensor of outputs
sigma_sq: the noise variance

Returns:
The posterior distribution over the weights.
"""
with torch.no_grad():
A = X.T @ X + sigma_sq * torch.eye(X.shape[-1], dtype=X.dtype, device=X.device)
# mean is given by: m = S @ x.T @ y, where S = A_inv
# compute inverse of A using solves
# covariance is A_inv * sigma
L_A = psd_safe_cholesky(A)
# solve L_A @ u = I
Iw = torch.eye(L_A.shape[0], dtype=X.dtype, device=X.device)
u = torch.triangular_solve(Iw, L_A, upper=False).solution
# solve L_A^T @ S = u
A_inv = torch.triangular_solve(u, L_A.T).solution
m = A_inv @ X.T @ y
L = psd_safe_cholesky(A_inv * sigma_sq)
return MultivariateNormal(loc=m, scale_tril=L)


def get_gp_samples(
model: Model, num_outputs: int, n_samples: int, num_rff_features: int = 500
) -> List[GenericDeterministicModel]:
r"""Sample functions from GP posterior using RFF.

Args:
model: the model
num_outputs: the number of outputs
n_samples: the number of sampled functions to draw
num_rff_features: the number of random fourier features

Returns:
A list of sampled functions.
"""
if num_outputs > 1:
if not isinstance(model, ModelListGP):
models = batched_to_model_list(model).models
else:
models = [model]
if isinstance(models[0], MultiTaskGP):
raise NotImplementedError

weights = []
bases = []
for m in range(num_outputs):
train_X = models[m].train_inputs[0]
# get random fourier features
basis = RandomFourierFeatures(
kernel=models[m].covar_module,
input_dim=train_X.shape[-1],
num_rff_features=num_rff_features,
)
bases.append(basis)
phi_X = basis(train_X)
# sample weights from bayesian linear model
mvn = get_weights_posterior(
X=phi_X,
y=models[m].train_targets,
sigma_sq=models[m].likelihood.noise.mean().item(),
)
weights.append(mvn.sample(torch.Size([n_samples])))
# construct a determinisitic, multi-output model for each sample
models = [
get_deterministic_model(
weights=[weights[m][i] for m in range(num_outputs)],
bases=bases,
)
for i in range(n_samples)
]
return models
Loading