Skip to content

Commit

Permalink
Add support for ARD. Simplify RBF Kernel
Browse files Browse the repository at this point in the history
Adds support for ARD on the level of the Kernel class.
Also simplifies the RBF kernel and makes it compatible with multiple
lengthscales.
  • Loading branch information
Balandat committed Mar 18, 2018
1 parent 2cdb3d3 commit 6563425
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 70 deletions.
55 changes: 41 additions & 14 deletions gpytorch/kernels/grid_interpolation_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,63 @@


class GridInterpolationKernel(GridKernel):
def __init__(self, base_kernel_module, grid_size, grid_bounds, active_dims=None):

def __init__(
self,
base_kernel_module,
grid_size,
grid_bounds,
active_dims=None,
):
grid = torch.zeros(len(grid_bounds), grid_size)
for i in range(len(grid_bounds)):
grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_size - 2)
grid[i] = torch.linspace(grid_bounds[i][0] - grid_diff,
grid_bounds[i][1] + grid_diff,
grid_size)
grid[i] = torch.linspace(
grid_bounds[i][0] - grid_diff,
grid_bounds[i][1] + grid_diff,
grid_size,
)

inducing_points = torch.zeros(int(pow(grid_size, len(grid_bounds))), len(grid_bounds))
inducing_points = torch.zeros(
int(pow(grid_size, len(grid_bounds))),
len(grid_bounds),
)
prev_points = None
for i in range(len(grid_bounds)):
for j in range(grid_size):
inducing_points[j * grid_size ** i:(j + 1) * grid_size ** i, i].fill_(grid[i, j])
inducing_points[
j * grid_size ** i:(j + 1) * grid_size ** i, i
].fill_(grid[i, j])
if prev_points is not None:
inducing_points[j * grid_size ** i:(j + 1) * grid_size ** i, :i].copy_(prev_points)
inducing_points[
j * grid_size ** i:(j + 1) * grid_size ** i, :i
].copy_(prev_points)
prev_points = inducing_points[:grid_size ** (i + 1), :(i + 1)]

super(GridInterpolationKernel, self).__init__(
base_kernel_module,
inducing_points,
grid,
base_kernel_module=base_kernel_module,
inducing_points=inducing_points,
grid=grid,
active_dims=active_dims,
)

def _compute_grid(self, inputs):
batch_size, n_data, n_dimensions = inputs.size()
inputs = inputs.view(batch_size * n_data, n_dimensions)
interp_indices, interp_values = Interpolation().interpolate(Variable(self.grid), inputs)
interp_indices, interp_values = Interpolation().interpolate(
Variable(self.grid),
inputs,
)
interp_indices = interp_indices.view(batch_size, n_data, -1)
interp_values = interp_values.view(batch_size, n_data, -1)
return interp_indices, interp_values

def _inducing_forward(self):
inducing_points_var = Variable(self.inducing_points)
return super(GridInterpolationKernel, self).forward(inducing_points_var, inducing_points_var)
return super(GridInterpolationKernel, self).forward(
inducing_points_var,
inducing_points_var,
)

def forward(self, x1, x2, **kwargs):
base_lazy_var = self._inducing_forward()
Expand All @@ -58,5 +80,10 @@ def forward(self, x1, x2, **kwargs):
right_interp_values = left_interp_values
else:
right_interp_indices, right_interp_values = self._compute_grid(x2)
return InterpolatedLazyVariable(base_lazy_var, left_interp_indices, left_interp_values,
right_interp_indices, right_interp_values)
return InterpolatedLazyVariable(
base_lazy_var,
left_interp_indices,
left_interp_values,
right_interp_indices,
right_interp_values,
)
23 changes: 19 additions & 4 deletions gpytorch/kernels/grid_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@


class GridKernel(Kernel):
def __init__(self, base_kernel_module, inducing_points, grid, active_dims=None):

def __init__(
self,
base_kernel_module,
inducing_points,
grid,
active_dims=None,
):
super(GridKernel, self).__init__(active_dims=active_dims)
self.base_kernel_module = base_kernel_module
if inducing_points.ndimension() != 2:
Expand All @@ -25,8 +32,13 @@ def train(self, mode=True):
return super(GridKernel, self).train(mode)

