In [1]:
import torch

from linear_operator.operators import RootLinearOperator, DiagLinearOperator

In [2]:
root = torch.randn(1500, 300)
root_lo = RootLinearOperator(root)

In [3]:
diag = torch.rand(1500)
diag_lo = DiagLinearOperator(diag)

In [4]:
sum_lo = root_lo + diag_lo

In [5]:
sum_lo.to_dense()

tensor([[ 2.8139e+02, -5.8629e+01, -2.5748e-01,  ...,  1.3839e+01,
         -4.2263e+00, -3.0929e+01],
        [-5.8629e+01,  3.0433e+02,  1.2438e+01,  ..., -2.0872e+01,
         -1.5603e+01,  2.1450e+01],
        [-2.5748e-01,  1.2438e+01,  2.8209e+02,  ...,  1.8188e+01,
         -3.2094e+00,  2.6696e+00],
        ...,
        [ 1.3839e+01, -2.0872e+01,  1.8188e+01,  ...,  2.8850e+02,
          7.8028e+00, -5.2010e+00],
        [-4.2263e+00, -1.5603e+01, -3.2094e+00,  ...,  7.8028e+00,
          2.6911e+02, -6.0806e+00],
        [-3.0929e+01,  2.1450e+01,  2.6696e+00,  ..., -5.2010e+00,
         -6.0806e+00,  2.4189e+02]])

In [6]:
root_lo.to_dense() + diag_lo.to_dense()

tensor([[ 2.8139e+02, -5.8629e+01, -2.5748e-01,  ...,  1.3839e+01,
         -4.2263e+00, -3.0929e+01],
        [-5.8629e+01,  3.0433e+02,  1.2438e+01,  ..., -2.0872e+01,
         -1.5603e+01,  2.1450e+01],
        [-2.5748e-01,  1.2438e+01,  2.8209e+02,  ...,  1.8188e+01,
         -3.2094e+00,  2.6696e+00],
        ...,
        [ 1.3839e+01, -2.0872e+01,  1.8188e+01,  ...,  2.8850e+02,
          7.8028e+00, -5.2010e+00],
        [-4.2263e+00, -1.5603e+01, -3.2094e+00,  ...,  7.8028e+00,
          2.6911e+02, -6.0806e+00],
        [-3.0929e+01,  2.1450e+01,  2.6696e+00,  ..., -5.2010e+00,
         -6.0806e+00,  2.4189e+02]])

In [7]:
sum_lo

<linear_operator.operators.low_rank_plus_diag_linear_operator.LowRankPlusDiagLinearOperator at 0x7fad6d9ddb50>

In [8]:
rhs = torch.randn(diag.shape[0])

So, this cholesky is an unoptimized for loop currently. I'd assume that there's probably a LAPACK-style function instead of my python only for loop.

In [9]:
%time lazy_cholesky = sum_lo.cholesky()

CPU times: user 3.55 s, sys: 330 ms, total: 3.88 s
Wall time: 3.77 s


In [10]:
dense_sum = sum_lo.to_dense()

%time eager_cholesky = dense_sum.cholesky()

CPU times: user 21.7 ms, sys: 4.44 ms, total: 26.2 ms
Wall time: 45.3 ms


L = torch.cholesky(A)
should be replaced with
L = torch.linalg.cholesky(A)
and
U = torch.cholesky(A, upper=True)
should be replaced with
U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj() (Triggered internally at  /Users/distiller/project/conda/conda-bld/pytorch_1623459064158/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:1284.)
  """Entry point for launching an IPython kernel.


In [11]:
torch.norm(lazy_cholesky - eager_cholesky)

tensor(0.0145)

This is pretty accurate overall. The relative error is going to be much smaller.

Just to ensure that I didn't actually do something weird.

In [12]:
a = torch.randn(diag.shape[0], diag.shape[0])
%time (a.matmul(a.t())).cholesky()

CPU times: user 87.9 ms, sys: 6.38 ms, total: 94.3 ms
Wall time: 102 ms


tensor([[37.2424,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.2050, 39.2814,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-1.4376, -1.5358, 38.1718,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 2.3922, -0.8449, -0.2205,  ...,  1.7147,  0.0000,  0.0000],
        [-0.3170,  0.3606, -1.5902,  ...,  0.1500,  0.5359,  0.0000],
        [-1.2030,  0.3276,  0.6875,  ...,  1.1023, -0.3045,  0.9772]])

In [13]:
from torch.distributions import MultivariateNormal

In [14]:
mean = torch.zeros(diag.shape[0])

But, there's a speedup below when we get to re-use the caches as expected.

In [15]:
def dist_and_log_prob(covar):
    dist = MultivariateNormal(mean, covariance_matrix = covar)
    return dist.log_prob(rhs)

In [16]:
%time dist_and_log_prob(sum_lo)

CPU times: user 4.81 ms, sys: 4.52 ms, total: 9.32 ms
Wall time: 21 ms


tensor(-4165.1772)

In [17]:

%time dist_and_log_prob(dense_sum)

CPU times: user 54.7 ms, sys: 9 ms, total: 63.7 ms
Wall time: 60.5 ms


tensor(-4165.2920)