Skip to content

Commit

Permalink
Fix kernel memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobrgardner committed Dec 17, 2018
1 parent 62bd0af commit ed40be9
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 47 deletions.
4 changes: 1 addition & 3 deletions gpytorch/kernels/cosine_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def _set_period_length(self, value):
def forward(self, x1, x2, **params):
x1_ = x1.div(self.period_length)
x2_ = x2.div(self.period_length)
x1_, x2_ = self._create_input_grid(x1_, x2_, **params)

diff = torch.norm((x1_ - x2_).abs(), 2, -1)
diff = self._covar_sq_dist(x1_, x2_, **params).sqrt_()
res = torch.cos(diff.mul(math.pi))
return res
48 changes: 17 additions & 31 deletions gpytorch/kernels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,42 +187,28 @@ def forward(self, x1, x2, diag=False, batch_dims=None, **params):
"""
raise NotImplementedError()

def _create_input_grid(self, x1, x2, diag=False, batch_dims=None, **params):
"""
This is a helper method for creating a grid of the kernel's inputs.
Use this helper rather than maually creating a meshgrid.
The grid dimensions depend on the kernel's evaluation mode.
def _covar_sq_dist(self, x1, x2, **params):
if params.get('batch_dims') == (0, 2):
x1 = x1.unsqueeze(0).transpose(0, -1)
x2 = x2.unsqueeze(0).transpose(0, -1)

Args:
:attr:`x1` (Tensor `n x d` or `b x n x d`)
:attr:`x2` (Tensor `m x d` or `b x m x d`) - for diag mode, these must be the same inputs
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)

