Skip to content

Commit

Permalink
LazyTensor refactor
Browse files Browse the repository at this point in the history
- Simplify _getitem
- Remove _get_indices
- Add unsqueeze, squeeze
- Add _expand_batch
- Remove sum_batch, mul_batch - add sum and prod
- SumBatchLazyTensor and BlockDiagLazyTensor use explicit (not implicit) batches
- Add broadcasting to the last of the LazyTensors
  • Loading branch information
gpleiss committed Mar 18, 2019
1 parent 3d939f8 commit c517e0b
Show file tree
Hide file tree
Showing 36 changed files with 938 additions and 1,193 deletions.
21 changes: 11 additions & 10 deletions gpytorch/distributions/multitask_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from ..lazy import BlockDiagLazyTensor, CatLazyTensor, LazyTensor, NonLazyTensor
from ..lazy import BlockDiagLazyTensor, CatLazyTensor, LazyTensor
from .multivariate_normal import MultivariateNormal


Expand Down Expand Up @@ -61,15 +61,16 @@ def from_independent_mvns(cls, mvns):
# covariance matrices. Instead, we want to use the lazies directly in the
# BlockDiagLazyTensor. This will require implementing a new BatchLazyTensor:
# https://github.com/cornellius-gp/gpytorch/issues/468
batch_mode = len(mvns[0].covariance_matrix.shape) == 3
if batch_mode:
covar_blocks_lazy = CatLazyTensor(
*[mvn.lazy_covariance_matrix for mvn in mvns], dim=0, output_device=mean.device
)
else:
covar_blocks_lazy = NonLazyTensor(torch.cat([mvn.covariance_matrix.unsqueeze(0) for mvn in mvns], dim=0))
covar_lazy = BlockDiagLazyTensor(covar_blocks_lazy, num_blocks=len(mvns) if batch_mode else None)
return cls(mean=mean, covariance_matrix=covar_lazy, interleaved=False)
covar_blocks_lazy = CatLazyTensor(
*[mvn.lazy_covariance_matrix.unsqueeze(0) for mvn in mvns],
dim=0,
output_device=mean.device
)
covar_lazy = BlockDiagLazyTensor(
covar_blocks_lazy,
block_dim=0
)
return cls(mean=mean, covariance_matrix=covar_lazy)

def get_base_samples(self, sample_shape=torch.Size()):
"""Get i.i.d. standard Normal samples (to be used with rsample(base_samples=base_samples))"""
Expand Down
1 change: 0 additions & 1 deletion gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def _unbroadcasted_scale_tril(self, ust):
self.__unbroadcasted_scale_tril = ust

def expand(self, batch_size):

new_loc = self.loc.expand(torch.Size(batch_size) + self.loc.shape[-1:])
new_covar = self._covar.expand(torch.Size(batch_size) + self._covar.shape[-2:])
res = self.__class__(new_loc, new_covar)
Expand Down
1 change: 0 additions & 1 deletion gpytorch/functions/_root_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def forward(self, *matrix_args):
def backward(self, root_grad_output, inverse_grad_output):
# Taken from http://homepages.inf.ed.ac.uk/imurray2/pub/16choldiff/choldiff.pdf
if any(self.needs_input_grad):

def is_empty(tensor):
return tensor.numel() == 0 or (tensor.numel() == 1 and tensor[0] == 0)

Expand Down
2 changes: 1 addition & 1 deletion gpytorch/kernels/additive_structure_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def forward(self, x1, x2, batch_dims=None, **params):
evaluate = True
res = NonLazyTensor(res)

res = res.sum_batch(sum_batch_size=x1.size(-1))
res = res.sum(-3).unsqueeze(0)

if evaluate:
res = res.evaluate()
Expand Down
2 changes: 1 addition & 1 deletion gpytorch/kernels/product_structure_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def forward(self, x1, x2, batch_dims=None, **params):
evaluate = True
res = NonLazyTensor(res)

res = res.mul_batch(mul_batch_size=x1.size(-1))
res = res.prod(-3).unsqueeze(0)

