In [1]:
#!/usr/bin/env python3

import unittest
from unittest.mock import MagicMock, patch

import torch

import linear_operator
from linear_operator.operators import DenseLinearOperator
from linear_operator.test.base_test_case import BaseTestCase


class TestInvQuadLogDetNonBatch(BaseTestCase, unittest.TestCase):
    seed = 0
    matrix_shape = torch.Size((50, 50))

    def _test_inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, improper_logdet=False, add_diag=False):
        # Set up
        x = torch.randn(*self.__class__.matrix_shape[:-1], 3)
        ls = torch.tensor(2.0).requires_grad_(True)
        ls_clone = torch.tensor(2.0).requires_grad_(True)
        mat = (x[..., :, None, :] - x[..., None, :, :]).pow(2.0).sum(dim=-1).mul(-0.5 * ls).exp()
        mat_clone = (x[..., :, None, :] - x[..., None, :, :]).pow(2.0).sum(dim=-1).mul(-0.5 * ls_clone).exp()

        if inv_quad_rhs is not None:
            inv_quad_rhs.requires_grad_(True)
            inv_quad_rhs_clone = inv_quad_rhs.detach().clone().requires_grad_(True)

        mat_clone_with_diag = mat_clone
        if add_diag:
            mat_clone_with_diag = mat_clone_with_diag + torch.eye(mat_clone.size(-1))

        if inv_quad_rhs is not None:
            actual_inv_quad = mat_clone_with_diag.inverse().matmul(inv_quad_rhs_clone).mul(inv_quad_rhs_clone)
            actual_inv_quad = actual_inv_quad.sum([-1, -2]) if inv_quad_rhs.dim() >= 2 else actual_inv_quad.sum()
        if logdet:
            flattened_tensor = mat_clone_with_diag.view(-1, *mat_clone.shape[-2:])
            logdets = torch.cat([mat.logdet().unsqueeze(0) for mat in flattened_tensor])
            if mat_clone.dim() > 2:
                actual_logdet = logdets.view(*mat_clone.shape[:-2])
            else:
                actual_logdet = logdets.squeeze()

        # Compute values with LinearOperator
        _wrapped_cg = MagicMock(wraps=linear_operator.utils.linear_cg)
        with linear_operator.settings.num_trace_samples(2000), linear_operator.settings.max_cholesky_size(
            0
        ), linear_operator.settings.cg_tolerance(1e-5), linear_operator.settings.skip_logdet_forward(
            improper_logdet
        ), patch(
            "linear_operator.utils.linear_cg", new=_wrapped_cg
        ) as linear_cg_mock, linear_operator.settings.min_preconditioning_size(
            0
        ), linear_operator.settings.max_preconditioner_size(
            30
        ):
            linear_op = DenseLinearOperator(mat)

            if add_diag:
                linear_op = linear_op.add_jitter(1.0)

            res_inv_quad, res_logdet = linear_operator.inv_quad_logdet(
                linear_op, inv_quad_rhs=inv_quad_rhs, logdet=logdet
            )

        # Compare forward pass
        if inv_quad_rhs is not None:
            self.assertAllClose(res_inv_quad, actual_inv_quad, rtol=1e-2)
        if logdet and not improper_logdet:
            self.assertAllClose(res_logdet, actual_logdet, rtol=1e-1, atol=2e-1)

        # Backward
        if inv_quad_rhs is not None:
            actual_inv_quad.sum().backward(retain_graph=True)
            res_inv_quad.sum().backward(retain_graph=True)
        if logdet:
            actual_logdet.sum().backward()
            res_logdet.sum().backward()

        self.assertAllClose(ls.grad, ls_clone.grad, rtol=1e-2, atol=1e-2)
        if inv_quad_rhs is not None:
            self.assertAllClose(inv_quad_rhs.grad, inv_quad_rhs_clone.grad, rtol=2e-2, atol=1e-2)

        # Make sure CG was called
        self.assertTrue(linear_cg_mock.called)

    def test_inv_quad_logdet_vector(self):
        rhs = torch.randn(self.matrix_shape[-1])
        self._test_inv_quad_logdet(inv_quad_rhs=rhs, logdet=True)

    def test_precond_inv_quad_logdet_vector(self):
        rhs = torch.randn(self.matrix_shape[-1])
        self._test_inv_quad_logdet(inv_quad_rhs=rhs, logdet=True, add_diag=True)

    def test_inv_quad_only_vector(self):
        rhs = torch.randn(self.matrix_shape[-1])
        self._test_inv_quad_logdet(inv_quad_rhs=rhs, logdet=False)

    def test_precond_inv_quad_only_vector(self):
        rhs = torch.randn(self.matrix_shape[-1])
        self._test_inv_quad_logdet(inv_quad_rhs=rhs, logdet=False, add_diag=True)

    def test_inv_quad_logdet_many_vectors(self):
        rhs = torch.randn(*self.matrix_shape[:-1], 5)
        self._test_inv_quad_logdet(inv_quad_rhs=rhs, logdet=True)

    def test_precond_inv_quad_logdet_many_vectors(self):
        rhs = torch.randn(*self.matrix_shape[:-1], 5)
        self._test_inv_quad_logdet(inv_quad_rhs=rhs, logdet=True, add_diag=True)

    def test_inv_quad_logdet_many_vectors_improper(self):
        rhs = torch.randn(*self.matrix_shape[:-1], 5)
        self._test_inv_quad_logdet(inv_quad_rhs=rhs, logdet=True, improper_logdet=True)

    def test_precond_inv_quad_logdet_many_vectors_improper(self):
        rhs = torch.randn(*self.matrix_shape[:-1], 5)
        self._test_inv_quad_logdet(inv_quad_rhs=rhs, logdet=True, improper_logdet=True, add_diag=True)

    def test_inv_quad_only_many_vectors(self):
        rhs = torch.randn(*self.matrix_shape[:-1], 5)
        self._test_inv_quad_logdet(inv_quad_rhs=rhs, logdet=False)

    def test_precond_inv_quad_only_many_vectors(self):
        rhs = torch.randn(*self.matrix_shape[:-1], 5)
        self._test_inv_quad_logdet(inv_quad_rhs=rhs, logdet=False, add_diag=True)


