Skip to content

Commit

Permalink
Grid bounds can be automatically determined for KISS-GP regression
Browse files Browse the repository at this point in the history
  • Loading branch information
gpleiss committed Sep 18, 2018
1 parent d19a04d commit bd28a67
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 42 deletions.
81 changes: 71 additions & 10 deletions gpytorch/kernels/additive_grid_interpolation_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,91 @@
from __future__ import print_function
from __future__ import unicode_literals

import warnings
from .grid_interpolation_kernel import GridInterpolationKernel
from ..utils import Interpolation


class AdditiveGridInterpolationKernel(GridInterpolationKernel):
def __init__(self, base_kernel_module, grid_size, grid_bounds, n_components, active_dims=None):
r"""
A variant of :class:`~gpytorch.kernels.GridInterpolationKernel` designed specifically
for additive kernels. If a kernel decomposes additively, then this module will be much more
computationally efficient.
A kernel function `k` decomposes additively if it can be written as
.. math::
\begin{equation*}
k(\mathbf{x_1}, \mathbf{x_2}) = k'(x_1^{(1)}, x_2^{(1)}) + \ldots + k'(x_1^({d}), x_2^{(d)})
\end{equation*}
for some kernel :math:`k'` that operates on a subset of dimensions.
The groupings of dimensions are specified by the :attr:`dim_groups` attribute.
* `dim_groups=d` (d is the dimensionality of :math:`\mathbf x`): the kernel
:math:`k` will be the sum of `d` sub-kernels, each operating on one dimension of :math:`\mathbf x`.
* `dim_groups=d/2`: the first sub-kernel operates on dimensions 1 and 2, the second sub-kernel
operates on dimensions 3 and 4, etc.
* `dim_groups=1`: there is no additive decomposition
.. note::
`AdditiveGridInterpolationKernel` can only wrap **stationary kernels** (such as RBF, Matern,
Periodic, Spectral Mixture, etc.)
Args:
:attr:`base_kernel_module` (Kernel):
The kernel to approximate with KISS-GP
:attr:`grid_size` (int):
The size of the grid (in each dimension)
:attr:`num_dims` (int):
The dimension of the input data. Required if `grid_bounds=None`
:attr:`dim_groups` (int):
The number of additive components
:attr:`grid_bounds` (tuple(float, float), optional):
The bounds of the grid, if known (high performance mode).
The length of the tuple must match the size of the dim group (num_dims // dim_groups).
The entries represent the min/max values for each dimension.
:attr:`active_dims` (tuple of ints, optional):
Passed down to the `base_kernel_module`.
"""

def __init__(
self,
base_kernel_module,
grid_size,
dim_groups=None,
num_dims=None,
grid_bounds=None,
active_dims=None,
n_components=None,
):
if n_components is not None:
warnings.warn("n_components is deprecated. Use dim_groups instead.", DeprecationWarning)
dim_groups = n_components
if dim_groups is None:
raise RuntimeError("Must supply dim_groups")

super(AdditiveGridInterpolationKernel, self).__init__(
base_kernel_module, grid_size, grid_bounds, active_dims=active_dims
base_kernel_module, grid_size, num_dims // dim_groups, grid_bounds, active_dims=active_dims
)
self.n_components = n_components

self.dim_groups = dim_groups

def _compute_grid(self, inputs):
inputs = inputs.view(inputs.size(0), inputs.size(1), self.n_components, -1)
batch_size, n_data, n_components, n_dimensions = inputs.size()
inputs = inputs.transpose(0, 2).contiguous().view(n_components * batch_size * n_data, n_dimensions)
inputs = inputs.view(inputs.size(0), inputs.size(1), self.dim_groups, -1)
batch_size, n_data, dim_groups, n_dimensions = inputs.size()
inputs = inputs.transpose(0, 2).contiguous().view(dim_groups * batch_size * n_data, n_dimensions)
interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs)
interp_indices = interp_indices.view(n_components * batch_size, n_data, -1)
interp_values = interp_values.view(n_components * batch_size, n_data, -1)
interp_indices = interp_indices.view(dim_groups * batch_size, n_data, -1)
interp_values = interp_values.view(dim_groups * batch_size, n_data, -1)
return interp_indices, interp_values

def _inducing_forward(self):
res = super(AdditiveGridInterpolationKernel, self)._inducing_forward()
return res.repeat(self.n_components, 1, 1)
return res.repeat(self.dim_groups, 1, 1)

def forward(self, x1, x2):
res = super(AdditiveGridInterpolationKernel, self).forward(x1, x2)
return res.sum_batch(sum_batch_size=self.n_components)
return res.sum_batch(sum_batch_size=self.dim_groups)
146 changes: 133 additions & 13 deletions gpytorch/kernels/grid_interpolation_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,117 @@


