In [53]:
import jax
import jax.numpy as jnp

from functools import partial
def pivoted_cholesky(kernel, x, max_rank, diag_rtol=1e-3, jitter=1e-3, name=None):
    n = x.shape[0]
    assert max_rank <= n

    orig_error = kernel.get_signal_scale() ** 2 + jitter
    print(f'orig_error: {orig_error}')
    matrix_diag = orig_error * jnp.ones((n,))

    m = 0
    pchol = jnp.zeros((max_rank, n))
    perm = jnp.arange(n)
    
    @partial(jax.jit, static_argnums=(0))
    def _body_fn(m, pchol, perm, matrix_diag):
        maxi = jnp.argmax(matrix_diag[perm[m:]]) + m
        maxval = matrix_diag[perm, maxi]

        perm = perm.at[..., [m, maxi]].set(perm[..., [maxi, m]])

        row = kernel.kernel_fn(x[perm[m]], x[perm[m + 1:]]).squeeze()
        print(row.shape)

        row -= jnp.sum(pchol[:m+1, perm[m + 1:]] * pchol[:m+1, perm[m:m+1]], axis=-2)
       
        pivot = jnp.sqrt(maxval)
        row /= pivot
        
        row = jnp.concatenate([pivot[None], row], axis=-1)

        matrix_diag = matrix_diag.at[perm[m:]].set(matrix_diag[perm[m:]] - row**2)

        pchol = pchol.at[m, perm[m:]].set(row)
        
        return pchol, perm, matrix_diag

    cond = True
    while cond:
        pchol, perm, matrix_diag = _body_fn(m, pchol, perm, matrix_diag)
        m = m + 1
        error = jnp.linalg.norm(matrix_diag, ord=1, axis=-1)
        max_err = jnp.max(error / orig_error)
        print(m, max_err)
        cond = (m < max_rank) and (max_err > diag_rtol)
        
    
    pchol = jnp.swapaxes(pchol, -1, -2)
    return pchol

In [54]:

# from linalg_utils import pivoted_cholesky
import jax
import tensorflow_probability as tfp
import jax.numpy as jnp
from kernels import Matern32Kernel, RBFKernel
import numpy as np

kernel = RBFKernel({'length_scale': jnp.array([0.3]), 'signal_scale': 0.4, 'noise_scale': 0.7})

x = jax.random.normal(jax.random.PRNGKey(1), (10000, 1))

# x = jnp.linspace(0, 1, 1000)[:, None]

A = kernel.kernel_fn(x, x) + 1e-6 * jnp.eye(10000)

# print()
rank = 100
Lk = jnp.array(tfp.math.pivoted_cholesky(np.array(A), rank))

Lk_jax = pivoted_cholesky(kernel, x, rank, jitter=1e-6)



orig_error: 0.16000100000000003
(9999,)
1 8087.016
(9998,)
2 7220.6084
(9997,)
3 6691.7383
(9996,)
4 6668.965
(9995,)
5 6666.662
(9994,)
6 6578.479
(9993,)
7 4671.5312
(9992,)
8 4666.87
(9991,)
9 4537.487
(9990,)
10 3461.0593
(9989,)
11 3444.0886
(9988,)
12 3119.9595
(9987,)
13 2284.4749
(9986,)
14 1137.2231
(9985,)
15 1029.5929
(9984,)
16 1025.0931
(9983,)
17 581.1304
(9982,)
18 524.6941
(9981,)
19 514.1856
(9980,)
20 513.61176
(9979,)
21 364.18307
(9978,)
22 356.21872
(9977,)
23 249.91315
(9976,)
24 249.2674
(9975,)
25 99.2229
(9974,)
26 54.993763
(9973,)
27 22.583582
(9972,)
28 22.30633
(9971,)
29 20.020258
(9970,)
30 12.343893
(9969,)
31 5.6381707
(9968,)
32 5.0802236
(9967,)
33 1.7481644
(9966,)


KeyboardInterrupt: 

In [51]:
print(jnp.linalg.norm(Lk - Lk_jax))

print(Lk - Lk_jax)

0.031477675
[[0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 1.1641532e-10 ... 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 ...
 [0.0000000e+00 2.9802322e-08 1.1641532e-10 ... 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [1.1641532e-10 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00
  0.0000000e+00 0.0000000e+00]]
