In [50]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import vmap
from jax import jit, lax
import matplotlib.pyplot as plt
import cmocean as cmo
import importlib

In [51]:
from jax.config import config
config.update("jax_enable_x64", True)

In [102]:
## import modules
import preconditioner as precond
import conjugate_gradient as cg
import pivoted_cholesky as pc
import pivoted_cholesky_ref as pc_ref # to use this script we need "torch", please comment out if not needed.
import calc_logdet
import calc_trace
def reload():
    importlib.reload(precond)
    importlib.reload(cg)
    importlib.reload(pc)
    importlib.reload(pc_ref)
    importlib.reload(calc_logdet)
    importlib.reload(calc_trace)
reload()

In [53]:
import gpytorch
import torch
import linear_operator
from linear_operator.operators import (
    AddedDiagLinearOperator,
    DiagLinearOperator,
    LinearOperator,
    DenseLinearOperator,
)

In [54]:
def generate_K(N, seed=0, noise=1e-06):
    """
    generate positive definite symmetric matrix
    """
    K = jax.random.normal(jax.random.PRNGKey(seed), (N, N))
    # K = K @ K.T + 30* jnp.eye(N) + noise*jnp.eye(N)
    # K = jnp.dot(K, K.T) + noise*jnp.eye(N)
    # K = jnp.dot(K, K.T) / N
    K = jnp.dot(K, K.T)/N
    # K += (noise+30)*jnp.eye(N) ## ??
    K += (5)*jnp.eye(N)
    K += (noise)*jnp.eye(N)
    if not is_positive_definite(K):
        raise Exception("K is not positive definite !")
    return K

In [55]:
def is_positive_definite(matrix):
    # 行列の固有値を計算
    eigenvalues = np.linalg.eigvals(matrix)

    # 全ての固有値が正であるかをチェック
    if np.all(eigenvalues > 0):
        return True
    else:
        return False

In [56]:
def rel_error(true, pred):
    nonzero_index = jnp.where(true != 0.)
    true = true[nonzero_index]
    pred = pred[nonzero_index]
    return jnp.mean(jnp.abs((true-pred)/true))

## 5. modified preconditioned conjugate gradient

### N=200

In [105]:
N = 200
rank=15
noise = 1e-06
n_tridiag = 10
K = generate_K(N)
y = jax.random.normal(key=jax.random.PRNGKey(0), shape=(N,1))
zs = jax.random.normal(jax.random.PRNGKey(0), (N, n_tridiag))
K_torch = torch.from_numpy(np.array(K))
rhs = jnp.concatenate([zs, y], axis=1)
rhs_torch = torch.from_numpy(np.array(rhs))

In [106]:
Kinvy_linalg = jnp.linalg.solve(K, rhs)

In [107]:
precondition, precond_lt, precond_logdet_cache = precond.setup_preconditioner(K, rank=rank, noise=noise)
Kinvy, t_mat = cg.mpcg_bbmm(K, rhs, precondition=precondition, print_process=True, tolerance=1, n_tridiag=n_tridiag)
print(rel_error(Kinvy_linalg, Kinvy))

j=0 r1norm: 0.04779651779671358
j=1 r1norm: 0.0070787538528101985
j=2 r1norm: 0.0010483181904175681
j=3 r1norm: 0.00015188492381236252
j=4 r1norm: 2.1406674187782856e-05
j=5 r1norm: 3.035221326906993e-06
j=6 r1norm: 4.3804727259127377e-07
j=7 r1norm: 6.439827677567317e-08
j=8 r1norm: 9.139817103941224e-09
j=9 r1norm: 1.2921999962780968e-09
j=10 r1norm: 1.9049306333819918e-10
j=11 r1norm: 2.658450826418492e-11
j=12 r1norm: 3.7855524678967105e-12
j=13 r1norm: 5.247903752545748e-13
j=14 r1norm: 7.537008125950595e-14
j=15 r1norm: 1.0463439833023137e-14
j=16 r1norm: 1.4900959429290623e-15
j=17 r1norm: 2.0316301585777383e-16
j=18 r1norm: 2.909394311291974e-17
j=19 r1norm: 4.0592935991440206e-18
j=20 r1norm: 5.526383546131225e-19
converged
1.8058385748678646e-15


