### Implementing Pivoted Cholesky

- Without jax.lax.while_loop, since `m` decides the shape of different tensors, it is hard to use `jax.lax.while_loop` to implement pivoted Cholesky. One possibility is `jax.lax.dynamic_slice` but that also doesn't jit well.

In [5]:
import jax
import jax.numpy as jnp

from functools import partial
def pivoted_cholesky(kernel, x, max_rank, diag_rtol=1e-3, jitter=1e-3, name=None):
    n = x.shape[0]
    assert max_rank <= n

    orig_error = kernel.get_signal_scale() ** 2 + jitter
    print(f'orig_error: {orig_error}')
    matrix_diag = orig_error * jnp.ones((n,))

    m = 0
    pchol = jnp.zeros((max_rank, n))
    perm = jnp.arange(n)
    
    @partial(jax.jit, static_argnums=(0))
    def _body_fn(m, pchol, perm, matrix_diag):
        maxi = jnp.argmax(matrix_diag[perm[m:]]) + m
        maxval = matrix_diag[perm][maxi]

        perm = perm.at[..., [m, maxi]].set(perm[..., [maxi, m]])

        # TODO: Figure out where jitter gets added, only where row is computed for same index kernel_fn(i, i)
        row = kernel.kernel_fn(x[perm[m]], x[perm[m + 1:]]).squeeze()

        row -= jnp.sum(pchol[:m+1, perm[m + 1:]] * pchol[:m+1, perm[m:m+1]], axis=-2)
        pivot = jnp.sqrt(maxval)
        row /= pivot

        row = jnp.concatenate([pivot[None], row], axis=-1)
        matrix_diag = matrix_diag.at[perm[m:]].set(matrix_diag[perm[m:]] - row**2)

        pchol = pchol.at[m, perm[m:]].set(row)
        
        return pchol, perm, matrix_diag

    cond = True
    while cond:
        pchol, perm, matrix_diag = _body_fn(m, pchol, perm, matrix_diag)
        m = m + 1
        error = jnp.linalg.norm(matrix_diag, ord=1, axis=-1)
        max_err = jnp.max(error / orig_error)
        print(f'Iteration: {m}, error : {max_err}')
        cond = (m < max_rank) and (max_err > diag_rtol)
        
    
    pchol = jnp.swapaxes(pchol, -1, -2)
    return pchol

In [6]:
import jax
import tensorflow_probability as tfp
import jax.numpy as jnp
from kernels import Matern32Kernel, RBFKernel
import numpy as np

kernel = RBFKernel({'length_scale': jnp.array([0.3]), 'signal_scale': 0.4, 'noise_scale': 0.7})

N, rank = 10000, 100
jitter = 0.

x = jax.random.normal(jax.random.PRNGKey(1), (N, 1))

A = kernel.kernel_fn(x, x) + jitter * jnp.eye(N)

Lk = jnp.array(tfp.math.pivoted_cholesky(np.array(A), rank))

Lk_jax = pivoted_cholesky(kernel, x, rank, jitter=jitter)



orig_error: 0.16000000000000003
Iteration: 1, error : 8086.99267578125
Iteration: 2, error : 7220.5732421875
Iteration: 3, error : 6691.69677734375
Iteration: 4, error : 6668.92236328125
Iteration: 5, error : 6666.6201171875
Iteration: 6, error : 6578.4365234375
Iteration: 7, error : 4671.46435546875
Iteration: 8, error : 4666.80322265625
Iteration: 9, error : 4537.4189453125
Iteration: 10, error : 3460.97802734375
Iteration: 11, error : 3444.007080078125
Iteration: 12, error : 3119.874267578125
Iteration: 13, error : 2284.379150390625
Iteration: 14, error : 1137.1146240234375
Iteration: 15, error : 1029.4822998046875
Iteration: 16, error : 1024.9822998046875
Iteration: 17, error : 581.0130004882812
Iteration: 18, error : 524.5748291015625
Iteration: 19, error : 514.0660400390625
Iteration: 20, error : 513.4921875
Iteration: 21, error : 364.0638427734375
Iteration: 22, error : 356.0989990234375
Iteration: 23, error : 249.79214477539062
Iteration: 24, error : 249.14634704589844
Iteratio

In [7]:
print(jnp.linalg.norm(Lk - Lk_jax))

print(Lk - Lk_jax)

0.06732793
[[ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00 -2.71050543e-20 ... -1.15539682e-04
   1.08598848e-04  3.12456337e-04]
 ...
 [-2.58493941e-26  0.00000000e+00  0.00000000e+00 ...  1.03876140e-04
  -1.15495066e-04 -1.09427077e-04]
 [-1.49011612e-08  1.77635684e-15  0.00000000e+00 ... -4.35540351e-05
  -1.52960900e-04 -1.27421954e-05]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  1.92492848e-06
   1.05673236e-04  1.55344766e-04]]
