In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ilupp
import scipy
import scipy.sparse
import scipy.sparse.linalg
import scipy.ndimage

import matplotlib
matplotlib.rcParams.update({'font.size': 20})

In [None]:
default_plateaus = (
    (.5,.9,.5,.8),
    (.1,.3,.1,.3),
    (.1,.3,.6,.9),
)

def poisson2d_plateaus(n, kbig=100, plateaus=default_plateaus):
    A = scipy.sparse.lil_matrix((n**2,n**2))
    
    # horizontal direction
    for i in range(n - 1):
        for j in range(n):
            x = (i + 0.5) / n 
            y = j / n
            k = 1
            
            for lx,rx,ly,ry in plateaus:
                if lx < x < rx and ly < y < ry:
                    k = kbig
            
            ind = j*n + i
            A[ind, ind] += k
            A[ind+1, ind+1] += k
            
            A[ind+1, ind] -= k
            A[ind, ind+1] -= k
    
    # vertical direction
    for i in range(n):
        for j in range(n - 1):
            x = i / n 
            y = (j + 0.5) / n
            k = 1
            
            for lx,rx,ly,ry in plateaus:
                if lx < x < rx and ly < y < ry:
                    k = kbig
            
            ind = j*n + i
            A[ind, ind] += k
            A[ind+n, ind+n] += k
            
            A[ind+n, ind] -= k
            A[ind, ind+n] -= k
    
    # something something for the border:
    for i in range(n):
        A[i,i] = 4 # top border
        A[i*n,i*n] = 4 # left border
        A[n*(n-1) + i,n*(n-1) + i] = 4 # bottom border
        A[(n-1) + i*n,(n-1) + i*n] = 4 # right border

    A = A.tocsr()
    
    b = np.zeros(A.shape[1])
    for i in range(n):
        b[i] = 1 # top border
        # b[i*n] = 4 # left border
        b[n*(n-1) + i] = 2 # bottom border
        # b[(n-1) + i*n] = 4 # right border

    return A, b


In [None]:
residuals = []

def conjgrad(A,M,x,b,tol=1e-8,maxiter=1000):
    P = scipy.sparse.linalg.LinearOperator(A.shape, matvec=lambda v: v, rmatvec=lambda v: v)
    return deflated(A, M, P, x, b, tol, maxiter)


def deflated(A,M,P,x,b,tol=1e-8,maxiter=1000):
    global solutions, direction, residuals
   
    uc = x
    r = b - A @ x
    rc = P @ r
    y = M @ rc
    p = y
    
    residuals = [np.linalg.norm(rc)]
    
    for j in range(maxiter):
        rcprev = rc
        yprev = y
        
        wc = P @ (A @ p)
        alpha = np.dot(rc, y) / np.dot(wc, p)
        uc = uc + alpha * p
        rc = rc - alpha * wc
        y = M @ rc
        beta = np.dot(rc, y) / np.dot(rcprev, yprev)
        p = y + beta * p
        
        nrc = np.linalg.norm(rc)
        residuals.append(nrc)
        if nrc < tol:
            print('Itr:', j)
            break
    
    return Q @ b + P.T @ uc

In [None]:
N = 51
A,b = poisson2d_plateaus(N,100)
# A,b = poisson2d(N)
tol = 1e-8
# A.shape

In [None]:
B = A.diagonal().reshape(N,N)
G = np.zeros_like(B)
G[B > 6] = 1
D = G.reshape((N**2,1))

G2, q = scipy.ndimage.measurements.label(G)
GG = np.zeros((N**2, q+1))
G3 = G2.flatten()
for i in range(q+1):
    GG[G3 == i, i] = 1

for i in range(q+1):
    plt.figure()
    plt.imshow(GG[:,i].reshape(N,N))
    
GG.shape

In [None]:
U = GG
E = U.T @ A @ U
Einv = np.linalg.inv(E)
def projA(b):
    return U @ (Einv @ (U.T @ b))
def projTA(b):
    return U @ (Einv.T @ (U.T @ b))
Q = scipy.sparse.linalg.LinearOperator(A.shape, matvec=projA, rmatvec=projTA)

def projAt(b):
    return b - A @ (Q @ b)
def projTAt(b):
    return b - Q.T @ (A.T @ b)

P = scipy.sparse.linalg.LinearOperator(A.shape, matvec=projAt, rmatvec=projTAt)


In [None]:
M = scipy.sparse.identity(b.shape[0])
M = scipy.sparse.diags(1/A.diagonal())
# M = ilupp.ICholTPreconditioner(A, add_fill_in=0)
# M = ilupp.ICholTPreconditioner(A, add_fill_in=3)
# M = ilupp.ICholTPreconditioner(A, add_fill_in=10)

In [None]:
maxiter = 10000
tol = 1e-8

x = np.zeros_like(b)
%time y1 = conjgrad(A, M, x, b, tol=tol, maxiter=maxiter)
# r = [np.linalg.norm(r) for r in residuals]
CON = residuals

x = np.zeros_like(b)
%time y2 = deflated(A, M, P, x, b, tol=tol, maxiter=maxiter)

# r = [np.linalg.norm(r) for r in residuals]
DEF = residuals

In [None]:
plt.figure(figsize=(8,6))

plt.semilogy(CON, label="CG")
plt.semilogy(DEF, label="DCG")

plt.legend()
plt.xlabel("Iteration")
plt.ylabel("Residual")

In [None]:
np.linalg.norm(A @ y2 - b)

In [None]:
DEF[-1]

In [None]:
plt.imshow(y1.reshape(N,N))

In [None]:
plt.imshow(y2.reshape(N,N))