In [108]:
K_linear_op = linear_operator.to_linear_operator(K_torch)
diag_tensor = torch.ones(N, dtype=torch.float64)*noise
diag_linear_op = DiagLinearOperator(diag_tensor)
added_diag = AddedDiagLinearOperator(K_linear_op, diag_linear_op)

In [73]:
preconditioner_torch, _, precond_logdet_torch = added_diag._preconditioner()
Kinvy_torch, t_mat_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch, num_tridiag=n_tridiag)
print(rel_error(Kinvy_linalg, Kinvy_torch.numpy()))
linear_operator.settings.cg_tolerance._set_value(1)

1.2146112032319772e-05


In [109]:
%%timeit
precondition, precond_lt, precond_logdet_cache = precond.setup_preconditioner(K, rank=rank, noise=noise)
Kinvy = cg.mpcg_bbmm(K, rhs, precondition=precondition, print_process=False, tolerance=1, n_tridiag=n_tridiag)

245 ms ± 898 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [92]:
%%timeit
preconditioner_torch, _, _ = added_diag._preconditioner()
Kinvy_torch, t_mat = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch, num_tridiag=n_tridiag)

3.88 ms ± 56.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [224]:
%%timeit
Kinvy_linalg = jnp.linalg.solve(K, rhs)

1.07 ms ± 85.6 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [110]:
calc_logdet.calc_logdet(K.shape, t_mat, preconditioner)

DeviceArray(353.07246999, dtype=float64)

In [111]:
added_diag.inv_quad_log_det(inv_quad_rhs=rhs_torch[:, -1:], logdet=True)

(tensor(34.8014, dtype=torch.float64), tensor(355.7268, dtype=torch.float64))

In [112]:
eval_torch, evec_torch = linear_operator.utils.lanczos.lanczos_tridiag_to_diag(t_mat_torch)
slq = linear_operator.utils.stochastic_lq.StochasticLQ()
(logdet_term,) = slq.to_dense(added_diag.matrix_shape, eval_torch, evec_torch,  [lambda x: x.log()])

In [113]:
logdet_term

tensor(353.8197, dtype=torch.float64)

In [114]:
dKdtheta = jax.random.normal(jax.random.PRNGKey(1), (N, N))+5*jnp.eye(N)

In [115]:
## trace term
calc_trace.calc_trace(Kinvy, dKdtheta, zs, n_tridiag)

DeviceArray(162.27705925, dtype=float64)

In [116]:
Kinv_linalg = jnp.linalg.inv(K)
trace_linalg = jnp.sum(jnp.diag(jnp.matmul(Kinv_linalg, dKdtheta)))
trace_linalg

DeviceArray(172.82511299, dtype=float64)

### N=5000

In [117]:
N = 5000
rank=5
noise = 1e-06
n_tridiag = 10
K = generate_K(N)
y = jax.random.normal(key=jax.random.PRNGKey(0), shape=(N,1))
zs = jax.random.normal(jax.random.PRNGKey(0), (N, n_tridiag))
K_torch = torch.from_numpy(np.array(K))
rhs = jnp.concatenate([zs, y], axis=1)
rhs_torch = torch.from_numpy(np.array(rhs))

In [119]:
Kinvy_linalg = jnp.linalg.solve(K, rhs)

In [120]:
precondition, precond_lt, precond_logdet_cache = precond.setup_preconditioner(K, rank=rank, noise=noise)
Kinvy, t_mat = cg.mpcg_bbmm(K, rhs, precondition=precondition, print_process=True, tolerance=1, n_tridiag=n_tridiag)
print(rel_error(Kinvy_linalg, Kinvy))