class TestInvQuadLogDetBatch(TestInvQuadLogDetNonBatch):
    seed = 0
    matrix_shape = torch.Size((3, 50, 50))

    def test_inv_quad_logdet_vector(self):
        pass

    def test_precond_inv_quad_logdet_vector(self):
        pass

    def test_inv_quad_only_vector(self):
        pass

    def test_precond_inv_quad_only_vector(self):
        pass


class TestInvQuadLogDetMultiBatch(TestInvQuadLogDetBatch):
    seed = 0
    matrix_shape = torch.Size((2, 3, 50, 50))


if __name__ == "__main__":
    unittest.main(argv=[""], exit=False)

..............................
----------------------------------------------------------------------
Ran 30 tests in 4.230s

OK


In [5]:
#!/usr/bin/env python3

import unittest

import torch

import linear_operator
from linear_operator.test.base_test_case import BaseTestCase
from linear_operator.utils.cholesky import psd_safe_cholesky
from linear_operator.utils.permutation import apply_permutation, inverse_permutation


def _ensure_symmetric_grad(grad):
    """
    A gradient-hook hack to ensure that symmetric matrix gradients are symmetric
    """
    res = torch.add(grad, grad.mT).mul(0.5)
    return res


class TestPivotedCholesky(BaseTestCase, unittest.TestCase):
    seed = 0

    def _create_mat(self):
        mat = torch.randn(8, 8)
        mat = mat @ mat.mT
        return mat

    def test_pivoted_cholesky(self, max_iter=3):
        mat = self._create_mat().detach().requires_grad_(True)
        mat.register_hook(_ensure_symmetric_grad)
        mat_copy = mat.detach().clone().requires_grad_(True)
        mat_copy.register_hook(_ensure_symmetric_grad)

        # Forward (with function)
        res, pivots = linear_operator.pivoted_cholesky(mat, rank=max_iter, return_pivots=True)

        # Forward (manual pivoting, actual Cholesky)
        inverse_pivots = inverse_permutation(pivots)
        # Apply pivoting
        pivoted_mat_copy = apply_permutation(mat_copy, pivots, pivots)
        # Compute Cholesky
        actual_pivoted = psd_safe_cholesky(pivoted_mat_copy)[..., :max_iter]
        # Undo pivoting
        actual = apply_permutation(actual_pivoted, left_permutation=inverse_pivots)

        self.assertAllClose(res, actual)

        # Backward
        grad_output = torch.randn_like(res)
        res.backward(gradient=grad_output)
        actual.backward(gradient=grad_output)
        self.assertAllClose(mat.grad, mat_copy.grad)


