In [None]:
import torch
import torch.linalg as tla
import torch.autograd
import torch.nn as tNN
import numml.sparse as sp
import numml.iterative as it
import numml.nn as nNN
import matplotlib.pyplot as plt

In [None]:
class Network(tNN.Module):
    def __init__(self, H):
        super().__init__()
        
        self.conv1 = nNN.GCNConv(3, H, normalize=False)
        self.conv2 = nNN.GCNConv(H, 2, normalize=False)
    
    def forward(self, A, xk, xkp, b):
        X = torch.column_stack((xk, xkp, b))
        X = torch.tanh(self.conv1(A, X))
        X = torch.sigmoid(self.conv2(A, X))
        return torch.mean(X, dim=0)

Consider the linear system of equations,
$$ Ax = b, $$
for $A \in \mathbb{R}^{n \times n}$ and $x, b \in \mathbb{R}^n$.  Defining the quadratic loss function
$$ \begin{align} f\left(A, b, x\right) &= \frac{1}{2}x^T A x - x^T b \\ \nabla_f\left(A, b, x\right) &= Ax - b \end{align}, $$
we can write Polyak's heavy-ball iteration like
$$ \begin{align} x^{(k+1)} &= x^{(k)} - \alpha \nabla_f + \beta\left(x^{(k)} - x^{(k-1)}\right) \\ &= x^{(k)} + \alpha \left(b-Ax^{(k)}\right) + \beta\left(x^{(k)} - x^{(k-1)}\right) \end{align}. $$

Since we have a nice adjoint solver for fixed-point problems, we will convert this to a FP problem by defining
$$ \bar{x}^{(k)} = \begin{bmatrix} x^{(k)} \\ x^{(k-1)} \end{bmatrix}, $$
as well as the restriction operators $R_1, R_2$ like
$$ \begin{align} R_1\bar{x}^{(k)} &= x^{(k)}, \\ R_2\bar{x}^{(k)} &= x^{(k-1)}. \end{align} $$
This gives the fixed-point map
$$ \begin{align}
g\left(A, b, \bar{x}\right) &= R_1^T\left(R_1\bar{x} + \alpha \left(b - AR_1\bar{x}\right) + \beta\left(R_1\bar{x} - R_2\bar{x}\right)\right) + R_2^TR_1\bar{x} \\
&= \begin{bmatrix}
R_1\bar{x} + \alpha \left(b - AR_1\bar{x}\right) + \beta\left(R_1\bar{x} - R_2\bar{x}\right) \\ R_1\bar{x}
\end{bmatrix}
\end{align},
$$
which at a fixed point will return $\begin{bmatrix}x^\star \\ x^\star\end{bmatrix}$, where $x^\star = A^{-1}b$.  Showing that this is a contraction and that $x^\star$ is an attracting fixed point is left as an exercise to the reader :-)

In [None]:
N = 12
A = sp.eye(N)*2 - sp.eye(N,k=-1) - sp.eye(N,k=1)
b = torch.ones(N)

network_H = 6

In [None]:
def f(A, b, x):
    return (x@A@x)/2 - x@b

def gradf(A, b, x):
    return A@x - b

def Hf(A, b, x):
    return A

In [None]:
def f_hb(x, A, b, net_param):
    
    # Grab x^{(k)} and x^{(k-1)}
    x_k = x[:N]
    x_kp = x[N:]
    
    # network parameters -> torch object
    if isinstance(net_param, torch.Tensor):
        net = Network(network_H)
        nNN.vector_to_model(net, net_param)
        alpha, beta = net(A, x_k, x_kp, (b-A@x_k))
    else:
        alpha, beta = net_param
    
    # Next iterate
    x_kn = x_k - alpha * gradf(A, b, x_k) + beta * (x_k - x_kp)
    
    # Re-pack into \bar{x}
    return torch.cat((x_kn, x_k))

In [None]:
x = torch.cat((torch.ones(N), torch.ones(N)))
net = Network(network_H)
net_param = nNN.model_to_vector(net).detach().clone()
net_param.requires_grad = True