j=0 r1norm: 0.051392514398530424
j=1 r1norm: 0.01333079274852888
j=2 r1norm: 0.011034137787164111
j=3 r1norm: 0.01097629305048675
j=4 r1norm: 0.01097478115420941
j=5 r1norm: 0.010974764171941301
j=6 r1norm: 0.010974756509927647
j=7 r1norm: 0.010974653805830958
j=8 r1norm: 0.010971016807683462
j=9 r1norm: 0.010799128364232786
j=10 r1norm: 0.007002732823990947
j=11 r1norm: 0.0014046497253585777
j=12 r1norm: 0.0002082332161509357
j=13 r1norm: 3.0311422732543305e-05
j=14 r1norm: 4.422529618970216e-06
j=15 r1norm: 6.451839428772055e-07
j=16 r1norm: 9.363943904745411e-08
j=17 r1norm: 1.3609670318680369e-08
j=18 r1norm: 1.9853718283074162e-09
j=19 r1norm: 2.9154399393473475e-10
j=20 r1norm: 5.871076741373119e-11
converged
1.000818364353847e-09


In [26]:
K_linear_op = linear_operator.to_linear_operator(K_torch)
diag_tensor = torch.ones(N, dtype=torch.float64)*noise
diag_linear_op = DiagLinearOperator(diag_tensor)
added_diag = AddedDiagLinearOperator(K_linear_op, diag_linear_op)

In [27]:
preconditioner_torch, _, precond_logdet_torch = added_diag._preconditioner()
Kinvy_torch, t_mat_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch, num_tridiag=n_tridiag)
print(rel_error(Kinvy_linalg, Kinvy_torch.numpy()))
linear_operator.settings.cg_tolerance._set_value(1)

3.5123671374860206e-07


In [121]:
%%timeit
precondition, precond_lt, precond_logdet_cache = precond.setup_preconditioner(K, rank=rank, noise=noise)
Kinvy, t_mat = cg.mpcg_bbmm(K, rhs, precondition=precondition, print_process=False, tolerance=1, n_tridiag=n_tridiag)

325 ms ± 6.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [235]:
%%timeit
preconditioner_torch, _, _ = added_diag._preconditioner()
Kinvy_torch, t_mat_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch, num_tridiag=n_tridiag)

1.09 s ± 2.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [236]:
%%timeit
Kinvy_linalg = jnp.linalg.solve(K, rhs)

306 ms ± 39.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [28]:
calc_logdet.calc_logdet(K.shape, t_mat, preconditioner)

DeviceArray(8953.09277797, dtype=float64)

In [29]:
added_diag.inv_quad_log_det(inv_quad_rhs=rhs_torch[:, -1:], logdet=True)

(tensor(869.8112, dtype=torch.float64), tensor(8869.7888, dtype=torch.float64))

In [30]:
eval_torch, evec_torch = linear_operator.utils.lanczos.lanczos_tridiag_to_diag(t_mat_torch)
slq = linear_operator.utils.stochastic_lq.StochasticLQ()
(logdet_term,) = slq.to_dense(added_diag.matrix_shape, eval_torch, evec_torch,  [lambda x: x.log()])

In [31]:
logdet_term + precond_logdet_torch

tensor(9127.5788, dtype=torch.float64)

In [32]:
## trace term
dKdtheta = jax.random.normal(jax.random.PRNGKey(1), (N, N))+5*jnp.eye(N)

In [33]:
calc_trace.calc_trace(Kinvy, dKdtheta, zs, n_tridiag)

DeviceArray(5083.61541929, dtype=float64)

In [34]:
Kinv_linalg = jnp.linalg.inv(K)
trace_linalg = jnp.sum(jnp.diag(jnp.matmul(Kinv_linalg, dKdtheta)))
trace_linalg

DeviceArray(4255.61158368, dtype=float64)

### N=10000