class TestPivotedCholeskyBatch(TestPivotedCholesky, unittest.TestCase):
    seed = 0

    def _create_mat(self):
        mat = torch.randn(2, 3, 8, 8)
        mat = mat @ mat.mT
        return mat


if __name__ == "__main__":
    unittest.main(argv=[""], exit=False)

.....................................
----------------------------------------------------------------------
Ran 37 tests in 4.083s

OK


In [6]:
import torch
import torch.nn as nn

# linear_operator packages
from linear_operator import inv_quad_logdet
from linear_operator.operators import DenseLinearOperator

class IterativeGP(nn.Module):
    """
    A simple illustration of using linear_operator's inv_quad_logdet
    to fit a GP model given a kernel matrix + diagonal noise.
    """
    def __init__(self, kernel, noise=0.1):
        super().__init__()
        # "kernel" can be something from GPyTorch (e.g. RBFKernel),
        # or any callable that can produce an (n x n) covariance matrix
        self.kernel = kernel
        # We make noise a Parameter so that it can be trained via gradient
        self.raw_noise = nn.Parameter(torch.log(torch.tensor(noise, dtype=torch.float32)))

    def noise_value(self):
        # Use softplus or exponent to keep noise strictly positive
        return torch.exp(self.raw_noise)

    def fit(self, train_x, train_y):
        """
        train_x: tensor of shape (n x d)
        train_y: tensor of shape (n,) or (n x 1)

        Returns: The scalar MLL (float) after computing inv_quad and logdet.
        """
        train_x = train_x.detach().clone()
        train_y = train_y.detach().clone()

        # Mark them as requiring gradient only if you want to backprop
        # through train_x or train_y themselves (usually you do not)
        # train_x.requires_grad_(False)
        # train_y.requires_grad_(False)

        # Produce a dense covariance matrix from the kernel
        K = self.kernel(train_x, train_x).evaluate()  # shape (n, n)
        # Add noise on the diagonal
        K = K + self.noise_value() * torch.eye(K.size(-1), dtype=K.dtype, device=K.device)

        # Wrap the matrix with DenseLinearOperator
        linear_op = DenseLinearOperator(K)

        # If train_y is just shape (n,), pass it as inv_quad_rhs
        # logdet=True means we also want the approximate log|K|
        inv_quad_term, logdet_term = inv_quad_logdet(
            linear_op,
            inv_quad_rhs=train_y,
            logdet=True
        )

        n = train_y.size(0)
        # Standard Gaussian log-likelihood:
        # 0.5 * ( y^T K^{-1} y + log|K| + n * log(2π) )
        const = n * torch.log(torch.tensor(2.0 * 3.141592653589793, dtype=K.dtype, device=K.device))
        mll = 0.5 * (inv_quad_term + logdet_term + const)

        # Backprop:
        mll.backward()  # This will compute d/d(kernel params) and d/d(noise)
        
        # Return the numeric value
        return mll.item()

In [13]:
import numpy as np
import torch
import os,sys
import matplotlib.pyplot as plt
from gpytorch.kernels import RBFKernel,MaternKernel,ScaleKernel
from gpytorch.priors import GammaPrior
from gpytorch.likelihoods import GaussianLikelihood
notebook_dir = os.getcwd()
src_path = os.path.abspath(os.path.join(notebook_dir, '../code'))
if src_path not in sys.path:
    sys.path.append(src_path)
from gps import CholeskyGaussianProcess
from plotting import plot_gp_simple,plot_gp_sample
from util import train,eval,plot_gpr_results
from gps import IterativeGaussianProcess,CholeskyGaussianProcess
%load_ext autoreload
%autoreload 2
# np.random.seed(42)
# torch.manual_seed(42)

device="cuda:0"
global_dtype=torch.float32


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
# Define data generation function with an abstract true function
def generate_data(true_function, train_range=(-3, 3), test_range=(-3, 3), 
                  n_train=40, n_test=100, noise_std=0.1, 
                  device='cuda:0', dtype=torch.float64):
    # Generate training data
    X_train = torch.linspace(train_range[0], train_range[1], n_train, dtype=dtype, device=device).unsqueeze(-1)
    y_train = true_function(X_train) + noise_std * torch.randn_like(X_train)
    
    # Generate test data
    X_test = torch.linspace(test_range[0], test_range[1], n_test, dtype=dtype, device=device).unsqueeze(-1)
    y_test = true_function(X_test)  # No noise added to test data
    
    return X_train, y_train.squeeze(), X_test, y_test.squeeze()

