diff --git a/botorch/models/higher_order_gp.py b/botorch/models/higher_order_gp.py index e4f8c65077..92cd930d0f 100644 --- a/botorch/models/higher_order_gp.py +++ b/botorch/models/higher_order_gp.py @@ -425,16 +425,7 @@ def posterior( es.enter_context(skip_posterior_variances(True)) mvn = self(X) if observation_noise is not False: - # TODO: implement Kronecker + diagonal solves so that this is possible. - # if torch.is_tensor(observation_noise): - # # TODO: Validate noise shape - # # make observation_noise `batch_shape x q x n` - # obs_noise = observation_noise.transpose(-1, -2) - # mvn = self.likelihood(mvn, X, noise=obs_noise) - # elif isinstance(self.likelihood, FixedNoiseGaussianLikelihood): - # noise = self.likelihood.noise.mean().expand(X.shape[:-1]) - # mvn = self.likelihood(mvn, X, noise=noise) - # else: + # TODO: ensure that this still works for structured noise solves. mvn = self.likelihood(mvn, X) # lazy covariance matrix includes the interpolated version of the full @@ -449,32 +440,43 @@ def posterior( ) else: train_inputs = self.train_inputs[0] - full_covar = self.covar_modules[0](torch.cat((train_inputs, X), dim=-2)) + # we now compute the data covariances for the training data, the testing + # data, the joint covariances, and the test train cross-covariance + train_train_covar = self.prediction_strategy.lik_train_train_covar.detach() + base_train_train_covar = train_train_covar.lazy_tensor + + data_train_covar = base_train_train_covar.lazy_tensors[0] + data_covar = self.covar_modules[0] + data_train_test_covar = data_covar(X, train_inputs) + data_test_test_covar = data_covar(X) + data_joint_covar = data_train_covar.cat_rows( + cross_mat=data_train_test_covar, + new_mat=data_test_test_covar, + ) + + # we detach the latents so that they don't cause gradient errors + # TODO: Can we enable backprop through the latent covariances? + batch_shape = data_train_test_covar.batch_shape + latent_covar_list = [] + for latent_covar in base_train_train_covar.lazy_tensors[1:]: + if latent_covar.batch_shape != batch_shape: + latent_covar = BatchRepeatLazyTensor(latent_covar, batch_shape) + latent_covar_list.append(latent_covar.detach()) + + joint_covar = KroneckerProductLazyTensor( + data_joint_covar, *latent_covar_list + ) + test_train_covar = KroneckerProductLazyTensor( + data_train_test_covar, *latent_covar_list + ) + + # compute the posterior variance if necessary if no_pred_variance: pred_variance = mvn.variance else: - # we detach all of the latent dimension posteriors which precludes - # computing quantities computed on the posterior wrt latents as - # this reduces the memory overhead somewhat - # TODO: add these back in if necessary - joint_covar = self._get_joint_covariance([X]) pred_variance = self.make_posterior_variances(joint_covar) - full_covar = KroneckerProductLazyTensor( - full_covar, *[x.detach() for x in joint_covar.lazy_tensors[1:]] - ) - - joint_covar_list = [self.covar_modules[0](X, train_inputs)] - batch_shape = joint_covar_list[0].batch_shape - for cm, param in zip(self.covar_modules[1:], self.latent_parameters): - covar = cm(param).detach() - if covar.batch_shape != batch_shape: - covar = BatchRepeatLazyTensor(covar, batch_shape) - joint_covar_list.append(covar) - - test_train_covar = KroneckerProductLazyTensor(*joint_covar_list) - # mean and variance get reshaped into the target shape new_mean = mvn.mean.reshape(*X.shape[:-1], *self.target_shape) if not no_pred_variance: @@ -487,8 +489,6 @@ def posterior( mvn = MultivariateNormal(new_mean, new_variance) - train_train_covar = self.prediction_strategy.lik_train_train_covar.detach() - # return a specialized Posterior to allow for sampling # cloning the full covar allows backpropagation through it posterior = HigherOrderGPPosterior( @@ -496,7 +496,7 @@ def posterior( train_targets=self.train_targets.unsqueeze(-1), train_train_covar=train_train_covar, test_train_covar=test_train_covar, - joint_covariance_matrix=full_covar.clone(), + joint_covariance_matrix=joint_covar.clone(), output_shape=X.shape[:-1] + self.target_shape, num_outputs=self._num_outputs, ) @@ -505,39 +505,6 @@ def posterior( return posterior - # TODO: remove when this gets exposed in gpytorch - def _get_joint_covariance(self, inputs): - """ - Internal method to expose the joint test train covariance. - """ - - from gpytorch.models import ExactGP - from gpytorch.utils.broadcasting import _mul_broadcast_shape - - train_inputs = self.train_inputs - # Concatenate the input to the training input - full_inputs = [] - batch_shape = train_inputs[0].shape[:-2] - for train_input, input in zip(train_inputs, inputs): - # Make sure the batch shapes agree for training/test data - # This seems to be deprecated - # if batch_shape != train_input.shape[:-2]: - # batch_shape = _mul_broadcast_shape( - # batch_shape, train_input.shape[:-2] - # ) - # train_input = train_input.expand( - # *batch_shape, *train_input.shape[-2:] - # ) - if batch_shape != input.shape[:-2]: - batch_shape = _mul_broadcast_shape(batch_shape, input.shape[:-2]) - train_input = train_input.expand(*batch_shape, *train_input.shape[-2:]) - input = input.expand(*batch_shape, *input.shape[-2:]) - full_inputs.append(torch.cat([train_input, input], dim=-2)) - - # Get the joint distribution for training/test data - full_output = super(ExactGP, self).__call__(*full_inputs) - return full_output.lazy_covariance_matrix - def make_posterior_variances(self, joint_covariance_matrix: LazyTensor) -> Tensor: r""" Computes the posterior variances given the data points X. As currently diff --git a/botorch/posteriors/higher_order.py b/botorch/posteriors/higher_order.py index 42e45e5dc4..eac57d1c6b 100644 --- a/botorch/posteriors/higher_order.py +++ b/botorch/posteriors/higher_order.py @@ -21,9 +21,12 @@ class HigherOrderGPPosterior(GPyTorchPosterior): HOGP is a tensorized GP model so the posterior covariance grows to be extremely large, but is highly structured, which means that we can exploit Kronecker identities to sample from the posterior using Matheron's rule as described in - [Doucet2010sampl]_. In general, this posterior should ONLY be used for HOGP models + [Doucet2010sampl]_. + + In general, this posterior should ONLY be used for HOGP models that have highly structured covariances. It should also only be used internally when - called from the HigherOrderGP.posterior(...) method. + called from the HigherOrderGP.posterior(...) method. At this time, the posterior + does not support gradients with respect to the training data. """ def __init__( @@ -168,7 +171,8 @@ def rsample( # base samples now have trailing sample dimension covariance_matrix = self.joint_covariance_matrix covar_root = covariance_matrix.root_decomposition().root - samples = covar_root.matmul(base_samples) + + samples = covar_root.matmul(base_samples[..., : covar_root.shape[-1], :]) # now pluck out Y_x and X_x noiseless_train_marginal_samples = samples[ diff --git a/test/models/test_higher_order_gp.py b/test/models/test_higher_order_gp.py index deaee2ed70..805e2ddc7d 100644 --- a/test/models/test_higher_order_gp.py +++ b/test/models/test_higher_order_gp.py @@ -19,7 +19,7 @@ from gpytorch.kernels import RBFKernel from gpytorch.likelihoods import GaussianLikelihood from gpytorch.mlls import ExactMarginalLogLikelihood -from gpytorch.settings import skip_posterior_variances +from gpytorch.settings import skip_posterior_variances, max_cholesky_size class TestHigherOrderGP(BotorchTestCase): @@ -77,28 +77,29 @@ def test_num_output_dims(self): def test_posterior(self): for dtype in [torch.float, torch.double]: - torch.random.manual_seed(0) - test_x = torch.rand(2, 30, 1).to(device=self.device, dtype=dtype) - - self.model.to(dtype) - if dtype == torch.double: - # need to clear float caches - self.model.train() - self.model.eval() - # test the posterior works - posterior = self.model.posterior(test_x) - self.assertIsInstance(posterior, GPyTorchPosterior) - - # test the posterior works with observation noise - posterior = self.model.posterior(test_x, observation_noise=True) - self.assertIsInstance(posterior, GPyTorchPosterior) - - # test the posterior works with no variances - # some funkiness in MVNs registration so the variance is non-zero. - with skip_posterior_variances(): - posterior = self.model.posterior(test_x) - self.assertIsInstance(posterior, GPyTorchPosterior) - self.assertLessEqual(posterior.variance.max(), 1e-6) + for mcs in [800, 10]: + torch.random.manual_seed(0) + with max_cholesky_size(mcs): + test_x = torch.rand(2, 12, 1).to(device=self.device, dtype=dtype) + + self.model.to(dtype) + # clear caches + self.model.train() + self.model.eval() + # test the posterior works + posterior = self.model.posterior(test_x) + self.assertIsInstance(posterior, GPyTorchPosterior) + + # test the posterior works with observation noise + posterior = self.model.posterior(test_x, observation_noise=True) + self.assertIsInstance(posterior, GPyTorchPosterior) + + # test the posterior works with no variances + # some funkiness in MVNs registration so the variance is non-zero. + with skip_posterior_variances(): + posterior = self.model.posterior(test_x) + self.assertIsInstance(posterior, GPyTorchPosterior) + self.assertLessEqual(posterior.variance.max(), 1e-6) def test_transforms(self): for dtype in [torch.float, torch.double]: