In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import vmap
from jax import jit, lax
import matplotlib.pyplot as plt
import cmocean as cmo
import importlib

In [2]:
from jax.config import config
config.update("jax_enable_x64", True)

In [3]:
## import modules
import preconditioner as precond
import conjugate_gradient as cg
import pivoted_cholesky as pc
import pivoted_cholesky_ref as pc_ref # to use this script we need "torch", please comment out if not needed.
def reload():
    importlib.reload(precond)
    importlib.reload(cg)
    importlib.reload(pc)
    importlib.reload(pc_ref)
reload()

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import gpytorch
import torch
import linear_operator
from linear_operator.operators import (
    AddedDiagLinearOperator,
    DiagLinearOperator,
    LinearOperator,
    DenseLinearOperator,
)

In [5]:
def generate_K(N, seed=0, noise=1e-06):
    """
    generate positive definite symmetric matrix
    """
    K = jax.random.normal(jax.random.PRNGKey(seed), (N, N))
    # K = K @ K.T + 30* jnp.eye(N) + noise*jnp.eye(N)
    # K = jnp.dot(K, K.T) + noise*jnp.eye(N)
    # K = jnp.dot(K, K.T) / N
    K = jnp.dot(K, K.T)/N
    # K += (noise+30)*jnp.eye(N) ## ??
    K += (5)*jnp.eye(N)
    K += (noise)*jnp.eye(N)
    if not is_positive_definite(K):
        raise Exception("K is not positive definite !")
    return K

In [6]:
def is_positive_definite(matrix):
    # 行列の固有値を計算
    eigenvalues = np.linalg.eigvals(matrix)

    # 全ての固有値が正であるかをチェック
    if np.all(eigenvalues > 0):
        return True
    else:
        return False

In [7]:
def rel_error(true, pred):
    nonzero_index = jnp.where(true != 0.)
    true = true[nonzero_index]
    pred = pred[nonzero_index]
    return jnp.mean(jnp.abs((true-pred)/true))

## 4. modified preconditioned conjugate gradient

In [8]:
N = 100
rank=15
noise = 1e-06
n_tridiag = 10
K = generate_K(N)
y = jax.random.normal(key=jax.random.PRNGKey(0), shape=(N,1))
zs = jax.random.normal(jax.random.PRNGKey(0), (N, n_tridiag))
preconditioner = precond.Preconditioner(K, rank=rank, noise=noise)
K_torch = torch.from_numpy(np.array(K))

In [9]:
rhs = jnp.concatenate([zs, y], axis=1)
rhs_torch = torch.from_numpy(np.array(rhs))

In [10]:
K_linear_op = linear_operator.to_linear_operator(K_torch)
diag_tensor = torch.ones(N, dtype=torch.float64)*noise
diag_linear_op = DiagLinearOperator(diag_tensor)
added_diag = AddedDiagLinearOperator(K_linear_op, diag_linear_op)

In [11]:
Kinvy_linalg = jnp.linalg.solve(K, rhs)

In [12]:
preconditioner = precond.Preconditioner(K, rank, noise)
Kinvy = cg.bcg_bbmm(K, rhs, preconditioner=preconditioner, print_process=False, tolerance=1)
print(rel_error(Kinvy_linalg, Kinvy))

2.796436867782459e-09


In [13]:
linear_operator.settings.cg_tolerance._set_value(1)
Kinvy_torch = added_diag.solve(rhs_torch)
print(rel_error(Kinvy_linalg, Kinvy_torch.numpy()))
linear_operator.settings.cg_tolerance._set_value(1)

2.475563402797503e-07


### tol=1e-04

In [44]:
preconditioner = precond.Preconditioner(K, rank, noise)
Kinvy = cg.bcg_bbmm(K, rhs, preconditioner=preconditioner, print_process=False, tolerance=1e-04)
print(rel_error(Kinvy_linalg, Kinvy))

0.0004721970976828854


In [46]:
linear_operator.settings.cg_tolerance._set_value(1e-04)
Kinvy_torch = added_diag.solve(rhs_torch)
print(rel_error(Kinvy_linalg, Kinvy_torch.numpy()))
# linear_operator.settings.cg_tolerance._set_value(1)

0.00047221962753508294


In [50]:
N = 7000
rank=15
noise = 1e-06
n_tridiag = 10
K = generate_K(N)
y = jax.random.normal(key=jax.random.PRNGKey(0), shape=(N,1))
zs = jax.random.normal(jax.random.PRNGKey(0), (N, n_tridiag))
preconditioner = precond.Preconditioner(K, rank=rank, noise=noise)
K_torch = torch.from_numpy(np.array(K))

