Skip to content

Commit

Permalink
Matmul lazy variable handles arbitrary LVs
Browse files Browse the repository at this point in the history
  • Loading branch information
gpleiss committed Jan 24, 2018
1 parent eaac85f commit f01591a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 33 deletions.
92 changes: 61 additions & 31 deletions gpytorch/lazy/matmul_lazy_variable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
from torch.autograd import Variable
from .lazy_variable import LazyVariable
from .non_lazy_variable import NonLazyVariable


def _inner_repeat(tensor, amt):
Expand All @@ -12,24 +14,22 @@ def _outer_repeat(tensor, amt):

class MatmulLazyVariable(LazyVariable):
def __init__(self, lhs, rhs):
if not isinstance(lhs, LazyVariable):
lhs = NonLazyVariable(lhs)
if not isinstance(rhs, LazyVariable):
rhs = NonLazyVariable(rhs)

super(MatmulLazyVariable, self).__init__(lhs, rhs)
self.lhs = lhs
self.rhs = rhs

def _matmul_closure_factory(self, lhs, rhs):
def closure(tensor):
return torch.matmul(lhs, rhs).matmul(tensor)

return closure
def _matmul_closure_factory(self, *args):
len_lhs_repr = len(self.lhs.representation())
lhs_matmul_closure = self.lhs._matmul_closure_factory(*args[:len_lhs_repr])
rhs_matmul_closure = self.rhs._matmul_closure_factory(*args[len_lhs_repr:])

def _derivative_quadratic_form_factory(self, lhs, rhs):
def closure(left_factor, right_factor):
if left_factor.ndimension() == 1:
left_factor = left_factor.unsqueeze(0)
right_factor = right_factor.unsqueeze(0)
left_grad = left_factor.transpose(-1, -2).matmul(right_factor.matmul(rhs.transpose(-1, -2)))
right_grad = lhs.transpose(-1, -2).matmul(left_factor.transpose(-1, -2)).matmul(right_factor)
return left_grad, right_grad
def closure(tensor):
return lhs_matmul_closure(rhs_matmul_closure(tensor))

return closure

Expand All @@ -43,37 +43,67 @@ def closure(tensor):

return closure

def _derivative_quadratic_form_factory(self, *args):
len_lhs_repr = len(self.lhs.representation())
lhs_t_matmul_closure = self.lhs.transpose(-1, -2)._t_matmul_closure_factory(*args[:len_lhs_repr])
rhs_matmul_closure = self.rhs._matmul_closure_factory(*args[len_lhs_repr:])
lhs_derivative_closure = self.lhs._derivative_quadratic_form_factory(*args[:len_lhs_repr])
rhs_derivative_closure = self.rhs._derivative_quadratic_form_factory(*args[len_lhs_repr:])

def closure(left_factor, right_factor):
if left_factor.ndimension() == 1:
left_factor = left_factor.unsqueeze(0)
right_factor = right_factor.unsqueeze(0)
left_grad, = lhs_derivative_closure(left_factor, right_factor)
left_grad = rhs_matmul_closure(left_grad.transpose(-1, -2)).transpose(-1, -2)
right_grad, = rhs_derivative_closure(left_factor, right_factor)
right_grad = lhs_t_matmul_closure(right_grad)
return left_grad, right_grad

return closure

def _size(self):
if self.lhs.ndimension() > 2:
return torch.Size((self.lhs.size()[0], self.lhs.size()[1], self.lhs.size()[1]))
return torch.Size((self.lhs.size(0), self.lhs.size(1), self.rhs.size(2)))
else:
return torch.Size((self.lhs.size()[0], self.lhs.size()[0]))
return torch.Size((self.lhs.size(0), self.rhs.size(1)))

def _transpose_nonbatch(self):
return MatmulLazyVariable(self.rhs.transpose(-1, -2), self.lhs.transpose(-1, -2))
def _transpose_nonbatch(self, *args):
return self.__class__(self.rhs._transpose_nonbatch(), self.lhs._transpose_nonbatch())