# Define the true function
def true_function(x):
    
    return torch.sin(2 * x) + torch.cos(3 * x)

# Generate data using the true function
train_x, train_y, test_x, test_y = generate_data(true_function,train_range=(-3, 3), test_range=(-5,5), dtype=global_dtype)


import torch
import torch.nn as nn

from linear_operator.operators import DenseLinearOperator
from linear_operator import inv_quad_logdet


# class IterativeGP(nn.Module):
#     """
#     A minimal iterative Gaussian Process model that:
#       - Does a single pass of fit(...) with inv_quad_logdet
#       - Caches logdet(K) and inv_quad term for repeated usage
#       - Provides a separate compute_mll(...) method
#     """

#     def __init__(self, kernel, noise=0.1, dtype=torch.float64, device="cuda:0"):
#         super().__init__()
#         self.kernel = kernel
#         self.dtype = dtype
#         self.device = device
#         # Use a raw noise parameter if you want to backprop through noise
#         self.raw_noise = nn.Parameter(torch.log(torch.tensor(noise, dtype=self.dtype, device=self.device)))

#         # We'll cache these during fit
#         self.cached_logdet = None
#         self.cached_inv_quad = None
#         self.cached_mll = None

#     def noise_value(self):
#         """
#         Exponentiate raw_noise to keep it strictly > 0
#         """
#         return torch.exp(self.raw_noise)

#     def fit(self, train_x, train_y):
#         """
#         Perform a single 'fit' step:
#          1) Build K = kernel(...) + noise * I
#          2) inv_quad_logdet(...) => y^T K^{-1} y, logdet(K)
#          3) Store/copy to self.cached_(...) variables
#          4) (Optional) call backward if you want to do gradient-based updates right away
#         """
#         train_x = train_x.to(self.device, self.dtype)
#         train_y = train_y.to(self.device, self.dtype)

#         # Build kernel matrix
#         K = self.kernel(train_x, train_x).evaluate()  # (n x n)
#         K = K + self.noise_value() * torch.eye(
#             K.size(0), dtype=self.dtype, device=self.device
#         )

#         # Wrap in DenseLinearOperator
#         lin_op = DenseLinearOperator(K)

#         # Compute inverse-quad & logdet
#         # inv_quad_rhs=train_y => y^T K^{-1} y
#         # logdet=True => approximate log|K|
#         inv_quad_term, logdet_term = inv_quad_logdet(
#             lin_op, inv_quad_rhs=train_y, logdet=True
#         )

#         # Optionally store for repeated usage
#         self.cached_inv_quad = inv_quad_term.detach()
#         self.cached_logdet = logdet_term.detach()

#         # Also store the final MLL if you like
#         n = train_y.size(0)
#         const = n * torch.log(torch.tensor(2.0 * 3.141592653589793, dtype=self.dtype, device=self.device))
#         mll = 0.5 * (inv_quad_term + logdet_term + const)

#         self.cached_mll = mll.detach()

#         # If you want to do a single backward pass right now, uncomment:
#         # mll.backward()

#     def compute_mll(self):
#         """
#         Given that fit(...) has cached inv_quad and logdet,
#         compute the same MLL without re-running inv_quad_logdet.
#         """
#         if self.cached_inv_quad is None or self.cached_logdet is None:
#             raise RuntimeError("Must call fit(...) first to cache inv_quad/logdet values.")

#         # We'll need to know the size of training data for the constant term
#         # but we haven't stored train_y. Let's assume we also stored n in fit:
#         n = self.cached_inv_quad.size(0) if self.cached_inv_quad.dim() > 0 else 1
#         # Actually, if y was shape (n,), then inv_quad is just a scalar,
#         # so we can't parse n from that. Instead, suppose we stored `self.n_train` in fit(...).

#         # If we have self.cached_mll from fit, we can just return that:
#         if self.cached_mll is not None:
#             return self.cached_mll

#         # Or recompute from cached scalars:
#         #  mll = 0.5 * ( self.cached_inv_quad + self.cached_logdet + n*log(2π) )
#         #  return mll

#         raise NotImplementedError("Either store self.n_train or store self.cached_mll in fit(...) and just return it.")