if evaluate:
res = res.evaluate()
Expand Down
127 changes: 86 additions & 41 deletions gpytorch/lazy/batch_repeat_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,123 +2,165 @@

import itertools
import torch
from .. import settings
from ..utils.memoize import cached
from ..utils.broadcasting import _matmul_broadcast_shape

from .lazy_tensor import LazyTensor
from .root_lazy_tensor import RootLazyTensor


class BatchRepeatLazyTensor(LazyTensor):
def __init__(self, base_lazy_tensor, batch_repeat=torch.Size((1,))):
if not isinstance(batch_repeat, torch.Size):
raise RuntimeError(
"batch_repeat must be a torch.Size, got a {} instead".format(batch_repeat.__class__.__name__)
)
if settings.debug.on():
if not isinstance(batch_repeat, torch.Size):
raise RuntimeError(
"batch_repeat must be a torch.Size, got a {} instead".format(batch_repeat.__class__.__name__)
)
if isinstance(base_lazy_tensor, BatchRepeatLazyTensor):
raise RuntimeError(
"BatchRepeatLazyTensor recieved the following args:\n"
"base_lazy_tensor: {} (size: {}), batch_repeat: {}.".format(
base_lazy_tensor, base_lazy_tensor.shape, batch_repeat
)
)

super(BatchRepeatLazyTensor, self).__init__(base_lazy_tensor, batch_repeat=batch_repeat)
self.base_lazy_tensor = base_lazy_tensor
self.batch_repeat = batch_repeat

def _get_indices(self, row_indices, col_indices, *batch_indices):
num_true_batch_dims = len(self.base_lazy_tensor.batch_shape)
batch_indices = [index % size for index, size in zip(batch_indices, self._padded_base_batch_shape)]
batch_indices = batch_indices[-num_true_batch_dims:] if num_true_batch_dims else []
return self.base_lazy_tensor._get_indices(row_indices, col_indices, *batch_indices)
def _compute_batch_repeat_size(self, batch_shape):
current_batch_shape = self._padded_base_batch_shape(num_batch_dims=len(batch_shape))
batch_repeat = torch.Size(
batch_size // current_batch_size
for batch_size, current_batch_size in zip(batch_shape, current_batch_shape)
)
return batch_repeat

def _expand_batch(self, batch_shape):
return self.__class__(self.base_lazy_tensor, batch_repeat=self._compute_batch_repeat_size(batch_shape))

def _getitem(self, *indices):
def _getitem(self, row_col_are_absorbed, row_index, col_index, *batch_indices):
args = []
kwargs = self.base_lazy_tensor._kwargs
num_base_batch_dims = len(self.base_lazy_tensor.batch_shape)

for arg in self.base_lazy_tensor._args:
if torch.is_tensor(arg):
if torch.is_tensor(arg) or isinstance(arg, LazyTensor):
arg_base_shape_len = max(arg.dim() - num_base_batch_dims, 0)
args.append(arg.repeat(*self.batch_repeat, *[1 for _ in range(arg_base_shape_len)]))
elif isinstance(arg, LazyTensor):
args.append(BatchRepeatLazyTensor(arg, batch_repeat=self.batch_repeat))
else:
args.append(arg)

new_lazy_tensor = self.base_lazy_tensor.__class__(*args, **kwargs)
return new_lazy_tensor._getitem(*indices)
return new_lazy_tensor._getitem(row_col_are_absorbed, row_index, col_index, *batch_indices)

def _matmul(self, rhs):
rhs = self._move_repeat_batches_to_columns(rhs)
output_shape = _matmul_broadcast_shape(self.shape, rhs.shape)
if rhs.shape != output_shape:
rhs = rhs.expand(*output_shape)

rhs = self._move_repeat_batches_to_columns(rhs, output_shape)
res = self.base_lazy_tensor._matmul(rhs)
res = self._move_repeat_batches_back(res)
res = self._move_repeat_batches_back(res, output_shape)
return res

