In [59]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import cho_factor as scipy_chol_fac
from scipy.linalg import cho_solve as scipy_chol_solve

In [60]:
# Generate a random matrix A
n = 100
A = np.random.normal(size=(n,n))
A = A.T @ A + 1e-3

In [61]:
chol = scipy_chol_fac(A)

In [62]:
z = np.ones(n)

In [63]:
scipy_chol_solve(chol, z) - np.linalg.solve(A, z)

array([-4.01900735e-11, -1.20978783e-11,  5.56887869e-12, -2.55132582e-11,
        4.97824004e-13,  4.13105106e-11, -4.95212760e-11,  1.42055256e-11,
        1.13811183e-11,  7.39852624e-11,  3.29514194e-12,  1.73905335e-12,
       -2.40896192e-12,  1.62181379e-12,  5.83746385e-11,  2.88604696e-11,
        7.20972171e-11,  9.81970061e-12,  2.38293829e-11, -3.02016190e-11,
       -5.32018873e-11, -2.57740496e-11, -8.83151330e-11, -7.77156117e-12,
       -2.10382822e-11, -1.80229165e-11, -4.74393858e-11, -2.88595814e-11,
        5.17843546e-11, -1.71311854e-11,  9.07718345e-13,  5.45981038e-11,
       -1.20508048e-11, -3.67137432e-11, -1.58539848e-12,  2.22204477e-11,
       -2.99156255e-11, -2.91384694e-11, -3.31841221e-11,  1.80446769e-11,
       -8.91056118e-11,  3.30349081e-11,  2.00213179e-11, -6.42685904e-11,
       -8.68283223e-12, -3.02945447e-11,  5.68540770e-11,  4.92228480e-11,
       -2.58877364e-11,  1.21795907e-11, -1.77706738e-11,  4.07140988e-12,
       -5.49071899e-12, -

In [64]:
def naive_diaginv(A, sample_size=1000, method="cholesky"):
    """Naive unbiased estimator for the diagonal of an inverse matrix, see [5]. A must be SPD.
    """

    valid_methods = ["cholesky"]
    assert method in valid_methods, f"method must be one of {valid_methods}"

    # Get shape
    n = A.shape[0]

    # Setup
    diaginv_estimate = np.zeros(n)
    tk = np.zeros(n)
    qk = np.zeros(n)

    if method == "cholesky":
        chol = scipy_chol_fac(A)
    else:
        raise NotImplementedError

    for j in range(sample_size):
        
        # Draw random vector
        vk = np.random.choice([-1, 1], size=n)

        # Update tk
        if method == "cholesky":
            Ainv_vk = scipy_chol_solve(chol, vk)
        else:
            raise NotImplementedError
    
        tk = tk + (( Ainv_vk ) * vk)

        # Update qk
        qk = qk + (vk*vk)

        # Update diag_estimate
        diaginv_estimate = tk / qk

    return diaginv_estimate

In [65]:
true_diaginv = np.diag(np.linalg.inv(A))

In [66]:
true_diaginv[:5]

array([8.57586596, 0.7267294 , 0.85203419, 2.76756246, 0.2214591 ])

In [67]:
diaginv_est = naive_diaginv(A, sample_size=10000)

In [68]:
diaginv_est[:5]

array([8.45785869, 0.7514646 , 0.8901894 , 2.54556302, 0.29175854])

# Explicit diag inverse probe

In [69]:
def explicit_diaginv_probe(A, method="cholesky"):
    """Computes the diagonal of inv(A) using an explicit probe. A must be SPD.
    """

    valid_methods = ["cholesky"]
    assert method in valid_methods, f"method must be one of {valid_methods}"

    # Setup
    n = A.shape[0]
    diagonal_inv = np.zeros(n)

    if method == "cholesky":
        chol = scipy_chol_fac(A)
    else:
        raise NotImplementedError

    for j in range(n):

        # jth column of the identity
        w = np.zeros(n)
        w[j] = 1.0

        # Compute w^T inv(A) w
        if method == "cholesky":
            Ainv_w = scipy_chol_solve(chol, w)
        else:
            raise NotImplementedError

        diagonal_inv[j] = w.T @ Ainv_w

    return diagonal_inv

In [70]:
true_diaginv

array([ 8.57586596,  0.7267294 ,  0.85203419,  2.76756246,  0.2214591 ,
        7.20166679,  9.1764802 ,  0.93558599,  1.04791352, 20.07114549,
        1.19428919,  0.32162996,  0.46371162,  0.51850962, 12.66586429,
        3.18986117, 20.71376633,  1.05068503,  2.58187642,  3.59578936,
       10.51153871,  2.94140851, 29.81763933,  1.02600227,  2.55142791,
        5.37002981, 10.93699529,  3.3122536 ,  9.99018729,  1.58450314,
        0.46893677, 11.53063194,  1.38695438,  5.99991168,  0.64716075,
        1.90461363,  4.10883592,  4.04324138,  4.55086426,  1.96278145,
       29.01828566,  4.18241753,  1.54951477, 15.30661395,  0.80910166,
        4.12588324, 14.0889585 ,  8.9417561 ,  3.17568107,  0.83490542,
        1.67432212,  0.78163145,  1.04055025,  2.06221432,  1.56951583,
        1.70784948,  0.70855028,  1.89357841,  1.98412581,  1.59130454,
        9.43910867,  0.75570913, 13.78749808,  4.93897568,  5.20844652,
        4.90017951,  1.05144977,  0.79222167,  8.69030863,  5.25

In [71]:
explicit_diaginv_probe(A)

array([ 8.57586596,  0.7267294 ,  0.85203419,  2.76756246,  0.2214591 ,
        7.20166679,  9.1764802 ,  0.93558599,  1.04791352, 20.07114549,
        1.19428919,  0.32162996,  0.46371162,  0.51850962, 12.66586429,
        3.18986117, 20.71376633,  1.05068503,  2.58187642,  3.59578936,
       10.51153871,  2.94140851, 29.81763933,  1.02600227,  2.55142791,
        5.37002981, 10.93699529,  3.3122536 ,  9.99018729,  1.58450314,
        0.46893677, 11.53063194,  1.38695438,  5.99991168,  0.64716075,
        1.90461363,  4.10883592,  4.04324138,  4.55086426,  1.96278145,
       29.01828566,  4.18241753,  1.54951477, 15.30661395,  0.80910166,
        4.12588324, 14.0889585 ,  8.9417561 ,  3.17568107,  0.83490542,
        1.67432212,  0.78163145,  1.04055025,  2.06221432,  1.56951583,
        1.70784948,  0.70855028,  1.89357841,  1.98412581,  1.59130454,
        9.43910867,  0.75570913, 13.78749808,  4.93897568,  5.20844652,
        4.90017951,  1.05144977,  0.79222167,  8.69030863,  5.25