Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mean and kernel functions for first and second derivatives #2235

Merged
merged 20 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
99697e1
Added an RBF kernel with second (non-mixed) derivatives. Need to be t…
ankushaggarwal Jun 11, 2021
e3e97cb
Added constant mean with second derivative and linear means with firs…
ankushaggarwal Jun 12, 2021
cbf0bf0
Merge branch 'master' into second-derivs
ankushaggarwal Jan 3, 2023
c18e6d8
Hook fixes using pre-commit
ankushaggarwal Jan 22, 2023
6fa7f0d
Merge branch 'master' into second-derivs
ankushaggarwal Jan 22, 2023
c43c8c8
Added new kernel and means to the documentation
ankushaggarwal Jan 22, 2023
451dfd1
Hook fix
ankushaggarwal Jan 22, 2023
9bf9c3b
Changed from lazy tensor to linear operator as per the newer gpytorch…
ankushaggarwal Jan 22, 2023
0bdfc3e
Changed from utils.broadcasting._mul_broadcast_shape to torch.broadca…
ankushaggarwal Jan 22, 2023
f6a3ce2
Revert "Hook fix"
ankushaggarwal Jan 22, 2023
6e91183
Fixed a minor error (as per the new version of gpytorch)
ankushaggarwal Jan 22, 2023
d473faf
Added the diag=True version of rbf gradgrad kernel
ankushaggarwal Jan 22, 2023
1f7ab49
Fix formatting with pre-commit
ankushaggarwal Jan 23, 2023
b18ecc6
Increased the underline length to pass doc test
ankushaggarwal Jan 23, 2023
e246d2f
Getting rid of the changes to pre-commit-hooks
ankushaggarwal Apr 18, 2023
daeedab
Merge remote-tracking branch 'upstream/master' into second-derivs
ankushaggarwal Apr 18, 2023
b717bfc
Added docstrings and type hints
ankushaggarwal Apr 18, 2023
1623823
Fixed the wrong references in doc of ConstantMeanGradGrad
ankushaggarwal Apr 18, 2023
bd8a9cc
Added unit tests
ankushaggarwal Apr 21, 2023
94819cf
Merge branch 'master' into second-derivs
gpleiss May 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind splitting out the hook update into a separate PR? It's good practice and easier to figure out if things go wrong if this is not intermingled with feature additions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Balandat I am not too familiar with the hook updates. Do you mean the changes in .pre-commit-config.yaml or changes to all the files by pre-commit (in commit c18e6d8)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why we need to change the pre-commit-hooks file at all. Can we get rid of these changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, removed it.

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
rev: v4.4.0
hooks:
- id: check-byte-order-marker
- id: check-case-conflict
Expand All @@ -12,13 +12,13 @@ repos:
- id: trailing-whitespace
- id: debug-statements
- repo: https://github.com/pycqa/flake8
rev: 4.0.1
rev: 6.0.0
hooks:
- id: flake8
args: [--config=setup.cfg]
exclude: ^(examples/*)|(docs/*)
- repo: https://github.com/ambv/black
rev: 22.3.0
rev: 22.12.0
hooks:
- id: black
exclude: ^(build/*)|(docs/*)|(examples/*)
Expand All @@ -31,15 +31,15 @@ repos:
exclude: ^(build/*)|(docs/*)|(examples/*)
args: [-w120, -m3, --tc, --project=gpytorch]
- repo: https://github.com/jumanjihouse/pre-commit-hooks
rev: 2.1.6
rev: 3.0.0
hooks:
- id: require-ascii
exclude: ^(examples/.*\.ipynb)|(.github/ISSUE_TEMPLATE/*)
- id: script-must-have-extension
- id: forbid-binary
exclude: ^(examples/*)|(test/examples/old_variational_strategy_model.pth)
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.1.13
rev: v1.4.0
hooks:
- id: forbid-crlf
- id: forbid-tabs
6 changes: 6 additions & 0 deletions docs/source/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ Specialty Kernels
.. autoclass:: RBFKernelGrad
:members:

:hidden:`RBFKernelGradGrad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: RBFKernelGradGrad
:members:


Kernels for Scalable GP Regression Methods
--------------------------------------------
Expand Down
18 changes: 18 additions & 0 deletions docs/source/means.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,21 @@ Specialty Means

.. autoclass:: ConstantMeanGrad
:members:

:hidden:`ConstantMeanGradGrad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ConstantMeanGrad
:members:

:hidden:`LinearMeanGrad`
~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: LinearMeanGrad
:members:

:hidden:`LinearMeanGradGrad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: LinearMeanGradGrad
:members:
2 changes: 2 additions & 0 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .product_structure_kernel import ProductStructureKernel
from .rbf_kernel import RBFKernel
from .rbf_kernel_grad import RBFKernelGrad
from .rbf_kernel_gradgrad import RBFKernelGradGrad
from .rff_kernel import RFFKernel
from .rq_kernel import RQKernel
from .scale_kernel import ScaleKernel
Expand Down Expand Up @@ -59,6 +60,7 @@
"RBFKernel",
"RFFKernel",
"RBFKernelGrad",
"RBFKernelGradGrad",
"RQKernel",
"ScaleKernel",
"SpectralDeltaKernel",
Expand Down
170 changes: 170 additions & 0 deletions gpytorch/kernels/rbf_kernel_gradgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python3

import torch
from linear_operator.operators import KroneckerProductLinearOperator

from .rbf_kernel import RBFKernel, postprocess_rbf


class RBFKernelGradGrad(RBFKernel):
r"""
Computes a covariance matrix of the RBF kernel that models the covariance
between the values and first and second (non-mixed) partial derivatives for inputs :math:`\mathbf{x_1}`
and :math:`\mathbf{x_2}`.

See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options.

.. note::

This kernel does not have an `outputscale` parameter. To add a scaling parameter,
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.

Args:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the doc string to be in the standard sphinx format? (See the Kernel base class for an example)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

:attr:`batch_shape` (torch.Size, optional):
Set this if you want a separate lengthscale for each
batch of input data. It should be `b` if :attr:`x1` is a `b x n x d` tensor. Default: `torch.Size([])`.
:attr:`active_dims` (tuple of ints, optional):
Set this if you want to compute the covariance of only a few input dimensions. The ints
corresponds to the indices of the dimensions. Default: `None`.
:attr:`lengthscale_prior` (Prior, optional):
Set this if you want to apply a prior to the lengthscale parameter. Default: `None`.
:attr:`lengthscale_constraint` (Constraint, optional):
Set this if you want to apply a constraint to the lengthscale parameter. Default: `Positive`.
:attr:`eps` (float):
The minimum value that the lengthscale can take (prevents divide by zero errors). Default: `1e-6`.

Attributes:
:attr:`lengthscale` (Tensor):
The lengthscale parameter. Size/shape of parameter depends on the
:attr:`ard_num_dims` and :attr:`batch_shape` arguments.

Example:
>>> x = torch.randn(10, 5)
>>> # Non-batch: Simple option
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad())
>>> covar = covar_module(x) # Output: LinearOperator of size (110 x 110), where 110 = n * (2*d + 1)
>>>
>>> batch_x = torch.randn(2, 10, 5)
>>> # Batch: Simple option
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad())
>>> # Batch: different lengthscale for each batch
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad(batch_shape=torch.Size([2])))
>>> covar = covar_module(x) # Output: LinearOperator of size (2 x 110 x 110)
"""

def forward(self, x1, x2, diag=False, **params):
batch_shape = x1.shape[:-2]
n_batch_dims = len(batch_shape)
n1, d = x1.shape[-2:]
n2 = x2.shape[-2]

K = torch.zeros(*batch_shape, n1 * (2 * d + 1), n2 * (2 * d + 1), device=x1.device, dtype=x1.dtype)

if not diag:
# Scale the inputs by the lengthscale (for stability)
x1_ = x1.div(self.lengthscale)
x2_ = x2.div(self.lengthscale)

# Form all possible rank-1 products for the gradient and Hessian blocks
outer = x1_.view(*batch_shape, n1, 1, d) - x2_.view(*batch_shape, 1, n2, d)
outer = outer / self.lengthscale.unsqueeze(-2)
outer = torch.transpose(outer, -1, -2).contiguous()

# 1) Kernel block
diff = self.covar_dist(x1_, x2_, square_dist=True, **params)
K_11 = postprocess_rbf(diff)
K[..., :n1, :n2] = K_11

# 2) First gradient block
outer1 = outer.view(*batch_shape, n1, n2 * d)
K[..., :n1, n2 : (n2 * (d + 1))] = outer1 * K_11.repeat([*([1] * (n_batch_dims + 1)), d])

# 3) Second gradient block
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
outer2 = outer2.transpose(-1, -2)
K[..., n1 : (n1 * (d + 1)), :n2] = -outer2 * K_11.repeat([*([1] * n_batch_dims), d, 1])

# 4) Hessian block
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d])
kp = KroneckerProductLinearOperator(
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
)
chain_rule = kp.to_dense() - outer3
K[..., n1 : (n1 * (d + 1)), n2 : (n2 * (d + 1))] = chain_rule * K_11.repeat([*([1] * n_batch_dims), d, d])

# 5) 1-3 block
douter1dx2 = KroneckerProductLinearOperator(
torch.ones(1, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
).to_dense()

K_13 = (-douter1dx2 + outer1 * outer1) * K_11.repeat(
[*([1] * (n_batch_dims + 1)), d]
) # verified for n1=n2=1 case
K[..., :n1, (n2 * (d + 1)) :] = K_13

K_31 = (-douter1dx2.transpose(-1, -2) + outer2 * outer2) * K_11.repeat(
[*([1] * n_batch_dims), d, 1]
) # verified for n1=n2=1 case
K[..., (n1 * (d + 1)) :, :n2] = K_31

# rest of the blocks are all of size (n1*d,n2*d)
outer1 = outer1.repeat([*([1] * n_batch_dims), d, 1])
outer2 = outer2.repeat([*([1] * (n_batch_dims + 1)), d])
# II = (torch.eye(d,d,device=x1.device,dtype=x1.dtype)/lengthscale.pow(2)).repeat(*batch_shape,n1,n2)
kp2 = KroneckerProductLinearOperator(
torch.ones(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
).to_dense()

# II may not be the correct thing to use. It might be more appropriate to use kp instead??
II = kp.to_dense()
K_11dd = K_11.repeat([*([1] * (n_batch_dims)), d, d])

K_23 = ((-kp2 + outer1 * outer1) * (-outer2) + 2.0 * II * outer1) * K_11dd # verified for n1=n2=1 case

K[..., n1 : (n1 * (d + 1)), (n2 * (d + 1)) :] = K_23

K_32 = (
(-kp2.transpose(-1, -2) + outer2 * outer2) * outer1 - 2.0 * II * outer2
) * K_11dd # verified for n1=n2=1 case

K[..., (n1 * (d + 1)) :, n2 : (n2 * (d + 1))] = K_32

K_33 = (
(-kp2.transpose(-1, -2) + outer2 * outer2) * (-kp2) - 2.0 * II * outer2 * outer1 + 2.0 * (II) ** 2
) * K_11dd + (
(-kp2.transpose(-1, -2) + outer2 * outer2) * outer1 - 2.0 * II * outer2
) * outer1 * K_11dd # verified for n1=n2=1 case

K[..., (n1 * (d + 1)) :, (n2 * (d + 1)) :] = K_33

# Symmetrize for stability
if n1 == n2 and torch.eq(x1, x2).all():
K = 0.5 * (K.transpose(-1, -2) + K)

# Apply a perfect shuffle permutation to match the MutiTask ordering
pi1 = torch.arange(n1 * (2 * d + 1)).view(2 * d + 1, n1).t().reshape((n1 * (2 * d + 1)))
pi2 = torch.arange(n2 * (2 * d + 1)).view(2 * d + 1, n2).t().reshape((n2 * (2 * d + 1)))
K = K[..., pi1, :][..., :, pi2]

return K

else:
if not (n1 == n2 and torch.eq(x1, x2).all()):
raise RuntimeError("diag=True only works when x1 == x2")

kernel_diag = super(RBFKernelGradGrad, self).forward(x1, x2, diag=True)
grad_diag = torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(2)
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
gradgrad_diag = (
3 * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(4)
)
gradgrad_diag = gradgrad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
k_diag = torch.cat((kernel_diag, grad_diag, gradgrad_diag), dim=-1)
pi = torch.arange(n2 * (2 * d + 1)).view(2 * d + 1, n2).t().reshape((n2 * (2 * d + 1)))
return k_diag[..., pi]

def num_outputs_per_input(self, x1, x2):
return x1.size(-1) * 2 + 1
15 changes: 14 additions & 1 deletion gpytorch/means/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,22 @@

from .constant_mean import ConstantMean
from .constant_mean_grad import ConstantMeanGrad
from .constant_mean_gradgrad import ConstantMeanGradGrad
from .linear_mean import LinearMean
from .linear_mean_grad import LinearMeanGrad
from .linear_mean_gradgrad import LinearMeanGradGrad
from .mean import Mean
from .multitask_mean import MultitaskMean
from .zero_mean import ZeroMean

__all__ = ["Mean", "ConstantMean", "ConstantMeanGrad", "LinearMean", "MultitaskMean", "ZeroMean"]
__all__ = [
"Mean",
"ConstantMean",
"ConstantMeanGrad",
"ConstantMeanGradGrad",
"LinearMean",
"LinearMeanGrad",
"LinearMeanGradGrad",
"MultitaskMean",
"ZeroMean",
]
20 changes: 20 additions & 0 deletions gpytorch/means/constant_mean_gradgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class ConstantMeanGradGrad(Mean):
def __init__(self, prior=None, batch_shape=torch.Size(), **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a doc string and type hints.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

super(ConstantMeanGradGrad, self).__init__()
self.batch_shape = batch_shape
self.register_parameter(name="constant", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1)))
if prior is not None:
self.register_prior("mean_prior", prior, "constant")

def forward(self, input):
batch_shape = torch.broadcast_shapes(self.batch_shape, input.shape[:-2])
mean = self.constant.unsqueeze(-1).expand(*batch_shape, input.size(-2), 2 * input.size(-1) + 1).contiguous()
mean[..., 1:] = 0
return mean
23 changes: 23 additions & 0 deletions gpytorch/means/linear_mean_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class LinearMeanGrad(Mean):
def __init__(self, input_size, batch_shape=torch.Size(), bias=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc string and type hints

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

super().__init__()
self.dim = input_size
self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1)))
if bias:
self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1)))
else:
self.bias = None

def forward(self, x):
res = x.matmul(self.weights)
if self.bias is not None:
res = res + self.bias.unsqueeze(-1)
dres = self.weights.expand(x.transpose(-1, -2).shape).transpose(-1, -2)
return torch.cat((res, dres), -1)
24 changes: 24 additions & 0 deletions gpytorch/means/linear_mean_gradgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class LinearMeanGradGrad(Mean):
def __init__(self, input_size, batch_shape=torch.Size(), bias=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc string and type hints.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

super().__init__()
self.dim = input_size
self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1)))
if bias:
self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1)))
else:
self.bias = None

def forward(self, x):
res = x.matmul(self.weights)
if self.bias is not None:
res = res + self.bias.unsqueeze(-1)
dres = self.weights.expand(x.transpose(-1, -2).shape).transpose(-1, -2)
ddres = torch.zeros_like(dres)
return torch.cat((res, dres, ddres), -1)