def forward(self, x1, x2, **kwargs):
if not torch.equal(x1.data, self.inducing_points) or not torch.equal(x2.data, self.inducing_points):
raise RuntimeError('The kernel should only receive the inducing points as input')
if (
not torch.equal(x1.data, self.inducing_points)
or not torch.equal(x2.data, self.inducing_points)
):
raise RuntimeError(
'The kernel should only receive the inducing points as input'
)

if not self.training and hasattr(self, '_cached_kernel_mat'):
return self._cached_kernel_mat
Expand All @@ -38,7 +50,10 @@ def forward(self, x1, x2, **kwargs):
if settings.use_toeplitz.on():
first_item = grid_var[:, 0:1].contiguous()
covar_columns = self.base_kernel_module(first_item, grid_var, **kwargs)
covars = [ToeplitzLazyVariable(covar_columns[i:i + 1].squeeze(-2)) for i in range(n_dim)]
covars = [
ToeplitzLazyVariable(covar_columns[i:i + 1].squeeze(-2))
for i in range(n_dim)
]
else:
grid_var = grid_var.view(n_dim, -1, 1)
covars = self.base_kernel_module(grid_var, grid_var, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions gpytorch/kernels/index_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


class IndexKernel(Kernel):

def __init__(
self,
n_tasks,
Expand Down
29 changes: 27 additions & 2 deletions gpytorch/kernels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,36 @@
from __future__ import print_function
from __future__ import unicode_literals

import torch
from ..module import Module


class Kernel(Module):
def __init__(self, active_dims=None):
self.active_dims = active_dims

def __init__(
self,
has_lengthscale=False,
ard_num_dims=None,
log_lengthscale_bounds=(-10000, 10000),
active_dims=None,
):
super(Kernel, self).__init__()
self.active_dims = active_dims
self.ard_num_dims = ard_num_dims
if has_lengthscale:
lengthscale_num_dims = 1 if ard_num_dims is None else ard_num_dims
self.register_parameter(
'log_lengthscale',
torch.nn.Parameter(torch.Tensor(1, 1, lengthscale_num_dims)),
bounds=log_lengthscale_bounds,
)

@property
def lengthscale(self):
if 'log_lengthscale' in self.named_parameters().keys():
return self.log_lengthscale.exp()
else:
return None

def forward(self, x1, x2, **params):
raise NotImplementedError()
Expand Down Expand Up @@ -53,6 +76,7 @@ def __mul__(self, other):


class AdditiveKernel(Kernel):

def __init__(self, kernel_1, kernel_2):
super(AdditiveKernel, self).__init__()
self.kernel_1 = kernel_1
Expand All @@ -63,6 +87,7 @@ def forward(self, x1, x2):


class ProductKernel(Kernel):

def __init__(self, kernel_1, kernel_2):
super(ProductKernel, self).__init__()
self.kernel_1 = kernel_1
Expand Down
34 changes: 20 additions & 14 deletions gpytorch/kernels/matern_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,27 @@

import math
import torch
from torch import nn
from gpytorch.kernels import Kernel
from .kernel import Kernel


class MaternKernel(Kernel):

def __init__(
self,
nu,
ard_num_dims=None,
log_lengthscale_bounds=(-10000, 10000),
active_dims=None,
):
super(MaternKernel, self).__init__(active_dims=active_dims)
if nu not in [0.5, 1.5, 2.5]:
if nu not in {0.5, 1.5, 2.5}:
raise RuntimeError('nu expected to be 0.5, 1.5, or 2.5')
self.nu = nu
self.register_parameter(
'log_lengthscale',
nn.Parameter(torch.zeros(1, 1, 1)),
bounds=log_lengthscale_bounds,
super(MaternKernel, self).__init__(
has_lengthscale=True,
ard_num_dims=ard_num_dims,
log_lengthscale_bounds=log_lengthscale_bounds,
active_dims=active_dims,
)
self.nu = nu

def forward(self, x1, x2):
lengthscale = (self.log_lengthscale.exp()).sqrt()
Expand All @@ -36,7 +37,11 @@ def forward(self, x1, x2):
x2_squared = x2_normed.norm(2, -1).pow(2)
x1_t_x_2 = torch.matmul(x1_normed, x2_normed.transpose(-1, -2))

distance_over_rho = (x1_squared.unsqueeze(-1) + x2_squared.unsqueeze(-2) - x1_t_x_2.mul(2))
distance_over_rho = (
x1_squared.unsqueeze(-1)
+ x2_squared.unsqueeze(-2)
- x1_t_x_2.mul(2)
)
distance_over_rho = distance_over_rho.clamp(0, 1e10).sqrt()
exp_component = torch.exp(-math.sqrt(self.nu * 2) * distance_over_rho)

Expand All @@ -45,8 +50,9 @@ def forward(self, x1, x2):
elif self.nu == 1.5:
constant_component = ((math.sqrt(3) * distance_over_rho).add(1))
elif self.nu == 2.5:
constant_component = ((math.sqrt(5) * distance_over_rho).
add(1).
add(5. / 3. * distance_over_rho ** 2))

constant_component = (
(math.sqrt(5) * distance_over_rho).
add(1).
add(5. / 3. * distance_over_rho ** 2)
)
return constant_component * exp_component
7 changes: 4 additions & 3 deletions gpytorch/kernels/multiplicative_grid_interpolation_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


class MultiplicativeGridInterpolationKernel(GridInterpolationKernel):

def __init__(
self,
base_kernel_module,
Expand All @@ -18,9 +19,9 @@ def __init__(
active_dims=None,
):
super(MultiplicativeGridInterpolationKernel, self).__init__(
base_kernel_module,
grid_size,
grid_bounds,
base_kernel_module=base_kernel_module,
grid_size=grid_size,
grid_bounds=grid_bounds,
active_dims=active_dims
)
self.n_components = n_components
Expand Down
15 changes: 7 additions & 8 deletions gpytorch/kernels/periodic_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@ def __init__(
eps=1e-5,
active_dims=None,
):
super(PeriodicKernel, self).__init__(active_dims=active_dims)
self.eps = eps
self.register_parameter(
'log_lengthscale',
nn.Parameter(torch.zeros(1, 1, 1)),
bounds=log_lengthscale_bounds,
super(PeriodicKernel, self).__init__(
has_lengthscale=True,
log_lengthscale_bounds=log_lengthscale_bounds,
active_dims=active_dims,
)
self.eps = eps
self.register_parameter(
'log_period_length',
nn.Parameter(torch.zeros(1, 1, 1)),
nn.Parameter(torch.zeros(1, 1)),
bounds=log_period_length_bounds,
)

Expand All @@ -37,4 +36,4 @@ def forward(self, x1, x2):
period_length = (self.log_period_length.exp() + self.eps).sqrt_()
diff = torch.sum((x1.unsqueeze(2) - x2.unsqueeze(1)).abs(), -1)
res = - 2 * torch.sin(math.pi * diff / period_length).pow(2) / lengthscale
return res.exp()
return res.exp().unsqueeze(1)
27 changes: 11 additions & 16 deletions gpytorch/kernels/rbf_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,27 @@
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch import nn
from .kernel import Kernel


class RBFKernel(Kernel):

def __init__(
self,
ard_num_dims=None,
log_lengthscale_bounds=(-10000, 10000),
eps=1e-5,
active_dims=None,
):
super(RBFKernel, self).__init__(active_dims=active_dims)
super(RBFKernel, self).__init__(
has_lengthscale=True,
ard_num_dims=ard_num_dims,
log_lengthscale_bounds=log_lengthscale_bounds,
active_dims=active_dims,
)
self.eps = eps
self.register_parameter('log_lengthscale', nn.Parameter(torch.zeros(1, 1)),
bounds=log_lengthscale_bounds)

def forward(self, x1, x2):
lengthscale = (self.log_lengthscale.exp() + self.eps).sqrt_()
mean = x1.mean(1).mean(0)
x1_normed = (x1 - mean.unsqueeze(0).unsqueeze(1)).div_(lengthscale)
x2_normed = (x2 - mean.unsqueeze(0).unsqueeze(1)).div_(lengthscale)

x1_squared = x1_normed.norm(2, -1).pow(2)
x2_squared = x2_normed.norm(2, -1).pow(2)
x1_t_x_2 = torch.matmul(x1_normed, x2_normed.transpose(-1, -2))
res = (x1_squared.unsqueeze(-1) - x1_t_x_2.mul_(2) + x2_squared.unsqueeze(-2)).mul_(-1)
res = res.exp()
return res
lengthscales = self.log_lengthscale.exp() + self.eps
diff = (x1.unsqueeze(2) - x2.unsqueeze(1)).div_(lengthscales.sqrt())
return diff.pow_(2).sum(-1).mul_(-1).exp_()

0 comments on commit 6563425

Please sign in to comment.