Returns:
(:class:`Tensor`, :class:`Tensor) corresponding to the gridded `x1` and `x2`.
The shape depends on the kernel's mode
if params.get('diag'):
mid = (x1 * x2).sum(dim=-1, keepdim=True)
res = (x1_norm - 2 * mid + x2_norm).squeeze(-1)
else:
mid = x1.matmul(x2.transpose(-2, -1))
res = x1_norm - 2 * mid + x2_norm.transpose(-2, -1)

* `full_covar`: (`b x n x 1 x d` and `b x 1 x m x d`)
* `full_covar` with `batch_dims=(0, 2)`: (`b x k x n x 1 x 1` and `b x k x 1 x m x 1`)
* `diag`: (`b x n x d` and `b x n x d`)
* `diag` with `batch_dims=(0, 2)`: (`b x k x n x 1` and `b x k x n x 1`)
"""
x1_, x2_ = x1, x2
if batch_dims == (0, 2):
x1_ = x1_.view(*x1.size()[:-1], -1, 1)
x1_ = x1_.permute(0, -2, *list(range(1, x1_.dim() - 2)), -1).contiguous()
x1_ = x1_.view(-1, *x1_.size()[2:])
if torch.equal(x1, x2):
x2_ = x1_
if params.get('batch_dims') == (0, 2):
if params.get('diag'):
res = res.transpose(0, 1).contiguous().view(-1, res.shape[-1])
else:
x2_ = x2_.view(*x2.size()[:-1], -1, 1)
x2_ = x2_.permute(0, -2, *list(range(1, x2_.dim() - 2)), -1).contiguous()
x2_ = x2_.view(-1, *x2_.size()[2:])
res = res.transpose(0, 1).contiguous().view(-1, *res.shape[-2:])

if diag:
return x1_, x2_
else:
return x1_.unsqueeze(-2), x2_.unsqueeze(-3)
return res

def __call__(self, x1, x2=None, diag=False, batch_dims=None, **params):
x1_, x2_ = x1, x2
Expand Down
4 changes: 1 addition & 3 deletions gpytorch/kernels/matern_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def forward(self, x1, x2, **params):

x1_ = (x1 - mean).div(self.lengthscale)
x2_ = (x2 - mean).div(self.lengthscale)
x1_, x2_ = self._create_input_grid(x1_, x2_, **params)

distance = (x1_ - x2_).norm(2, dim=-1)
distance = self._covar_sq_dist(x1_, x2_, **params).sqrt_()
exp_component = torch.exp(-math.sqrt(self.nu * 2) * distance)

if self.nu == 0.5:
Expand Down
4 changes: 1 addition & 3 deletions gpytorch/kernels/periodic_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def _set_period_length(self, value):
def forward(self, x1, x2, **params):
x1_ = x1.div(self.period_length)
x2_ = x2.div(self.period_length)
x1_, x2_ = self._create_input_grid(x1_, x2_, **params)

diff = torch.sum((x1_ - x2_).abs(), -1)
diff = self._covar_sq_dist(x1_, x2_, **params).sqrt_()
res = torch.sin(diff.mul(math.pi)).pow(2).mul(-2 / self.lengthscale).exp_()
if diff.ndimension() == 2:
res = res.squeeze(0)
Expand Down
6 changes: 2 additions & 4 deletions gpytorch/kernels/rbf_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,5 @@ def __init__(
def forward(self, x1, x2, **params):
x1_ = x1.div(self.lengthscale)
x2_ = x2.div(self.lengthscale)
x1_, x2_ = self._create_input_grid(x1_, x2_, **params)

diff = (x1_ - x2_).norm(2, dim=-1)
return diff.pow(2).div_(-2).exp_()
diff = self._covar_sq_dist(x1_, x2_, **params)
return diff.div_(-2).exp_()
37 changes: 37 additions & 0 deletions gpytorch/kernels/spectral_mixture_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,43 @@ def initialize_from_data(self, train_x, train_y, **kwargs):
self.raw_mixture_weights.data.fill_(train_y.std() / self.num_mixtures)
self.raw_mixture_weights.data = self._inv_param_transform(self.raw_mixture_weights.data)

def _create_input_grid(self, x1, x2, diag=False, batch_dims=None, **params):
"""
This is a helper method for creating a grid of the kernel's inputs.
Use this helper rather than maually creating a meshgrid.
The grid dimensions depend on the kernel's evaluation mode.
Args:
:attr:`x1` (Tensor `n x d` or `b x n x d`)
:attr:`x2` (Tensor `m x d` or `b x m x d`) - for diag mode, these must be the same inputs
Returns:
(:class:`Tensor`, :class:`Tensor) corresponding to the gridded `x1` and `x2`.
The shape depends on the kernel's mode
* `full_covar`: (`b x n x 1 x d` and `b x 1 x m x d`)
* `full_covar` with `batch_dims=(0, 2)`: (`b x k x n x 1 x 1` and `b x k x 1 x m x 1`)
* `diag`: (`b x n x d` and `b x n x d`)
* `diag` with `batch_dims=(0, 2)`: (`b x k x n x 1` and `b x k x n x 1`)
"""
x1_, x2_ = x1, x2
if batch_dims == (0, 2):
x1_ = x1_.view(*x1.size()[:-1], -1, 1)
x1_ = x1_.permute(0, -2, *list(range(1, x1_.dim() - 2)), -1).contiguous()
x1_ = x1_.view(-1, *x1_.size()[2:])
if torch.equal(x1, x2):
x2_ = x1_
else:
x2_ = x2_.view(*x2.size()[:-1], -1, 1)
x2_ = x2_.permute(0, -2, *list(range(1, x2_.dim() - 2)), -1).contiguous()
x2_ = x2_.view(-1, *x2_.size()[2:])

if diag:
return x1_, x2_
else:
return x1_.unsqueeze(-2), x2_.unsqueeze(-3)

def forward(self, x1, x2, **params):
batch_size, n, num_dims = x1.size()
_, m, _ = x2.size()
Expand Down
6 changes: 3 additions & 3 deletions test/kernels/test_grid_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,23 @@ def test_grid_grid(self):
self.assertIsInstance(grid_covar, KroneckerProductLazyTensor)
grid_eval = kernel(grid_data, grid_data).evaluate()
actual_eval = base_kernel(grid_data, grid_data).evaluate()
self.assertLess(torch.norm(grid_eval - actual_eval), 1e-5)
self.assertLess(torch.norm(grid_eval - actual_eval), 2e-5)

def test_nongrid_grid(self):
base_kernel = RBFKernel()
data = torch.randn(5, 2)
kernel = GridKernel(base_kernel, grid)
grid_eval = kernel(grid_data, data).evaluate()
actual_eval = base_kernel(grid_data, data).evaluate()
self.assertLess(torch.norm(grid_eval - actual_eval), 1e-5)
self.assertLess(torch.norm(grid_eval - actual_eval), 2e-5)

def test_nongrid_nongrid(self):
base_kernel = RBFKernel()
data = torch.randn(5, 2)
kernel = GridKernel(base_kernel, grid)
grid_eval = kernel(data, data).evaluate()
actual_eval = base_kernel(data, data).evaluate()
self.assertLess(torch.norm(grid_eval - actual_eval), 1e-5)
self.assertLess(torch.norm(grid_eval - actual_eval), 2e-5)


if __name__ == "__main__":
Expand Down

0 comments on commit ed40be9

Please sign in to comment.