In [1]:
import numpy as np
from scipy.sparse.linalg import aslinearoperator
from scipy.stats import ortho_group

import tracelogdetdiag as tld
from tracelogdetdiag.logdet import logdet_via_cholesky
from tracelogdetdiag.diag import naive_diag, explicit_diag_probe
from tracelogdetdiag.diaginv import naive_diaginv, explicit_diaginv_probe
from tracelogdetdiag.trace import hutch_plus_plus_trace, hutch_plus_plus_epsilon_delta_trace, hutchinson_trace, hutchinson_epsilon_delta_trace
from tracelogdetdiag.logdet import logdet_stochastic_chebyshev_approx, logdet_stochastic_chebyshev_epsilon_delta_approx

# Test matrix

In [2]:
d = 100

np.random.seed(0)
basis = ortho_group.rvs(d)
eigvals = np.random.uniform(low=5.0, high=10.0, size=d)
B = basis @ np.diag(eigvals) @ basis.T
true_logdet = logdet_via_cholesky(B)
true_diag = np.diag(B)
true_diaginv = np.diag(np.linalg.inv(B))
true_trace = np.trace(B)
B_explicit = B.copy()
B = aslinearoperator(B)


# Diagonal

In [3]:
explicit_diag_probe(B) - true_diag

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [4]:
naive_diag(B) - true_diag

array([-3.43559959e-02,  1.84873093e-02, -6.48401868e-03, -1.27983367e-02,
       -6.22031242e-03,  1.44348759e-02, -2.22599250e-02,  6.94515918e-02,
       -4.86491765e-02,  8.96828666e-02, -8.41206338e-02, -1.50461390e-02,
       -1.12136369e-04, -4.19203590e-02,  4.06737013e-02, -5.12897171e-03,
       -7.30109428e-02, -2.54048535e-02,  3.55007192e-03,  7.40461146e-04,
       -2.17098736e-02,  6.93133036e-03,  1.70748628e-02, -6.63547379e-02,
       -7.62758077e-02, -9.64010875e-03,  9.63786176e-02, -4.23267073e-02,
        4.18440077e-02,  7.67070752e-02,  3.50869261e-02, -4.15474795e-02,
        9.34188869e-02,  6.25277077e-02, -1.73917825e-02,  4.51162514e-02,
       -1.72480306e-02,  1.16287557e-02, -4.35860558e-02, -4.61223385e-03,
       -5.58011430e-02,  8.81634955e-02,  4.01418092e-02, -4.02491385e-02,
       -1.83808315e-02, -8.78012618e-02, -1.39030420e-02, -1.52303659e-02,
       -9.63817221e-04, -1.21296345e-02, -2.38118520e-02,  2.56552576e-02,
        6.51874222e-02,  

# Diagonal inverse

In [10]:
explicit_diaginv_probe(B_explicit) - true_diaginv

array([ 2.77555756e-17,  2.77555756e-17,  0.00000000e+00, -2.77555756e-17,
        0.00000000e+00,  2.77555756e-17,  2.77555756e-17,  0.00000000e+00,
       -2.77555756e-17, -8.32667268e-17,  2.77555756e-17,  2.77555756e-17,
        2.77555756e-17, -2.77555756e-17, -5.55111512e-17,  5.55111512e-17,
       -5.55111512e-17,  0.00000000e+00, -5.55111512e-17,  0.00000000e+00,
        0.00000000e+00,  2.77555756e-17,  2.77555756e-17,  0.00000000e+00,
        0.00000000e+00, -2.77555756e-17,  2.77555756e-17, -2.77555756e-17,
        2.77555756e-17,  2.77555756e-17,  5.55111512e-17,  0.00000000e+00,
        0.00000000e+00,  2.77555756e-17, -5.55111512e-17,  2.77555756e-17,
       -2.77555756e-17,  0.00000000e+00,  2.77555756e-17,  0.00000000e+00,
        2.77555756e-17, -2.77555756e-17,  2.77555756e-17,  0.00000000e+00,
        2.77555756e-17, -2.77555756e-17, -2.77555756e-17,  0.00000000e+00,
        2.77555756e-17,  5.55111512e-17,  0.00000000e+00, -2.77555756e-17,
        5.55111512e-17, -

In [13]:
naive_diaginv(B_explicit, sample_size=20) - true_diaginv

array([ 8.38903405e-04, -1.96477648e-03,  4.02769432e-03,  2.88891955e-03,
        2.60294094e-03, -7.40352602e-03, -3.24692121e-03,  7.02143646e-03,
       -4.65556680e-03, -1.35460091e-02,  1.19362659e-02,  7.50664401e-03,
        4.00786936e-03,  1.51041183e-03, -1.59646004e-04, -4.30781157e-03,
        3.25219359e-03,  8.78754583e-04,  4.53903806e-03,  2.59469757e-04,
       -3.11321693e-03, -4.18122761e-03,  8.33155342e-03,  2.36586600e-03,
       -3.66821063e-03,  9.67534817e-03, -3.67471092e-03,  8.08277898e-04,
       -8.42977056e-03, -4.33746408e-03,  1.07051016e-02, -4.01899365e-03,
        1.16301819e-03, -2.25211227e-03,  4.12173552e-03, -3.77952174e-03,
        1.42983500e-02,  2.32780508e-03,  7.70397701e-03,  2.64755260e-03,
        1.76146196e-03, -2.83577165e-03,  5.74160991e-03, -2.69964269e-03,
       -2.18393134e-03,  3.27112040e-04,  1.29542316e-02, -1.74183212e-03,
        2.91546886e-03,  6.90765535e-03, -5.13808009e-03,  5.31180059e-03,
        3.85096797e-03, -

# Trace

In [5]:
true_trace

745.0690170504758

In [8]:
hutch_plus_plus_epsilon_delta_trace(B, epsilon=0.01, delta=0.01)

745.4715057624711

In [6]:
hutch_plus_plus_trace(B, sample_size=3*1000)

745.0690170504761

In [3]:
hutchinson_epsilon_delta_trace(B, epsilon=0.1, delta=0.1)

745.0892468291919

In [4]:
hutchinson_trace(B, sample_size=2000)

745.7563518673404

# Logdet

In [4]:
true_logdet

199.06687370895025

In [3]:
logdet_stochastic_chebyshev_approx(B, sample_size=300)

199.10091112942428

In [6]:
logdet_stochastic_chebyshev_epsilon_delta_approx(B)

KeyboardInterrupt: 