def _move_repeat_batches_back(self, batch_matrix):
def _move_repeat_batches_back(self, batch_matrix, output_shape):
"""
The opposite of _move_repeat_batches_to_columns
Takes a b x m x nr tensor, and moves the batches associated with repeating
So that the tensor is now rb x m x n.
"""
orig_shape = (*self.batch_shape, batch_matrix.size(-2), -1)
padded_base_batch_shape = self._padded_base_batch_shape
if hasattr(self, "_batch_move_memo"):
padded_base_batch_shape, batch_repeat = self.__batch_move_memo
del self.__batch_move_memo
else:
padded_base_batch_shape = self._padded_base_batch_shape(num_batch_dims=(len(output_shape) - 2))
batch_repeat = self._compute_batch_repeat_size(output_shape[:-2])

# Now we have to move the columns back to their original repeat dimensions
batch_matrix = batch_matrix.view(*padded_base_batch_shape, batch_matrix.size(-2), -1, *self.batch_repeat)
batch_matrix = batch_matrix.view(*padded_base_batch_shape, output_shape[-2], -1, *batch_repeat)
dims = tuple(
itertools.chain.from_iterable([i + len(orig_shape), i] for i in range(len(padded_base_batch_shape)))
) + (self.dim() - 2, self.dim() - 1)
itertools.chain.from_iterable([i + len(output_shape), i] for i in range(len(padded_base_batch_shape)))
) + (len(output_shape) - 2, len(output_shape) - 1)
batch_matrix = batch_matrix.permute(*dims).contiguous()

# Combine the repeat and the batch dimensions, and return the batch_matrixult!
batch_matrix = batch_matrix.view(*orig_shape)
batch_matrix = batch_matrix.view(*output_shape)
return batch_matrix

def _move_repeat_batches_to_columns(self, batch_matrix):
def _move_repeat_batches_to_columns(self, batch_matrix, output_shape):
"""
Takes a rb x m x n tensor, and moves the batches associated with repeating
So that the tensor is now b x m x nr.
This allows us to use the base_lazy_tensor routines.
"""
batch_matrix_shape = batch_matrix.shape
padded_base_batch_shape = self._padded_base_batch_shape
padded_base_batch_shape = self._padded_base_batch_shape(num_batch_dims=(len(output_shape) - 2))
batch_repeat = self._compute_batch_repeat_size(output_shape[:-2])

# Reshape batch_matrix so that each batch dimension is split in two:
# The repeated part, and the actual part
split_shape = torch.Size(
tuple(
itertools.chain.from_iterable(
[repeat, size] for repeat, size in zip(self.batch_repeat, padded_base_batch_shape)
[repeat, size] for repeat, size in zip(batch_repeat, padded_base_batch_shape)
)
)
+ batch_matrix_shape[-2:]
+ output_shape[-2:]
)
batch_matrix = batch_matrix.view(*split_shape)

# Now chuck the repeat parts of the batch dimensions into the last dimension of batch_matrix
# These will act like extra columns of the batch matrix that we are multiplying against
# The repeated part, and the actual part
repeat_dims = range(0, len(self.batch_repeat) * 2, 2)
batch_dims = range(1, len(self.batch_repeat) * 2, 2)
repeat_dims = range(0, len(batch_repeat) * 2, 2)
batch_dims = range(1, len(batch_repeat) * 2, 2)
batch_matrix = batch_matrix.permute(*batch_dims, -2, -1, *repeat_dims).contiguous()
batch_matrix = batch_matrix.view(*self.base_lazy_tensor.batch_shape, batch_matrix_shape[-2], -1)
batch_matrix = batch_matrix.view(*self.base_lazy_tensor.batch_shape, output_shape[-2], -1)

self.__batch_move_memo = output_shape, padded_base_batch_shape, batch_repeat
return batch_matrix

@property
def _padded_base_batch_shape(self):
def _padded_base_batch_shape(self, num_batch_dims=None):
if num_batch_dims is None:
num_batch_dims = len(self.batch_repeat)
base_batch_shape = self.base_lazy_tensor.batch_shape
return torch.Size(([1] * (len(self.batch_repeat) - len(base_batch_shape))) + list(base_batch_shape))
return torch.Size(([1] * (num_batch_dims - len(base_batch_shape))) + list(base_batch_shape))

