In [None]:
import torch
import torch.linalg as tla
import numpy as np
import matplotlib.pyplot as plt
import numml.sparse as sp

In [None]:
# Create our favorite poisson operator

N = 32
A = sp.eye(N)*2. - sp.eye(N, k=1) - sp.eye(N, k=-1)
b = torch.zeros(N)

In [None]:
# CG and PCG implementations taken from Y. Saad, Iterative Methods for Sparse Linear Systems

def cg(A, b, iterations=15):
    x = torch.zeros(A.shape[1])
    r = b
    p = r
    
    res_hist = torch.zeros(iterations + 1)
    res_hist[0] = tla.norm(r)
    
    for i in range(iterations):
        Ap = A @ p
        rr = r@r
        alpha = rr/(Ap@p)
        x = x + alpha * p
        r = r - alpha * Ap
        beta = (r@r) / rr
        p = r + beta * p
        
        res_hist[i+1] = tla.norm(r)
    
    return x, res_hist

def pcg(M, A, b, iterations=15):
    x = torch.zeros(A.shape[1])
    r = b
    z = M @ r
    p = r
    
    res_hist = torch.zeros(iterations + 1)
    res_hist[0] = tla.norm(r)
    
    for i in range(iterations):
        Ap = A @ p
        rz = r@z
        alpha = rz/(Ap@p)
        x = x + alpha * p
        r = r - alpha * Ap
        z = M @ r
        beta = (r@z)/(rz)
        p = z + beta * p
        
        res_hist[i+1] = tla.norm(r)
    
    return x, res_hist

In [None]:
# linear forcing term, gives sinusoidal looking u

b = torch.linspace(-1, 1, N+2)[1:-1]

plt.plot(b, label='rhs')
plt.plot(sp.spsolve(A, b), label='true solution')
plt.legend()

In [None]:
# Find some preconditioner w/ same sparsity as A
# optimize residual of last iterate wrt entries of M

M = A.copy()
M.requires_grad = True

optimizer = torch.optim.Adam([M.data], lr=0.01)
epochs = 750
lh = torch.zeros(epochs)

for i in range(epochs):
    optimizer.zero_grad()
    
    x_, res_ = pcg(M, A, b)
    loss = res_[-1]
    loss.backward()
    
    optimizer.step()
    lh[i] = loss.item()
    
    if i % 10 == 0 or i == epochs - 1:
        print(i, loss.item())

In [None]:
plt.semilogy(lh)

In [None]:
x, res = cg(A, b)
x_m, res_m = pcg(M, A, b)

In [None]:
plt.semilogy(res, label='Conjugate Gradient')
plt.semilogy(res_m.detach(), label='Optimized PCG')
plt.legend()
plt.grid()
plt.xlabel('Iteration')
plt.ylabel('Residual')