In [1]:
import torch

from linear_operator.operators import RootLinearOperator, DiagLinearOperator

In [2]:
root = torch.randn(5000, 10)
root_lo = RootLinearOperator(root)

In [3]:
diag = torch.rand(5000)
diag_lo = DiagLinearOperator(diag)

In [4]:
sum_lo = root_lo + diag_lo

In [5]:
sum_lo.to_dense()

tensor([[ 8.4470,  0.4707,  2.9372,  ...,  3.7270,  6.0950,  4.3886],
        [ 0.4707, 11.7690,  4.3348,  ..., -2.2922,  3.1082, -5.4692],
        [ 2.9372,  4.3348,  9.2938,  ...,  0.9028, -0.0155, -0.3508],
        ...,
        [ 3.7270, -2.2922,  0.9028,  ..., 14.5677,  3.2758,  3.7564],
        [ 6.0950,  3.1082, -0.0155,  ...,  3.2758,  9.3680,  1.5262],
        [ 4.3886, -5.4692, -0.3508,  ...,  3.7564,  1.5262, 10.4140]])

In [6]:
root_lo.to_dense() + diag_lo.to_dense()

tensor([[ 8.4470,  0.4707,  2.9372,  ...,  3.7270,  6.0950,  4.3886],
        [ 0.4707, 11.7690,  4.3348,  ..., -2.2922,  3.1082, -5.4692],
        [ 2.9372,  4.3348,  9.2938,  ...,  0.9028, -0.0155, -0.3508],
        ...,
        [ 3.7270, -2.2922,  0.9028,  ..., 14.5677,  3.2758,  3.7564],
        [ 6.0950,  3.1082, -0.0155,  ...,  3.2758,  9.3680,  1.5262],
        [ 4.3886, -5.4692, -0.3508,  ...,  3.7564,  1.5262, 10.4140]])

In [7]:
sum_lo

<linear_operator.operators.low_rank_plus_diag_linear_operator.LowRankPlusDiagLinearOperator at 0x7fe37b31a3d0>

In [8]:
rhs = torch.randn(diag.shape[0])

So, this cholesky is an unoptimized for loop currently. I'd assume that there's probably a LAPACK-style function instead of my python only for loop.

In [9]:
%time lazy_cholesky = sum_lo.cholesky()

CPU times: user 1.29 s, sys: 170 ms, total: 1.46 s
Wall time: 1.53 s


In [10]:
dense_sum = sum_lo.to_dense()

%time eager_cholesky = dense_sum.cholesky()

CPU times: user 551 ms, sys: 8.49 ms, total: 559 ms
Wall time: 579 ms


L = torch.cholesky(A)
should be replaced with
L = torch.linalg.cholesky(A)
and
U = torch.cholesky(A, upper=True)
should be replaced with
U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj() (Triggered internally at  /Users/distiller/project/conda/conda-bld/pytorch_1623459064158/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:1284.)
  """Entry point for launching an IPython kernel.


In [11]:
torch.norm(lazy_cholesky - eager_cholesky)

tensor(0.0113)

This is pretty accurate overall. The relative error is going to be much smaller.

Just to ensure that I didn't actually do something weird.

In [12]:
a = torch.randn(diag.shape[0], diag.shape[0])
%time (a.matmul(a.t())).cholesky()

RuntimeError: cholesky: U(5000,5000) is zero, singular U.

In [18]:
from torch.distributions import MultivariateNormal

In [19]:
mean = torch.zeros(diag.shape[0])

But, there's a speedup below when we get to re-use the caches as expected.

In [20]:
def dist_and_log_prob(covar):
    dist = MultivariateNormal(mean, covariance_matrix = covar)
    return dist.log_prob(rhs)

In [21]:
%time dist_and_log_prob(sum_lo)

CPU times: user 28.1 ms, sys: 11.9 ms, total: 40 ms
Wall time: 35.4 ms


tensor(-21202.5703)

In [22]:

%time dist_and_log_prob(dense_sum)

CPU times: user 1.33 s, sys: 30.5 ms, total: 1.36 s
Wall time: 1.26 s


tensor(-21200.1191)