### Implementing Pivoted Cholesky

- Without jax.lax.while_loop, since `m` decides the shape of different tensors, it is hard to use `jax.lax.while_loop` to implement pivoted Cholesky. One possibility is `jax.lax.dynamic_slice` but that also doesn't jit well.

In [3]:
import jax
import tensorflow_probability as tfp
import jax.numpy as jnp
from scalable_gps.kernels import Matern32Kernel
import numpy as np
from scalable_gps.linalg_utils import pivoted_cholesky

kernel = Matern32Kernel({'length_scale': jnp.array([0.3]), 'signal_scale': 0.4, 'noise_scale': 0.7})

N, rank = 10000, 100
jitter = 1.

x = jax.random.normal(jax.random.PRNGKey(1), (N, 1))

A = kernel.kernel_fn(x, x) + jitter * jnp.eye(N)

Lk = jnp.array(tfp.math.pivoted_cholesky(np.array(A), rank))

Lk_jax = pivoted_cholesky(kernel, x, rank, jitter=jitter)



orig_error: 1.1600000000000001
Iteration: 1, error : 9969.2314453125
Iteration: 2, error : 9954.9658203125
Iteration: 3, error : 9952.06640625
Iteration: 4, error : 9951.048828125
Iteration: 5, error : 9949.98828125
Iteration: 6, error : 9947.6162109375
Iteration: 7, error : 9934.2109375
Iteration: 8, error : 9903.54296875
Iteration: 9, error : 9902.224609375
Iteration: 10, error : 9900.9462890625
Iteration: 11, error : 9894.2822265625
Iteration: 12, error : 9887.927734375
Iteration: 13, error : 9866.775390625
Iteration: 14, error : 9845.6884765625
Iteration: 15, error : 9815.50390625
Iteration: 16, error : 9814.001953125
Iteration: 17, error : 9812.8671875
Iteration: 18, error : 9811.3955078125
Iteration: 19, error : 9802.82421875
Iteration: 20, error : 9799.7138671875
Iteration: 21, error : 9798.72265625
Iteration: 22, error : 9797.66015625
Iteration: 23, error : 9789.916015625
Iteration: 24, error : 9768.8232421875
Iteration: 25, error : 9765.3017578125
Iteration: 26, error : 9752.3

In [4]:
print(jnp.linalg.norm(Lk - Lk_jax))

print(Lk - Lk_jax)

6.5183473
[[ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00 -1.8626451e-09  0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 ...
 [ 0.0000000e+00 -2.3283064e-10  0.0000000e+00 ... -1.6374457e-11
   5.3776428e-17 -2.3010760e-13]
 [ 0.0000000e+00  0.0000000e+00 -7.2759576e-12 ... -5.7043508e-09
  -2.0780135e-08  7.9050660e-06]
 [ 0.0000000e+00 -3.7252903e-09  0.0000000e+00 ... -1.7350976e-07
   2.3092639e-13 -7.3850970e-10]]
