## Preserving sparsity of matrix inverse square root

The posterior joint distribution of a Gaussian probabilistic graphical models has a sparse precision matrix $P$, with sparsity structure determined by the dependency graph. In plated graphical models, the sparsity structure is blockwise, allowing within-block use of fast dense linear algebra operations on CPU and GPU. During Bayesian inference, we need to compute the inverse square root matrix $P^{-1/2}$, or at least compute a Cholesky factor $P^{1/2}$ and solve $P^{1/2} \backslash z$ for white noise vectors $z$.

This notebook explores sparsity preserving representations of precision matrices, Cholesky factors, and Cholesky factor inverses. By contrast, a naive Cholesky decomposition does not preserve sparsity structure, leading to so-called fill-in elements.

In [1]:
import operator
from functools import reduce
import numpy as np
import matplotlib.pyplot as plt
import torch

torch.manual_seed(20210920)
torch.set_default_dtype(torch.double)
np.set_printoptions(precision=3, suppress=True)

Consider a representation that decomposes a Cholesky factor into a product of lower-triangular matrices each of which is the identity on all except one row, e.g.
$$
\begin{pmatrix}
a & 0 & 0 \\ b & c & 0 \\ d & e & f
\end{pmatrix}
= 
\begin{pmatrix}
1 & 0 & 0 \\ 0 & 1 & 0 \\ d & e & f
\end{pmatrix}
\times
\begin{pmatrix}
1 & 0 & 0 \\ b & c & 0 \\ 0 & 0 & 1
\end{pmatrix}
\times
\begin{pmatrix}
a & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1
\end{pmatrix}
$$

In [2]:
def triangular_decompose(L):
    assert L.size(-1) == L.size(-2)
    assert (L == L.tril()).all()
    N, N = L.shape
    factors = []
    for i, row in enumerate(L):
        Li = torch.eye(N)
        Li[i] = L[i]
        factors.append(Li)
    return factors

In [3]:
N = 3
L = torch.randn(N, N).tril_()
L.diag().exp_()
print("L:", L.numpy(), sep="\n")
factors = triangular_decompose(L)
for i, Li in enumerate(factors):
    print(f"L_{i}", Li.numpy(), sep="\n")
prod = reduce(operator.matmul, factors)
assert torch.allclose(prod, L)
print("product(factors):", prod.numpy(), sep="\n")

L:
[[-1.061  0.     0.   ]
 [-0.668  1.006  0.   ]
 [-0.055 -0.861  0.878]]
L_0
[[-1.061  0.     0.   ]
 [ 0.     1.     0.   ]
 [ 0.     0.     1.   ]]
L_1
[[ 1.     0.     0.   ]
 [-0.668  1.006  0.   ]
 [ 0.     0.     1.   ]]
L_2
[[ 1.     0.     0.   ]
 [ 0.     1.     0.   ]
 [-0.055 -0.861  0.878]]
product(factors):
[[-1.061  0.     0.   ]
 [-0.668  1.006  0.   ]
 [-0.055 -0.861  0.878]]


Now note that each factor can be inverted and/or squared while preserving sparsity structure.

In [4]:
N = 8
R = 5  # row
L = torch.eye(N)
L[R, :R].normal_()
L[R, R].abs_()
print("L:", L.numpy(), sep="\n")
print("L.inverse():", L.inverse().numpy(), sep="\n")
print("L @ L.T:", (L @ L.T).numpy(), sep="\n")

L:
[[ 1.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     1.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     1.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     1.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     1.     0.     0.     0.   ]
 [-0.076 -1.178 -1.087  1.198  0.158  1.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     1.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     1.   ]]
L.inverse():
[[ 1.     0.     0.     0.     0.     0.     0.    -0.   ]
 [ 0.     1.     0.     0.     0.     0.     0.    -0.   ]
 [ 0.     0.     1.     0.     0.     0.     0.    -0.   ]
 [ 0.     0.     0.     1.    -0.     0.     0.    -0.   ]
 [ 0.     0.     0.     0.     1.     0.     0.    -0.   ]
 [ 0.076  1.178  1.087 -1.198 -0.158  1.     0.    -0.   ]
 [ 0.     0.     0.     0.     0.     0.     1.    -0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     1.   ]]