#     def predict(self, X_star):
#         """
#         Optional. If you want to do posterior prediction with CG solves:
#          - Build (K + noise*I) again
#          - Solve for alpha = K^{-1} y
#          - Then compute K(X_star, X) alpha for the predictive mean
#          - Possibly multiple solves for the predictive covariance
#         We'll keep it minimal and leave it as a placeholder.
#         """
#         raise NotImplementedError("Implement iterative solves for predictive mean/cov if desired.")

In [3]:
#!/usr/bin/env python3

import warnings

import torch
from torch.autograd import Function

from linear_operator import settings
from linear_operator.utils.lanczos import lanczos_tridiag_to_diag
from linear_operator.utils.stochastic_lq import StochasticLQ


class customInvQuadLogdet(Function):
    """
    Given a PSD matrix A (or a batch of PSD matrices A), this function computes one or both
    of the following
    - The matrix solves A^{-1} b
    - logdet(A)

    This function uses preconditioned CG and Lanczos quadrature to compute the inverse quadratic
    and log determinant terms, using the variance reduction strategy outlined in:
    ``Reducing the Variance of Gaussian Process Hyperparameter Optimization with Preconditioning''
    (https://arxiv.org/abs/2107.00243)
    """

    @staticmethod
    def forward(
        ctx,
        representation_tree,
        precond_representation_tree,
        preconditioner,
        num_precond_args,
        inv_quad,
        probe_vectors,
        probe_vector_norms,
        *args,
    ):
        """
        *args - The arguments representing the PSD matrix A (or batch of PSD matrices A)
        If self.inv_quad is true, the first entry in *args is inv_quad_rhs (Tensor)
        - the RHS of the matrix solves.

        Returns:
        - (Scalar) The inverse quadratic form (or None, if self.inv_quad is False)
        - (Scalar) The log determinant (or None, self.if logdet is False)
        """

        ctx.representation_tree = representation_tree
        ctx.precond_representation_tree = precond_representation_tree
        ctx.preconditioner = preconditioner
        ctx.inv_quad = inv_quad
        ctx.num_precond_args = num_precond_args

        matrix_args = None
        precond_args = tuple()
        inv_quad_rhs = None
        if ctx.inv_quad:
            inv_quad_rhs = args[0]
            args = args[1:]
        if ctx.num_precond_args:
            matrix_args = args[:-num_precond_args]
            precond_args = args[-num_precond_args:]
        else:
            matrix_args = args

        # Get closure for matmul
        linear_op = ctx.representation_tree(*matrix_args)
        precond_lt = ctx.precond_representation_tree(*precond_args)

        # Get info about matrix
        ctx.dtype = linear_op.dtype
        ctx.device = linear_op.device
        ctx.matrix_shape = linear_op.matrix_shape
        ctx.batch_shape = linear_op.batch_shape

        # Probe vectors
        if probe_vectors is None or probe_vector_norms is None:
            num_random_probes = settings.num_trace_samples.value()
            if settings.deterministic_probes.on():
                # NOTE: calling precond_lt.root_decomposition() is expensive
                # because it requires Lanczos
                # We don't have any other choice for when we want to use deterministic probes, however
                if precond_lt.size()[-2:] == torch.Size([1, 1]):
                    covar_root = precond_lt.to_dense().sqrt()
                else:
                    covar_root = precond_lt.root_decomposition().root

                warnings.warn(
                    "The deterministic probes feature is now deprecated. "
                    "See https://github.com/cornellius-gp/linear_operator/pull/1836.",
                    DeprecationWarning,
                )
                base_samples = settings.deterministic_probes.probe_vectors
                if base_samples is None or covar_root.size(-1) != base_samples.size(-2):
                    base_samples = torch.randn(
                        *precond_lt.batch_shape,
                        covar_root.size(-1),
                        num_random_probes,
                        dtype=precond_lt.dtype,
                        device=precond_lt.device,
                    )
                    settings.deterministic_probes.probe_vectors = base_samples

                probe_vectors = covar_root.matmul(base_samples).permute(-1, *range(precond_lt.dim() - 1))
            else:
                probe_vectors = precond_lt.zero_mean_mvn_samples(num_random_probes)
            probe_vectors = probe_vectors.unsqueeze(-2).transpose(0, -2).squeeze(0).mT.contiguous()
            probe_vector_norms = torch.norm(probe_vectors, p=2, dim=-2, keepdim=True)
            probe_vectors = probe_vectors.div(probe_vector_norms)

        # Probe vectors
        ctx.probe_vectors = probe_vectors
        ctx.probe_vector_norms = probe_vector_norms

        # Collect terms for LinearCG
        # We use LinearCG for both matrix solves and for stochastically estimating the log det
        rhs_list = [ctx.probe_vectors]
        num_random_probes = ctx.probe_vectors.size(-1)
        num_inv_quad_solves = 0

        # RHS for inv_quad
        ctx.is_vector = False
        if ctx.inv_quad:
            if inv_quad_rhs.ndimension() == 1:
                inv_quad_rhs = inv_quad_rhs.unsqueeze(-1)
                ctx.is_vector = True
            rhs_list.append(inv_quad_rhs)
            num_inv_quad_solves = inv_quad_rhs.size(-1)

        # Perform solves (for inv_quad) and tridiagonalization (for estimating logdet)
        rhs = torch.cat(rhs_list, -1)
        solves, t_mat = linear_op._solve(rhs, preconditioner, num_tridiag=num_random_probes)

        # Final values to return
        logdet_term = torch.zeros(linear_op.batch_shape, dtype=ctx.dtype, device=ctx.device)
        inv_quad_term = torch.zeros(linear_op.batch_shape, dtype=ctx.dtype, device=ctx.device)

        # Compute logdet from tridiagonalization
        if settings.skip_logdet_forward.off():
            if torch.any(torch.isnan(t_mat)).item():
                logdet_term = torch.tensor(float("nan"), dtype=ctx.dtype, device=ctx.device)
            else:
                if ctx.batch_shape is None:
                    t_mat = t_mat.unsqueeze(1)
                eigenvalues, eigenvectors = lanczos_tridiag_to_diag(t_mat)
                slq = StochasticLQ()
                (logdet_term,) = slq.to_dense(ctx.matrix_shape, eigenvalues, eigenvectors, [lambda x: x.log()])

        # Extract inv_quad solves from all the solves
        if ctx.inv_quad:
            inv_quad_solves = solves.narrow(-1, num_random_probes, num_inv_quad_solves)
            inv_quad_term = (inv_quad_solves * inv_quad_rhs).sum(-2)

        ctx.num_random_probes = num_random_probes
        ctx.num_inv_quad_solves = num_inv_quad_solves

        to_save = list(precond_args) + list(matrix_args) + [solves]
        ctx.save_for_backward(*to_save)

        return inv_quad_term, logdet_term

    @staticmethod
    def backward(ctx, inv_quad_grad_output, logdet_grad_output):
        # Get input arguments, and get gradients in the proper form
        if ctx.num_precond_args:
            precond_args = ctx.saved_tensors[: ctx.num_precond_args]
            matrix_args = ctx.saved_tensors[ctx.num_precond_args : -1]
        else:
            precond_args = []
            matrix_args = ctx.saved_tensors[:-1]
        solves = ctx.saved_tensors[-1]

        linear_op = ctx.representation_tree(*matrix_args)
        precond_lt = ctx.precond_representation_tree(*precond_args)

        # Fix grad_output sizes
        if ctx.inv_quad:
            inv_quad_grad_output = inv_quad_grad_output.unsqueeze(-2)
        logdet_grad_output = logdet_grad_output.unsqueeze(-1)
        logdet_grad_output.unsqueeze_(-1)

        # Un-normalize probe vector solves
        coef = 1.0 / ctx.probe_vectors.size(-1)
        probe_vector_solves = solves.narrow(-1, 0, ctx.num_random_probes).mul(coef)
        probe_vector_solves.mul_(ctx.probe_vector_norms).mul_(logdet_grad_output)

        # Apply preconditioner to probe vectors (originally drawn from N(0, P))
        # Now the probe vectors will be drawn from N(0, P^{-1})
        if ctx.preconditioner is not None:
            precond_probe_vectors = ctx.preconditioner(ctx.probe_vectors * ctx.probe_vector_norms)
        else:
            precond_probe_vectors = ctx.probe_vectors * ctx.probe_vector_norms

        # matrix gradient
        # Collect terms for arg grads
        left_factors_list = [probe_vector_solves]
        right_factors_list = [precond_probe_vectors]

        inv_quad_solves = None
        neg_inv_quad_solves_times_grad_out = None
        if ctx.inv_quad:
            inv_quad_solves = solves.narrow(-1, ctx.num_random_probes, ctx.num_inv_quad_solves)
            neg_inv_quad_solves_times_grad_out = inv_quad_solves.mul(inv_quad_grad_output).mul_(-1)
            left_factors_list.append(neg_inv_quad_solves_times_grad_out)
            right_factors_list.append(inv_quad_solves)

        left_factors = torch.cat(left_factors_list, -1)
        right_factors = torch.cat(right_factors_list, -1)
        matrix_arg_grads = linear_op._bilinear_derivative(left_factors, right_factors)

        # precond gradient
        precond_arg_grads = precond_lt._bilinear_derivative(
            -precond_probe_vectors * coef, precond_probe_vectors * logdet_grad_output
        )

        # inv_quad_rhs gradients
        if ctx.inv_quad:
            inv_quad_rhs_grad = neg_inv_quad_solves_times_grad_out.mul_(-2)
            if ctx.is_vector:
                inv_quad_rhs_grad.squeeze_(-1)
            res = [inv_quad_rhs_grad] + list(matrix_arg_grads) + list(precond_arg_grads)
        else:
            res = list(matrix_arg_grads) + list(precond_arg_grads)

        return tuple([None] * 7 + res)

