Skip to content

Commit

Permalink
Merge pull request #2511 from SebastianAment/constant_kernel
Browse files Browse the repository at this point in the history
`ConstantKernel`
  • Loading branch information
Balandat committed Apr 19, 2024
2 parents e3d8a5e + 89dfb46 commit a158e44
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 7 deletions.
9 changes: 8 additions & 1 deletion docs/source/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ gpytorch.kernels


If you don't know what kernel to use, we recommend that you start out with a
:code:`gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())`.
:code:`gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) + gpytorch.kernel.ConstantKernel()`.


Kernel
Expand All @@ -22,6 +22,13 @@ Kernel
Standard Kernels
-----------------------------

:hidden:`ConstantKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ConstantKernel
:members:


:hidden:`CosineKernel`
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from . import keops
from .additive_structure_kernel import AdditiveStructureKernel
from .arc_kernel import ArcKernel
from .constant_kernel import ConstantKernel
from .cosine_kernel import CosineKernel
from .cylindrical_kernel import CylindricalKernel
from .distributional_input_kernel import DistributionalInputKernel
Expand Down Expand Up @@ -38,6 +39,7 @@
"ArcKernel",
"AdditiveKernel",
"AdditiveStructureKernel",
"ConstantKernel",
"CylindricalKernel",
"MultiDeviceKernel",
"CosineKernel",
Expand Down
123 changes: 123 additions & 0 deletions gpytorch/kernels/constant_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#!/usr/bin/env python3

from typing import Optional, Tuple

import torch
from torch import Tensor

from ..constraints import Interval, Positive
from ..priors import Prior
from .kernel import Kernel


class ConstantKernel(Kernel):
"""
Constant covariance kernel for the probabilistic inference of constant coefficients.
ConstantKernel represents the prior variance `k(x1, x2) = var(c)` of a constant `c`.
The prior variance of the constant is optimized during the GP hyper-parameter
optimization stage. The actual value of the constant is computed (implicitly) using
the linear algebraic approaches for the computation of GP samples and posteriors.
The constant kernel `k_constant` is most useful as a modification of an arbitrary
base kernel `k_base`:
1) Additive constants: The modification `k_base + k_constant` allows the GP to
infer a non-zero asymptotic value far from the training data, which generally
leads to more accurate extrapolation. Notably, the uncertainty in this constant
value affects the posterior covariances through the posterior inference equations.
This is not the case when a constant prior mean is not used, since the prior mean
does not show up the posterior covariance and is regularized by the log-determinant
during the optimization of the marginal likelihood.
2) Multiplicative constants: The modification `k_base * k_constant` allows the GP to
modulate the variance of the kernel `k_base`, and is mathematically identical to
`ScaleKernel(base_kernel)` with the same constant.
"""

has_lengthscale = False

def __init__(
self,
batch_shape: Optional[torch.Size] = None,
constant_prior: Optional[Prior] = None,
constant_constraint: Optional[Interval] = None,
active_dims: Optional[Tuple[int, ...]] = None,
):
"""Constructor of ConstantKernel.
Args:
batch_shape: The batch shape of the kernel.
constant_prior: Prior over the constant parameter.
constant_constraint: Constraint to place on constant parameter.
active_dims: The dimensions of the input with which to evaluate the kernel.
This is mute for the constant kernel, but added for compatability with
the Kernel API.
"""
super().__init__(batch_shape=batch_shape, active_dims=active_dims)

self.register_parameter(
name="raw_constant",
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)),
)

if constant_prior is not None:
if not isinstance(constant_prior, Prior):
raise TypeError("Expected gpytorch.priors.Prior but got " + type(constant_prior).__name__)
self.register_prior(
"constant_prior",
constant_prior,
lambda m: m.constant,
lambda m, v: m._set_constant(v),
)

if constant_constraint is None:
constant_constraint = Positive()
self.register_constraint("raw_constant", constant_constraint)

@property
def constant(self) -> Tensor:
return self.raw_constant_constraint.transform(self.raw_constant)

@constant.setter
def constant(self, value: Tensor) -> None:
self._set_constant(value)

def _set_constant(self, value: Tensor) -> None:
value = value.view(*self.batch_shape, 1)
self.initialize(raw_constant=self.raw_constant_constraint.inverse_transform(value))

def forward(
self,
x1: Tensor,
x2: Tensor,
diag: Optional[bool] = False,
last_dim_is_batch: Optional[bool] = False,
) -> Tensor:
"""Evaluates the constant kernel.
Args:
x1: First input tensor of shape (batch_shape x n1 x d).
x2: Second input tensor of shape (batch_shape x n2 x d).
diag: If True, returns the diagonal of the covariance matrix.
last_dim_is_batch: If True, the last dimension of size `d` of the input
tensors are treated as a batch dimension.
Returns:
A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of
constant covariance values if diag is False, resp. True.
"""
if last_dim_is_batch:
x1 = x1.transpose(-1, -2).unsqueeze(-1)
x2 = x2.transpose(-1, -2).unsqueeze(-1)

dtype = torch.promote_types(x1.dtype, x2.dtype)
batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],))
constant = self.constant.to(dtype=dtype, device=x1.device)

if not diag:
constant = constant.unsqueeze(-1)

if last_dim_is_batch:
constant = constant.unsqueeze(-1)

return constant.expand(shape)
10 changes: 4 additions & 6 deletions gpytorch/test/base_kernel_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,21 @@ def test_no_batch_kernel_double_batch_x_ard(self):
actual_diag = actual_covar_mat.diagonal(dim1=-1, dim2=-2)
self.assertAllClose(kernel_diag, actual_diag, rtol=1e-3, atol=1e-5)

def test_smoke_double_batch_kernel_double_batch_x_no_ard(self):
def test_smoke_double_batch_kernel_double_batch_x_no_ard(self) -> None:
kernel = self.create_kernel_no_ard(batch_shape=torch.Size([3, 2]))
x = self.create_data_double_batch()
batch_covar_mat = kernel(x).evaluate_kernel().to_dense()
kernel(x).evaluate_kernel().to_dense()
kernel(x, diag=True)
return batch_covar_mat

def test_smoke_double_batch_kernel_double_batch_x_ard(self):
def test_smoke_double_batch_kernel_double_batch_x_ard(self) -> None:
try:
kernel = self.create_kernel_ard(num_dims=2, batch_shape=torch.Size([3, 2]))
except NotImplementedError:
return

x = self.create_data_double_batch()
batch_covar_mat = kernel(x).evaluate_kernel().to_dense()
kernel(x).evaluate_kernel().to_dense()
kernel(x, diag=True)
return batch_covar_mat

def test_kernel_getitem_single_batch(self):
kernel = self.create_kernel_no_ard(batch_shape=torch.Size([2]))
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def find_version(*file_paths):
"nbclient<=0.7.3",
"nbformat<=5.8.0",
"nbsphinx<=0.9.1",
"lxml_html_clean",
"platformdirs<=3.2.0",
"setuptools_scm<=7.1.0",
"sphinx<=6.2.1",
Expand Down
113 changes: 113 additions & 0 deletions test/kernels/test_constant_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/usr/bin/env python3

import itertools
import unittest

import torch

from torch import Tensor

from gpytorch.kernels import AdditiveKernel, ConstantKernel, MaternKernel, ProductKernel, ScaleKernel
from gpytorch.lazy import LazyEvaluatedKernelTensor
from gpytorch.priors.torch_priors import GammaPrior
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase


class TestConstantKernel(unittest.TestCase, BaseKernelTestCase):
def create_kernel_no_ard(self, **kwargs):
return ConstantKernel(**kwargs)

def test_constant_kernel(self):
with self.subTest(device="cpu"):
self._test_constant_kernel(torch.device("cpu"))

if torch.cuda.is_available():
with self.subTest(device="cuda"):
self._test_constant_kernel(torch.device("cuda"))

def _test_constant_kernel(self, device: torch.device):
n, d = 3, 5
dtypes = [torch.float, torch.double]
batch_shapes = [(), (2,), (7, 2)]
torch.manual_seed(123)
for dtype, batch_shape in itertools.product(dtypes, batch_shapes):
tkwargs = {"dtype": dtype, "device": device}
places = 6 if dtype == torch.float else 12
X = torch.rand(*batch_shape, n, d, **tkwargs)

constant_kernel = ConstantKernel(batch_shape=batch_shape)
KL = constant_kernel(X)
self.assertIsInstance(KL, LazyEvaluatedKernelTensor)
KM = KL.to_dense()
self.assertIsInstance(KM, Tensor)
self.assertEqual(KM.shape, (*batch_shape, n, n))
self.assertEqual(KM.dtype, dtype)
self.assertEqual(KM.device.type, device.type)
# standard deviation is zero iff KM is constant
self.assertAlmostEqual(KM.std().item(), 0, places=places)

# testing last_dim_is_batch
with self.subTest(last_dim_is_batch=True):
KD = constant_kernel(X, last_dim_is_batch=True).to(device=device)
self.assertIsInstance(KD, LazyEvaluatedKernelTensor)
KM = KD.to_dense()
self.assertIsInstance(KM, Tensor)
self.assertEqual(KM.shape, (*batch_shape, d, n, n))
self.assertAlmostEqual(KM.std().item(), 0, places=places)
self.assertEqual(KM.dtype, dtype)
self.assertEqual(KM.device.type, device.type)

# testing diag
with self.subTest(diag=True):
KD = constant_kernel(X, diag=True)
self.assertIsInstance(KD, Tensor)
self.assertEqual(KD.shape, (*batch_shape, n))
self.assertAlmostEqual(KD.std().item(), 0, places=places)
self.assertEqual(KD.dtype, dtype)
self.assertEqual(KD.device.type, device.type)

# testing diag and last_dim_is_batch
with self.subTest(diag=True, last_dim_is_batch=True):
KD = constant_kernel(X, diag=True, last_dim_is_batch=True)
self.assertIsInstance(KD, Tensor)
self.assertEqual(KD.shape, (*batch_shape, d, n))
self.assertAlmostEqual(KD.std().item(), 0, places=places)
self.assertEqual(KD.dtype, dtype)
self.assertEqual(KD.device.type, device.type)

# testing AD
with self.subTest(requires_grad=True):
X.requires_grad = True
constant_kernel(X).to_dense().sum().backward()
self.assertIsNone(X.grad) # constant kernel is not dependent on X

# testing algebraic combinations with another kernel
base_kernel = MaternKernel().to(device=device)

with self.subTest(additive=True):
sum_kernel = base_kernel + constant_kernel
self.assertIsInstance(sum_kernel, AdditiveKernel)
self.assertAllClose(
sum_kernel(X).to_dense(),
base_kernel(X).to_dense() + constant_kernel.constant.unsqueeze(-1),
)

# product with constant is equivalent to scale kernel
with self.subTest(product=True):
product_kernel = base_kernel * constant_kernel
self.assertIsInstance(product_kernel, ProductKernel)

scale_kernel = ScaleKernel(base_kernel, batch_shape=batch_shape)
scale_kernel.to(device=device)
self.assertAllClose(scale_kernel(X).to_dense(), product_kernel(X).to_dense())

# setting constant
pies = torch.full_like(constant_kernel.constant, torch.pi)
constant_kernel.constant = pies
self.assertAllClose(constant_kernel.constant, pies)

# specifying prior
constant_kernel = ConstantKernel(constant_prior=GammaPrior(concentration=2.4, rate=2.7))

with self.assertRaisesRegex(TypeError, "Expected gpytorch.priors.Prior but got"):
ConstantKernel(constant_prior=1)

0 comments on commit a158e44

Please sign in to comment.