In [None]:
import torch
import torch.nn as tNN
import torch.linalg as tla
import numml.sparse as sp
import numml.nn as nNN
import numml.krylov as kry
import matplotlib.pyplot as plt

In [None]:
N = 16
A = sp.eye(N) * 2 - sp.eye(N,k=-1) - sp.eye(N,k=1)

In [None]:
# Use our GCN implementation to create a network that maps right-hand-side
# for a matrix to a guess to a solution of Ax=b

class Network(tNN.Module):
    def __init__(self, H):
        super().__init__()
        
        self.conv1 = nNN.GCNConv(1, H, normalize=False)
        self.conv2 = nNN.GCNConv(H, H, normalize=False)
        self.conv3 = nNN.GCNConv(H, H, normalize=False)
        self.conv4 = nNN.GCNConv(H, 1, normalize=False)
        self.upscale = tNN.Linear(1, H)
        self.downscale = tNN.Linear(H, 1)
    
    def forward(self, A, X):
        X = torch.tanh(self.conv1(A, X)) + self.upscale(torch.unsqueeze(X, 1))
        X = torch.tanh(self.conv2(A, X)) + X
        X = torch.tanh(self.conv3(A, X)) + X
        X = torch.tanh(self.conv4(A, X)) + self.downscale(X)
        X = torch.squeeze(X)
        return X

In [None]:
# Optimize over the entries of the network, not totally working yet...

network = Network(5)
optimizer = torch.optim.Adam(network.parameters(), lr=0.01)

N_e = 1_000
N_b = 100

for i in range(N_e):
    optimizer.zero_grad()
    l = 0.
    for j in range(N_b):
        b = torch.randn(N)
        x_g = network(A, b)
        r = b - (A @ x_g)
        rr = (r@r)/(b@b)
        l += rr
    l.backward()
    optimizer.step()
    print(i, l.item()/N_b)

In [None]:
b = torch.zeros(N)
b[N//2] = 1.

plt.plot(b)
plt.plot(network(A, b).detach().cpu())