In [1]:
import torch

from linear_operator.operators import RootLinearOperator, DiagLinearOperator

In [2]:
root = torch.randn(1500, 300)
root_lo = RootLinearOperator(root)

In [3]:
diag = torch.rand(1500)
diag_lo = DiagLinearOperator(diag)

In [4]:
sum_lo = root_lo + diag_lo

In [5]:
sum_lo.to_dense()

tensor([[334.7018,   2.6060,   7.5849,  ...,   3.3292,  -9.6117,  -2.6092],
        [  2.6060, 274.0935, -14.9300,  ...,  18.8053, -23.6799, -11.7659],
        [  7.5849, -14.9300, 304.6841,  ..., -14.5315,  24.8336,  22.5867],
        ...,
        [  3.3292,  18.8053, -14.5315,  ..., 286.3884,  -0.9479, -22.3044],
        [ -9.6117, -23.6799,  24.8336,  ...,  -0.9479, 285.7664,   4.3336],
        [ -2.6092, -11.7659,  22.5867,  ..., -22.3044,   4.3336, 289.7219]])

In [6]:
root_lo.to_dense() + diag_lo.to_dense()

tensor([[334.7018,   2.6060,   7.5849,  ...,   3.3292,  -9.6117,  -2.6092],
        [  2.6060, 274.0935, -14.9300,  ...,  18.8053, -23.6799, -11.7659],
        [  7.5849, -14.9300, 304.6841,  ..., -14.5315,  24.8336,  22.5867],
        ...,
        [  3.3292,  18.8053, -14.5315,  ..., 286.3884,  -0.9479, -22.3044],
        [ -9.6117, -23.6799,  24.8336,  ...,  -0.9479, 285.7664,   4.3336],
        [ -2.6092, -11.7659,  22.5867,  ..., -22.3044,   4.3336, 289.7219]])

In [7]:
sum_lo

<linear_operator.operators.low_rank_plus_diag_linear_operator.LowRankPlusDiagLinearOperator at 0x7fb7396fa590>

In [8]:
rhs = torch.randn(diag.shape[0])

So, this cholesky is an unoptimized for loop currently. I'd assume that there's probably a LAPACK-style function instead of my python only for loop.

In [9]:
%time lazy_cholesky = sum_lo.cholesky()

CPU times: user 32.1 s, sys: 298 ms, total: 32.4 s
Wall time: 32.5 s


In [10]:
dense_sum = sum_lo.to_dense()

%time eager_cholesky = dense_sum.cholesky()

CPU times: user 22.9 ms, sys: 2.24 ms, total: 25.1 ms
Wall time: 25.3 ms


L = torch.cholesky(A)
should be replaced with
L = torch.linalg.cholesky(A)
and
U = torch.cholesky(A, upper=True)
should be replaced with
U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj() (Triggered internally at  /Users/distiller/project/conda/conda-bld/pytorch_1623459064158/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:1284.)
  """Entry point for launching an IPython kernel.


In [11]:
torch.norm(lazy_cholesky - eager_cholesky)

tensor(0.0136)

This is pretty accurate overall. The relative error is going to be much smaller.

Just to ensure that I didn't actually do something weird.

In [12]:
a = torch.randn(diag.shape[0], diag.shape[0])
%time (a.matmul(a.t())).cholesky()

CPU times: user 85.2 ms, sys: 5.89 ms, total: 91.1 ms
Wall time: 91.8 ms


tensor([[ 3.8059e+01,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-4.7665e-01,  3.8670e+01,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 7.1737e-01,  8.5909e-02,  3.8566e+01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [-2.9369e-01,  2.0791e-01,  4.1217e-01,  ...,  2.5676e-01,
          0.0000e+00,  0.0000e+00],
        [-1.1586e+00,  1.0562e+00,  3.3495e-01,  ..., -5.3565e-01,
          1.6429e+00,  0.0000e+00],
        [ 5.9293e-02, -1.3554e+00, -2.0754e+00,  ..., -1.7443e+00,
          7.1475e-03,  5.4692e-01]])

In [13]:
from torch.distributions import MultivariateNormal

In [14]:
mean = torch.zeros(diag.shape[0])

But, there's a speedup below when we get to re-use the caches as expected.

In [15]:
def dist_and_log_prob(covar):
    dist = MultivariateNormal(mean, covariance_matrix = covar)
    return dist.log_prob(rhs)

In [16]:
%time dist_and_log_prob(sum_lo)

CPU times: user 11.7 ms, sys: 2.82 ms, total: 14.5 ms
Wall time: 11.1 ms


tensor(-3976.6140)

In [17]:

%time dist_and_log_prob(dense_sum)

CPU times: user 47 ms, sys: 8.36 ms, total: 55.3 ms
Wall time: 50.2 ms


tensor(-3976.2256)