In [None]:
import torch
import torch.linalg as tla
import torch.autograd
import torch.nn as tNN
import numml.sparse as sp
import numml.nn as nNN
import numml.utils as utils
import numml.krylov as kry
import matplotlib.pyplot as plt
import numpy as np

Consider the linear system of equations,
$$ Ax = b, $$
for $A \in \mathbb{R}^{n \times n}$ and $x, b \in \mathbb{R}^n$.  Defining the quadratic loss function
$$ \begin{align} f\left(A, b, x\right) &= \frac{1}{2}x^T A x - x^T b \\ \nabla_f\left(A, b, x\right) &= Ax - b \end{align}, $$
we can write Polyak's heavy-ball iteration like
$$ \begin{align} x^{(k+1)} &= x^{(k)} - \alpha \nabla_f + \beta\left(x^{(k)} - x^{(k-1)}\right) \\ &= x^{(k)} + \alpha \left(b-Ax^{(k)}\right) + \beta\left(x^{(k)} - x^{(k-1)}\right) \end{align}. $$

Since we have a nice adjoint solver for fixed-point problems, we will convert this to a FP problem by defining
$$ \bar{x}^{(k)} = \begin{bmatrix} x^{(k)} \\ x^{(k-1)} \end{bmatrix}, $$
as well as the restriction operators $R_1, R_2$ like
$$ \begin{align} R_1\bar{x}^{(k)} &= x^{(k)}, \\ R_2\bar{x}^{(k)} &= x^{(k-1)}. \end{align} $$
This gives the fixed-point map
$$ \begin{align}
g\left(A, b, \bar{x}\right) &= R_1^T\left(R_1\bar{x} + \alpha \left(b - AR_1\bar{x}\right) + \beta\left(R_1\bar{x} - R_2\bar{x}\right)\right) + R_2^TR_1\bar{x} \\
&= \begin{bmatrix}
R_1\bar{x} + \alpha \left(b - AR_1\bar{x}\right) + \beta\left(R_1\bar{x} - R_2\bar{x}\right) \\ R_1\bar{x}
\end{bmatrix}
\end{align},
$$
which at a fixed point will return $\begin{bmatrix}x^\star \\ x^\star\end{bmatrix}$, where $x^\star = A^{-1}b$.  Showing that this is a contraction and that $x^\star$ is an attracting fixed point is left as an exercise to the reader :-)

In [None]:
# Problem setup