def _quad_form_derivative(self, left_vectors, right_vectors):
left_vectors = self._move_repeat_batches_to_columns(left_vectors)
right_vectors = self._move_repeat_batches_to_columns(right_vectors)
left_output_shape = _matmul_broadcast_shape(self.shape, left_vectors.shape)
if left_output_shape != left_vectors.shape:
left_vectors = left_vectors.expand(left_output_shape)
right_output_shape = _matmul_broadcast_shape(self.shape, right_vectors.shape)
if right_output_shape != right_vectors.shape:
right_vectors = right_vectors.expand(right_output_shape)
left_vectors = self._move_repeat_batches_to_columns(left_vectors, left_output_shape)
right_vectors = self._move_repeat_batches_to_columns(right_vectors, right_output_shape)
return self.base_lazy_tensor._quad_form_derivative(left_vectors, right_vectors)

def _size(self):
repeated_batch_shape = torch.Size(
size * repeat for size, repeat in zip(self._padded_base_batch_shape, self.batch_repeat)
size * repeat for size, repeat in zip(self._padded_base_batch_shape(), self.batch_repeat)
)
res = torch.Size(repeated_batch_shape + self.base_lazy_tensor.matrix_shape)
return res

def _transpose_nonbatch(self):
return self.__class__(self.base_lazy_tensor._transpose_nonbatch(), batch_repeat=self.batch_repeat)

def _unsqueeze_batch(self, dim):
base_lazy_tensor = self.base_lazy_tensor
batch_repeat = list(self.batch_repeat)
batch_repeat.insert(dim, 1)
batch_repeat = torch.Size(batch_repeat)
# If the dim only adds a new padded dimension, then we're done
# Otherwise we have to also unsqueeze the base_lazy_tensor
base_unsqueeze_dim = dim - (len(self._padded_base_batch_shape()) - len(self.base_lazy_tensor.batch_shape))
if base_unsqueeze_dim > 0:
base_lazy_tensor = base_lazy_tensor._unsqueeze_batch(base_unsqueeze_dim)
return self.__class__(base_lazy_tensor, batch_repeat=batch_repeat)

def add_jitter(self, jitter_val=1e-3):
return self.__class__(self.base_lazy_tensor.add_jitter(jitter_val=jitter_val), batch_repeat=self.batch_repeat)

Expand All @@ -143,15 +185,18 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
)

if inv_quad_rhs is not None:
inv_quad_rhs = self._move_repeat_batches_to_columns(inv_quad_rhs)
output_shape = _matmul_broadcast_shape(self.shape, inv_quad_rhs.shape)
inv_quad_rhs = self._move_repeat_batches_to_columns(inv_quad_rhs, output_shape)

inv_quad_term, logdet_term = self.base_lazy_tensor.inv_quad_logdet(
inv_quad_rhs, logdet, reduce_inv_quad=False
)

if inv_quad_term is not None and inv_quad_term.numel():
inv_quad_term = inv_quad_term.view(*inv_quad_term.shape[:-1], -1, self.batch_repeat.numel())
inv_quad_term = self._move_repeat_batches_back(inv_quad_term).squeeze(-1)
inv_quad_term = inv_quad_term.view(*inv_quad_term.shape[:-1], -1, 1, self.batch_repeat.numel())
output_shape = list(output_shape)
output_shape[-2] = 1
inv_quad_term = self._move_repeat_batches_back(inv_quad_term, output_shape).squeeze(-2)
if reduce_inv_quad:
inv_quad_term = inv_quad_term.sum(-1)

Expand All @@ -169,7 +214,7 @@ def repeat(self, *sizes):

padded_batch_repeat = tuple(1 for _ in range(len(sizes) - 2 - len(self.batch_repeat))) + self.batch_repeat
return self.__class__(
self,
self.base_lazy_tensor,
batch_repeat=torch.Size(
orig_repeat_size * new_repeat_size
for orig_repeat_size, new_repeat_size in zip(padded_batch_repeat, sizes[:-2])
Expand Down

0 comments on commit c517e0b

Please sign in to comment.