In [25]:
import torch
import torch.nn as nn
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.priors import GammaPrior
from linear_operator import inv_quad_logdet
from gpytorch.lazy.lazy_tensor import LazyTensor


def custom_inv_quad_logdet(lazy_tsr, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
    assert isinstance(lazy_tsr, LazyTensor)

    # Special case: use Cholesky to compute these terms
    if settings.fast_computations.log_prob.off() or (lazy_tsr.size(-1) <= settings.max_cholesky_size.value()):
        from linear_operator.operators.chol_linear_operator import CholLinearOperator
        from linear_operator.operators.chol_linear_operator import TriangularLinearOperator

        cholesky = CholLinearOperator(TriangularLinearOperator(lazy_tsr.cholesky()))
        return cholesky.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=logdet, reduce_inv_quad=reduce_inv_quad)

    # Default: use modified batch conjugate gradients to compute these terms
    # See NeurIPS 2018 paper: https://arxiv.org/abs/1809.11165
    if not lazy_tsr.is_square:
        raise RuntimeError(
            "inv_quad_logdet only operates on (batches of) square (positive semi-definite) LazyTensors. "
            "Got a {} of size {}.".format(lazy_tsr.__class__.__name__, lazy_tsr.size())
        )

    if inv_quad_rhs is not None:
        if lazy_tsr.dim() == 2 and inv_quad_rhs.dim() == 1:
            if lazy_tsr.shape[-1] != inv_quad_rhs.numel():
                raise RuntimeError(
                    "LazyTensor (size={}) cannot be multiplied with right-hand-side Tensor (size={}).".format(
                        lazy_tsr.shape, inv_quad_rhs.shape
                    )
                )
        elif lazy_tsr.dim() != inv_quad_rhs.dim():
            raise RuntimeError(
                "LazyTensor (size={}) and right-hand-side Tensor (size={}) should have the same number "
                "of dimensions.".format(lazy_tsr.shape, inv_quad_rhs.shape)
            )
        elif lazy_tsr.batch_shape != inv_quad_rhs.shape[:-2] or lazy_tsr.shape[-1] != inv_quad_rhs.shape[-2]:
            raise RuntimeError(
                "LazyTensor (size={}) cannot be multiplied with right-hand-side Tensor (size={}).".format(
                    lazy_tsr.shape, inv_quad_rhs.shape
                )
            )

    args = lazy_tsr.representation()  # TODO: check this

    if inv_quad_rhs is not None:
        args = [inv_quad_rhs] + list(args)

    probe_vectors, probe_vector_norms = lazy_tsr._probe_vectors_and_norms()

    func = CustomInvQuadLogDet.apply

    inv_quad_term, logdet_term = func(
        lazy_tsr.representation_tree(),
        lazy_tsr.dtype,
        lazy_tsr.device,
        lazy_tsr.matrix_shape,
        lazy_tsr.batch_shape,
        (inv_quad_rhs is not None),
        logdet,
        probe_vectors,
        probe_vector_norms,
        *args,
    )

    if inv_quad_term.numel() and reduce_inv_quad:
        inv_quad_term = inv_quad_term.sum(-1)
    return inv_quad_term, logdet_term



