Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GP regression with derivative information #462

Merged
merged 25 commits into from Jan 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4dc61b5
Working on GPs with derivative observations.
dme65 Dec 11, 2018
ab65069
WIP
jacobrgardner Dec 11, 2018
7d54fcf
Working gpderiv in 1d
dme65 Jan 4, 2019
6b5188a
Temporarily adding _create_input_grid + fixing diagonal
dme65 Jan 4, 2019
6b02be6
rbf_kernel_grad works in more than 1d + example
dme65 Jan 7, 2019
212e1f2
Reordering operations to get rid of cancellation
dme65 Jan 7, 2019
cf6adac
Adding comments to notebooks
dme65 Jan 10, 2019
556baff
Adding tests + cleaning up batch handling
dme65 Jan 10, 2019
6153d68
Adding comments
dme65 Jan 11, 2019
d6bd723
Docs + linting
dme65 Jan 11, 2019
a0ad6c0
Adding GPU support
dme65 Jan 11, 2019
1ede625
Adding CUDA test
dme65 Jan 11, 2019
11b5c1b
JIT for distance computation, kernels can override with custom JIT sc…
jacobrgardner Jan 14, 2019
a1e8bcc
JIT script internals of Linear CG
jacobrgardner Jan 19, 2019
b33a396
Matern kernel fix
jacobrgardner Jan 22, 2019
80864e1
Code reuse for JIT scripts
jacobrgardner Jan 23, 2019
0ff6416
Kernels now pass scripts for handling postprocessing only
jacobrgardner Jan 23, 2019
8e7d3ed
Call postprocess in diag mode
jacobrgardner Jan 24, 2019
1f88504
Remove double negative variable name
jacobrgardner Jan 24, 2019
6faba6a
Implement Gauss-Hermite quadrature for use with likelihoods
jacobrgardner Jan 23, 2019
195fa8b
remove log_ kwargs deprecation
KeAWang Jan 26, 2019
4108f3b
fix GaussianLikelihood initialize noise bug
KeAWang Jan 26, 2019
fc51fac
update test files to not use log parameters
KeAWang Jan 26, 2019
7f469e4
Adding size method to decorator kernels
dme65 Jan 26, 2019
5edcbaa
Fix squeeze in RBFKernelGrad, update examples
gpleiss Jan 27, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/kernels.rst
Expand Up @@ -127,6 +127,12 @@ Specialty Kernels
.. autoclass:: MultitaskKernel
:members:

:hidden:`RBFKernelGrad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: RBFKernelGrad
:members:


Kernels for Scalable GP Regression Methods
--------------------------------------------
Expand Down
6 changes: 6 additions & 0 deletions docs/source/means.rst
Expand Up @@ -39,3 +39,9 @@ Specialty Means

.. autoclass:: MultitaskMean
:members:

:hidden:`ConstantMeanGrad`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ConstantMeanGrad
:members:
6 changes: 6 additions & 0 deletions docs/source/utils.rst
Expand Up @@ -25,6 +25,12 @@ Pivoted Cholesky Utilities
.. automodule:: gpytorch.utils.pivoted_cholesky
:members:

Quadrature Utilities
~~~~~~~~~~~~~~~~~~~~

.. automodule:: gpytorch.utils.quadrature
:members:

Sparse Utilities
~~~~~~~~~~~~~~~~~

Expand Down
3 changes: 3 additions & 0 deletions examples/10_GP_Regression_Derivative_Information/README.md
@@ -0,0 +1,3 @@
# GP regression with derivative information

This is a new feature, and is likely unstable. Enjoy!

Large diffs are not rendered by default.

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions examples/10_GP_Regression_Derivative_Information/index.rst
@@ -0,0 +1,9 @@
.. mdinclude:: README.md

.. toctree::
:glob:
:maxdepth: 1
:hidden:

Simple_GP_Regression_Derivative_Information_1d.ipynb
Simple_GP_Regression_Derivative_Information_2d.ipynb
2 changes: 2 additions & 0 deletions gpytorch/kernels/__init__.py
Expand Up @@ -15,6 +15,7 @@
from .periodic_kernel import PeriodicKernel
from .product_structure_kernel import ProductStructureKernel
from .rbf_kernel import RBFKernel
from .rbf_kernel_grad import RBFKernelGrad
from .scale_kernel import ScaleKernel
from .spectral_mixture_kernel import SpectralMixtureKernel
from .white_noise_kernel import WhiteNoiseKernel
Expand All @@ -38,6 +39,7 @@
"ProductKernel",
"ProductStructureKernel",
"RBFKernel",
"RBFKernelGrad",
"ScaleKernel",
"SpectralMixtureKernel",
"WhiteNoiseKernel",
Expand Down
3 changes: 3 additions & 0 deletions gpytorch/kernels/additive_structure_kernel.py
Expand Up @@ -54,3 +54,6 @@ def forward(self, x1, x2, batch_dims=None, **params):
if evaluate:
res = res.evaluate()
return res

def size(self, x1, x2):
return self.base_kernel.size(x1, x2)
4 changes: 0 additions & 4 deletions gpytorch/kernels/cosine_kernel.py
Expand Up @@ -3,7 +3,6 @@
import math
import torch
from .kernel import Kernel
from ..utils.deprecation import _deprecate_kwarg
from torch.nn.functional import softplus


Expand Down Expand Up @@ -66,9 +65,6 @@ def __init__(
inv_param_transform=None,
**kwargs
):
period_length_prior = _deprecate_kwarg(
kwargs, "log_period_length_prior", "period_length_prior", period_length_prior
)
super(CosineKernel, self).__init__(
active_dims=active_dims, param_transform=param_transform, inv_param_transform=inv_param_transform
)
Expand Down
3 changes: 3 additions & 0 deletions gpytorch/kernels/grid_interpolation_kernel.py
Expand Up @@ -180,3 +180,6 @@ def forward(self, x1, x2, batch_dims=None, **params):
)

return res

def size(self, x1, x2):
return self.base_kernel.size(x1, x2)
3 changes: 3 additions & 0 deletions gpytorch/kernels/grid_kernel.py
Expand Up @@ -98,3 +98,6 @@ def forward(self, x1, x2, diag=False, batch_dims=None, **params):
return covar
else:
return self.base_kernel.forward(x1, x2, diag=diag, batch_dims=batch_dims, **params)

def size(self, x1, x2):
return self.base_kernel.size(x1, x2)
3 changes: 3 additions & 0 deletions gpytorch/kernels/inducing_point_kernel.py
Expand Up @@ -98,3 +98,6 @@ def forward(self, x1, x2, **kwargs):
self.update_added_loss_term("inducing_point_loss_term", new_added_loss_term)

return covar

def size(self, x1, x2):
return self.base_kernel.size(x1, x2)
119 changes: 65 additions & 54 deletions gpytorch/kernels/kernel.py
Expand Up @@ -6,11 +6,47 @@
from ..lazy import lazify, LazyEvaluatedKernelTensor, ZeroLazyTensor
from ..module import Module
from .. import settings
from ..utils.deprecation import _deprecate_kwarg
from ..utils.transforms import _get_inv_param_transform
from torch.nn.functional import softplus
from numpy import triu_indices
import warnings


@torch.jit.script
def default_postprocess_script(x):
return x


class Distance(torch.jit.ScriptModule):
def __init__(self, postprocess_script=default_postprocess_script):
super().__init__()
self._postprocess = postprocess_script

@torch.jit.script_method
def _jit_sq_dist(self, x1, x2, diag, x1_eq_x2):
# Compute squared distance matrix using quadratic expansion
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
if bool(x1_eq_x2):
x2_norm = x1_norm
else:
x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)

res = x1.matmul(x2.transpose(-2, -1))
res = res.mul_(-2).add_(x2_norm.transpose(-2, -1)).add_(x1_norm)

if bool(x1_eq_x2):
# Ensure zero diagonal
res.diagonal(dim1=-2, dim2=-1).fill_(0)

# Zero out negative values
res.clamp_min_(0)

return self._postprocess(res)

@torch.jit.script_method
def _jit_dist(self, x1, x2, diag, x1_eq_x2):
res = self._jit_sq_dist(x1, x2, diag, x1_eq_x2)
res = res.clamp_min_(1e-30).sqrt_()
return self._postprocess(res)


class Kernel(Module):
Expand Down Expand Up @@ -101,7 +137,6 @@ def __init__(
eps=1e-6,
**kwargs
):
lengthscale_prior = _deprecate_kwarg(kwargs, "log_lengthscale_prior", "lengthscale_prior", lengthscale_prior)
super(Kernel, self).__init__()
if active_dims is not None and not torch.is_tensor(active_dims):
active_dims = torch.tensor(active_dims, dtype=torch.long)
Expand Down Expand Up @@ -193,29 +228,20 @@ def __pdist_dist(self, x1):

return res

def __slow_sq_dist(self, x1, x2, diag, x1_eq_x2):
# Compute squared distance matrix using quadratic expansion
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
@torch.jit.script_method
def _postprocess(self, dist_mat):
return dist_mat

if diag:
mid = (x1 * x2).sum(dim=-1, keepdim=True)
res = (x1_norm - 2 * mid + x2_norm).squeeze(-1)
else:
mid = x1.matmul(x2.transpose(-2, -1))
res = x1_norm - 2 * mid + x2_norm.transpose(-2, -1)

if x1_eq_x2:
# Ensure zero diagonal
diag_inds = torch.arange(x1.shape[-2])
res[..., diag_inds, diag_inds] = 0

# Zero out negative values
res.clamp_min_(0)

return res

def _covar_dist(self, x1, x2, diag=False, batch_dims=None, square_dist=False, **params):
def _covar_dist(
self,
x1,
x2,
diag=False,
batch_dims=None,
square_dist=False,
postprocess_func=default_postprocess_script,
**params
):
"""
This is a helper method for computing the Euclidean distance between
all pairs of points in x1 and x2.
Expand Down Expand Up @@ -248,6 +274,8 @@ def _covar_dist(self, x1, x2, diag=False, batch_dims=None, square_dist=False, **

res = None

distance_module = Distance(postprocess_func)

if diag:
# Special case the diagonal because we can return all zeros most of the time.
if x1_eq_x2:
Expand All @@ -256,41 +284,18 @@ def _covar_dist(self, x1, x2, diag=False, batch_dims=None, square_dist=False, **
res = torch.zeros(x1.shape[0] * x1.shape[-1], x2.shape[-2], dtype=x1.dtype, device=x1.device)
else:
res = torch.zeros(x1.shape[0], x2.shape[-2], dtype=x1.dtype, device=x1.device)
return res
return postprocess_func(res)
else:
if square_dist:
res = (x1 - x2).pow(2).sum(-1)
else:
res = torch.norm(x1 - x2, p=2, dim=-1)

# TODO: Remove the size check when pytorch/15511 is fixed.
elif x1.size(-2) < 200 and x1_eq_x2:
# Full distance matrix in the square symmetric case
if x1.dim() == 3 and x1.shape[0] == 1:
# If we aren't in batch mode, we can always use torch.pdist
res = self.__pdist_dist(x1.squeeze(0)).unsqueeze(0)
res = res.pow(2) if square_dist else res
elif self.__pdist_supports_batch:
# torch.pdist still works on the latest pytorch-nightly
# TODO: This else branch is not needed on the next PyTorch release (> 1.0.0).
try:
res = self.__pdist_dist(x1)
res = res.pow(2) if square_dist else res
except RuntimeError as e:
if 'pdist only supports 2D tensors, got:' in str(e):
warnings.warn('You are using a version of PyTorch where torch.pdist does not support batch '
'matrices. Falling back on manual distance computation. Updating PyTorch to the '
'latest pytorch-nightly build will offer significant memory savings during kernel'
' computation.')
self.__pdist_supports_batch = False
else:
raise e

if res is None:
if not square_dist:
res = torch.norm(x1.unsqueeze(-2) - x2.unsqueeze(-3), p=2, dim=-1)
else:
res = self.__slow_sq_dist(x1, x2, diag, x1_eq_x2)
res = postprocess_func(res)
elif not square_dist:
res = distance_module._jit_dist(x1, x2, torch.tensor(diag), torch.tensor(x1_eq_x2))
else:
res = distance_module._jit_sq_dist(x1, x2, torch.tensor(diag), torch.tensor(x1_eq_x2))

if batch_dims == (0, 2):
if diag:
Expand Down Expand Up @@ -432,6 +437,9 @@ def forward(self, x1, x2, **params):
res = res + lazify(next_term)
return res

def size(self, x1, x2):
return self.kernels[0].size(x1, x2)


class ProductKernel(Kernel):
"""
Expand All @@ -453,3 +461,6 @@ def forward(self, x1, x2, **params):
next_term = kern(x1, x2, **params)
res = res * lazify(next_term)
return res

def size(self, x1, x2):
return self.kernels[0].size(x1, x2)
2 changes: 0 additions & 2 deletions gpytorch/kernels/matern_kernel.py
Expand Up @@ -3,7 +3,6 @@
import math
import torch
from .kernel import Kernel
from ..utils.deprecation import _deprecate_kwarg
from torch.nn.functional import softplus


Expand Down Expand Up @@ -92,7 +91,6 @@ def __init__(
eps=1e-6,
**kwargs
):
_deprecate_kwarg(kwargs, "log_lengthscale_prior", "lengthscale_prior", lengthscale_prior)
if nu not in {0.5, 1.5, 2.5}:
raise RuntimeError("nu expected to be 0.5, 1.5, or 2.5")
super(MaternKernel, self).__init__(
Expand Down
3 changes: 3 additions & 0 deletions gpytorch/kernels/multi_device_kernel.py
Expand Up @@ -60,3 +60,6 @@ def forward(self, x1, x2, diag=False, **kwargs):

def gather(self, outputs, output_device):
return CatLazyTensor(*[lazify(o) for o in outputs], dim=self.dim, output_device=self.output_device)

def size(self, x1, x2):
return self.base_kernel.size(x1, x2)
5 changes: 0 additions & 5 deletions gpytorch/kernels/periodic_kernel.py
Expand Up @@ -3,7 +3,6 @@
import math
import torch
from .kernel import Kernel
from ..utils.deprecation import _deprecate_kwarg
from torch.nn.functional import softplus


Expand Down Expand Up @@ -83,10 +82,6 @@ def __init__(
eps=1e-6,
**kwargs
):
lengthscale_prior = _deprecate_kwarg(kwargs, "log_lengthscale_prior", "lengthscale_prior", lengthscale_prior)
period_length_prior = _deprecate_kwarg(
kwargs, "log_period_length_prior", "period_length_prior", period_length_prior
)
super(PeriodicKernel, self).__init__(
has_lengthscale=True,
active_dims=active_dims,
Expand Down
3 changes: 3 additions & 0 deletions gpytorch/kernels/product_structure_kernel.py
Expand Up @@ -77,3 +77,6 @@ def __call__(self, x1_, x2_=None, diag=False, batch_dims=None, **params):
.__call__(x1_, x2_, diag=diag, batch_dims=batch_dims, **params)
.evaluate_kernel()
)

def size(self, x1, x2):
return self.base_kernel.size(x1, x2)
14 changes: 9 additions & 5 deletions gpytorch/kernels/rbf_kernel.py
@@ -1,8 +1,13 @@
#!/usr/bin/env python3

from .kernel import Kernel
from ..utils.deprecation import _deprecate_kwarg
from torch.nn.functional import softplus
import torch


@torch.jit.script
def postprocess_rbf(dist_mat):
return dist_mat.div_(-2).exp_()


class RBFKernel(Kernel):
Expand Down Expand Up @@ -77,7 +82,6 @@ def __init__(
eps=1e-6,
**kwargs
):
_deprecate_kwarg(kwargs, "log_lengthscale_prior", "lengthscale_prior", lengthscale_prior)
super(RBFKernel, self).__init__(
has_lengthscale=True,
ard_num_dims=ard_num_dims,
Expand All @@ -89,8 +93,8 @@ def __init__(
eps=eps,
)

def forward(self, x1, x2, **params):
def forward(self, x1, x2, diag=False, **params):
x1_ = x1.div(self.lengthscale)
x2_ = x2.div(self.lengthscale)
diff = self._covar_dist(x1_, x2_, square_dist=True, **params)
return diff.div_(-2).exp_()
diff = self._covar_dist(x1_, x2_, square_dist=True, diag=diag, postprocess_func=postprocess_rbf, **params)
return diff