In [None]:
import torch
import torch.nn as tNN
import torch.linalg as tla
import numml.sparse as sp
import numml.nn as nNN
import numml.krylov as kry
import matplotlib.pyplot as plt

In [None]:
N = 16
A = sp.eye(N) * 2 - sp.eye(N,k=-1) - sp.eye(N,k=1)
gpu = torch.device('cuda:0')

A_c = A.to(gpu)

In [None]:
# Use our GCN implementation to create a network that maps right-hand-side
# for a matrix to a guess to a solution of Ax=b

class Network(tNN.Module):
    def __init__(self, H):
        super().__init__()
        
        self.conv1 = nNN.TAGConv(1, H, normalize=False)
        self.conv2 = nNN.TAGConv(H, H, normalize=False)
        self.conv3 = nNN.TAGConv(H, H, normalize=False)
        self.conv4 = nNN.TAGConv(H, 1, normalize=False)
        self.upscale = tNN.Linear(1, H)
        self.downscale = tNN.Linear(H, 1)
    
    def forward(self, A, X):
        X = torch.tanh(self.conv1(A, X)) + self.upscale(torch.unsqueeze(X, 1))
        X = torch.tanh(self.conv2(A, X)) + X
        X = torch.tanh(self.conv3(A, X)) + X
        X = torch.tanh(self.conv4(A, X)) + self.downscale(X)
        X = torch.squeeze(X)
        return X

In [None]:
# Optimize over the entries of the network, not totally working yet...

network = Network(16).to(gpu)
optimizer = torch.optim.Adam(network.parameters(), lr=0.01)

N_e = 1_000
N_b = 100
N_it = 1
lh = []

for i in range(N_e):
    optimizer.zero_grad()
    l = 0.
    for j in range(N_b):
        b = torch.randn(N).to(gpu)
        x_g = torch.zeros(N).to(gpu)
        for k in range(N_it):
            x_g = x_g + network(A_c, b-A_c@x_g)
        r = b - (A_c @ x_g)
        rr = (r@r)/(b@b)
        l += rr
    l.backward()
    optimizer.step()
    lh.append(l.item()/N_b)
    if i % 10 == 0:
        print(i, l.item()/N_b)

In [None]:
plt.semilogy(lh)

In [None]:
b = torch.zeros(N)
b[N//2] = 1.
b = b.to(gpu)

plt.plot(b.cpu())
plt.plot(network(A_c, b).detach().cpu())