In [51]:
rhs = jnp.concatenate([zs, y], axis=1)
rhs_torch = torch.from_numpy(np.array(rhs))

In [52]:
K_linear_op = linear_operator.to_linear_operator(K_torch)
diag_tensor = torch.ones(N, dtype=torch.float64)*noise
diag_linear_op = DiagLinearOperator(diag_tensor)
added_diag = AddedDiagLinearOperator(K_linear_op, diag_linear_op)

In [53]:
%%timeit
preconditioner = precond.Preconditioner(K, rank, noise)
Kinvy = cg.bcg_bbmm(K, rhs, preconditioner=preconditioner, print_process=False, tolerance=1e-04)

258 ms ± 1.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [54]:
%%timeit
Kinvy_torch = added_diag.solve(rhs_torch)

1.52 s ± 4.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [55]:
%%timeit
Kinvy_linalg = jnp.linalg.solve(K, rhs)

751 ms ± 157 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [56]:
print(rel_error(Kinvy_linalg, Kinvy))
print(rel_error(Kinvy_linalg, Kinvy_torch.numpy()))

0.0004721970976828854
0.00047221962753508294


In [57]:
%%timeit
Kinvy = cg.bcg_bbmm(K, rhs, preconditioner=None, print_process=False, tolerance=1e-04)

156 ms ± 2.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### やはりpreconditionしない方がはやいのが現状

## chceking trace term

### with preconditioner

In [86]:
N = 3000
rank=15
noise = 1e-06
n_tridiag = 10
K = generate_K(N)
y = jax.random.normal(key=jax.random.PRNGKey(0), shape=(N,1))
zs = jax.random.normal(jax.random.PRNGKey(0), (N, n_tridiag))
preconditioner = precond.Preconditioner(K, rank=rank, noise=noise)
K_torch = torch.from_numpy(np.array(K))

In [87]:
rhs = jnp.concatenate([zs, y], axis=1)
rhs_torch = torch.from_numpy(np.array(rhs))

In [88]:
K_linear_op = linear_operator.to_linear_operator(K_torch)
diag_tensor = torch.ones(N, dtype=torch.float64)*noise
diag_linear_op = DiagLinearOperator(diag_tensor)
added_diag = AddedDiagLinearOperator(K_linear_op, diag_linear_op)
preconditioner_torch, _, _ = added_diag._preconditioner()

In [89]:
Kinvy_linalg = jnp.linalg.solve(K, rhs)
preconditioner = precond.Preconditioner(K, rank, noise)
Kinvy = cg.bcg_bbmm(K, rhs, preconditioner=preconditioner, print_process=False, tolerance=1)
print(rel_error(Kinvy_linalg, Kinvy))
linear_operator.settings.cg_tolerance._set_value(1)
Kinvy_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch)
print(rel_error(Kinvy_linalg, Kinvy_torch.numpy()))
linear_operator.settings.cg_tolerance._set_value(1)

0.1485656644481834
0.14856548077918746


In [90]:
%%timeit
Kinvy = cg.bcg_bbmm(K, rhs, preconditioner=None, print_process=False, tolerance=1e-04)

73.9 ms ± 1.86 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [91]:
dKdtheta = jax.random.normal(jax.random.PRNGKey(1), (N, N))+5*jnp.eye(N)

In [92]:
trace = jnp.einsum("ij, ij ->", Kinvy[:, :n_tridiag], jnp.einsum("ij, jk->ik", dKdtheta, zs))/n_tridiag

In [93]:
Kinvy_linalg = jnp.linalg.inv(K)
trace_linalg = jnp.sum(jnp.diag(jnp.matmul(Kinvy_linalg, dKdtheta)))

In [99]:
# trace_torch, _ = added_diag.inv_quad_logdet(rhs_torch, logdet=True)

In [98]:
trace, trace_linalg, #trace_torch/2

(DeviceArray(2674.62575194, dtype=float64),
 DeviceArray(2587.2629889, dtype=float64))

In [89]:
(trace_linalg-trace)/trace_linalg

DeviceArray(-0.03376648, dtype=float64)

### without preconditioner

In [100]:
Kinvy = cg.bcg_bbmm(K, rhs, preconditioner=None, print_process=False, tolerance=1)
print(rel_error(Kinvy_linalg, Kinvy))

3448.798220099444


In [101]:
dKdtheta = jax.random.normal(jax.random.PRNGKey(1), (N, N))+5*jnp.eye(N)

In [106]:
trace = jnp.einsum("ij, ij ->", Kinvy[:, :n_tridiag], jnp.einsum("ij, jk->ik", dKdtheta, zs))/n_tridiag

