In [170]:
# in my own setting, this is pre-imported
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [10, 5]
import scipy.sparse as sp
np.set_printoptions(linewidth=200, precision=5)
n = 20

In [171]:
def smw_solver(a, b, d, V):
    W = 1.0/V
    Wd = W*d
    Wa = W*a
    # print("Wd = ", W*d)
    # print("a^T W d = ", a @ Wd)
    # print("1 + a^T W a = ", a @ Wa + 1)
    ratio = ((b - a @ Wd)/(a @ Wa + 1))
    # print(ratio)
    return W*d + ratio * Wa

In [172]:
def naive_solver(a, b, d, V):
    la = np.reshape(a, (-1, 1))
    M = la @ la.T + np.diag(V)
    return np.linalg.solve(M, b*a + d)

In [173]:
a = (np.random.rand(n) - 0.5)* 10 * np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0 ,0])
d = (np.random.rand(n) - 0.5)* 10
V = V = np.ones(n)*3.0
V[6:15] += 2
b = 0.32
print(naive_solver(a, b, d, V))
print(smw_solver(a, b, d, V))

[ 1.59284  1.10331 -1.15913  1.00418  0.6663  -0.56612 -0.66303 -0.48176  0.5111   0.30298  0.28735 -0.54992 -0.9861   0.74486 -0.32044 -0.01337  0.11784  0.13905 -1.3535   0.95671]
[ 1.59284  1.10331 -1.15913  1.00418  0.6663  -0.56612 -0.66303 -0.48176  0.5111   0.30298  0.28735 -0.54992 -0.9861   0.74486 -0.32044 -0.01337  0.11784  0.13905 -1.3535   0.95671]


In [174]:
W = 1/V
print(W)

[0.33333 0.33333 0.33333 0.33333 0.33333 0.33333 0.2     0.2     0.2     0.2     0.2     0.2     0.2     0.2     0.2     0.33333 0.33333 0.33333 0.33333 0.33333]


In [175]:
print (a @ (W*d))

7.675265422525664


In [176]:
def sp_smw_solver(a, b, d, V):
    """here a is a 1-row csr_matrix..."""
    idx = a.indices
    ad = a.data
    W = 1/V
    Wd = W * d
    Wa = W[idx] * ad

    # print("a.T Wd = ", ad @ Wd[idx])
    # print("1 + a.T W a = ", 1 + ad @ Wa)
    L = ((b - ad @ Wd[idx])/(1 + ad @ Wa))* Wa
    
    Wd[idx] += L
    return Wd

In [177]:
spa = sp.csr_matrix(a)
print(spa.indices)

[ 4  5  6  7  8  9 14 15 16]


In [178]:
np.abs(sp_smw_solver(spa, b, d, V) - naive_solver(a, b, d, V)).sum()

1.412064909445121e-15

In [179]:
print(spa.todense())

[[ 0.       0.       0.       0.      -4.86725  2.4542  -4.81585  1.82589  2.80994  4.32394  0.       0.       0.       0.      -1.18385  0.3735  -1.30271  0.       0.       0.     ]]


In [180]:
print("residual error between the two solvers = ", np.abs(smw_solver(a, b, d, V) - sp_smw_solver(spa, b, d, V)).sum())

residual error between the two solvers =  0.0