class GridInterpolationKernel(GridKernel):
def __init__(self, base_kernel_module, grid_size, grid_bounds, active_dims=None):
grid = torch.zeros(len(grid_bounds), grid_size)
for i in range(len(grid_bounds)):
grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_size - 2)
grid[i] = torch.linspace(grid_bounds[i][0] - grid_diff, grid_bounds[i][1] + grid_diff, grid_size)
r"""
Implements the KISS-GP (or SKI) approximation for a given kernel.
It was proposed in `Kernel Interpolation for Scalable Structured Gaussian Processes`_,
and offers extremely fast and accurate Kernel approximations for large datasets.
inducing_points = torch.zeros(int(pow(grid_size, len(grid_bounds))), len(grid_bounds))
prev_points = None
for i in range(len(grid_bounds)):
for j in range(grid_size):
inducing_points[j * grid_size ** i : (j + 1) * grid_size ** i, i].fill_(grid[i, j])
if prev_points is not None:
inducing_points[j * grid_size ** i : (j + 1) * grid_size ** i, :i].copy_(prev_points)
prev_points = inducing_points[: grid_size ** (i + 1), : (i + 1)]
Given a base kernel `k`, the covariance :math:`k(\mathbf{x_1}, \mathbf{x_2})` is approximated by
using a grid of regularly spaced *inducing points*:
.. math::
\begin{equation*}
k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{U,U} \mathbf{w_{x_2}}
\end{equation*}
where
* :math:`U` is the set of gridded inducing points
* :math:`K_{U,U}` is the kernel matrix between the inducing points
* :math:`\mathbf{w_{x_1}` and :math:`\mathbf{w_{x_2}}` are sparse vectors based on
:math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` that apply cubic interpolation.
The user should supply the size of the grid (using the :attr:`grid_size` attribute).
To choose a reasonable grid value, we highly recommend using the
:func:`gpytorch.utils.choose_grid_size` helper function.
The bounds of the grid will automatically be determined by data.
(Alternatively, you can hard-code bounds using the :attr:`grid_bounds`, which
will speed up this kernel's computations.)
.. note::
`GridInterpolationKernel` can only wrap **stationary kernels** (such as RBF, Matern,
Periodic, Spectral Mixture, etc.)
Args:
:attr:`base_kernel_module` (Kernel):
The kernel to approximate with KISS-GP
:attr:`grid_size` (int):
The size of the grid (in each dimension)
:attr:`num_dims` (int):
The dimension of the input data. Required if `grid_bounds=None`
:attr:`grid_bounds` (tuple(float, float), optional):
The bounds of the grid, if known (high performance mode).
The length of the tuple must match the number of dimensions.
The entries represent the min/max values for each dimension.
:attr:`active_dims` (tuple of ints, optional):
Passed down to the `base_kernel_module`.
.. Kernel Interpolation for Scalable Structured Gaussian Processes:
http://proceedings.mlr.press/v37/wilson15.pdf
"""

def __init__(self, base_kernel_module, grid_size, num_dims=None, grid_bounds=None, active_dims=None):
has_initialized_grid = 0
grid_is_dynamic = True

# Make some temporary grid bounds, if none exist
if grid_bounds is None:
if num_dims is None:
raise RuntimeError("num_dims must be supplied if grid_bounds is None")
else:
# Create some temporary grid bounds - they'll be changed soon
grid_bounds = tuple((-1., 1.) for _ in range(num_dims))
else:
has_initialized_grid = 1
grid_is_dynamic = False
if num_dims is None:
num_dims = len(grid_bounds)
elif num_dims != len(grid_bounds):
raise RuntimeError(
"num_dims ({}) disagrees with the number of supplied "
"grid_bounds ({})".format(num_dims, len(grid_bounds))
)

# Initialize values and the grid
self.grid_is_dynamic = grid_is_dynamic
self.num_dims = num_dims
self.grid_size = grid_size
self.grid_bounds = grid_bounds
inducing_points, grid = self._create_grid()

super(GridInterpolationKernel, self).__init__(
base_kernel_module=base_kernel_module, inducing_points=inducing_points, grid=grid, active_dims=active_dims
)
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.uint8))

def _create_grid(self):
grid = torch.zeros(len(self.grid_bounds), self.grid_size)
for i in range(len(self.grid_bounds)):
grid_diff = float(self.grid_bounds[i][1] - self.grid_bounds[i][0]) / (self.grid_size - 2)
grid[i] = torch.linspace(
self.grid_bounds[i][0] - grid_diff, self.grid_bounds[i][1] + grid_diff, self.grid_size
)

inducing_points = torch.zeros(int(pow(self.grid_size, len(self.grid_bounds))), len(self.grid_bounds))
prev_points = None
for i in range(len(self.grid_bounds)):
for j in range(self.grid_size):
inducing_points[j * self.grid_size ** i : (j + 1) * self.grid_size ** i, i].fill_(grid[i, j])
if prev_points is not None:
inducing_points[j * self.grid_size ** i : (j + 1) * self.grid_size ** i, :i].copy_(prev_points)
prev_points = inducing_points[: self.grid_size ** (i + 1), : (i + 1)]