device = (torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu'))
#device = torch.device('cpu')

N = 24
A = sp.eye(N)*2 - sp.eye(N,k=-1) - sp.eye(N,k=1)
A = A.to(device)
b = torch.zeros(N).to(device)
bt = torch.ones(N).to(device)
xt = kry.conjugate_gradient(A, bt)[0]

max_it = 16

network_H = 8

In [None]:
# Network definition.  We'll use a shallow network composed of two GCN layers.

class Network(tNN.Module):
    def __init__(self, H):
        super().__init__()
        
        self.conv1 = nNN.GCNConv(3, H, normalize=True)
        self.conv2 = nNN.GCNConv(H, H, normalize=True)
        self.conv3 = nNN.GCNConv(H, 2, normalize=True)
    
    def forward(self, A, xk, xkp, r):
        X = torch.column_stack((xk, xkp, r))
        X = torch.relu(self.conv1(A, X))
        X = torch.relu(self.conv2(A, X))
        X = torch.sigmoid(self.conv3(A, X))
        return torch.mean(X, dim=0)

In [None]:
# Objective functions

def f(A, b, x):
    return (x@A@x)/2 - x@b

def gradf(A, b, x):
    return A@x - b

def Hf(A, b, x):
    return A

In [None]:
# Heavyball iteration

def f_hb(x, A, b):   
    # Grab x^{(k)} and x^{(k-1)}
    x_k = x[:N]
    x_kp = x[N:]
    
    alpha, beta = net(A, x_k, x_kp, b)
    
    # Next iterate
    x_kn = x_k - alpha * gradf(A, b, x_k) + beta * (x_k - x_kp)
    
    # Re-pack into \bar{x}
    return torch.cat((x_kn, x_k))

def f_hb_fixed(x, A, b, alphabeta):   
    # Grab x^{(k)} and x^{(k-1)}
    x_k = x[:N]
    x_kp = x[N:]

    alpha, beta = alphabeta
    
    # Next iterate
    x_kn = x_k - alpha * gradf(A, b, x_k) + beta * (x_k - x_kp)
    
    # Re-pack into \bar{x}
    return torch.cat((x_kn, x_k))

In [None]:
def fp_wrapper(f, x, *args, max_iter=1):
    for i in range(int(max_iter)):
        x = f(x, *args)
    return x

In [None]:
# Optimize the network to output alpha and beta values at each iteration of heavyball

xg = torch.randn(N, device=device)
xg /= tla.norm(xg)
x = torch.cat((xg, xg)).to(device)
net = Network(network_H).to(device)

opt = torch.optim.Adam(net.parameters(), lr=0.01)
tr_lh = []
te_lh = []
N_b = 100

i = 0

In [None]:
print('| It  | Train Loss | Test Loss |')

def test_loss():
    with torch.no_grad():
        xk, xkp = fp_wrapper(f_hb, x, A, bt, max_iter=max_it).reshape((2, -1))
    e = xt - xk
    return (e@(A@e))

Xt = torch.randn((N_b, N)).to(device)
Xg = torch.randn((N_b, N)).to(device)
B = (A@Xt.T).T

while True:
    opt.zero_grad()
    
    loss = 0.
    tr_loss = torch.zeros(N_b)
    
    for j in range(N_b):
        x0 = torch.cat((Xg[j], Xg[j]))
        xk, xkp = fp_wrapper(f_hb, x0, A, B[j], max_iter=max_it).reshape((2, -1))
        
        e = Xt[j] - xk
        loss_it = (e@(A@e))
        tr_loss[j] = loss_it
        
        loss += loss_it / N_b
    loss.backward()
        
    opt.step()
    tl = test_loss()
    print(f'| {i:3} | {loss.item():10.5f} | {tl:9.3f} |')

    tr_lh.append(tr_loss.detach().cpu().numpy())
    te_lh.append(tl.detach().cpu().item())
    i += 1

In [None]:
plt.figure()
plt.plot(np.array(tr_lh))
plt.ylabel('Training Loss')
plt.xlabel('Iteration')
plt.title('Ind. Training Loss')
plt.grid()

plt.figure()
plt.plot(te_lh, 'r', label='Test loss')
plt.ylabel('Testing loss')
plt.xlabel('Iteration')
plt.title('Test Loss')
plt.grid()

In [None]:
# Solution plots

static_ab = (0.7, 0.7)

xg = fp_wrapper(f_hb, x, A, bt, max_iter=N)[:N]
xg2 = fp_wrapper(f_hb_fixed, x, A, bt, static_ab, max_iter=N)[:N]
plt.plot(kry.conjugate_gradient(A, bt, x[:N])[0].detach().cpu().numpy(), 'k', label='True Soln')
plt.plot(xg.detach().cpu().numpy(), 'r--', label=f'GNN Heavy-ball Soln.')
plt.plot(xg2.detach().cpu().numpy(), 'b--', label=f'Heavy-ball Soln. (a={static_ab[0]:0.2f}, b={static_ab[1]:0.2f})')
plt.legend()

In [None]:
# Residual plots

def heavyball_res(x, A, b, max_it=None, alphabeta=None):
    N = A.shape[0]
    res = torch.empty(max_it+1)
    res[0] = tla.norm(b-A@x) / tla.norm(b)
    
    alphas = torch.empty(max_it)
    betas = torch.empty(max_it)
    
    x_p = x
    for i in range(1, max_it+1):
        if alphabeta is None:
            alpha, beta = net(A, x, x_p, bt)
        else:
            alpha, beta = alphabeta
        
        x_n = x - alpha * gradf(A, bt, x) + beta * (x - x_p)
        x_p = x
        x = x_n
        res[i] = tla.norm(b-A@x) / tla.norm(b)
        alphas[i-1] = alpha
        betas[i-1] = beta
    
    return res, alphas, betas

res_max_it = N*3

with torch.no_grad():
    res_opt, a_opt, b_opt = heavyball_res(torch.ones(N).to(device), A, bt, max_it=res_max_it)
    res_naive, a_naive, b_naive = heavyball_res(torch.ones(N).to(device), A, bt, alphabeta=static_ab, max_it=res_max_it)

plt.figure()
plt.semilogy(res_opt, 'r', label='GNN')
plt.semilogy(res_naive, 'b', label=f'a={static_ab[0]:.2f}, b={static_ab[1]:.2f}')
plt.semilogy((torch.tensor(kry.conjugate_gradient(A, b, x[:N], iterations=N, rtol=1e-10)[1]).detach().cpu().numpy() / np.linalg.norm(b.cpu().numpy())), 'k', label='CG')
#plt.plot([N, N]

plt.xlabel('Iteration')
plt.ylabel('Relative Residual')
plt.legend()
plt.grid()

plt.figure()
plt.plot(torch.arange(1, res_max_it+1), a_opt, '.-', label='Opt A')
plt.plot(torch.arange(1, res_max_it+1), b_opt, '.-', label='Opt B')
plt.plot(torch.arange(1, res_max_it+1), a_naive, '.-', label='Static A')
plt.plot(torch.arange(1, res_max_it+1), b_naive, '--', label='Static B')
plt.xlabel('Iteration')
plt.ylabel('Heavyball Parameters')

plt.legend()
plt.grid()