In [None]:
import torch
import torch.linalg as tla
import torch.autograd
import numml.sparse as sp
import numml.iterative as it
import matplotlib.pyplot as plt

In [None]:
N = 12
A = sp.eye(N)*2 - sp.eye(N,k=-1) - sp.eye(N,k=1)
b = torch.ones(N)

In [None]:
def f(A, b, x):
    return (x@A@x)/2 - x@b

def gradf(A, b, x):
    return A@x - b

def Hf(A, b, x):
    return A

Consider the linear system of equations,
$$ Ax = b, $$
for $A \in \mathbb{R}^{n \times n}$ and $x, b \in \mathbb{R}^n$.  Defining the quadratic loss function
$$ \begin{align} f\left(A, b, x\right) &= \frac{1}{2}x^T A x - x^T b \\ \nabla_f\left(A, b, x\right) &= Ax - b \end{align}, $$
we can write Polyak's heavy-ball iteration like
$$ \begin{align} x^{(k+1)} &= x^{(k)} - \alpha \nabla_f + \beta\left(x^{(k)} - x^{(k-1)}\right) \\ &= x^{(k)} + \alpha \left(b-Ax^{(k)}\right) + \beta\left(x^{(k)} - x^{(k-1)}\right) \end{align}. $$

Since we have a nice adjoint solver for fixed-point problems, we will convert this to a FP problem by defining
$$ \bar{x}^{(k)} = \begin{bmatrix} x^{(k)} \\ x^{(k-1)} \end{bmatrix}, $$
as well as the restriction operators $R_1, R_2$ like
$$ \begin{align} R_1\bar{x}^{(k)} &= x^{(k)}, \\ R_2\bar{x}^{(k)} &= x^{(k-1)}. \end{align} $$
This gives the fixed-point map
$$ \begin{align}
g\left(A, b, \bar{x}\right) &= R_1^T\left(R_1\bar{x} + \alpha \left(b - AR_1\bar{x}\right) + \beta\left(R_1\bar{x} - R_2\bar{x}\right)\right) + R_2^TR_1\bar{x} \\
&= \begin{bmatrix}
R_1\bar{x} + \alpha \left(b - AR_1\bar{x}\right) + \beta\left(R_1\bar{x} - R_2\bar{x}\right) \\ R_1\bar{x}
\end{bmatrix}
\end{align},
$$
which at a fixed point will return $\begin{bmatrix}x^\star \\ x^\star\end{bmatrix}$, where $x^\star = A^{-1}b$.  Showing that this is a contraction and that $x^\star$ is an attracting fixed point is left as an exercise to the reader :-)

In [None]:
def f_hb(x, A, b, alpha, beta):
    # Heavyball iteration to solve Ax=b
    
    # Grab x^{(k)} and x^{(k-1)}
    x_k = x[:N]
    x_kp = x[N:]
    
    # Next iterate
    x_kn = x_k - alpha * gradf(A, b, x_k) + beta * (x_k - x_kp)
    
    # Re-pack into \bar{x}
    return torch.cat((x_kn, x_k))

In [None]:
# Optimize heavyball parameters over loss of l = (||b - A @ x^{(k)}|| / ||b - A @ x^{(0)}||)
# at each iteration, we generate N_b random right-hand-sides and average the loss over these.
# The test loss is b=1.

x = torch.cat((torch.ones(N), torch.ones(N)))
alpha = torch.tensor(0.5, requires_grad=True)
beta = torch.tensor(0.5, requires_grad=True)

opt = torch.optim.Adam([alpha, beta], lr=0.01)
lh = []
ah = []
bh = []
N_b = 10

def test_loss(alpha, beta):
    with torch.no_grad():
        xk, xkp = it.fp_wrapper(f_hb, x, A, b, alpha, beta, max_iter=N).reshape((2, -1))
    return (tla.norm(b - A @ xk) / tla.norm(b - A@torch.ones(N))).item()

print('| It | Train Loss | Test Loss |')
for i in range(40):
    opt.zero_grad()
    
    loss = 0.
    for j in range(N_b):
        b_rand = torch.randn(N)
        xk, xkp = it.fp_wrapper(f_hb, x, A, b_rand, alpha, beta, max_iter=N).reshape((2, -1))
        loss += (tla.norm(b_rand - A @ xk) / tla.norm(b_rand - A@torch.ones(N))) / N_b
    loss.backward()
    
    opt.step()
    tl = test_loss(alpha, beta)
    print(f'| {i:2} | {loss.item():10.3f} | {tl:9.3f} |')
    
    lh.append(tl)
    ah.append(alpha.item())
    bh.append(beta.item())

In [None]:
plt.figure()
plt.plot(lh, 'k--')
plt.xlabel('Iteration')
plt.ylabel('Test Loss')

ax2 = plt.gca().twinx()
ax2.plot(ah, 'r', label='Alpha')
ax2.plot(bh, 'b', label='Beta')
ax2.set_ylabel('Heavy-ball Weight')

plt.grid()
plt.legend()

In [None]:
# Solution plots

xg = it.fp_wrapper(f_hb, x, A, b, alpha, beta, max_iter=N)[:N]
xg2 = it.fp_wrapper(f_hb, x, A, b, 0.5, 0.5, max_iter=N)[:N]
plt.plot(sp.spsolve(A, b).detach(), 'k', label='True Soln')
plt.plot(xg.detach(), 'r--', label=f'Opt. Heavy-ball Soln. (a={alpha:0.2f}, b={beta:0.2f})')
plt.plot(xg2.detach(), 'b--', label=f'Naive Heavy-ball Soln. (a={0.5:0.2f}, b={0.5:0.2f})')
plt.legend()

In [None]:
# Residual plots

def heavyball_res(x, A, b, alpha, beta):
    N = A.shape[0]
    res = torch.empty(N+1)
    res[0] = tla.norm(b-A@x) / tla.norm(b)
    
    x_p = x
    for i in range(1, N+1):
        x_n = x - alpha * gradf(A, b, x) + beta * (x - x_p)
        x_p = x
        x = x_n
        res[i] = tla.norm(b-A@x) / tla.norm(b)
    
    return res

with torch.no_grad():
    res_opt = heavyball_res(torch.ones(N), A, b, alpha, beta)
    res_naive = heavyball_res(torch.ones(N), A, b, 0.5, 0.5)

plt.semilogy(res_opt, 'r')
plt.semilogy(res_naive, 'b')
plt.xlabel('Iteration')
plt.ylabel('Relative Residual')