Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 33 additions & 66 deletions botorch/models/higher_order_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,16 +425,7 @@ def posterior(
es.enter_context(skip_posterior_variances(True))
mvn = self(X)
if observation_noise is not False:
# TODO: implement Kronecker + diagonal solves so that this is possible.
# if torch.is_tensor(observation_noise):
# # TODO: Validate noise shape
# # make observation_noise `batch_shape x q x n`
# obs_noise = observation_noise.transpose(-1, -2)
# mvn = self.likelihood(mvn, X, noise=obs_noise)
# elif isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
# noise = self.likelihood.noise.mean().expand(X.shape[:-1])
# mvn = self.likelihood(mvn, X, noise=noise)
# else:
# TODO: ensure that this still works for structured noise solves.
mvn = self.likelihood(mvn, X)

# lazy covariance matrix includes the interpolated version of the full
Expand All @@ -449,32 +440,43 @@ def posterior(
)
else:
train_inputs = self.train_inputs[0]
full_covar = self.covar_modules[0](torch.cat((train_inputs, X), dim=-2))

# we now compute the data covariances for the training data, the testing
# data, the joint covariances, and the test train cross-covariance
train_train_covar = self.prediction_strategy.lik_train_train_covar.detach()
base_train_train_covar = train_train_covar.lazy_tensor

data_train_covar = base_train_train_covar.lazy_tensors[0]
data_covar = self.covar_modules[0]
data_train_test_covar = data_covar(X, train_inputs)
data_test_test_covar = data_covar(X)
data_joint_covar = data_train_covar.cat_rows(
cross_mat=data_train_test_covar,
new_mat=data_test_test_covar,
)

# we detach the latents so that they don't cause gradient errors
# TODO: Can we enable backprop through the latent covariances?
batch_shape = data_train_test_covar.batch_shape
latent_covar_list = []
for latent_covar in base_train_train_covar.lazy_tensors[1:]:
if latent_covar.batch_shape != batch_shape:
latent_covar = BatchRepeatLazyTensor(latent_covar, batch_shape)
latent_covar_list.append(latent_covar.detach())

joint_covar = KroneckerProductLazyTensor(
data_joint_covar, *latent_covar_list
)
test_train_covar = KroneckerProductLazyTensor(
data_train_test_covar, *latent_covar_list
)

# compute the posterior variance if necessary
if no_pred_variance:
pred_variance = mvn.variance
else:
# we detach all of the latent dimension posteriors which precludes
# computing quantities computed on the posterior wrt latents as
# this reduces the memory overhead somewhat
# TODO: add these back in if necessary
joint_covar = self._get_joint_covariance([X])
pred_variance = self.make_posterior_variances(joint_covar)

full_covar = KroneckerProductLazyTensor(
full_covar, *[x.detach() for x in joint_covar.lazy_tensors[1:]]
)

joint_covar_list = [self.covar_modules[0](X, train_inputs)]
batch_shape = joint_covar_list[0].batch_shape
for cm, param in zip(self.covar_modules[1:], self.latent_parameters):
covar = cm(param).detach()
if covar.batch_shape != batch_shape:
covar = BatchRepeatLazyTensor(covar, batch_shape)
joint_covar_list.append(covar)

test_train_covar = KroneckerProductLazyTensor(*joint_covar_list)

# mean and variance get reshaped into the target shape
new_mean = mvn.mean.reshape(*X.shape[:-1], *self.target_shape)
if not no_pred_variance:
Expand All @@ -487,16 +489,14 @@ def posterior(

mvn = MultivariateNormal(new_mean, new_variance)

train_train_covar = self.prediction_strategy.lik_train_train_covar.detach()

# return a specialized Posterior to allow for sampling
# cloning the full covar allows backpropagation through it
posterior = HigherOrderGPPosterior(
mvn=mvn,
train_targets=self.train_targets.unsqueeze(-1),
train_train_covar=train_train_covar,
test_train_covar=test_train_covar,
joint_covariance_matrix=full_covar.clone(),
joint_covariance_matrix=joint_covar.clone(),
output_shape=X.shape[:-1] + self.target_shape,
num_outputs=self._num_outputs,
)
Expand All @@ -505,39 +505,6 @@ def posterior(

return posterior

# TODO: remove when this gets exposed in gpytorch
def _get_joint_covariance(self, inputs):
"""
Internal method to expose the joint test train covariance.
"""

from gpytorch.models import ExactGP
from gpytorch.utils.broadcasting import _mul_broadcast_shape

train_inputs = self.train_inputs
# Concatenate the input to the training input
full_inputs = []
batch_shape = train_inputs[0].shape[:-2]
for train_input, input in zip(train_inputs, inputs):
# Make sure the batch shapes agree for training/test data
# This seems to be deprecated
# if batch_shape != train_input.shape[:-2]:
# batch_shape = _mul_broadcast_shape(
# batch_shape, train_input.shape[:-2]
# )
# train_input = train_input.expand(
# *batch_shape, *train_input.shape[-2:]
# )
if batch_shape != input.shape[:-2]:
batch_shape = _mul_broadcast_shape(batch_shape, input.shape[:-2])
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
input = input.expand(*batch_shape, *input.shape[-2:])
full_inputs.append(torch.cat([train_input, input], dim=-2))

# Get the joint distribution for training/test data
full_output = super(ExactGP, self).__call__(*full_inputs)
return full_output.lazy_covariance_matrix

def make_posterior_variances(self, joint_covariance_matrix: LazyTensor) -> Tensor:
r"""
Computes the posterior variances given the data points X. As currently
Expand Down
10 changes: 7 additions & 3 deletions botorch/posteriors/higher_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ class HigherOrderGPPosterior(GPyTorchPosterior):
HOGP is a tensorized GP model so the posterior covariance grows to be extremely
large, but is highly structured, which means that we can exploit Kronecker
identities to sample from the posterior using Matheron's rule as described in
[Doucet2010sampl]_. In general, this posterior should ONLY be used for HOGP models
[Doucet2010sampl]_.

In general, this posterior should ONLY be used for HOGP models
that have highly structured covariances. It should also only be used internally when
called from the HigherOrderGP.posterior(...) method.
called from the HigherOrderGP.posterior(...) method. At this time, the posterior
does not support gradients with respect to the training data.
"""

def __init__(
Expand Down Expand Up @@ -168,7 +171,8 @@ def rsample(
# base samples now have trailing sample dimension
covariance_matrix = self.joint_covariance_matrix
covar_root = covariance_matrix.root_decomposition().root
samples = covar_root.matmul(base_samples)

samples = covar_root.matmul(base_samples[..., : covar_root.shape[-1], :])

# now pluck out Y_x and X_x
noiseless_train_marginal_samples = samples[
Expand Down
47 changes: 24 additions & 23 deletions test/models/test_higher_order_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.settings import skip_posterior_variances
from gpytorch.settings import skip_posterior_variances, max_cholesky_size


class TestHigherOrderGP(BotorchTestCase):
Expand Down Expand Up @@ -77,28 +77,29 @@ def test_num_output_dims(self):

def test_posterior(self):
for dtype in [torch.float, torch.double]:
torch.random.manual_seed(0)
test_x = torch.rand(2, 30, 1).to(device=self.device, dtype=dtype)

self.model.to(dtype)
if dtype == torch.double:
# need to clear float caches
self.model.train()
self.model.eval()
# test the posterior works
posterior = self.model.posterior(test_x)
self.assertIsInstance(posterior, GPyTorchPosterior)

# test the posterior works with observation noise
posterior = self.model.posterior(test_x, observation_noise=True)
self.assertIsInstance(posterior, GPyTorchPosterior)

# test the posterior works with no variances
# some funkiness in MVNs registration so the variance is non-zero.
with skip_posterior_variances():
posterior = self.model.posterior(test_x)
self.assertIsInstance(posterior, GPyTorchPosterior)
self.assertLessEqual(posterior.variance.max(), 1e-6)
for mcs in [800, 10]:
torch.random.manual_seed(0)
with max_cholesky_size(mcs):
test_x = torch.rand(2, 12, 1).to(device=self.device, dtype=dtype)

self.model.to(dtype)
# clear caches
self.model.train()
self.model.eval()
# test the posterior works
posterior = self.model.posterior(test_x)
self.assertIsInstance(posterior, GPyTorchPosterior)

# test the posterior works with observation noise
posterior = self.model.posterior(test_x, observation_noise=True)
self.assertIsInstance(posterior, GPyTorchPosterior)

# test the posterior works with no variances
# some funkiness in MVNs registration so the variance is non-zero.
with skip_posterior_variances():
posterior = self.model.posterior(test_x)
self.assertIsInstance(posterior, GPyTorchPosterior)
self.assertLessEqual(posterior.variance.max(), 1e-6)

def test_transforms(self):
for dtype in [torch.float, torch.double]:
Expand Down