def _batch_get_indices(self, batch_indices, left_indices, right_indices):
outer_size = batch_indices.size(0)
batch_indices = batch_indices.data
left_indices = left_indices.data
right_indices = right_indices.data

inner_size = self.lhs.size(-1)
inner_indices = right_indices.new(inner_size)
torch.arange(0, inner_size, out=inner_indices)
inner_indices = Variable(right_indices.data.new(inner_size))
torch.arange(0, inner_size, out=inner_indices.data)

left_vals = self.lhs._batch_get_indices(_outer_repeat(batch_indices, inner_size),
_outer_repeat(left_indices, inner_size),
_inner_repeat(inner_indices, outer_size))
right_vals = self.rhs._batch_get_indices(_outer_repeat(batch_indices, inner_size),
_inner_repeat(inner_indices, outer_size),
_outer_repeat(right_indices, inner_size))

left_vals = self.lhs[_outer_repeat(batch_indices, inner_size), _outer_repeat(left_indices, inner_size),
_inner_repeat(inner_indices, outer_size)]
right_vals = self.rhs[_outer_repeat(batch_indices, inner_size), _inner_repeat(inner_indices, outer_size),
_outer_repeat(right_indices, inner_size)]
return (left_vals.view(-1, inner_size) * right_vals.view(-1, inner_size)).sum(-1)

def _get_indices(self, left_indices, right_indices):
res = self.lhs.index_select(-2, left_indices) * self.rhs.index_select(-1, right_indices).transpose(-1, -2)
return res.sum(-1)
outer_size = left_indices.size(0)
inner_size = self.lhs.size(-1)
inner_indices = Variable(right_indices.data.new(inner_size))
torch.arange(0, inner_size, out=inner_indices.data)

left_vals = self.lhs._get_indices(_outer_repeat(left_indices, inner_size),
_inner_repeat(inner_indices, outer_size))
right_vals = self.lhs._get_indices(_inner_repeat(inner_indices, outer_size),
_outer_repeat(right_indices, inner_size))

return (left_vals.view(-1, inner_size) * right_vals.view(-1, inner_size)).sum(-1)

def diag(self):
return (self.lhs * self.rhs.transpose(-1, -2)).sum(-1)
if isinstance(self.lhs, NonLazyVariable) and isinstance(self.rhs, NonLazyVariable):
return (self.lhs.var * self.rhs.var.transpose(-1, -2)).sum(-1)
else:
return super(MatmulLazyVariable, self).diag()

def evaluate(self):
return torch.matmul(self.lhs, self.rhs)
return torch.matmul(self.lhs.evaluate(), self.rhs.evaluate())
4 changes: 4 additions & 0 deletions gpytorch/lazy/root_lazy_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ class RootLazyVariable(MatmulLazyVariable):
def __init__(self, root):
super(RootLazyVariable, self).__init__(root, root.transpose(-1, -2))

@property
def root(self):
return self.lhs

def chol_approx_size(self):
return self.lhs.size(-1)

Expand Down
3 changes: 2 additions & 1 deletion gpytorch/models/variational_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def __call__(self, inputs, **kwargs):
q_mat, t_mat = lq_object.lanczos_batch(induc_induc_matmul, init_vector)
self.prior_chol = Variable(q_mat[0].matmul(t_mat[0].potrf().inverse()))

self.variational_chol = gpytorch.inv_matmul(induc_induc_covar, variational_output.covar().lhs)
chol_variational_output = variational_output.covar().root.evaluate()
self.variational_chol = gpytorch.inv_matmul(induc_induc_covar, chol_variational_output)
self.has_computed_chol = True

# Test mean
Expand Down
2 changes: 1 addition & 1 deletion gpytorch/variational/mvn_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def kl_divergence(self):
variational_covar = self.variational_dist.covar()
if not isinstance(variational_covar, RootLazyVariable):
raise RuntimeError('The variational covar for an MVN distribution should be a RootLazyVariable')
chol_variational_covar = variational_covar.lhs
chol_variational_covar = variational_covar.root.evaluate()

mean_diffs = prior_mean - variational_mean
chol_variational_covar = chol_variational_covar
Expand Down

0 comments on commit f01591a

Please sign in to comment.