In [107]:
Kinvy_linalg = jnp.linalg.inv(K)
trace_linalg = jnp.sum(jnp.diag(jnp.matmul(Kinvy_linalg, dKdtheta)))

In [108]:
trace, trace_linalg

(DeviceArray(2671.43610977, dtype=float64),
 DeviceArray(2587.2629889, dtype=float64))

In [109]:
(trace_linalg-trace)/trace_linalg

DeviceArray(-0.03253365, dtype=float64)

### checking t_mat

### N = 100

In [303]:
N = 100
rank=15
noise = 1e-06
n_tridiag = 10
K = generate_K(N)
y = jax.random.normal(key=jax.random.PRNGKey(0), shape=(N,1))
zs = jax.random.normal(jax.random.PRNGKey(0), (N, n_tridiag))
preconditioner = precond.Preconditioner(K, rank=rank, noise=noise)
K_torch = torch.from_numpy(np.array(K))

In [15]:
rhs = jnp.concatenate([zs, y], axis=1)
rhs_torch = torch.from_numpy(np.array(rhs))

In [16]:
K_linear_op = linear_operator.to_linear_operator(K_torch)
diag_tensor = torch.ones(N, dtype=torch.float64)*noise
diag_linear_op = DiagLinearOperator(diag_tensor)
added_diag = AddedDiagLinearOperator(K_linear_op, diag_linear_op)
preconditioner_torch, _, _ = added_diag._preconditioner()

In [17]:
Kinvy_linalg = jnp.linalg.solve(K, rhs)
preconditioner = precond.Preconditioner(K, rank, noise)
Kinvy, t_mat = cg.mpcg_bbmm(K, rhs, preconditioner=preconditioner, print_process=False, tolerance=1, n_tridiag=n_tridiag)
print(rel_error(Kinvy_linalg, Kinvy))
Kinvy_torch, t_mat_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch,num_tridiag=n_tridiag)
# Kinvy_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch)
print(rel_error(Kinvy_linalg, Kinvy_torch.numpy()))

9.241723714046066e-16
7.371856260187996e-06


#### linear_operatorでは, solveの際にn_tridiagをありにするとなぜか精度が向上する

In [18]:
t_mat.shape, t_mat_torch.numpy().shape

((10, 20, 20), (10, 9, 9))

In [19]:
end = 9
jnp.mean(jnp.abs(t_mat[:, :end , :end]-t_mat_torch[:, :end, :end].numpy()))

DeviceArray(0.1816411, dtype=float64)

### N = 3000

In [38]:
N = 3000
rank=15
noise = 1e-06
n_tridiag = 10
K = generate_K(N)
y = jax.random.normal(key=jax.random.PRNGKey(0), shape=(N,1))
zs = jax.random.normal(jax.random.PRNGKey(0), (N, n_tridiag))
preconditioner = precond.Preconditioner(K, rank=rank, noise=noise)
K_torch = torch.from_numpy(np.array(K))

In [39]:
rhs = jnp.concatenate([zs, y], axis=1)
rhs_torch = torch.from_numpy(np.array(rhs))

In [40]:
K_linear_op = linear_operator.to_linear_operator(K_torch)
diag_tensor = torch.ones(N, dtype=torch.float64)*noise
diag_linear_op = DiagLinearOperator(diag_tensor)
added_diag = AddedDiagLinearOperator(K_linear_op, diag_linear_op)
preconditioner_torch, _, _ = added_diag._preconditioner()

In [41]:
Kinvy_linalg = jnp.linalg.solve(K, rhs)
preconditioner = precond.Preconditioner(K, rank, noise)
Kinvy, t_mat = cg.mpcg_bbmm(K, rhs, preconditioner=preconditioner, print_process=False, tolerance=1, n_tridiag=n_tridiag)
print(rel_error(Kinvy_linalg, Kinvy))
Kinvy_torch, t_mat_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch,num_tridiag=n_tridiag)
# Kinvy_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch)
print(rel_error(Kinvy_linalg, Kinvy_torch.numpy()))

7.393597288661838e-10
2.7001585109192347e-07


In [42]:
t_mat.shape, t_mat_torch.numpy().shape

((10, 20, 20), (10, 20, 20))

In [43]:
end = 7
jnp.mean(jnp.abs(t_mat[:, :end , :end]-t_mat_torch[:, :end, :end].numpy()))

DeviceArray(90717.27715001, dtype=float64)

### Tについてはエラーが大きい: 原因はおそらくzero divisionの除外やstop updating afterの実装の有無によるalpha, betaの違いであるはず一旦sum diagだけを計算するので問題ないのでは?

### checking log determinant

