In [1]:
import numpy as np
import matplotlib.pyplot as plt

from scipy.sparse.linalg import LinearOperator, aslinearoperator

from tracelogdetdiag.util import relative_resigual_cg
from tracelogdetdiag.trace import hutch_plus_plus_trace, hutchinson_trace
from tracelogdetdiag.diag import naive_diag

In [22]:
# Generate a random matrix A
n = 100
A = np.random.normal(size=(n,n))
A = A.T @ A + 100
A_explicit = A.copy()
A = aslinearoperator(A)

In [23]:
# A subclass must implement either one of the methods ``_matvec``
#  |  and ``_matmat``, and the attributes/properties ``shape`` (pair of
#  |  integers) and ``dtype`` (may be None). It may call the ``__init__``
#  |  on this class to have these attributes validated. Implementing
#  |  ``_matvec`` automatically implements ``_matmat`` (using a naive
#  |  algorithm) and vice-versa.

In [24]:
class AinvCGLinearOperator(LinearOperator):
    """Subclass of LinearOperator that represents A^{-1}, where A^{-1} x is computed approximately
      by the conjugate gradient method."""

    def __init__(self, A, cg_tol=1e-4, cg_maxits=1000, use_prev=True):

        # Bind
        self.A = A
        self.cg_tol = cg_tol
        self.cg_maxits = cg_maxits
        self.x0 = None
        self.use_prev = use_prev
        self.shape = self.A.shape
        self.dtype = self.A.dtype

        # Super
        super().__init__(self.dtype, self.shape)

    def _matvec(self, x):
        # Compute approximate sol
        approx_sol = relative_resigual_cg(self.A, x, eps=self.cg_tol, maxits=self.cg_maxits, x0=self.x0)
        approx_sol = approx_sol["x"]
        if self.use_prev: self.x0 = approx_sol

        return approx_sol
    
    def _matmat(self, B):
        output_shape = (self.shape[0], B.shape[1])
        result = np.zeros(output_shape)
        for j in range(B.shape[1]):
            result[:,j] = self._matvec(B[:,j])
        return result
    
    def _rmatmat(self, B):
        return self._matmat(B)
    
    def _rmatvec(self, x):
        return self._matvec(x)

In [25]:
Ainv = AinvCGLinearOperator(A, cg_maxits=int(1e5), cg_tol=1e-5)

In [26]:
z = np.random.normal(size=(n, 2))

In [27]:
Ainv._matmat(z)

