解决$(A+BC)x=b$的问题

$$x=A^{-1}b-A^{-1}B(I+CA^{-1}B)^{-1}CA^{-1}b$$

1.$z=A^{-1}b,Z=A^{-1}B.$
$$x=z-Z(I+CZ)^{-1}Cz$$
2.$d=Cz,D=CZ.$
$$x=z-Z(I+D)^{-1}d$$
3.$y=(I+D)^{-1}d$
$$x=z-Zy$$

In [196]:
import numpy as np
import scipy.sparse as sparse

In [252]:
def solver(A,B,C,b,format='csc'):
    from scipy.sparse.linalg import spsolve
    from scipy.sparse import eye,hstack,csc_matrix
    
    n=A.shape[0]
    k=B.shape[1]
    
    assert A.shape==(n,n)
    assert B.shape==(n,k)
    assert C.shape==(k,n)
    assert b.shape[0]==n
    
    A=A if type(A) is csc_matrix else csc_matrix(A)
    B=B if type(B) is csc_matrix else csc_matrix(B)
    C=C if type(C) is csc_matrix else csc_matrix(C)
    
    if(type(b)==np.ndarray):
        b=csc_matrix(b) if b.ndim!=1 else csc_matrix(b[:,np.newaxis])
    ntarget=b.shape[1]

    Z=spsolve(A,hstack([b,B],format=format))
    
    H=C.dot(Z)
    y=spsolve(H[:,ntarget:]+eye(k,format=format),H[:,0:ntarget])
    if(ntarget==1):
        y=y[:,np.newaxis]
        y=csc_matrix(y)
    x=Z[:,:ntarget]-Z[:,ntarget:].dot(y)
    return x

In [254]:
def native_solver(A,B,C,b):
    return np.linalg.solve(A+B.dot(C),b)

In [297]:
n,k=10000000,40
delta,eta=1,1
e=np.ones((n))
A=np.random.rand(k,n)
b=np.random.randn(k,3)

D=sparse.spdiags([-e,2*e,e],[-1,0,1],n,n,format='csc')
D[0,0]=1
D[-1,-1]=1

I=sparse.eye(n)

In [277]:
x=solver(delta*D+eta*I,A.T,A,A.T.dot(b))
xt=native_solver(delta*D+eta*I,A.T,A,A.T.dot(b))
np.testing.assert_allclose(np.array(delta*D+eta*I+A.T.dot(A)).dot(x.toarray()),A.T.dot(b))
np.testing.assert_allclose(x.toarray(),xt)

## 效率对比

In [298]:
_=%time solver(delta*D+eta*I,A.T,A,A.T.dot(b))

CPU times: user 3min 21s, sys: 50.2 s, total: 4min 11s
Wall time: 3min 59s


In [292]:
_=%time native_solver(delta*D+eta*I,A.T,A,A.T.dot(b))

MemoryError: 