In [1]:
import numpy as np
from scipy.sparse.linalg import aslinearoperator
from scipy.stats import ortho_group

import tracelogdetdiag as tld
from tracelogdetdiag.logdet import logdet_via_cholesky
from tracelogdetdiag.diag import naive_diag, explicit_diag_probe
from tracelogdetdiag.diaginv import naive_diaginv, explicit_diaginv_probe
from tracelogdetdiag.trace import hutch_plus_plus_trace, hutch_plus_plus_epsilon_delta_trace, hutchinson_trace, hutchinson_epsilon_delta_trace
from tracelogdetdiag.logdet import logdet_stochastic_chebyshev_approx, logdet_stochastic_chebyshev_epsilon_delta_approx

# NumPy

## Test matrix

In [2]:
d = 100

np.random.seed(0)
basis = ortho_group.rvs(d)
eigvals = np.random.uniform(low=5.0, high=10.0, size=d)
B = basis @ np.diag(eigvals) @ basis.T
true_logdet = logdet_via_cholesky(B)
true_diag = np.diag(B)
true_diaginv = np.diag(np.linalg.inv(B))
true_trace = np.trace(B)
B_explicit = B.copy()
B = aslinearoperator(B)


## Diagonal

In [3]:
explicit_diag_probe(B) - true_diag

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [4]:
naive_diag(B) - true_diag

array([-3.43559959e-02,  1.84873093e-02, -6.48401868e-03, -1.27983367e-02,
       -6.22031242e-03,  1.44348759e-02, -2.22599250e-02,  6.94515918e-02,
       -4.86491765e-02,  8.96828666e-02, -8.41206338e-02, -1.50461390e-02,
       -1.12136369e-04, -4.19203590e-02,  4.06737013e-02, -5.12897171e-03,
       -7.30109428e-02, -2.54048535e-02,  3.55007192e-03,  7.40461146e-04,
       -2.17098736e-02,  6.93133036e-03,  1.70748628e-02, -6.63547379e-02,
       -7.62758077e-02, -9.64010875e-03,  9.63786176e-02, -4.23267073e-02,
        4.18440077e-02,  7.67070752e-02,  3.50869261e-02, -4.15474795e-02,
        9.34188869e-02,  6.25277077e-02, -1.73917825e-02,  4.51162514e-02,
       -1.72480306e-02,  1.16287557e-02, -4.35860558e-02, -4.61223385e-03,
       -5.58011430e-02,  8.81634955e-02,  4.01418092e-02, -4.02491385e-02,
       -1.83808315e-02, -8.78012618e-02, -1.39030420e-02, -1.52303659e-02,
       -9.63817221e-04, -1.21296345e-02, -2.38118520e-02,  2.56552576e-02,
        6.51874222e-02,  

## Diagonal inverse

In [5]:
explicit_diaginv_probe(B_explicit) - true_diaginv

array([ 2.77555756e-17,  2.77555756e-17,  0.00000000e+00, -2.77555756e-17,
        0.00000000e+00,  2.77555756e-17,  2.77555756e-17,  0.00000000e+00,
       -2.77555756e-17, -8.32667268e-17,  2.77555756e-17,  2.77555756e-17,
        2.77555756e-17, -2.77555756e-17, -5.55111512e-17,  5.55111512e-17,
       -5.55111512e-17,  0.00000000e+00, -5.55111512e-17,  0.00000000e+00,
        0.00000000e+00,  2.77555756e-17,  2.77555756e-17,  0.00000000e+00,
        0.00000000e+00, -2.77555756e-17,  2.77555756e-17, -2.77555756e-17,
        2.77555756e-17,  2.77555756e-17,  5.55111512e-17,  0.00000000e+00,
        0.00000000e+00,  2.77555756e-17, -5.55111512e-17,  2.77555756e-17,
       -2.77555756e-17,  0.00000000e+00,  2.77555756e-17,  0.00000000e+00,
        2.77555756e-17, -2.77555756e-17,  2.77555756e-17,  0.00000000e+00,
        2.77555756e-17, -2.77555756e-17, -2.77555756e-17,  0.00000000e+00,
        2.77555756e-17,  5.55111512e-17,  0.00000000e+00, -2.77555756e-17,
        5.55111512e-17, -

In [6]:
naive_diaginv(B_explicit, sample_size=20) - true_diaginv

array([-0.00471889, -0.00861564, -0.0057984 ,  0.00109973, -0.01342509,
        0.00099314, -0.00453394, -0.00364562, -0.00074555, -0.00224129,
       -0.00698532, -0.0071619 ,  0.00203268, -0.00277637, -0.00529265,
        0.00068859,  0.00515377, -0.00138805, -0.00251945, -0.00493543,
       -0.00104649, -0.00608721,  0.00171939, -0.00164274, -0.00063705,
        0.00915407, -0.00418932,  0.0028758 ,  0.00459588, -0.00393199,
       -0.00841797, -0.00528054, -0.00549151,  0.0047718 ,  0.00274288,
        0.01125847, -0.00879622,  0.00085156,  0.00138484, -0.00566558,
       -0.00267387,  0.0001833 , -0.00442397,  0.00095498,  0.00032642,
        0.00105554, -0.00321066, -0.00107542,  0.00541068,  0.00808827,
       -0.00335957,  0.00252727,  0.00359217,  0.00599683,  0.00383187,
       -0.00442563,  0.0058817 , -0.00102439, -0.01087727, -0.00189004,
        0.00523295,  0.00707617,  0.00578611, -0.00360523, -0.00646649,
        0.00628827,  0.00546547, -0.00671276, -0.00449197, -0.00

## Trace

In [7]:
true_trace

745.0690170504758

In [8]:
hutch_plus_plus_epsilon_delta_trace(B, epsilon=0.01, delta=0.01)

744.8991878941576

In [9]:
hutch_plus_plus_trace(B, sample_size=3*1000)

745.0690170504758

In [10]:
hutchinson_epsilon_delta_trace(B, epsilon=0.1, delta=0.1)

745.0512502213192

In [11]:
hutchinson_trace(B, sample_size=2000)

744.5254308243823

## Logdet

In [12]:
true_logdet

199.06687370895025

In [13]:
logdet_stochastic_chebyshev_approx(B, sample_size=300)

198.9909349774243

In [14]:
logdet_stochastic_chebyshev_epsilon_delta_approx(B)

199.03614961611908

# CuPy

## Test matrix

In [16]:
import cupy as cp
from cupyx.scipy.sparse.linalg import LinearOperator as CuPyLinearOperator
from cupyx.scipy.sparse.linalg import aslinearoperator as cupyaslinearoperator

In [18]:
d = 1000

cp.random.seed(0)
basis = ortho_group.rvs(d)
eigvals = np.random.uniform(low=5.0, high=10.0, size=d)
B_numpy = basis @ np.diag(eigvals) @ basis.T
basis = cp.asarray(basis)
B = basis @ cp.diag(cp.asarray(eigvals)) @ basis.T
true_logdet = logdet_via_cholesky(B_numpy)
true_diag = np.diag(B_numpy)
true_diaginv = np.diag(np.linalg.inv(B_numpy))
true_trace = np.trace(B_numpy)
B_explicit = B_numpy.copy()
B = cupyaslinearoperator(B)

## Diagonal

In [20]:
cp.linalg.norm(explicit_diag_probe(B) - cp.asarray(true_diag))

array(6.24383744e-14)

In [21]:
cp.linalg.norm(naive_diag(B) - cp.asarray(true_diag))/len(true_diag)

array(0.00144637)

## Diagonal inverse

## Trace

## Logdet