array([[  5.00654987,   3.5196535 ],
       [  2.0522352 ,   1.89491736],
       [ 11.11740659,   8.12441164],
       [ -0.85904955,  -1.05563439],
       [  6.39073162,   4.86286621],
       [-16.54282322, -11.42945443],
       [  9.63412902,   6.46904691],
       [  6.2021483 ,   4.71579063],
       [  2.61208674,   2.24431741],
       [ -8.7792274 ,  -6.17320712],
       [  0.02844451,  -0.30424996],
       [  1.2199393 ,   0.87467462],
       [ -6.6607148 ,  -5.20813699],
       [ -8.1715952 ,  -5.588308  ],
       [  2.14491599,   0.16591262],
       [  0.44188422,   0.99757897],
       [-12.24342563,  -8.67187103],
       [ 14.52992251,  11.04346779],
       [  2.56797397,   2.5372542 ],
       [ -4.27752809,  -2.66443093],
       [ -5.30619393,  -4.03941634],
       [ -5.23262332,  -3.31039856],
       [  8.16166134,   5.63561011],
       [ 17.0618202 ,  11.60509464],
       [ -3.8108868 ,  -2.93804126],
       [  9.6899563 ,   7.60954002],
       [  4.63994835,   2.90479273],
 

In [28]:
Ainv @ z

array([[  5.00654994,   3.51965346],
       [  2.05223531,   1.89491742],
       [ 11.11740663,   8.12441167],
       [ -0.85904956,  -1.05563437],
       [  6.39073163,   4.86286622],
       [-16.54282322, -11.42945441],
       [  9.63412901,   6.46904687],
       [  6.20214821,   4.71579061],
       [  2.61208668,   2.24431734],
       [ -8.77922745,  -6.1732071 ],
       [  0.02844453,  -0.30424997],
       [  1.21993938,   0.87467464],
       [ -6.66071479,  -5.20813702],
       [ -8.1715952 ,  -5.58830797],
       [  2.14491608,   0.16591262],
       [  0.44188419,   0.99757899],
       [-12.24342561,  -8.67187105],
       [ 14.52992251,  11.04346776],
       [  2.56797401,   2.53725421],
       [ -4.27752811,  -2.66443091],
       [ -5.30619402,  -4.03941636],
       [ -5.23262325,  -3.31039849],
       [  8.16166141,   5.63561008],
       [ 17.06182032,  11.60509462],
       [ -3.81088677,  -2.93804125],
       [  9.68995624,   7.60954004],
       [  4.63994834,   2.90479275],
 

In [29]:
np.linalg.solve(A_explicit, z)

array([[  5.00654987,   3.51965348],
       [  2.05223529,   1.89491738],
       [ 11.11740656,   8.12441163],
       [ -0.85904955,  -1.05563441],
       [  6.39073161,   4.86286622],
       [-16.54282323, -11.42945443],
       [  9.634129  ,   6.46904689],
       [  6.20214826,   4.71579065],
       [  2.6120867 ,   2.2443174 ],
       [ -8.77922745,  -6.17320713],
       [  0.02844452,  -0.30424995],
       [  1.21993936,   0.87467461],
       [ -6.66071477,  -5.208137  ],
       [ -8.17159517,  -5.58830799],
       [  2.14491606,   0.16591266],
       [  0.44188424,   0.99757898],
       [-12.24342564,  -8.67187103],
       [ 14.52992253,  11.04346779],
       [  2.56797401,   2.53725419],
       [ -4.27752811,  -2.66443093],
       [ -5.30619401,  -4.03941635],
       [ -5.23262324,  -3.31039856],
       [  8.16166137,   5.63561009],
       [ 17.06182028,  11.60509463],
       [ -3.81088677,  -2.93804125],
       [  9.68995625,   7.60954003],
       [  4.63994839,   2.90479275],
 

In [30]:
K = Ainv.T

In [31]:
K.shape

(100, 100)

In [32]:
K @ z

array([[  5.00654991,   3.51965356],
       [  2.05223527,   1.89491734],
       [ 11.11740657,   8.12441167],
       [ -0.85904953,  -1.05563438],
       [  6.39073157,   4.86286623],
       [-16.54282322, -11.42945443],
       [  9.634129  ,   6.46904687],
       [  6.20214827,   4.71579063],
       [  2.61208671,   2.24431745],
       [ -8.77922741,  -6.1732071 ],
       [  0.02844453,  -0.30424996],
       [  1.21993936,   0.87467462],
       [ -6.66071477,  -5.20813702],
       [ -8.17159522,  -5.58830807],
       [  2.14491603,   0.1659126 ],
       [  0.44188422,   0.99757895],
       [-12.24342564,  -8.67187106],
       [ 14.52992251,  11.04346777],
       [  2.56797397,   2.53725417],
       [ -4.27752809,  -2.66443092],
       [ -5.30619399,  -4.03941631],
       [ -5.23262331,  -3.31039862],
       [  8.1616614 ,   5.63561015],
       [ 17.06182029,  11.60509464],
       [ -3.81088677,  -2.93804128],
       [  9.68995626,   7.60953999],
       [  4.63994838,   2.90479273],
 

# traceinv?

In [33]:
Ainv.shape

(100, 100)

In [34]:
np.trace(np.linalg.inv(A_explicit))

130.47601361816993

In [35]:
hutchinson_trace(Ainv, sample_size=1000)

120.83562016378814

In [36]:
hutch_plus_plus_trace(Ainv, sample_size=33)

130.10257382676272

# diaginv?

In [51]:
naive_diag(A, sample_size=100000)

array([191.76929834, 187.59792328, 203.40805795, 197.00074114,
       213.2690058 , 200.25636248, 199.37431374, 193.7271652 ,
       211.05345459, 206.1704364 , 224.67356122, 195.33512526,
       196.17527573, 177.64828964, 213.38583728, 189.39970346,
       220.38023354, 212.42302568, 209.29267297, 199.13339732,
       206.73958158, 207.4320074 , 179.93816061, 209.07943047,
       207.34366067, 187.53167946, 194.02155104, 202.24079913,
       178.73716888, 196.55668964, 186.40731044, 195.17724509,
       245.64047079, 221.4658032 , 208.45199649, 232.48811807,
       203.64149373, 197.49558195, 203.94934502, 211.40595437,
       185.37262864, 226.81262913, 183.74640063, 226.90582952,
       220.24482336, 241.47397647, 225.13903788, 184.86451856,
       203.77855654, 214.51862042, 183.18964961, 174.9968005 ,
       213.00321229, 196.05276394, 207.4556256 , 207.55920648,
       199.28836826, 192.78429084, 163.72072828, 196.18561668,
       190.06702883, 202.37062881, 222.17753629, 190.18

In [52]:
np.diag(A_explicit)

array([190.51818264, 185.08384849, 201.18565956, 192.69523137,
       214.85986746, 197.48049196, 199.44631854, 194.45869579,
       208.19799993, 207.24107003, 216.43745312, 195.22247421,
       196.9151237 , 178.56136556, 214.7757187 , 189.31585524,
       217.6348893 , 214.10934514, 206.54398299, 198.52392599,
       202.72822948, 209.60769479, 179.11470379, 205.39296374,
       208.4074175 , 186.42723215, 194.00427637, 192.09991704,
       176.88007438, 196.13207149, 193.21508866, 200.35737927,
       240.55216558, 223.44800822, 205.52170149, 234.92444719,
       204.60458477, 200.93059355, 207.7872262 , 208.06289156,
       184.02087247, 226.59746755, 183.69703391, 224.81548774,
       226.15344211, 238.91946833, 226.48832354, 187.07868115,
       204.78215115, 210.5970958 , 183.86655235, 173.84495214,
       209.93557007, 192.17585709, 209.30658514, 210.0349478 ,
       201.11161361, 189.70768577, 163.14743822, 195.69014474,
       193.50817522, 208.08369487, 221.21639719, 190.86

In [53]:
naive_diag(A, sample_size=10000)

array([195.17529869, 189.86201058, 196.12164714, 204.91299754,
       222.13144677, 184.91067222, 193.49113443, 180.87318086,
       206.54491225, 214.85555576, 235.49618049, 197.35126063,
       195.89523037, 185.53177254, 215.71517518, 195.05338441,
       209.13722047, 223.80667146, 192.11394028, 201.54296837,
       195.99938587, 207.28957812, 196.99115543, 200.61684814,
       226.59555868, 183.24644798, 203.00262973, 195.25075867,
       184.94154529, 203.81034627, 182.96051874, 201.34294283,
       237.47528068, 217.01646796, 198.41686983, 232.701571  ,
       193.29963228, 208.3445892 , 196.09867387, 206.46993579,
       181.52687635, 205.02682011, 193.28615007, 238.88424015,
       230.4949124 , 249.22889862, 237.37542297, 188.32165777,
       215.41953996, 225.90108963, 190.77358536, 163.59879508,
       221.88566195, 182.74673493, 209.98810223, 225.84302364,
       207.17200037, 189.46127199, 173.99155101, 201.39883744,
       191.82134274, 221.35963478, 220.59188927, 212.21

In [57]:
naive_diag(Ainv, sample_size=200)

array([-1.24605406e-01,  3.18102561e-01,  3.16144599e+00,  4.68387818e-02,
        1.19533436e+00,  2.64035605e+00,  4.47585964e-01,  1.78908056e+00,
        1.16066268e-01,  1.48792152e+00, -2.22500597e-02,  3.88064559e-01,
        7.42454942e-01,  6.63275591e-01,  9.81183489e-01,  6.25738323e-01,
        2.11662665e+00,  3.87752269e+00,  6.26456501e-03,  7.67812599e-01,
        7.22387898e-01,  3.90712634e-01,  1.00289397e+00,  1.59461564e+00,
        3.35061394e-02,  1.21360654e+00,  8.24979026e-01,  5.81816935e+00,
        3.37211492e+00,  5.40553602e+00,  1.32955455e-01,  2.22984431e+00,
       -3.28937006e-01,  6.24643341e-01,  1.28725812e+00, -5.51511075e-01,
       -4.82982401e-01,  2.80787865e+00,  1.70132306e-02,  4.18975019e+00,
        1.44047736e-01,  1.58959662e+00,  3.30359604e-01,  1.11254495e+00,
        1.33029316e-01, -7.50801903e-01, -2.61710265e-03,  4.66468007e+00,
        6.05504657e-01,  8.49579588e-01, -8.48724592e-01,  2.12413599e+00,
        2.01220550e+00,  

In [58]:
np.diag(np.linalg.inv(A_explicit))

array([0.31095738, 0.34085586, 1.37961852, 0.17596005, 0.74799692,
       2.18617789, 0.73978012, 0.74454212, 0.32963367, 1.36294031,
       0.10707933, 0.38854342, 0.86007121, 0.62376148, 0.79932273,
       0.4617486 , 1.43635013, 2.77185226, 0.26649243, 0.67546388,
       0.53005704, 0.2987953 , 0.64802319, 2.40234834, 0.25958594,
       1.76460028, 0.23807739, 2.85787275, 2.47569378, 2.98212726,
       0.24056697, 2.27601113, 0.35406681, 0.52709104, 1.16277711,
       1.47807617, 0.50368088, 4.86088061, 0.135581  , 2.88649466,
       0.18899542, 1.04121171, 0.6660716 , 0.43559954, 0.10192087,
       0.98244519, 0.17219872, 6.27197772, 0.44780189, 0.65832965,
       3.72301292, 2.41254116, 1.23902784, 0.29880534, 1.99150474,
       1.85991051, 0.19277244, 0.18348453, 2.443609  , 1.90489182,
       0.76328967, 0.65390343, 0.30345657, 0.53742429, 1.75871315,
       2.06870805, 2.4641839 , 0.23161719, 0.28063477, 4.52579496,
       1.19166173, 0.32919679, 6.27908474, 3.344783  , 0.59964