In [8]:
N = 100
rank=15
noise = 1e-06
n_tridiag = 10
K = generate_K(N)
y = jax.random.normal(key=jax.random.PRNGKey(0), shape=(N,1))
zs = jax.random.normal(jax.random.PRNGKey(0), (N, n_tridiag))
preconditioner = precond.Preconditioner(K, rank=rank, noise=noise)
K_torch = torch.from_numpy(np.array(K))

In [9]:
rhs = jnp.concatenate([zs, y], axis=1)
rhs_torch = torch.from_numpy(np.array(rhs))

In [10]:
K_linear_op = linear_operator.to_linear_operator(K_torch)
diag_tensor = torch.ones(N, dtype=torch.float64)*noise
diag_linear_op = DiagLinearOperator(diag_tensor)
added_diag = AddedDiagLinearOperator(K_linear_op, diag_linear_op)
preconditioner_torch, precond_lt_torch, precond_logdet_torch = added_diag._preconditioner()

In [11]:
Kinvy_linalg = jnp.linalg.solve(K, rhs)
preconditioner = precond.Preconditioner(K, rank, noise)
Kinvy, t_mat = cg.mpcg_bbmm(K, rhs, preconditioner=preconditioner, print_process=False, tolerance=1, n_tridiag=n_tridiag)
print(rel_error(Kinvy_linalg, Kinvy))
Kinvy_torch, t_mat_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch,num_tridiag=n_tridiag)
# Kinvy_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch)
print(rel_error(Kinvy_linalg, Kinvy_torch.numpy()))

9.241723714046066e-16
7.371856260187996e-06


In [12]:
inv_quad_torch, logdet_torch = added_diag.inv_quad_logdet(rhs_torch[:, -1:], logdet=True)

In [13]:
inv_quad = jnp.dot(y[:, 0], Kinvy[: ,-1])

In [14]:
inv_quad, inv_quad_torch

(DeviceArray(19.46686596, dtype=float64), tensor(19.4669, dtype=torch.float64))

### check eigvals and eigvecs

In [16]:
def lanczos_tridiag_to_diag(t_mat):
    eigvals, eigvectors = jnp.linalg.eigh(t_mat)
    mask = eigvals >= 0.
    eigvectors = eigvectors * jnp.expand_dims(mask, axis=-2)
    
    eigvals = eigvals.at[~mask].set(1.)
    return eigvals, eigvectors

In [17]:
def to_dense(matrix_shape, eigenvalues, eigenvectors, funcs):
    results = [jnp.zeros(eigenvalues.shape[1:-1], dtype=eigenvalues.dtype) for _ in funcs]
    num_random_probes = eigenvalues.shape[0]
    for j in range(num_random_probes):
        eigenvalues_for_probe = eigenvalues[j]
        eigenvectors_for_probe = eigenvectors[j]
        for i, func in enumerate(funcs):
            eigenvecs_first_component = eigenvectors_for_probe[..., 0, :]
            func_eigenvalues = func(eigenvalues_for_probe)

            dot_products = (eigenvecs_first_component ** 2 * func_eigenvalues).sum(-1)
            results[i] = results[i] + matrix_shape[-1] / float(num_random_probes) * dot_products

    return results

In [18]:
def calc_logdet(matrix_shape, t_mat, preconditioner):
    eigvals, eigvectors = lanczos_tridiag_to_diag(t_mat)
    (pinvk_logdet,) =  to_dense(K.shape, eigvals, eigvectors, [lambda x:jnp.log(x)])
    
    try:
        logdet_p = preconditioner._precond_logdet_cache
    except:
        logdet_p = 0.
    
    logdet = pinvk_logdet + logdet_p
    
    return logdet

In [19]:
calc_logdet(K.shape, t_mat, preconditioner)

DeviceArray(177.15528805, dtype=float64)

In [20]:
eval_torch, evec_torch = linear_operator.utils.lanczos.lanczos_tridiag_to_diag(t_mat_torch)

In [21]:
slq = linear_operator.utils.stochastic_lq.StochasticLQ()
(logdet_term,) = slq.to_dense(added_diag.matrix_shape, eval_torch, evec_torch,  [lambda x: x.log()])

In [22]:
logdet_term

tensor(177.5486, dtype=torch.float64)

In [23]:
inv_quad_torch, logdet_torch = added_diag.inv_quad_logdet(rhs_torch[:, -1:], logdet=True)

In [24]:
logdet_torch

tensor(178.0780, dtype=torch.float64)

In [315]:
det_ans = jnp.sum(jnp.log(jnp.diag(K)))
det_ans

DeviceArray(179.3679413, dtype=float64)