return inducing_points, grid

@property
def _tight_grid_bounds(self):
grid_spacings = tuple((bound[1] - bound[0]) / self.grid_size for bound in self.grid_bounds)
return tuple(
(bound[0] + 2.01 * spacing, bound[1] - 2.01 * spacing)
for bound, spacing in zip(self.grid_bounds, grid_spacings)
)

@property
def has_custom_exact_predictions(self):
Expand All @@ -49,6 +142,33 @@ def forward_diag(self, x1, x2, **kwargs):
return super(Kernel, self).__call__(x1, x2, **kwargs).diag().unsqueeze(-1)

def forward(self, x1, x2, **kwargs):
# See if we need to update the grid or not
if self.grid_is_dynamic: # This is true if a grid_bounds wasn't passed in
if torch.equal(x1, x2):
x = x1.view(-1, self.num_dims)
else:
x = torch.cat([x1.view(-1, self.num_dims), x2.view(-1, self.num_dims)])
x_maxs = x.max(0)[0].tolist()
x_mins = x.min(0)[0].tolist()

# We need to update the grid if
# 1) it hasn't ever been initialized, or
# 2) if any of the grid points are "out of bounds"
update_grid = (not self.has_initialized_grid.item()) or any(
x_min < bound[0] or x_max > bound[1]
for x_min, x_max, bound in zip(x_mins, x_maxs, self._tight_grid_bounds)
)

# Update the grid if needed
if update_grid:
grid_spacings = tuple((x_max - x_min) / (self.grid_size - 4.02) for x_min, x_max in zip(x_mins, x_maxs))
self.grid_bounds = tuple(
(x_min - 2.01 * spacing, x_max + 2.01 * spacing)
for x_min, x_max, spacing in zip(x_mins, x_maxs, grid_spacings)
)
inducing_points, grid = self._create_grid()
self.update_inducing_points_and_grid(inducing_points, grid)

base_lazy_tsr = self._inducing_forward()
if x1.size(0) > 1:
base_lazy_tsr = base_lazy_tsr.repeat(x1.size(0), 1, 1)
Expand Down
42 changes: 42 additions & 0 deletions gpytorch/kernels/grid_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,38 @@


class GridKernel(Kernel):
r"""
If the input data :math:`X` are regularly spaced on a grid, then
`GridKernel` can dramatically speed up computatations for stationary kernel.
GridKernel exploits Toeplitz and Kronecker structure within the covariance matrix.
See `Fast kernel learning for multidimensional pattern extrapolation`_ for more info.
Implements the KISS-GP (or SKI) approximation for a given kernel.
It was proposed in `Kernel Interpolation for Scalable Structured Gaussian Processes`_,
and offers extremely fast and accurate Kernel approximations for large datasets.
Given a base kernel `k`, the covariance :math:`k(\mathbf{x_1}, \mathbf{x_2})` is approximated by
using a grid of regularly spaced *inducing points*:
.. note::
`GridKernel` can only wrap **stationary kernels** (such as RBF, Matern,
Periodic, Spectral Mixture, etc.)
Args:
:attr:`base_kernel_module` (Kernel):
The kernel to speed up with grid methods.
:attr:`inducing_points` (Tensor, n x d):
This will be the set of points that lie on the grid.
:attr:`grid` (Tensor, k x d):
The exact grid points.
:attr:`active_dims` (tuple of ints, optional):
Passed down to the `base_kernel_module`.
.. Fast kernel learning for multidimensional pattern extrapolation:
http://www.cs.cmu.edu/~andrewgw/manet.pdf
"""

def __init__(self, base_kernel_module, inducing_points, grid, active_dims=None):
super(GridKernel, self).__init__(active_dims=active_dims)
self.base_kernel_module = base_kernel_module
Expand All @@ -23,6 +55,16 @@ def train(self, mode=True):
del self._cached_kernel_mat
return super(GridKernel, self).train(mode)

def update_inducing_points_and_grid(self, inducing_points, grid):
"""
Supply a new set of `inducing_points` and a new `grid` if they ever change.
"""
self.inducing_points.detach().resize_(inducing_points.size()).copy_(inducing_points)
self.grid.detach().resize_(grid.size()).copy_(grid)
if hasattr(self, "_cached_kernel_mat"):
del self._cached_kernel_mat
return self

def forward(self, x1, x2, **kwargs):
if not torch.equal(x1, self.inducing_points) or not torch.equal(x2, self.inducing_points):
raise RuntimeError("The kernel should only receive the inducing points as input")
Expand Down

0 comments on commit bd28a67

Please sign in to comment.