In [124]:
import jax.numpy as jnp
import jax
from jax import random
from jax.config import config
import scipy
import time
config.update('jax_platform_name', 'cpu')
config.update("jax_enable_x64", True)


# Example usage
rng_key = random.PRNGKey(int(time.time()))
matrix = random.normal(rng_key, (300, 300))
matrix = matrix @ matrix.T  # Make it positive definite
eps = 1.0

true_inv = jnp.linalg.inv(matrix + eps * jnp.eye(matrix.shape[0]))


def nystrom_inv(matrix, eps):
    rng_key = random.PRNGKey(int(time.time()))
    n = matrix.shape[0]
    m = int(n / 2)
    matrix_mean = jnp.mean(matrix)
    print(matrix_mean)
    matrix = matrix / matrix_mean  # Scale the matrix to avoid numerical issues

    # Randomly select m columns
    idx = jax.random.choice(rng_key, n, (m, ), replace=False)

    W = matrix[idx, :][:, idx]
    U, s, V = jnp.linalg.svd(W)

    U_recon = jnp.sqrt(m / n) * matrix[:, idx] @ U @ jnp.diag(1. / s)
    S_recon = s * (n / m)

    Sigma_inv = (1. / eps) * jnp.eye(n)
    approx_inv = Sigma_inv - Sigma_inv @ U_recon @ jnp.linalg.inv(jnp.diag(1. / S_recon) + U_recon.T @ Sigma_inv @ U_recon) @ U_recon.T @ Sigma_inv
    approx_inv = approx_inv / matrix_mean  # Don't forget the scaling!
    return approx_inv

approx_inv = nystrom_inv(matrix, eps)


1.3199613975955582


In [125]:
approx_inv @ (matrix + eps * jnp.eye(matrix.shape[0]))

Array([[ 1.08978205e+02, -1.23373788e+01, -8.13510099e+00, ...,
         4.04342212e-03, -3.56795748e-03,  1.29558924e+01],
       [-1.63963069e+01,  1.10719134e+02,  7.12116009e+00, ...,
        -7.74298233e-04, -7.67300391e-03,  4.33908983e+00],
       [-6.51365148e+00,  4.71249212e+00,  9.51654783e+01, ...,
         1.04407898e-02, -2.05619611e-03,  2.54714329e+00],
       ...,
       [-3.02762780e+00,  1.94127847e+00, -6.08021967e+00, ...,
         9.20572149e-01,  5.25785276e-03, -5.66609409e+00],
       [ 7.30895218e+00,  8.93898723e-02, -1.85999542e-01, ...,
         5.25785276e-03,  9.03769115e-01,  3.99285519e-01],
       [ 1.33902740e+01,  3.18381293e+00, -3.63972978e+00, ...,
         9.76412491e-03, -2.22158272e-03,  8.58261381e+01]],      dtype=float64)

In [117]:
true_inv @ (matrix + eps * jnp.eye(matrix.shape[0]))

Array([[ 1.00000000e+00,  1.01307851e-15,  8.34228520e-15, ...,
        -1.79977561e-15,  7.07767178e-16, -1.22124533e-15],
       [-1.07552856e-15,  1.00000000e+00, -7.20951077e-15, ...,
         9.12464548e-16,  3.02535774e-15, -1.33226763e-15],
       [-2.19269047e-15,  1.08246745e-15,  1.00000000e+00, ...,
         6.80011603e-16, -1.88737914e-15,  8.88178420e-16],
       ...,
       [-9.06284596e-16, -5.41190356e-15, -1.59377719e-16, ...,
         1.00000000e+00, -2.26164573e-16,  1.30104261e-15],
       [ 2.28289609e-15, -4.64905892e-16,  7.14706072e-16, ...,
        -9.20270804e-16,  1.00000000e+00, -3.33066907e-16],
       [ 2.77555756e-16,  0.00000000e+00, -2.85882429e-15, ...,
        -2.91433544e-16, -2.66453526e-15,  1.00000000e+00]],      dtype=float64)