In [None]:
import torch
import torch.nn as tNN
import torch.linalg as tla
import numml.sparse as sp
import numml.nn as nNN
import matplotlib.pyplot as plt

In [None]:
N = 10
A = sp.eye(N)*2 - sp.eye(N,k=-1) - sp.eye(N,k=1)

In [None]:
def f(x, b):
    return (x@(A@x))/2 - x@b

def gradf(x, b):
    return A@x - b

def Hf(x, b):
    return A

In [None]:
def heavyball(x, xp, b, alpha, beta):
    return x - alpha * gradf(x, b) + beta*(x - xp)

In [None]:
class Network(tNN.Module):
    def __init__(self, H):
        super().__init__()
        
        self.conv1 = nNN.TAGConv(2, H, normalize=True)
        self.conv2 = nNN.TAGConv(H, 2, normalize=True)
    
    def forward(self, A, x, xp):
        X = torch.column_stack((x, xp))
        X = torch.relu(self.conv1(A, X))
        X = torch.relu(self.conv2(A, X))
        alpha, beta = torch.tanh(torch.mean(X, dim=0))
        return alpha, beta

In [None]:
#network = Network(10)
#optimizer = torch.optim.Adam(network.parameters(), lr=0.01)
alpha = torch.tensor(1., requires_grad=True)
beta = torch.tensor(1., requires_grad=True)
optimizer = torch.optim.Adam([alpha, beta], lr=0.01)

N_e = 1_000
N_b = 100

lh = []
ah = []
bh = []

for i in range(N_e):
    optimizer.zero_grad()
    
    loss = 0.
    for j in range(N_b):
        b = torch.randn(N)
        xp = torch.zeros(N)
        x = torch.zeros(N)

        #alpha, beta = network(A, x, xp)
        xn = heavyball(x, xp, b, alpha, beta)
        xp = x
        x = xn

        #alpha, beta = network(A, x, xp)
        xn = heavyball(x, xp, b, alpha, beta)
        xp = x
        x = xn

        loss += tla.norm(b-A@x) / tla.norm(b)
    loss /= N_b
    loss.backward()
    lh.append(loss.item())
    ah.append(alpha.item())
    bh.append(beta.item())
    
    optimizer.step()
    if i % 10 == 0:
        print(i, loss.item(), alpha.item(), beta.item())

In [None]:
plt.semilogy(lh)
plt.figure()
plt.plot(ah)
plt.plot(bh)