L @ L.T:
[[ 1.     0.     0.     0.   

In [5]:
L = torch.eye(N)
L[R, 2:R].normal_()
L[R, R].abs_()
print("L:", L.numpy(), sep="\n")
print("L.inverse():", L.inverse().numpy(), sep="\n")
print("L @ L.T:", (L @ L.T).numpy(), sep="\n")

L:
[[ 1.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     1.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     1.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     1.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     1.     0.     0.     0.   ]
 [ 0.     0.    -1.819 -0.3    0.819  1.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     1.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     1.   ]]
L.inverse():
[[ 1.     0.     0.     0.     0.     0.     0.    -0.   ]
 [ 0.     1.     0.     0.     0.     0.     0.    -0.   ]
 [ 0.     0.     1.     0.     0.     0.     0.    -0.   ]
 [ 0.     0.     0.     1.     0.     0.     0.    -0.   ]
 [ 0.     0.     0.     0.     1.     0.     0.    -0.   ]
 [ 0.     0.     1.819  0.3   -0.819  1.     0.    -0.   ]
 [ 0.     0.     0.     0.     0.     0.     1.    -0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     1.   ]]
L @ L.T:
[[ 1.     0.     0.     0.   

In [6]:
L = torch.eye(N)
L[R, :2].normal_()
L[R, R].normal_().abs_()
print("L:", L.numpy(), sep="\n")
print("L.inverse():", L.inverse().numpy(), sep="\n")
print("L @ L.T:", (L @ L.T).numpy(), sep="\n")

L:
[[ 1.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     1.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     1.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     1.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     1.     0.     0.     0.   ]
 [-0.728 -0.44   0.     0.     0.     0.54   0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     1.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     1.   ]]
L.inverse():
[[ 1.     0.     0.     0.     0.     0.     0.    -0.   ]
 [ 0.     1.     0.     0.     0.     0.     0.    -0.   ]
 [ 0.     0.     1.     0.     0.     0.     0.    -0.   ]
 [ 0.     0.     0.     1.     0.     0.     0.    -0.   ]
 [ 0.     0.     0.     0.     1.     0.     0.    -0.   ]
 [ 1.348  0.815  0.     0.     0.     1.852  0.    -0.   ]
 [ 0.     0.     0.     0.     0.     0.     1.    -0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     1.   ]]
L @ L.T:
[[ 1.     0.     0.     0.   

Next consider blockwise sparsity structure.

In [7]:
def triangular_decompose_block(L, B):
    assert L.size(-1) == L.size(-2)
    assert (L == L.tril()).all()
    N, N = L.shape
    assert N % B == 0
    factors = []
    for i in range(N // B):
        Li = torch.eye(N)
        Li[i * B: (i + 1) * B] = L[i * B: (i + 1) * B]
        factors.append(Li)
    return factors

In [8]:
N = 9
B = 3
L = torch.randn(N, N).tril_()
L.diag().exp_()
print("L:", L.numpy(), sep="\n")
factors = triangular_decompose_block(L, 3)
for i, Li in enumerate(factors):
    print(f"L_{i}", Li.numpy(), sep="\n")
prod = reduce(operator.matmul, factors)
assert torch.allclose(prod, L)
print("product(factors):", prod.numpy(), sep="\n")

L:
[[ 1.078  0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 1.416 -1.02   0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.871  0.042  1.045  0.     0.     0.     0.     0.     0.   ]
 [-1.881 -0.185  0.715  0.307  0.     0.     0.     0.     0.   ]
 [-1.424  0.892 -1.504 -1.043  0.224  0.     0.     0.     0.   ]
 [ 0.641  0.067 -0.551 -1.975  0.153 -1.462  0.     0.     0.   ]
 [ 0.577 -0.062  0.688  0.339 -0.068  0.825  1.212  0.     0.   ]
 [ 0.852  0.013  0.693 -0.594 -0.465 -0.733 -0.366  0.588  0.   ]
 [ 0.725  0.485 -0.139  0.05   0.902  0.404  0.199 -1.202 -0.567]]
L_0
[[ 1.078  0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 1.416 -1.02   0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.871  0.042  1.045  0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     1.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     1.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     1.     0.     0.     0.   ]
 [

Next we'll see if we can represent a sparse precision matrix as compressed inverse Cholesky factors.

In [9]:
def make_precision(size, density):
    mask = torch.full((size, size), density).bernoulli_()
    x = torch.zeros(size, size)
    for i, row in enumerate(mask):
        for j, m in enumerate(row[:i].tolist()):
            if m:
                xij = torch.randn(2, 2)
                xij = xij @ xij.T
                x[i, i] += xij[0, 0]
                x[i, j] += xij[0, 1]
                x[j, i] += xij[1, 0]
                x[j, j] += xij[1, 1]
    return x

In [10]:
torch.manual_seed(20210920)
precision = make_precision(9, 0.0)
print("Precision:", precision.numpy(), sep="\n")
print("Cholesky:", torch.cholesky(precision).numpy(), sep="\n")
print("inv(Cholesky):", torch.cholesky(precision).inverse().numpy(), sep="\n")

Precision:
[[ 7.716 -2.364  0.     1.228  0.     0.112  0.     0.753  0.   ]
 [-2.364 14.85   1.932  0.    -1.255  0.     0.     3.182  0.   ]
 [ 0.     1.932  8.682  0.841  0.    -0.169  0.     0.187  0.788]
 [ 1.228  0.     0.841  4.663  0.039  0.    -0.747  0.     0.   ]
 [ 0.    -1.255  0.     0.039  7.817  0.    -0.586 -0.054  0.   ]
 [ 0.112  0.    -0.169  0.     0.     5.544  0.     0.273  0.   ]
 [ 0.     0.     0.    -0.747 -0.586  0.     7.006 -0.306  0.   ]
 [ 0.753  3.182  0.187  0.    -0.054  0.273 -0.306 11.761  0.422]
 [ 0.     0.     0.788  0.     0.     0.     0.     0.422  2.606]]
Cholesky:
[[ 2.778  0.     0.     0.     0.     0.     0.     0.     0.   ]
 [-0.851  3.758  0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.514  2.901  0.     0.     0.     0.     0.     0.   ]
 [ 0.442  0.1    0.272  2.094  0.     0.     0.     0.     0.   ]
 [ 0.    -0.334  0.059  0.027  2.775  0.     0.     0.     0.   ]
 [ 0.04   0.009 -0.06  -0.001  0.002  2.353  0.     0.

L = torch.cholesky(A)
should be replaced with
L = torch.linalg.cholesky(A)
and
U = torch.cholesky(A, upper=True)
should be replaced with
U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj() (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:1284.)
  after removing the cwd from sys.path.


Note that the Cholesky matrix and its inverse do not fully preserve sparsity.

In [11]:
def decompress(param):
    factors = triangular_decompose(param)
    L = reduce(operator.matmul, factors)
    cov = L @ L.T
    precision = torch.inverse(cov)
    return precision

In [12]:
mask = (precision != 0).type_as(precision).tril()
print(mask.numpy())

[[1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 1. 1. 0. 0. 0. 0. 0.]
 [0. 1. 0. 1. 1. 0. 0. 0. 0.]
 [1. 0. 1. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 1. 0. 1. 0. 0.]
 [1. 1. 1. 0. 1. 1. 1. 1. 0.]
 [0. 0. 1. 0. 0. 0. 0. 1. 1.]]


In [13]:
param = torch.eye(len(precision)).requires_grad_()
optim = torch.optim.Adam([param], lr=0.02)
for _ in range(1000):
    optim.zero_grad()
    loss = (decompress(param * mask) - precision).square().sum()
    loss.backward()
    optim.step()
    print(".", end="", flush=True)

........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

In [14]:
print("param:", param.data.numpy(), sep="\n")
print("reconstructed precision:", decompress(param.data * mask), sep="\n")

param:
[[ 0.381  0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.07   0.273  0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.    -0.057  0.346  0.     0.     0.     0.     0.     0.   ]
 [-0.102  0.    -0.06   0.466  0.     0.     0.     0.     0.   ]
 [ 0.     0.042  0.     0.003  0.359  0.     0.     0.     0.   ]
 [-0.006  0.     0.011  0.     0.     0.425  0.     0.     0.   ]
 [ 0.     0.     0.     0.045  0.03   0.     0.378  0.     0.   ]
 [-0.043 -0.072 -0.003  0.     0.003 -0.01   0.009  0.292  0.   ]
 [ 0.     0.    -0.097  0.     0.     0.     0.    -0.04   0.619]]
reconstructed precision:
tensor([[ 7.7164e+00, -2.3306e+00, -1.0719e-01,  1.2156e+00,  2.1680e-01,
          1.0917e-01, -2.0741e-01,  7.4231e-01, -5.6066e-03],
        [-2.3306e+00,  1.4863e+01,  1.8746e+00,  1.7613e-01, -1.2131e+00,
          3.3626e-02, -1.0802e-02,  3.1688e+00,  2.4833e-01],
        [-1.0719e-01,  1.8746e+00,  8.6978e+00,  8.1427e-01,  3.1551e-04,
         -1.7449e-01, -1.21

In [15]:
print((param @ param.T).data.numpy())

[[ 0.145  0.027  0.    -0.039  0.    -0.002  0.    -0.016  0.   ]
 [ 0.027  0.079 -0.016 -0.007  0.011 -0.     0.    -0.023  0.   ]
 [ 0.    -0.016  0.123 -0.021 -0.002  0.004  0.     0.003 -0.034]
 [-0.039 -0.007 -0.021  0.231  0.001 -0.     0.021  0.005  0.006]
 [ 0.     0.011 -0.002  0.001  0.13   0.     0.011 -0.002  0.   ]
 [-0.002 -0.     0.004 -0.     0.     0.181  0.    -0.004 -0.001]
 [ 0.     0.     0.     0.021  0.011  0.     0.146  0.003  0.   ]
 [-0.016 -0.023  0.003  0.005 -0.002 -0.004  0.003  0.093 -0.011]
 [ 0.     0.    -0.034  0.006  0.    -0.001  0.    -0.011  0.395]]