opt = torch.optim.Adam([net_param], lr=0.025)
tr_lh = []
te_lh = []
N_b = 10
B = torch.randn((N_b, N))

def test_loss(net_param):
    with torch.no_grad():
        xk, xkp = it.fp_wrapper(f_hb, x, A, b, net_param, max_iter=N).reshape((2, -1))
    return (tla.norm(b - A @ xk) / tla.norm(b - A@torch.ones(N))).item()

print('| It | Train Loss | Test Loss |')
i = 0
while True:
    opt.zero_grad()
    
    loss = 0.
    for j in range(N_b):
        b_rand = B[j]
        xk, xkp = it.fp_wrapper(f_hb, x, A, b_rand, net_param, max_iter=N).reshape((2, -1))
        loss += (tla.norm(b_rand - A @ xk) / tla.norm(b_rand - A@torch.ones(N))) / N_b
    loss.backward()
    
    opt.step()
    tl = test_loss(net_param)
    print(f'| {i:2} | {loss.item():10.3f} | {tl:9.3f} |')
    
    if tl < 0.05:
        break
    
    tr_lh.append(loss.item())
    te_lh.append(tl)
    i += 1

In [None]:
plt.figure()
plt.plot(tr_lh, 'k', label='Train Loss')
plt.ylabel('Training Loss')
plt.xlabel('Iteration')

ax1 = plt.gca()

ax2 = plt.gca().twinx()
ax2.plot(te_lh, 'r', label='Test loss')
ax2.set_ylabel('Testing loss')

plt.grid()
ax1.legend()
ax2.legend()

In [None]:
# Solution plots

static_ab = (0.7, 0.7)

xg = it.fp_wrapper(f_hb, x, A, b, net_param, max_iter=N)[:N]
xg2 = it.fp_wrapper(f_hb, x, A, b, static_ab, max_iter=N)[:N]
plt.plot(sp.spsolve(A, b).detach(), 'k', label='True Soln')
plt.plot(xg.detach(), 'r--', label=f'GNN Heavy-ball Soln.')
plt.plot(xg2.detach(), 'b--', label=f'Heavy-ball Soln. (a={static_ab[0]:0.2f}, b={static_ab[1]:0.2f})')
plt.legend()

In [None]:
# Residual plots

def heavyball_res(x, A, b, net_param):
    N = A.shape[0]
    res = torch.empty(N+1)
    res[0] = tla.norm(b-A@x) / tla.norm(b)
    
    alphas = torch.empty(N)
    betas = torch.empty(N)
    
    if isinstance(net_param, torch.Tensor):
        net = Network(network_H)
        nNN.vector_to_model(net, net_param)
    
    x_p = x
    for i in range(1, N+1):
        if isinstance(net_param, torch.Tensor):
            alpha, beta = net(A, x, x_p, b)
        else:
            alpha, beta = net_param
        
        x_n = x - alpha * gradf(A, b, x) + beta * (x - x_p)
        x_p = x
        x = x_n
        res[i] = tla.norm(b-A@x) / tla.norm(b)
        alphas[i-1] = alpha
        betas[i-1] = beta
    
    return res, alphas, betas

with torch.no_grad():
    res_opt, a_opt, b_opt = heavyball_res(torch.ones(N), A, b, net_param)
    res_naive, a_naive, b_naive = heavyball_res(torch.ones(N), A, b, static_ab)

plt.figure()
plt.semilogy(res_opt, 'r', label='GNN')
plt.semilogy(res_naive, 'b', label=f'a={static_ab[0]:.2f}, b={static_ab[1]:.2f}')
plt.xlabel('Iteration')
plt.ylabel('Relative Residual')
plt.legend()
plt.grid()

plt.figure()
plt.plot(torch.arange(1, N+1), a_opt, '.-', label='Opt A')
plt.plot(torch.arange(1, N+1), b_opt, '.-', label='Opt B')
plt.plot(torch.arange(1, N+1), a_naive, '.-', label='Static A')
plt.plot(torch.arange(1, N+1), b_naive, '--', label='Static B')

plt.legend()
plt.grid()