def generate_data(f, train_range=(-3, 3), test_range=(-3, 3), n_train=40, n_test=100, noise_std=0.1, device='cuda:0', dtype=torch.float64):
    X_train = torch.linspace(train_range[0], train_range[1], n_train, dtype=dtype, device=device).unsqueeze(-1)
    y_train = f(X_train) + noise_std * torch.randn_like(X_train)
    X_test = torch.linspace(test_range[0], test_range[1], n_test, dtype=dtype, device=device).unsqueeze(-1)
    y_test = f(X_test)
    return X_train, y_train.squeeze(), X_test, y_test.squeeze()

def true_function(x):
    return torch.sin(2 * x) + torch.cos(3 * x)

class IterativeGP2(nn.Module):
    def __init__(self, kernel, noise=0.1, dtype=torch.float64, device="cuda:0"):
        super().__init__()
        self.kernel = kernel
        self.dtype = dtype
        self.device = device
        self.raw_noise = nn.Parameter(torch.log(torch.tensor(noise, dtype=self.dtype, device=self.device)))
        self.train_x = None
        self.train_y = None
        self.logdet_term = None

    def noise_value(self):
        return torch.exp(self.raw_noise)

    def fit(self, train_x, train_y):
        self.train_x = train_x.to(self.device, self.dtype)
        self.train_y = train_y.to(self.device, self.dtype)

        self.K = self.kernel(self.train_x, self.train_x)
        self.K.add_diag(self.noise_value()**2)
        inv_quad,logdet = inv_quad_logdet(self.K, inv_quad_rhs=train_y.unsqueeze(1), logdet=True, reduce_inv_quad=True)
        print(inv_quad)
        print(logdet)
        self.logdet_term= logdet
        self.inv_quad_term = inv_quad

    def compute_mll(self):
        n = self.train_y.size(0)
        const = n * torch.log(torch.tensor(2.0 * 3.141592653589793, dtype=self.dtype, device=self.device))
        return 0.5 * (self.inv_quad_term + self.logdet_term + const)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