In [209]:
N = 10000
rank=15
noise = 1e-06
n_tridiag = 10
K = generate_K(N)
y = jax.random.normal(key=jax.random.PRNGKey(0), shape=(N,1))
zs = jax.random.normal(jax.random.PRNGKey(0), (N, n_tridiag))
K_torch = torch.from_numpy(np.array(K))
rhs = jnp.concatenate([zs, y], axis=1)
rhs_torch = torch.from_numpy(np.array(rhs))

KeyboardInterrupt: 

In [124]:
Kinvy_linalg = jnp.linalg.solve(K, rhs)

In [185]:
precondition, precond_lt, precond_logdet_cache = precond.setup_preconditioner(K, rank=rank, noise=noise)
Kinvy = cg.mpcg_bbmm(K, rhs, precondition=precondition, print_process=True, tolerance=1, n_tridiag=n_tridiag)
print(rel_error(Kinvy_linalg, Kinvy))

j=0 r1norm: 0.05273127789083324
j=1 r1norm: 0.018229989448990253
j=2 r1norm: 0.016712777654460265
j=3 r1norm: 0.016678237420476163
j=4 r1norm: 0.01667748463554563
j=5 r1norm: 0.016677473148952585
j=6 r1norm: 0.016677456685424495
j=7 r1norm: 0.016677112063466847
j=8 r1norm: 0.016665606870972682
j=9 r1norm: 0.016125945002406064
j=10 r1norm: 0.008159012203917657
j=11 r1norm: 0.0013833460631000965
j=12 r1norm: 0.00020257472772172632
j=13 r1norm: 2.942618446057154e-05
j=14 r1norm: 4.2913264357712855e-06
j=15 r1norm: 6.251606662586217e-07
j=16 r1norm: 9.07054767621244e-08
converged
2.6688080648450045e-06


In [186]:
K_linear_op = linear_operator.to_linear_operator(K_torch)
diag_tensor = torch.ones(N, dtype=torch.float64)*noise
diag_linear_op = DiagLinearOperator(diag_tensor)
added_diag = AddedDiagLinearOperator(K_linear_op, diag_linear_op)

In [187]:
preconditioner_torch, _, precond_logdet_torch = added_diag._preconditioner()
Kinvy_torch, t_mat_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch, num_tridiag=n_tridiag)
print(rel_error(Kinvy_linalg, Kinvy_torch.numpy()))
linear_operator.settings.cg_tolerance._set_value(1)

3.5123671374860206e-07


In [176]:
%%timeit
precondition, precond_lt, precond_logdet_cache = precond.setup_preconditioner(K, rank=rank, noise=noise)
Kinvy, t_mat = cg.mpcg_bbmm(K, rhs, precondition=precondition, print_process=True, tolerance=1, n_tridiag=n_tridiag)

298 ms ± 2.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [170]:
%%timeit
preconditioner_torch, _, _ = added_diag._preconditioner()
Kinvy_torch, t_mat = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch, num_tridiag=n_tridiag)

1.09 s ± 3.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [171]:
%%timeit
Kinvy_linalg = jnp.linalg.solve(K, rhs)

306 ms ± 615 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [188]:
calc_logdet.calc_logdet(K.shape, t_mat, preconditioner)

DeviceArray(9108.50948291, dtype=float64)

In [189]:
added_diag.inv_quad_log_det(inv_quad_rhs=rhs_torch[:, -1:], logdet=True)

(tensor(869.8112, dtype=torch.float64), tensor(8848.4105, dtype=torch.float64))

In [190]:
eval_torch, evec_torch = linear_operator.utils.lanczos.lanczos_tridiag_to_diag(t_mat_torch)
slq = linear_operator.utils.stochastic_lq.StochasticLQ()
(logdet_term,) = slq.to_dense(added_diag.matrix_shape, eval_torch, evec_torch,  [lambda x: x.log()])

In [191]:
logdet_term + precond_logdet_torch

tensor(9127.5788, dtype=torch.float64)