Skip to content

Commit

Permalink
add the unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wjmaddox committed Jan 19, 2021
1 parent da3cab3 commit 90d9f97
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
24 changes: 21 additions & 3 deletions gpytorch/lazy/kronecker_product_added_diag_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ def _logdet(self):
first_term = (const_times_evals.diag() + 1).log().sum(dim=-1)
return diag_term + first_term

else:
# we use the same matrix determinant identity: |K + D| = |D| |I + D^{-1}K|
# but have to symmetrize the second matrix because torch.eig may not be
# completely differentiable.
kron_times_diag_list = [
tfull.matmul(tdiag.inverse()) for tfull, tdiag in zip(lt.lazy_tensors, dlt.lazy_tensors)
]

# We symmetrize the sub-components K_i D_i^{-1}
kron_times_diag_symm = KroneckerProductLazyTensor(
*[k.matmul(k.transpose(-1, -2)) for k in kron_times_diag_list]
)
evals_square, _ = kron_times_diag_symm.symeig(eigenvectors=True)
evals_plus_i = DiagLazyTensor(evals_square.sqrt() + 1)

diag_term = self.diag_tensor.diag().clamp(min=1e-7).log().sum(dim=-1)
return diag_term + evals_plus_i.sum(dim=-1)

return super()._logdet()

def _preconditioner(self):
Expand Down Expand Up @@ -133,18 +151,18 @@ def _solve(self, rhs, preconditioner=None, num_tridiag=0):
return res.to(rhs_dtype)

# in all other cases we fall back to the default
super()._solve(rhs, preconditioner=preconditioner, num_tridiag=num_tridiag)
return super()._solve(rhs, preconditioner=preconditioner, num_tridiag=num_tridiag)

def _root_decomposition(self):
if self._diag_is_constant:
evals, q_matrix = self.lazy_tensor.symeig(eigenvectors=True)
updated_evals = DiagLazyTensor((evals + self.diag_tensor.diag()).pow(0.5))
return MatmulLazyTensor(q_matrix, updated_evals)
super()._root_decomposition()
return super()._root_decomposition()

def _root_inv_decomposition(self, initial_vectors=None):
if self._diag_is_constant:
evals, q_matrix = self.lazy_tensor.symeig(eigenvectors=True)
inv_sqrt_evals = DiagLazyTensor((evals + self.diag_tensor.diag()).pow(-0.5))
return MatmulLazyTensor(q_matrix, inv_sqrt_evals)
super()._root_inv_decomposition(initial_vectors=initial_vectors)
return super()._root_inv_decomposition(initial_vectors=initial_vectors)
12 changes: 8 additions & 4 deletions test/lazy/test_kronecker_product_added_diag_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
class TestKroneckerProductAddedDiagLazyTensor(unittest.TestCase, LazyTensorTestCase):
# this lazy tensor has an explicit inverse so we don't need to run these
skip_slq_tests = True
should_call_lanczos = False
should_call_cg = False

def create_lazy_tensor(self, constant_diag=True):
def create_lazy_tensor(self, constant_diag=False):
a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float)
b = torch.tensor([[2, 1], [1, 2]], dtype=torch.float)
c = torch.tensor([[4, 0.5, 1, 0], [0.5, 4, -1, 0], [1, -1, 3, 0], [0, 0, 0, 4]], dtype=torch.float)
Expand All @@ -43,6 +41,13 @@ def evaluate_lazy_tensor(self, lazy_tensor):
diag = lazy_tensor._diag_tensor._diag
return tensor + diag.diag()

class TestKroneckerProductAddedConstDiagLazyTensor(TestKroneckerProductAddedDiagLazyTensor):
should_call_lanczos = False
should_call_cg = False

def create_lazy_tensor(self):
return super().create_lazy_tensor(constant_diag=True)

def test_if_cholesky_used(self):
lazy_tensor = self.create_lazy_tensor()
rhs = torch.randn(lazy_tensor.size(-1))
Expand All @@ -63,6 +68,5 @@ def test_root_inv_decomposition_no_cholesky(self):
self.assertAllClose(res, actual, rtol=0.05, atol=0.02)
chol_mock.assert_not_called()


if __name__ == "__main__":
unittest.main()

0 comments on commit 90d9f97

Please sign in to comment.