In [25]:
from calc_logdet import calc_logdet

#### N = 3000

In [26]:
N = 3000
rank=15
noise = 1e-06
n_tridiag = 10
K = generate_K(N)
y = jax.random.normal(key=jax.random.PRNGKey(0), shape=(N,1))
zs = jax.random.normal(jax.random.PRNGKey(0), (N, n_tridiag))
K_torch = torch.from_numpy(np.array(K))

In [27]:
rhs = jnp.concatenate([zs, y], axis=1)
rhs_torch = torch.from_numpy(np.array(rhs))

In [28]:
K_linear_op = linear_operator.to_linear_operator(K_torch)
diag_tensor = torch.ones(N, dtype=torch.float64)*noise
diag_linear_op = DiagLinearOperator(diag_tensor)
added_diag = AddedDiagLinearOperator(K_linear_op, diag_linear_op)
preconditioner_torch, precond_lt_torch, precond_logdet_torch = added_diag._preconditioner()

In [29]:
Kinvy_linalg = jnp.linalg.solve(K, rhs)
preconditioner = precond.Preconditioner(K, rank, noise)
Kinvy, t_mat = cg.mpcg_bbmm(K, rhs, preconditioner=preconditioner, print_process=False, tolerance=1, n_tridiag=n_tridiag)
print(rel_error(Kinvy_linalg, Kinvy))
Kinvy_torch, t_mat_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch,num_tridiag=n_tridiag)
# Kinvy_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch)
print(rel_error(Kinvy_linalg, Kinvy_torch.numpy()))

7.393597288661838e-10
2.7001585109192347e-07


In [30]:
inv_quad_torch, logdet_torch = added_diag.inv_quad_logdet(rhs_torch[:, -1:], logdet=True)

In [31]:
inv_quad_torch, logdet_torch

(tensor(529.2999, dtype=torch.float64), tensor(5325.6091, dtype=torch.float64))

In [32]:
jnp.dot(y[:, 0], Kinvy[: ,-1]), calc_logdet(K.shape, t_mat, preconditioner)

(DeviceArray(529.30002498, dtype=float64),
 DeviceArray(5557.76785537, dtype=float64))

#### N = 5000

In [326]:
N = 5000
rank=15
noise = 1e-06
n_tridiag = 10
K = generate_K(N)
y = jax.random.normal(key=jax.random.PRNGKey(0), shape=(N,1))
zs = jax.random.normal(jax.random.PRNGKey(0), (N, n_tridiag))
K_torch = torch.from_numpy(np.array(K))

In [327]:
rhs = jnp.concatenate([zs, y], axis=1)
rhs_torch = torch.from_numpy(np.array(rhs))

In [328]:
K_linear_op = linear_operator.to_linear_operator(K_torch)
diag_tensor = torch.ones(N, dtype=torch.float64)*noise
diag_linear_op = DiagLinearOperator(diag_tensor)
added_diag = AddedDiagLinearOperator(K_linear_op, diag_linear_op)
preconditioner_torch, precond_lt_torch, precond_logdet_torch = added_diag._preconditioner()

In [329]:
Kinvy_linalg = jnp.linalg.solve(K, rhs)
preconditioner = precond.Preconditioner(K, rank, noise)
Kinvy, t_mat = cg.mpcg_bbmm(K, rhs, preconditioner=preconditioner, print_process=False, tolerance=1, n_tridiag=n_tridiag)
print(rel_error(Kinvy_linalg, Kinvy))
Kinvy_torch, t_mat_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch,num_tridiag=n_tridiag)
# Kinvy_torch = added_diag._solve(rhs_torch, preconditioner=preconditioner_torch)
print(rel_error(Kinvy_linalg, Kinvy_torch.numpy()))

1.0524374915648925e-09
3.5123671374860206e-07


In [330]:
inv_quad_torch, logdet_torch = added_diag.inv_quad_logdet(rhs_torch[:, -1:], logdet=True)

In [331]:
inv_quad_torch, logdet_torch

(tensor(869.8112, dtype=torch.float64), tensor(8904.8094, dtype=torch.float64))

In [332]:
jnp.dot(y[:, 0], Kinvy[: ,-1]), calc_logdet(K.shape, t_mat, preconditioner)

(DeviceArray(869.81138294, dtype=float64),
 DeviceArray(9108.50948291, dtype=float64))

In [333]:
preconditioner._precond_logdet_cache, precond_logdet_torch

(DeviceArray(-68842.86687952, dtype=float64),
 tensor(-68842.8669, dtype=torch.float64))

#### t_mat由来の誤差多少はあるが、最適化に使うだけだから大した問題ではない