diff --git a/botorch/cross_validation.py b/botorch/cross_validation.py index 9a0bfd29ca..6bf8631961 100644 --- a/botorch/cross_validation.py +++ b/botorch/cross_validation.py @@ -18,7 +18,11 @@ from botorch.models.gpytorch import GPyTorchModel from botorch.models.multitask import MultiTaskGP from botorch.posteriors.gpytorch import GPyTorchPosterior +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.likelihoods import FixedNoiseGaussianLikelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood +from linear_operator.operators import DiagLinearOperator +from linear_operator.utils.cholesky import psd_safe_cholesky from torch import Tensor @@ -32,6 +36,22 @@ class CVFolds(NamedTuple): class CVResults(NamedTuple): + """Results from cross-validation. + + This named tuple contains the cross-validation predictions and observed values. + For both ``batch_cross_validation`` and ``efficient_loo_cv``, the ``posterior`` + field contains the predictive distribution with mean and variance accessible + via ``posterior.mean`` and ``posterior.variance``. + + For ``batch_cross_validation``, the posterior has shape ``n x 1 x m`` where n + is the number of folds, 1 is the single held-out point per fold, and m is the + number of outputs. + + For ``efficient_loo_cv``, the posterior has the same shape structure to maintain + consistency, though the underlying distribution is constructed from the + efficient LOO formulas rather than from separate model fits. + """ + model: GPyTorchModel posterior: GPyTorchPosterior observed_Y: Tensor @@ -125,7 +145,11 @@ def batch_cross_validation( model_cls: A GPyTorchModel class. This class must initialize the likelihood internally. Note: Multi-task GPs are not currently supported. mll_cls: A MarginalLogLikelihood class. - cv_folds: A CVFolds tuple. + cv_folds: A CVFolds tuple. For LOO-CV with n training points, the leading + dimension of size n represents the n folds (batch dimension), e.g., + ``cv_folds.train_X`` has shape ``n x (n-1) x d`` and ``cv_folds.test_X`` + has shape ``n x 1 x d``. This batch structure enables fitting n + independent GPs simultaneously. fit_args: Arguments passed along to fit_gpytorch_mll. model_init_kwargs: Keyword arguments passed to the model constructor. @@ -194,3 +218,347 @@ def batch_cross_validation( observed_Y=cv_folds.test_Y, observed_Yvar=cv_folds.test_Yvar, ) + + +def efficient_loo_cv( + model: GPyTorchModel, + observation_noise: bool = True, +) -> CVResults: + r"""Compute efficient Leave-One-Out cross-validation for a GP model. + + NOTE: This function does not refit the model to each LOO fold, in contrast to + batch_cross_validation. This is a memory- and compute-efficient way to compute LOO, + but it does not account for potential changes in the model parameters due to the + removal of a single observation. This is typically ok in cases with a lot of data, + but can results in substantial differences (typically over-estimating performance) + in the low data regime. + + This function leverages a well-known linear algebraic identity to compute + all LOO predictive distributions in O(n^3) time, compared to the naive + approach which requires O(n^4) time (O(n^3) per fold for n folds). + + The efficient LOO formulas for GPs are: + + .. math:: + + \mu_{LOO,i} = y_i - \frac{[K^{-1}(y - \mu)]_i}{[K^{-1}]_{ii}} + + \sigma^2_{LOO,i} = \frac{1}{[K^{-1}]_{ii}} + + where K is the covariance matrix including observation noise. This gives + the posterior predictive variance (including noise). To get the posterior + variance (excluding noise), we subtract the observation noise: + + .. math:: + + \sigma^2_{posterior,i} = \sigma^2_{LOO,i} - \sigma^2_{noise} + + NOTE: This function assumes the model has already been fitted and that the + model's `forward` method returns a `MultivariateNormal` distribution. + + Args: + model: A fitted GPyTorchModel whose `forward` method returns a + `MultivariateNormal` distribution. + observation_noise: If True (default), return the posterior + predictive variance (including observation noise). If False, + return the posterior variance of the latent function (excluding + observation noise). + + Returns: + CVResults: A named tuple containing: + - model: The fitted GP model. + - posterior: A GPyTorchPosterior with the LOO predictive distributions. + The posterior mean and variance have shape ``n x 1 x m`` or + ``batch_shape x n x 1 x m``, matching the structure of + ``batch_cross_validation`` (n folds, 1 held-out point per fold, + m outputs). The underlying distribution has diagonal covariance + since LOO predictions at different held-out points are computed + independently. + - observed_Y: The observed Y values with shape ``n x 1 x m`` or + ``batch_shape x n x 1 x m``. + - observed_Yvar: The observed noise variances (if provided) with shape + ``n x 1 x m`` or ``batch_shape x n x 1 x m``. + + Example: + >>> import torch + >>> from botorch.cross_validation import efficient_loo_cv + >>> from botorch.models import SingleTaskGP + >>> from botorch.fit import fit_gpytorch_mll + >>> from gpytorch.mlls import ExactMarginalLogLikelihood + >>> + >>> train_X = torch.rand(20, 2, dtype=torch.float64) + >>> train_Y = torch.sin(train_X).sum(dim=-1, keepdim=True) + >>> model = SingleTaskGP(train_X, train_Y) + >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) + >>> fit_gpytorch_mll(mll) + >>> loo_results = efficient_loo_cv(model) + >>> loo_results.posterior.mean.shape + torch.Size([20, 1, 1]) + """ + # Compute raw LOO predictions + loo_mean, loo_variance, train_Y = _compute_loo_predictions( + model, observation_noise=observation_noise + ) + + # Get the number of outputs + num_outputs = model.num_outputs + + # Build the posterior from raw LOO predictions + posterior = _build_loo_posterior( + loo_mean=loo_mean, loo_variance=loo_variance, num_outputs=num_outputs + ) + + # Reshape observed data to LOO CV output format: n x 1 x m + observed_Y = _reshape_to_loo_cv_format(train_Y, num_outputs) + + # Get observed Yvar if available (for fixed noise models) + observed_Yvar = None + if isinstance(model.likelihood, FixedNoiseGaussianLikelihood): + observed_Yvar = _reshape_to_loo_cv_format(model.likelihood.noise, num_outputs) + + return CVResults( + model=model, + posterior=posterior, + observed_Y=observed_Y, + observed_Yvar=observed_Yvar, + ) + + +def _subtract_observation_noise(model: GPyTorchModel, loo_variance: Tensor) -> Tensor: + r"""Subtract observation noise from LOO variance to get posterior variance. + + The efficient LOO formula computes the posterior predictive variance, which + includes observation noise. To get the posterior variance of the latent + function (without noise), we subtract the observation noise variance. + + This implementation uses the likelihood's ``forward`` method to extract noise + variances in a general way. The ``forward`` method takes function samples and + returns a distribution where the variance represents the observation noise. + + .. math:: + + \sigma^2_{posterior,i} = \sigma^2_{LOO,i} - \sigma^2_{noise} + + Args: + model: The GP model with a likelihood containing the noise variance. + loo_variance: The LOO posterior predictive variance with shape + ``... x n x 1``. + + Returns: + The posterior variance (without noise) with the same shape. + """ + likelihood = model.likelihood + + # Use the likelihood's forward method to extract noise variances. + # By passing zeros as function samples, the returned distribution's + # variance gives us the observation noise at each point. + noise_shape = loo_variance.shape[:-1] # ... x n + zeros = torch.zeros( + noise_shape, dtype=loo_variance.dtype, device=loo_variance.device + ) + + # Some likelihoods (e.g., SparseOutlierGaussianLikelihood) require training + # inputs to be passed to correctly compute the noise. We pass the model's + # train_inputs if available. + train_inputs = getattr(model, "train_inputs", None) + + # Call forward to get the observation noise distribution. + # We pass train_inputs as a positional argument so it flows through *params + # to the noise model, which is compatible with both standard Noise classes + # (that use *params) and SparseOutlierNoise (that uses X as the first arg). + noise_dist = likelihood.forward(zeros, train_inputs) + + # Extract noise variance and reshape to match loo_variance + noise = noise_dist.variance.unsqueeze(-1) # ... x n x 1 + + loo_variance = loo_variance - noise + + # Clamp to ensure non-negative variance + return loo_variance.clamp(min=0.0) + + +def _compute_loo_predictions( + model: GPyTorchModel, + observation_noise: bool = True, +) -> tuple[Tensor, Tensor, Tensor]: + r"""Compute raw LOO predictions (means and variances) for a GP model. + + This is an internal helper that computes the leave-one-out predictive means + and variances using efficient matrix algebra. The formulas are: + + .. math:: + + \mu_{LOO,i} = y_i - \frac{[K^{-1}(y - \mu)]_i}{[K^{-1}]_{ii}} + + \sigma^2_{LOO,i} = \frac{1}{[K^{-1}]_{ii}} + + where K is the covariance matrix including observation noise and μ is the + prior mean. This gives the posterior predictive variance (including noise). + To get the posterior variance (excluding noise), we subtract the observation + noise variance. + + Args: + model: A fitted GPyTorchModel in eval mode whose `forward` method returns + a `MultivariateNormal` distribution. + observation_noise: If True (default), return the posterior + predictive variance (including observation noise). If False, + return the posterior variance of the latent function (excluding + observation noise). + + Returns: + A tuple of (loo_mean, loo_variance, train_Y) where: + - loo_mean: LOO predictive means with shape `... x n x 1` + - loo_variance: LOO predictive variances with shape `... x n x 1` + - train_Y: The training targets from the model + + Raises: + UnsupportedError: If the model doesn't have required attributes or + the forward method doesn't return a MultivariateNormal. + """ + # Get training data - model should have train_inputs attribute + if not hasattr(model, "train_inputs") or model.train_inputs is None: + raise UnsupportedError( + "Model must have train_inputs attribute for efficient LOO CV." + ) + if not hasattr(model, "train_targets") or model.train_targets is None: + raise UnsupportedError( + "Model must have train_targets attribute for efficient LOO CV." + ) + + train_X = model.train_inputs[0] # Shape: n x d or batch_shape x n x d + + # Check for models with auxiliary inputs (e.g., auxiliary experiment data) + # In such models, train_inputs[0] is a tuple of tensors rather than a single tensor + if isinstance(train_X, tuple): + raise UnsupportedError( + "Efficient LOO CV is not supported for models with auxiliary inputs. " + "train_inputs[0] is a tuple of tensors, indicating auxiliary data." + ) + + train_Y = model.train_targets # Shape: n or batch_shape x n (batched outputs) + + n = train_X.shape[-2] + prior_dist = model.forward(train_X) + + # Check that we got a MultivariateNormal + if not isinstance(prior_dist, MultivariateNormal): + raise UnsupportedError( + f"Model's forward method must return a MultivariateNormal, " + f"got {type(prior_dist).__name__}." + ) + + # Extract mean from the prior + # Shape: n for single-output, or m x n for batched multi-output + mean = prior_dist.mean + + # Add observation noise to the diagonal via the likelihood + # The likelihood adds noise: K_noisy = K + sigma^2 * I + # Some likelihoods (e.g., SparseOutlierGaussianLikelihood) require training + # inputs to be passed to correctly apply the noise model. We pass them as + # a positional argument for compatibility with both standard likelihoods + # and SparseOutlierGaussianLikelihood. + train_inputs = model.train_inputs + noisy_mvn = model.likelihood(prior_dist, train_inputs) + + # Get the covariance matrix - use lazy representation for potential caching + K_noisy = noisy_mvn.lazy_covariance_matrix.to_dense() + + # Compute Cholesky decomposition (adds jitter if needed) + L = psd_safe_cholesky(K_noisy) + + # Compute K^{-1}(y - mean) via Cholesky solve + # Shape: ... x n x 1 where ... is batch_shape (includes m for multi-output) + residuals = (train_Y - mean).unsqueeze(-1) + K_inv_residuals = torch.cholesky_solve(residuals, L) + + # Compute diagonal of K^{-1} + # K_inv = L^{-T} @ L^{-1}, so K_inv_diag[i] = sum_j (L^{-1}[j,i])^2 + identity = torch.eye(n, dtype=L.dtype, device=L.device) + if L.dim() > 2: + identity = identity.expand(*L.shape[:-2], n, n) + L_inv = torch.linalg.solve_triangular(L, identity, upper=False) + K_inv_diag = (L_inv**2).sum(dim=-2) # ... x n + + # Compute LOO predictions using the efficient formulas: + # sigma2_loo_i = 1 / [K^{-1}]_{ii} + # mu_loo_i = y_i - [K^{-1}(y - mean)]_i / [K^{-1}]_{ii} + # K_inv_diag has shape ... x n, so after unsqueeze(-1) we get ... x n x 1 + # (the last dim is 1 because each GP is single-output). + loo_variance = (1.0 / K_inv_diag).unsqueeze(-1) # ... x n x 1 + loo_mean = train_Y.unsqueeze(-1) - K_inv_residuals * loo_variance # ... x n x 1 + + # If we want the posterior (noiseless) variance, subtract the noise + if not observation_noise: + loo_variance = _subtract_observation_noise(model, loo_variance) + + return loo_mean, loo_variance, train_Y + + +def _build_loo_posterior( + loo_mean: Tensor, + loo_variance: Tensor, + num_outputs: int, +) -> GPyTorchPosterior: + r"""Build a GPyTorchPosterior from raw LOO predictions. + + Args: + loo_mean: LOO means with shape `... x m x n x 1` (multi-output) or + `... x n x 1` (single-output), where `...` is optional batch_shape. + loo_variance: LOO variances with same shape as loo_mean. + num_outputs: Number of outputs (m). 1 for single-output models. + + Returns: + A GPyTorchPosterior with shape `... x n x 1 x m`. + """ + # Reshape tensors to final shape: ... x n x 1 x m + if num_outputs > 1: + # Multi-output: ... x m x n x 1 -> ... x n x 1 x m + # The m dimension is at position -3, move it to position -1 + loo_mean = loo_mean.movedim(-3, -1) + loo_variance = loo_variance.movedim(-3, -1) + else: + # Single-output: ... x n x 1 -> ... x n x 1 x 1 + loo_mean = loo_mean.unsqueeze(-1) + loo_variance = loo_variance.unsqueeze(-1) + + # Create distribution: for multi-output use MTMVN, for single-output use MVN. + # Both require mean shape ... x n x q (where q=1) and diagonal covariance. + # We squeeze the m dimension to get ... x n x 1 for the MVN mean, then + # iterate over outputs to create independent MVNs. + mvns = [ + MultivariateNormal( + mean=loo_mean[..., t], + covariance_matrix=DiagLinearOperator(loo_variance[..., t]), + ) + for t in range(num_outputs) + ] + + if num_outputs > 1: + mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) + else: + mvn = mvns[0] + + return GPyTorchPosterior(distribution=mvn) + + +def _reshape_to_loo_cv_format(tensor: Tensor, num_outputs: int) -> Tensor: + r"""Reshape a tensor to the standard LOO CV output format: ``n x 1 x m``. + + This helper converts tensors with internal model format (which varies by + number of outputs) to the consistent output format used by LOO CV results. + + Args: + tensor: Input tensor with shape: + - Single-output: ``n`` (1D) + - Multi-output: ``m x n`` (2D) + num_outputs: Number of outputs (m). 1 for single-output models. + + Returns: + Reshaped tensor with shape ``n x 1 x m``. + """ + if num_outputs > 1: + # Multi-output: m x n -> n x m -> n x 1 x m + return tensor.movedim(-2, -1).unsqueeze(-2) + else: + # Single-output: n -> n x 1 -> n x 1 x 1 + return tensor.unsqueeze(-1).unsqueeze(-1) diff --git a/test/models/test_relevance_pursuit.py b/test/models/test_relevance_pursuit.py index 082d6a911e..6a3ad328cf 100644 --- a/test/models/test_relevance_pursuit.py +++ b/test/models/test_relevance_pursuit.py @@ -14,6 +14,7 @@ import gpytorch import torch +from botorch.cross_validation import efficient_loo_cv from botorch.exceptions.errors import UnsupportedError from botorch.exceptions.warnings import InputDataWarning from botorch.fit import fit_gpytorch_mll @@ -29,6 +30,7 @@ get_posterior_over_support, RelevancePursuitMixin, ) + from botorch.models.robust_relevance_pursuit_model import ( FRACTIONS_OF_OUTLIERS, RobustRelevancePursuitSingleTaskGP, @@ -732,3 +734,52 @@ def test_experimental_utils(self) -> None: # after num_seeds has been exhausted, the evaluation will error. with self.assertRaises(StopIteration): f(X) + + def test_efficient_loo_cv(self) -> None: + """Test that efficient_loo_cv works with RobustRelevancePursuitSingleTaskGP.""" + n = 10 + train_X = torch.rand(n, 2, dtype=torch.float64) + train_Y = ( + torch.sin(train_X[:, :1]) + torch.randn(n, 1, dtype=torch.float64) * 0.1 + ) + model = RobustRelevancePursuitSingleTaskGP(train_X, train_Y) + + # Test both observation_noise=True and observation_noise=False + prev_variance = None + for observation_noise in [True, False]: + with self.subTest(observation_noise=observation_noise): + # Run efficient LOO CV and check that no warnings are raised + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + loo_results = efficient_loo_cv( + model, observation_noise=observation_noise + ) + self.assertEqual( + caught_warnings, + [], + "Unexpected warnings raised: " + f"{[str(w.message) for w in caught_warnings]}", + ) + + # Check shapes - posterior has shape n x 1 x m + self.assertEqual( + loo_results.posterior.mean.shape, torch.Size([n, 1, 1]) + ) + self.assertEqual( + loo_results.posterior.variance.shape, torch.Size([n, 1, 1]) + ) + self.assertEqual(loo_results.observed_Y.shape, torch.Size([n, 1, 1])) + + # Check that variances are positive + self.assertTrue((loo_results.posterior.variance > 0).all()) + + # Check that model is returned + self.assertIs(loo_results.model, model) + + # When observation_noise=False, variance should be smaller + # (noise subtracted from predictive variance) + if prev_variance is not None: + self.assertTrue( + (loo_results.posterior.variance < prev_variance).all() + ) + prev_variance = loo_results.posterior.variance diff --git a/test/test_cross_validation.py b/test/test_cross_validation.py index 9aea1a70bb..2b19b3e7ef 100644 --- a/test/test_cross_validation.py +++ b/test/test_cross_validation.py @@ -8,12 +8,17 @@ import warnings import torch -from botorch.cross_validation import batch_cross_validation, gen_loo_cv_folds +from botorch.cross_validation import ( + batch_cross_validation, + efficient_loo_cv, + gen_loo_cv_folds, +) from botorch.exceptions.errors import UnsupportedError from botorch.exceptions.warnings import OptimizationWarning from botorch.models.gp_regression import SingleTaskGP from botorch.models.multitask import MultiTaskGP from botorch.models.transforms.input import Normalize +from botorch.models.transforms.outcome import Standardize from botorch.utils.testing import BotorchTestCase, get_random_data from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood @@ -119,3 +124,295 @@ def test_mtgp(self): cv_folds=cv_folds, fit_args={"optimizer_kwargs": {"options": {"maxiter": 1}}}, ) + + +class TestEfficientLOOCV(BotorchTestCase): + def test_basic(self) -> None: + """Test efficient LOO CV with various data configurations. + + This test covers: + - Single and multiple outputs (m=1 and m>1) + - With and without batch dimensions + """ + n = 10 + tkwargs = {"device": self.device, "dtype": torch.double} + + for m, batch_shape in itertools.product( + (1, 3), # single and multi-output + (torch.Size(), torch.Size([2])), # no batch and with batch + ): + with self.subTest(m=m, batch_shape=batch_shape): + train_X, train_Y = get_random_data( + batch_shape=batch_shape, m=m, n=n, **tkwargs + ) + + model = SingleTaskGP(train_X, train_Y) + model.eval() + + loo_results = efficient_loo_cv(model) + + # Output shape: batch_shape x n x 1 x m + expected_shape = batch_shape + torch.Size([n, 1, m]) + self.assertEqual(loo_results.posterior.mean.shape, expected_shape) + self.assertEqual(loo_results.posterior.variance.shape, expected_shape) + self.assertEqual(loo_results.observed_Y.shape, expected_shape) + self.assertTrue((loo_results.posterior.variance > 0).all()) + + def test_matches_naive(self) -> None: + """Test that efficient LOO CV matches naive LOO CV.""" + tkwargs = {"device": self.device, "dtype": torch.double} + n, d = 6, 2 + + for ( + m, + batch_shape, + use_transforms, + use_fixed_noise, + obs_noise, + ) in itertools.product( + (1, 3), # single and multi-output + (torch.Size(), torch.Size([2])), # no batch and with batch + (False, True), # transforms + (False, True), # fixed noise + (False, True), # observation noise + ): + # Skip transforms with batch dimensions - Standardize requires + # matching batch_shape argument which complicates the test setup. + # The core functionality is tested without transforms. + if batch_shape and use_transforms: + continue + + with self.subTest( + m=m, + batch_shape=batch_shape, + transforms=use_transforms, + fixed_noise=use_fixed_noise, + obs_noise=obs_noise, + ): + train_X, train_Y = get_random_data( + batch_shape=batch_shape, m=m, n=n, d=d, **tkwargs + ) + + # Build model kwargs with optional transforms + model_kwargs = {} + if use_transforms: + model_kwargs["input_transform"] = Normalize(d=d) + model_kwargs["outcome_transform"] = Standardize(m=m) + else: + model_kwargs["outcome_transform"] = None + + train_Yvar = torch.full_like(train_Y, 5e-3) if use_fixed_noise else None + if use_fixed_noise: + model = SingleTaskGP(train_X, train_Y, train_Yvar, **model_kwargs) + else: + model = SingleTaskGP(train_X, train_Y, **model_kwargs) + + # Put into eval mode, simulating a post-fit model + model.eval() + + # Compare efficient vs naive + loo_results = efficient_loo_cv(model, observation_noise=obs_noise) + naive_mean, naive_var = naive_loo_cv( + model, observation_noise=obs_noise, batch_shape=batch_shape + ) + + loo_mean = loo_results.posterior.mean.squeeze(-2) + loo_var = loo_results.posterior.variance.squeeze(-2) + self.assertAllClose(loo_mean, naive_mean, rtol=1e-6, atol=1e-6) + self.assertAllClose(loo_var, naive_var, rtol=1e-6, atol=1e-6) + + # Verify observed_Y and observed_Yvar shapes + expected_shape = batch_shape + torch.Size([n, 1, m]) + self.assertEqual(loo_results.observed_Y.shape, expected_shape) + + if use_fixed_noise: + self.assertIsNotNone(loo_results.observed_Yvar) + self.assertEqual( + loo_results.observed_Yvar.shape, + expected_shape, + f"observed_Yvar shape mismatch: got " + f"{loo_results.observed_Yvar.shape}, expected {expected_shape}", + ) + else: + self.assertIsNone(loo_results.observed_Yvar) + + def test_error_handling(self) -> None: + """Test error cases for efficient_loo_cv.""" + + # Test 1: Model without train_inputs + class MockModelNoInputs: + train_inputs = None + train_targets = None + training = False + + def eval(self): + self.training = False + return self + + model_no_inputs = MockModelNoInputs() + with self.assertRaisesRegex( + UnsupportedError, "Model must have train_inputs attribute" + ): + efficient_loo_cv(model_no_inputs) + + # Test 2: Model without train_targets + class MockModelNoTargets: + def __init__(self, train_X: torch.Tensor) -> None: + self.train_inputs = (train_X,) + self.train_targets = None + self.training = False + + def eval(self): + self.training = False + return self + + train_X = torch.rand(10, 2, device=self.device) + model_no_targets = MockModelNoTargets(train_X) + with self.assertRaisesRegex( + UnsupportedError, "Model must have train_targets attribute" + ): + efficient_loo_cv(model_no_targets) + + # Test 3: Model's forward doesn't return MultivariateNormal + class MockModelBadForward: + def __init__(self, train_X: torch.Tensor, train_Y: torch.Tensor) -> None: + self.train_inputs = (train_X,) + self.train_targets = train_Y.squeeze(-1) + self.training = False + self.input_transform = None + + def eval(self): + self.training = False + return self + + def train(self, mode: bool = True): + self.training = mode + return self + + def forward(self, x: torch.Tensor): + return x.mean() + + train_Y = torch.rand(10, 1, device=self.device) + model_bad_forward = MockModelBadForward(train_X, train_Y) + with self.assertRaisesRegex( + UnsupportedError, "Model's forward method must return a MultivariateNormal" + ): + efficient_loo_cv(model_bad_forward) + + # Test 4: Model with auxiliary inputs (tuple train_inputs) + model = SingleTaskGP(train_X, train_Y) + model.eval() + # Mock train_inputs[0] to be a tuple (simulating auxiliary inputs) + model.train_inputs = ((train_X, train_X),) + with self.assertRaisesRegex( + UnsupportedError, "not supported for models with auxiliary inputs" + ): + efficient_loo_cv(model) + + +_EMPTY_BATCH_SHAPE: torch.Size = torch.Size() + + +def naive_loo_cv( + fitted_model: SingleTaskGP, + observation_noise: bool = True, + batch_shape: torch.Size = _EMPTY_BATCH_SHAPE, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute naive LOO CV by creating a model for each fold. + + This O(n^4) implementation creates a model for each left-out point, + copying hyperparameters from fitted_model (no refitting). + + Args: + fitted_model: A fitted GP model whose hyperparameters will be copied. + observation_noise: If True, include observation noise in the posterior + variance. For fixed noise models, uses the held-out point's noise. + batch_shape: The batch shape of the data. For batched models, LOO is + computed independently for each batch element. + + Returns: + A tuple of (loo_means, loo_variances) with shape `batch_shape x n x m`. + """ + from gpytorch.likelihoods import FixedNoiseGaussianLikelihood + + fitted_model.eval() + train_X = fitted_model.train_inputs[0] + train_Y = fitted_model.train_targets + n = train_X.shape[-2] + m = fitted_model.num_outputs + has_fixed_noise = isinstance(fitted_model.likelihood, FixedNoiseGaussianLikelihood) + + # Normalize train_X: for multi-output models, BoTorch internally stores X as + # [batch x] m x n x d (replicated per output) because it uses batched GPs. + # See BatchedMultiOutputGPyTorchModel.__init__ which calls: + # train_X = train_X.unsqueeze(-3).expand(..., self._num_outputs, ...) + # We extract the canonical X by selecting the first output's X. + if m > 1: + train_X = train_X.select(-3, 0) # Remove output dim: [batch x] n x d + + # Normalize train_Y from internal format to SingleTaskGP input format: + # - Multi-output: [batch x] m x n -> [batch x] n x m + # - Single-output: [batch x] n -> [batch x] n x 1 + if m > 1: + train_Y = train_Y.movedim(-2, -1) + else: + train_Y = train_Y.unsqueeze(-1) + + # Normalize noise similarly if present + if has_fixed_noise: + noise = fitted_model.likelihood.noise + if m > 1: + noise = noise.movedim(-2, -1) # [batch x] m x n -> [batch x] n x m + else: + noise = noise.unsqueeze(-1) # [batch x] n -> [batch x] n x 1 + else: + noise = None + + # Output shape: batch_shape x n x m + output_shape = batch_shape + torch.Size([n, m]) + loo_means = torch.zeros(output_shape, dtype=train_X.dtype, device=train_X.device) + loo_vars = torch.zeros(output_shape, dtype=train_X.dtype, device=train_X.device) + + for i in range(n): + # Create mask excluding point i + mask = torch.arange(n, device=train_X.device) != i + + # Extract fold data - ellipsis handles any batch dimensions + # train_X: [batch x] n x d -> [batch x] (n-1) x d + # train_Y: [batch x] n x m -> [batch x] (n-1) x m + fold_X = train_X[..., mask, :] + fold_Y = train_Y[..., mask, :] + test_X = train_X[..., i : i + 1, :] + + # Create fold model + kwargs = {"outcome_transform": None, "input_transform": None} + if has_fixed_noise: + fold_noise = noise[..., mask, :] + model = SingleTaskGP(fold_X, fold_Y, fold_noise, **kwargs) + else: + model = SingleTaskGP(fold_X, fold_Y, **kwargs) + + # Copy matching hyperparameters + fitted_state = fitted_model.state_dict() + fold_state = model.state_dict() + for name, param in fitted_state.items(): + if has_fixed_noise and "noise" in name.lower(): + continue + if name in fold_state and fold_state[name].shape == param.shape: + fold_state[name] = param + model.load_state_dict(fold_state) + model.eval() + + # Get posterior prediction + with torch.no_grad(): + if has_fixed_noise and observation_noise: + held_out_noise = noise[..., i : i + 1, :] + posterior = model.posterior(test_X, observation_noise=held_out_noise) + else: + posterior = model.posterior(test_X, observation_noise=observation_noise) + + # posterior.mean/variance: [batch x] 1 x m -> [batch x] m + loo_means[..., i, :] = posterior.mean.squeeze(-2) + loo_vars[..., i, :] = posterior.variance.squeeze(-2) + + return loo_means, loo_vars