train_x, train_y, test_x, test_y = generate_data(true_function, train_range=(-3,3), test_range=(-5,5), device=device)
base_kernel = MaternKernel(ard_num_dims=train_x.shape[-1], nu=1.5, lengthscale_prior=GammaPrior(-3.0, 3.0))
kernel = ScaleKernel(base_kernel, outputscale_prior=GammaPrior(-3.0, 3.0)).to(device)
igp1 = IterativeGP2(kernel, noise=0.4, dtype=torch.float64, device=device)
igp1.fit(train_x, train_y)
mll_value = igp1.compute_mll()
print(mll_value.item())

import gpytorch
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        base_kernel = MaternKernel(ard_num_dims=train_x.shape[-1], nu=1.5, lengthscale_prior=GammaPrior(-3.0, 3.0))
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = ScaleKernel(base_kernel, outputscale_prior=GammaPrior(-3.0, 3.0))
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
train_x, train_y, test_x, test_y = generate_data(true_function, train_range=(-3, 3), test_range=(-5, 5), device=device)

# Initialize likelihood and the exact GP model
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
model = ExactGPModel(train_x, train_y, likelihood).to(device)

# Train the model by maximizing the marginal log likelihood
model.train()
likelihood.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
output = model(train_x)
loss = -mll(output, train_y)
print("Gpytorch mll",loss)



device = "cuda:0" if torch.cuda.is_available() else "cpu"
train_x, train_y, test_x, test_y = generate_data(true_function, train_range=(-3,3), test_range=(-5,5), device=device)
base_kernel = MaternKernel(ard_num_dims=train_x.shape[-1], nu=1.5, lengthscale_prior=GammaPrior(-3.0, 3.0))
kernel = ScaleKernel(base_kernel, outputscale_prior=GammaPrior(-3.0, 3.0)).to(device)
igp2 = CholeskyGaussianProcess(kernel, noise=0.4, dtype=torch.float64, device=device)
igp2.fit(train_x, train_y)
mll_value = igp2.compute_mll(train_y)
print("cholgp",mll_value.item())

tensor(76.8859, device='cuda:0', dtype=torch.float64, grad_fn=<SumBackward1>)
tensor(-119.1794, device='cuda:0', dtype=torch.float64, grad_fn=<SumBackward1>)
15.610799816887443
Gpytorch mll tensor(inf, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)
cholgp 27.37256775093035
