In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import cho_factor as scipy_chol_fac
from scipy.linalg import cho_solve as scipy_chol_solve

In [12]:
# Generate a random matrix A
n = 100
A = np.random.normal(size=(n,n))
A = A.T @ A + 1e-3

In [3]:
chol = scipy_chol_fac(A)

In [4]:
z = np.ones(n)

In [5]:
scipy_chol_solve(chol, z) - np.linalg.solve(A, z)

array([ 1.93622895e-12, -1.30899735e-11,  1.37809764e-11, -1.52178270e-12,
        4.22284430e-12, -4.83435514e-12, -1.01785247e-12, -6.04138961e-12,
        1.49231738e-11,  3.98259203e-11, -9.27613542e-12, -9.15889586e-12,
        6.88604729e-12,  1.54756208e-11, -1.14397380e-11,  8.64197602e-12,
       -3.75521836e-11, -1.10063070e-11,  3.04964942e-11,  2.12132534e-11,
        2.75122147e-11, -7.89945886e-12, -1.12088117e-12,  1.50421897e-11,
       -2.11137774e-11, -2.47224463e-12,  2.94697600e-12, -1.47135637e-11,
        3.00870440e-12, -7.06212866e-13,  6.67732536e-12,  1.47224455e-11,
        2.75051093e-11, -6.81926737e-13,  4.09983159e-11,  1.11040066e-11,
        7.35278505e-12, -1.39657175e-11, -1.94688710e-11, -7.81952281e-12,
        3.26636496e-11, -1.00719433e-11, -2.63042921e-11,  4.32454073e-12,
       -1.63886682e-11, -4.87521135e-12,  7.87014898e-12,  2.55724331e-11,
        1.32782674e-11, -1.39834810e-11,  6.85451695e-13, -9.67226299e-12,
        1.07238662e-11,  

In [6]:
def naive_diaginv(A, sample_size=1000, method="cholesky"):
    """Naive unbiased estimator for the diagonal of an inverse matrix, see [5]. A must be SPD.
    """

    valid_methods = ["cholesky"]
    assert method in valid_methods, f"method must be one of {valid_methods}"

    # Get shape
    n = A.shape[0]

    # Setup
    diaginv_estimate = np.zeros(n)
    tk = np.zeros(n)
    qk = np.zeros(n)

    if method == "cholesky":
        chol = scipy_chol_fac(A)
    else:
        raise NotImplementedError

    for j in range(sample_size):
        
        # Draw random vector
        vk = np.random.choice([-1, 1], size=n)

        # Update tk
        if method == "cholesky":
            Ainv_vk = scipy_chol_solve(chol, vk)
        else:
            raise NotImplementedError
    
        tk = tk + (( Ainv_vk ) * vk)

        # Update qk
        qk = qk + (vk*vk)

        # Update diag_estimate
        diaginv_estimate = tk / qk

    return diaginv_estimate

In [7]:
true_diaginv = np.diag(np.linalg.inv(A))

In [8]:
true_diaginv[:5]

array([0.79306065, 1.3395026 , 1.49225971, 0.34982241, 0.24981129])

In [9]:
diaginv_est = naive_diaginv(A, sample_size=10000)

In [10]:
diaginv_est[:5]

array([0.82199926, 1.24619743, 1.4366137 , 0.36932097, 0.38807513])

# Chatgpt method?

In [15]:
import numpy as np

def estimate_diagonal_inverse(A, n_iter=100):
    n = A.shape[0]
    x = np.random.rand(n)
    y = np.zeros(n)

    for _ in range(n_iter):
        z = np.dot(A, x)
        y += z * z
        z /= np.dot(x, z)
        x = z

    diagonal_inverse = 1.0 / y

    return diagonal_inverse

# Example usage
diagonal_inverse = estimate_diagonal_inverse(A)
print(diagonal_inverse)


[5.87924915e-06 1.43336100e-06 1.36202003e-06 6.93958530e-05
 3.53325113e-06 1.24925900e-05 6.60042141e-06 1.63201599e-06
 4.02721173e-07 1.58842288e-06 6.54198452e-07 1.55487764e-06
 3.34966332e-07 7.86422929e-07 6.22357304e-06 3.92576390e-06
 3.45962091e-07 1.19053432e-05 8.76829610e-07 1.52428887e-07
 1.68922239e-06 6.98286768e-06 4.28942883e-06 4.55346637e-07
 2.17170195e-07 1.08801063e-06 2.87356692e-06 7.44146409e-07
 6.30145268e-06 1.62683903e-07 5.46320155e-06 2.65952833e-07
 1.42308872e-06 2.89783742e-06 4.69089936e-07 3.79828858e-07
 8.83071997e-06 2.35935462e-07 9.28359808e-06 9.31397237e-07
 3.46448948e-07 1.82221484e-07 1.14236583e-06 4.32079970e-06
 2.40692082e-07 2.72933046e-06 1.84046088e-06 2.13368586e-07
 8.70258013e-07 2.19173150e-06 5.74187007e-06 4.76109485e-06
 2.43359422e-06 8.66781192e-06 4.90716465e-06 1.50850447e-05
 2.45040577e-06 7.60788657e-07 5.95052433e-07 9.99976608e-07
 7.04871893e-07 4.13553237e-07 2.07701181e-05 1.06241263e-06
 2.16597434e-06 6.206934

# Explicit diag inverse probe

In [69]:
def explicit_diaginv_probe(A, method="cholesky"):
    """Computes the diagonal of inv(A) using an explicit probe. A must be SPD.
    """

    valid_methods = ["cholesky"]
    assert method in valid_methods, f"method must be one of {valid_methods}"

    # Setup
    n = A.shape[0]
    diagonal_inv = np.zeros(n)

    if method == "cholesky":
        chol = scipy_chol_fac(A)
    else:
        raise NotImplementedError

    for j in range(n):

        # jth column of the identity
        w = np.zeros(n)
        w[j] = 1.0

        # Compute w^T inv(A) w
        if method == "cholesky":
            Ainv_w = scipy_chol_solve(chol, w)
        else:
            raise NotImplementedError

        diagonal_inv[j] = w.T @ Ainv_w

    return diagonal_inv

In [70]:
true_diaginv

array([ 8.57586596,  0.7267294 ,  0.85203419,  2.76756246,  0.2214591 ,
        7.20166679,  9.1764802 ,  0.93558599,  1.04791352, 20.07114549,
        1.19428919,  0.32162996,  0.46371162,  0.51850962, 12.66586429,
        3.18986117, 20.71376633,  1.05068503,  2.58187642,  3.59578936,
       10.51153871,  2.94140851, 29.81763933,  1.02600227,  2.55142791,
        5.37002981, 10.93699529,  3.3122536 ,  9.99018729,  1.58450314,
        0.46893677, 11.53063194,  1.38695438,  5.99991168,  0.64716075,
        1.90461363,  4.10883592,  4.04324138,  4.55086426,  1.96278145,
       29.01828566,  4.18241753,  1.54951477, 15.30661395,  0.80910166,
        4.12588324, 14.0889585 ,  8.9417561 ,  3.17568107,  0.83490542,
        1.67432212,  0.78163145,  1.04055025,  2.06221432,  1.56951583,
        1.70784948,  0.70855028,  1.89357841,  1.98412581,  1.59130454,
        9.43910867,  0.75570913, 13.78749808,  4.93897568,  5.20844652,
        4.90017951,  1.05144977,  0.79222167,  8.69030863,  5.25

In [71]:
explicit_diaginv_probe(A)

array([ 8.57586596,  0.7267294 ,  0.85203419,  2.76756246,  0.2214591 ,
        7.20166679,  9.1764802 ,  0.93558599,  1.04791352, 20.07114549,
        1.19428919,  0.32162996,  0.46371162,  0.51850962, 12.66586429,
        3.18986117, 20.71376633,  1.05068503,  2.58187642,  3.59578936,
       10.51153871,  2.94140851, 29.81763933,  1.02600227,  2.55142791,
        5.37002981, 10.93699529,  3.3122536 ,  9.99018729,  1.58450314,
        0.46893677, 11.53063194,  1.38695438,  5.99991168,  0.64716075,
        1.90461363,  4.10883592,  4.04324138,  4.55086426,  1.96278145,
       29.01828566,  4.18241753,  1.54951477, 15.30661395,  0.80910166,
        4.12588324, 14.0889585 ,  8.9417561 ,  3.17568107,  0.83490542,
        1.67432212,  0.78163145,  1.04055025,  2.06221432,  1.56951583,
        1.70784948,  0.70855028,  1.89357841,  1.98412581,  1.59130454,
        9.43910867,  0.75570913, 13.78749808,  4.93897568,  5.20844652,
        4.90017951,  1.05144977,  0.79222167,  8.69030863,  5.25