In [None]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib as mpl
import scipy as sp
import pyamg
import random
import math
from torch.autograd import Variable
from pyamg.multilevel import multilevel_solver
from pyamg.relaxation.smoothing import change_smoothers
from scipy.sparse import csr_matrix, coo_matrix, lil_matrix, isspmatrix_csr, SparseEfficiencyWarning
import collections
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def map_2_to_1_numpy(grid_size=8):
    # maps 2D coordinates to the corresponding 1D coordinate in the matrix.
    k = np.zeros((grid_size, grid_size, 3, 3))
    M = np.reshape(np.arange(grid_size ** 2), (grid_size, grid_size)).T
    M = np.concatenate([M, M], 0)
    M = np.concatenate([M, M], 1)
    for i in range(3):
        I = (i - 1) % grid_size
        for j in range(3):
            J = (j - 1) % grid_size
            k[:, :, i, j] = M[I:I + grid_size, J:J + grid_size]
    return k
def get_p_matrix_indices_one_numpy(grid_size):
    K = map_2_to_1_numpy(grid_size=grid_size)
    indices = []
    for ic in range(grid_size // 2):
        i = 2 * ic + 1
        for jc in range(grid_size // 2):
            j = 2 * jc + 1
            J = int(grid_size // 2 * jc + ic)
            for k in range(3):
                for m in range(3):
                    I = int(K[i, j, k, m])
                    indices.append([I, J])

    return np.array(indices)

def compute_stencil_numpy(A, grid_size):
    indices = get_indices_compute_A_one_numpy(grid_size)
    stencil = np.array(A[indices[:, 0], indices[:, 1]]).reshape((grid_size, grid_size, 3, 3))
    return stencil
def get_indices_compute_A_one_numpy(grid_size):
    indices = []
    K = map_2_to_1_numpy(grid_size=grid_size)
    for i in range(grid_size):
        for j in range(grid_size):
            I = int(K[i, j, 1, 1])
            for k in range(3):
                for m in range(3):
                    J = int(K[i, j, k, m])
                    indices.append([I, J])

    return np.array(indices)
def compute_p2_numpy(P_stencil, grid_size):
    indexes = get_p_matrix_indices_one_numpy(grid_size)
    P = sp.sparse.csr_matrix(arg1=(P_stencil.reshape(-1), (indexes[:, 1], indexes[:, 0])),
                   shape=((grid_size//2) ** 2, (grid_size) ** 2))

    return P
def compute_A_indices_numpy(grid_size):
    K = map_2_to_1_numpy(grid_size=grid_size)
    A_idx = []
    stencil_idx = []
    for i in range(grid_size):
        for j in range(grid_size):
            I = int(K[i, j, 1, 1])
            for k in range(3):
                for m in range(3):
                    J = int(K[i, j, k, m])
                    A_idx.append([I, J])
                    stencil_idx.append([i, j, k, m])
    return np.array(A_idx), stencil_idx
def prolongation_fn(grid_size):
#     grid_size = int(math.sqrt(A.shape[0]))
    res_stencil = np.double(np.zeros((3,3)))
    k=16
    res_stencil[0,0] = 1/k
    res_stencil[0,1] = 2/k
    res_stencil[0,2] = 1/k
    res_stencil[1,0] = 2/k
    res_stencil[1,1] = 4/k
    res_stencil[1,2] = 2/k
    res_stencil[2,0] = 1/k
    res_stencil[2,1] = 2/k
    res_stencil[2,2] = 1/k
    P_stencils= np.zeros((grid_size//2,grid_size//2,3,3))
    for i in range(grid_size//2):
        for j in range(grid_size//2):
            P_stencils[i,j,:,:]=res_stencil
    return compute_p2_numpy(P_stencils, grid_size).astype(np.double)  # imaginary part should be zero

In [None]:
def compute_A_torch(P_stencil, grid_size):
    A,indexes = compute_A_indices_torch(grid_size)
    P = torch.sparse.DoubleTensor(torch.LongTensor(A.T), P_stencil.view(-1), (grid_size**2,grid_size**2))
    return P
def compute_A_indices_torch(grid_size):
    K = map_2_to_1_torch(grid_size=grid_size)
    A_idx = []
    stencil_idx = []
    for i in range(grid_size):
        for j in range(grid_size):
            I = int(K[i, j, 1, 1])
            for k in range(3):
                for m in range(3):
                    J = int(K[i, j, k, m])
                    A_idx.append([I, J])
                    stencil_idx.append([i, j, k, m])
    return np.array(A_idx), stencil_idx
def map_2_to_1_torch(grid_size=8):
    # maps 2D coordinates to the corresponding 1D coordinate in the matrix.
    k = np.zeros((grid_size, grid_size, 3, 3))
    M = np.reshape(np.arange(grid_size ** 2), (grid_size, grid_size)).T
    M = np.concatenate([M, M], 0)
    M = np.concatenate([M, M], 1)
    for i in range(3):
        I = (i - 1) % grid_size
        for j in range(3):
            J = (j - 1) % grid_size
            k[:, :, i, j] = M[I:I + grid_size, J:J + grid_size]
    return k


In [None]:
# def g(x,y):
#     return np.sin(np.pi*x*y)
def g(x,y,option=2):
    if option == 1:
#         dd = x*(y**2+1)
        dd = 0.1*x+3*y+1.5*y**2
    elif option == 2:
        return np.sin(0.1*np.pi*x*y)+1.01
    return np.exp(dd)
def construct_problem(size,g):
    x = [(i+1)/(size+1) for i in range(size)]
    y = x
    prob = np.zeros((size,size,3,3))
    h = 1/(size+1)
    for i in range(size):
        for j in range(size):
            # Compute each stencil
            xnw = (x[i-1]+x[i])/2 if i-1>=0 else x[i]
            ynw = (y[j+1]+y[j])/2 if j+1<size else y[j]
            xne = (x[i+1]+x[i])/2 if i+1<size else x[i]
            yne = (y[j+1]+y[j])/2 if j+1<size else y[j]
            xsw = (x[i-1]+x[i])/2 if i-1>=0 else x[i]
            ysw = (y[j-1]+y[j])/2 if j-1>=0 else y[j]
            xse = (x[i+1]+x[i])/2 if i+1<size else x[i]
            yse = (y[j-1]+y[j])/2 if j-1>=0 else y[j ]
            gnw = g(xnw,ynw)
            gne = g(xne,yne)
            gsw = g(xsw,ysw)
            gse = g(xse,yse)

            prob[i,j,0,2] =  -1/3*gnw
            prob[i,j,1,2] = -1/6*(gnw+gne)
            prob[i,j,2,2] = -1/3*gne
            prob[i,j,0,1] = -1/6*(gsw+gnw)
            prob[i,j,1,1] = 2/3*(gnw+gne+gse+gsw)
            prob[i,j,2,1] = -1/6*(gne+gse)
            prob[i,j,0,0] = -1/3*gsw
            prob[i,j,1,0] = -1/6*(gse+gsw)
            prob[i,j,2,0] = -1/3*gse
    prob[ :, 0, :, 0] = 0.
    prob[ :, -1, :, -1] = 0.
    prob[ 0, :, 0, :] = 0.
    prob[ -1, :, -1, :] = 0.
    return prob

In [None]:
def coo_to_tensor(coo):
    values = coo.data
    indices = np.vstack((coo.row, coo.col))
    i = torch.LongTensor(indices)
    v = torch.DoubleTensor(values)
    shape = coo.shape
    temp = coo
    row = temp.row
    col = temp.col
    data = temp.data
    return torch.sparse_coo_tensor(i, v, torch.Size(shape),requires_grad = False)

In [None]:
def compute_A_numpy(stencils,grid_size):
    A,indexes = compute_A_indices_numpy(grid_size)
    P_numpy = sp.sparse.csr_matrix(arg1=(stencils.reshape(-1), (A[:, 0], A[:, 1])),
               shape=(grid_size ** 2, grid_size  ** 2))
    return P_numpy
def get_p_matrix_indices_one(grid_size):
    K = map_2_to_1_numpy(grid_size=grid_size)
    indices = []
    for ic in range(grid_size // 2):
        i = 2 * ic + 1
        for jc in range(grid_size // 2):
            j = 2 * jc + 1
            J = int(grid_size // 2 * jc + ic)
            for k in range(3):
                for m in range(3):
                    I = int(K[i, j, k, m])
                    indices.append([I, J])

    return np.array(indices)

In [None]:
A_train = []
stencil_train = []
eig_vec_train = []
Ag_train = []
A0_train = []
n = 15
eps = 10
for iii in range(1):
    # theta = np.pi/6*np.random.rand(1).item()+np.pi/6
    s = construct_problem(n,g)
    A = compute_A_numpy(s,n)
    A0 = coo_to_tensor(A.tocoo())
    k=15
    # eig_value,eig_vec = sp.sparse.linalg.eigs(A,k=k,which = 'SM')
    R = prolongation_fn(n)
    P = R.T
    T = R*P
    A = R*A*P
    # print(eig_value)
    stencil_train = compute_stencil_numpy(A,n//2)
    stencil_train = torch.from_numpy(stencil_train)
    eig_value,eig_vec = sp.sparse.linalg.eigs(A,k=k,M=T,which = 'SM')
    eig_vec = eig_vec.T
    # eig_vec = torch.from_numpy(eig_vec).view(49,1,n//2,n//2)

    eig_vec = torch.real(torch.from_numpy(eig_vec))
    A = coo_to_tensor(A.tocoo())



    res_stencil = torch.zeros(1,1,3,3).double()
    res_stencil[0,0,0,0] = 1/16
    res_stencil[0,0,0,1] = 2/16
    res_stencil[0,0,0,2] = 1/16
    res_stencil[0,0,1,0] = 2/16
    res_stencil[0,0,1,1] = 4/16
    res_stencil[0,0,1,2] = 2/16
    res_stencil[0,0,2,0] = 1/16
    res_stencil[0,0,2,1] = 2/16
    res_stencil[0,0,2,2] = 1/16
    res = nn.Conv2d(1, 1, 3, padding = 0, stride=2, bias = False)
    res.weight = nn.Parameter(res_stencil)
    A_train.append(A)
    eig_vec_train.append(eig_vec)
    # Ag_train.append(Ag)
    # A0_train.append(A0)

In [None]:
AA = compute_A_torch(stencil_train,n//2)
print(torch.norm(A-AA))

In [None]:
11025/49

In [None]:
logits = torch.rand(10)
gumbel = torch.empty_like(logits)
gumbel2 = torch.rand(logits.shape)
y_hard = torch.zeros_like(logits).scatter_(-1, torch.Tensor([1,2,3]).type(torch.int64), torch.Tensor([1,2,3]))
print(y_hard)
print(logits.softmax(-1))
print(gumbel)
print(gumbel2)

In [None]:
def top_k(logits,k,tau):
    y_soft = logits.softmax(-1)
    index = torch.topk(y_soft,k)[1]
    y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
    ret = y_hard - y_soft.detach() + y_soft

    return ret

In [None]:
class GNN_prob(nn.Module):
    def __init__(self):
        super(GNN_prob, self).__init__()
        self.fc1 = nn.Linear(8,50).double().to(device)
        self.fc2 = nn.Linear(50,50).double().to(device)
        self.fc3 = nn.Linear(50,50).double().to(device)
        self.fc4 = nn.Linear(50,8).double().to(device)
        self.fc5 = nn.Linear(100,100).double().to(device)
        self.fc6 = nn.Linear(100,100).double().to(device)
        self.fc7 = nn.Linear(100,100).double().to(device)

        self.fc8 = nn.Linear(100,8).double().to(device)


        # torch.nn.init.sparse_(self.fc1.weight,0.1)
        # torch.nn.init.sparse_(self.fc2.weight,0.1)
        # torch.nn.init.sparse_(self.fc3.weight,0.15)
        # torch.nn.init.sparse_(self.fc4.weight,0.15)

    def forward(self,X):
        X = self.fc1(X)
        # X = torch.relu(X)

        # X = torch.relu(X)
        X = torch.tanh(X)

        X = self.fc2(X)
        # X = torch.relu(X)

        # X = torch.relu(X)
        X = torch.tanh(X)
        X = self.fc3(X)
        # X = torch.relu(X)

        # X = torch.relu(X)
        X = torch.tanh(X)
        X = self.fc4(X)
        # X = torch.relu(X)
        X = torch.tanh(X)
        # X = self.fc5(X)
        # X = torch.tanh(X)
        # X = self.fc6(X)
        # X = torch.tanh(X)

        # X = self.fc7(X)
        # X = torch.tanh(X)

        # # X = F.leaky_relu(X)
        # X = self.fc8(X)
        # # X = torch.relu(X)
        # X = torch.tanh(X)
        # X = torch.tanh(X)

        # X = F.leaky_relu(X)

        return X
class GNN_value(nn.Module):
    def __init__(self):
        super(GNN_value, self).__init__()
        self.fc1 = nn.Linear(8,50).double().to(device)
        self.fc2 = nn.Linear(50,50).double().to(device)
        self.fc3 = nn.Linear(50,50).double().to(device)
        self.fc4 = nn.Linear(50,8).double().to(device)
        self.fc5 = nn.Linear(100,100).double().to(device)

        self.fc6 = nn.Linear(100,8).double().to(device)
        # torch.nn.init.sparse_(self.fc1.weight,0.1)
        # torch.nn.init.sparse_(self.fc2.weight,0.1)
        # torch.nn.init.sparse_(self.fc3.weight,0.15)
        # torch.nn.init.sparse_(self.fc4.weight,0.15)

    def forward(self,X):
        X = self.fc1(X)
        # X = torch.relu(X)
        X = torch.tanh(X)

        X = self.fc2(X)
        # X = torch.relu(X)
        X = torch.tanh(X)
        X = self.fc3(X)
        X = torch.tanh(X)
        X = self.fc4(X)

        X = torch.tanh(X)
        # X = self.fc5(X)

        # # X = torch.relu(X)
        # X = torch.tanh(X)
        # X = self.fc6(X)
        # # X = torch.relu(X)
        # X = torch.tanh(X)
        # X = F.leaky_relu(X)

        return X

def solver(X):
    Y = torch.zeros(X.shape).double()
    u,v = torch.topk(X,4)
    Y[v] = 1
    return Y
def sparsify(prob,value,tau):
    # prob = torch.sigmoid(X[0:8])
    # prob = X[0:8]
    stencil = top_k(prob,4,tau).squeeze()
    stencil = stencil*value
    # print(stencil)
    stencil = torch.cat((stencil[0:4],-stencil.sum().view(1),stencil[4:])).view(3,3).view(1,1,3,3)
    # u,v = torch.topk(torch.sigmoid(prob),4)
    return stencil




def train(A,stencil_train,eig_vec,epochs):
    model_prob = GNN_prob()
    model_value = GNN_value()

#     model = GCN(32,32,10)
    optimizer = torch.optim.Adam(list(model_prob.parameters())+list(model_value.parameters()), lr=1e-4)
    tau = 1

    for epoch in range(epochs):
      optimizer.zero_grad()
      eig_vec = Variable(eig_vec.clone())
      stencil_c = torch.zeros(stencil_train.shape).double()
      for i in range(stencil_train.shape[0]):
        for j in range(stencil_train.shape[1]):
          stencil0 = stencil_train[i][j]
          stencil = Variable(stencil0.clone().squeeze(0).squeeze(0)).reshape(9,1)
          stencil = torch.cat([stencil[0:4],stencil[5:]]).t()
          orig_s = stencil0.clone().view(1,1,3,3)
            # r = 1e-1
            # if epoch%500==0:
            #     tau = max(0.05,np.exp(-r*epoch))
          tau = 0.05
          prob = model_prob(stencil).squeeze()
          value = model_value(stencil).squeeze()

          coarse_stencil = sparsify(prob,value,tau)
          stencil_c[i,j,:,:] = coarse_stencil
      stencil_c[ :, 0, :, 0] = 0.
      stencil_c[ :, -1, :, -1] = 0.
      stencil_c[ 0, :, 0, :] = 0.
      stencil_c[ -1, :, -1, :] = 0.

      A_c = compute_A_torch(stencil_c,7)

      loss = 0
      for ii in range(eig_vec.shape[0]):
          temp = torch.sparse.mm(A_c,eig_vec[ii,:].view(-1,1))-torch.sparse.mm(A,eig_vec[ii,:].view(-1,1))
          loss += torch.norm(temp)
          # for i in range(temp.shape[0]):
          #   # loss+= torch.mm(temp[i,:].t(),torch.sparse.mm(Ag,torch.sparse.mm(Ag,temp[i,:])))
          #   # loss += torch.mm(temp[i,:].t(),torch.sparse.mm(Ag,temp[i,:]))

          #   # loss += torch.mm(temp[i,:].t(),torch.mm(torch.inverse(Ag.to_dense()),temp[i,:]))
          #   loss += torch.norm(temp[i,:])**2

        # loss+= torch.mm(temp[i,:].t(),torch.sparse.mm(Ag,temp[i,:]))

      # loss = (torch.norm(temp.squeeze(1),dim=(1,2))**2).mean()
      # n = n//2
      # A = project_stencil(coarse_stencil)
      # stencil = A.weight
        # orig_s = project_stencil(orig_s).weight
        # A = project_stencil(orig_s)
        # eig_vec = torch.rand(num_eigenvecs,1,n,n).double()
        # b = torch.rand(num_eigenvecs,1,n,n).double()
        # for k in range(10):
        #   # eig_vec = np.matmul(np.linalg.inv(L.toarray()),b-np.matmul(U.toarray(),eig_vec))
        #   eig_vec = 2/(3*coarse_stencil[0,0,1,1])*(b-A_c(eig_vec))+eig_vec
        # eig_vec = b-A_c(eig_vec)
        # eig_vec = res(eig_vec)
      loss.backward()

      # for name, param in model_value.named_parameters():
      #     print(name, param.grad)
      optimizer.step()

      if epoch%500==0:
        print(' epoch: ',epoch,' loss: ',loss)
    return model_prob,model_value

In [None]:
# n = 31
# num_eigenvecs = n*n
# eig_vec = torch.rand(num_eigenvecs,1,n,n).double()
# b = torch.rand(num_eigenvecs,1,n,n).double()
# for k in range(10):
#   # eig_vec = np.matmul(np.linalg.inv(L.toarray()),b-np.matmul(U.toarray(),eig_vec))
#   eig_vec = 2/(3*A0.weight[0,0,1,1])*(b-A0(eig_vec))+eig_vec
# eig_vec = b-A0(eig_vec)
# eig_vec = res(eig_vec)

model_prob,model_value = train(A,stencil_train,eig_vec,3000)

# New Section

In [None]:
def geometric_solver(A,option1,option2,models,n,
                     presmoother=('gauss_seidel', {'sweep': 'forward'}),
                     postsmoother=('gauss_seidel', {'sweep': 'forward'}),
                     max_levels=5, max_coarse=10,coarse_solver='splu',stencil=0,**kwargs):

    levels = [multilevel_solver.level()]

    # convert A to csr
    if not isspmatrix_csr(A):
        try:
            A = csr_matrix(A)
            warn("Implicit conversion of A to CSR",
                 SparseEfficiencyWarning)
        except BaseException:
            raise TypeError('Argument A must have type csr_matrix, \
                             or be convertible to csr_matrix')
    # preprocess A
    A = A.asfptype()
    if A.shape[0] != A.shape[1]:
        raise ValueError('expected square matrix')

    levels[-1].A = A
    levels[-1].stencils = stencil
    levels[-1].n = n

    while len(levels) < max_levels and levels[-1].A.shape[0] > max_coarse:
        extend_hierarchy(levels,option1,option2,models,stencil)

    ml = multilevel_solver(levels, **kwargs)
    change_smoothers(ml, presmoother, postsmoother)
    return ml

# internal function
def extend_hierarchy(levels,option1,option2,models,stencil):
    """Extend the multigrid hierarchy."""

    A = levels[-1].A
    n = levels[-1].n
    model = models[len(levels)-1]
    # Generate the interpolation matrix that maps from the coarse-grid to the

#     R = prolongation_fn(size)
#     P = R.T.tocsr()*4

    if option1=='standard':
        # Form next level through Galerkin product
        R = prolongation_fn(n)
        P = R.T
        A=R*A*P
        # print(A)
        # eig_value,eig_vec = sp.sparse.linalg.eigs(A,feature_dims,which = 'SM')
        # X = torch.real(torch.from_numpy(eig_vec).to(device))
        n=n//2
        levels[-1].P = P  # prolongation operator
        levels[-1].R = R  # restriction operator

        levels.append(multilevel_solver.level())
        A = A.astype(np.float64)  # convert from complex numbers, should have A.imag==0
        levels[-1].A = A.tocsr()
    elif option1=='non-galerkin':
        stencils = levels[-1].stencils
        R = prolongation_fn(n)
        P = R.T
        A = R*A*P
        # print(A)
        # eig_value,eig_vec = sp.sparse.linalg.eigs(A,feature_dims,which = 'SM')
        # X = torch.real(torch.from_numpy(eig_vec).to(device))
        n=n//2
        levels[-1].P = P  # prolongation operator
        levels[-1].R = R  # restriction operator
        levels.append(multilevel_solver.level())

        stencils = compute_stencil_numpy(A,n)
        stencils = torch.from_numpy(stencils)

        stencil_c = torch.zeros(stencils.shape).double()
        for i in range(stencils.shape[0]):
          for j in range(stencils.shape[1]):
            stencil0 = stencils[i][j]
            stencil = Variable(stencil0.clone().squeeze(0).squeeze(0)).reshape(9,1)
            stencil = torch.cat([stencil[0:4],stencil[5:]]).t()
            orig_s = stencil0.clone().view(1,1,3,3)
            # r = 1e-1
            # if epoch%500==0:
            #     tau = max(0.05,np.exp(-r*epoch))
            tau = 0.05
            prob = model_prob(stencil).squeeze()
            value = model_value(stencil).squeeze()
            coarse_stencil = sparsify(prob,value,tau)
            stencil_c[i,j,:,:] = coarse_stencil
          # print(stencil_c.shape)
        stencil_c[ :, 0, :, 0] = 0.
        stencil_c[ :, -1, :, -1] = 0.
        stencil_c[ 0, :, 0, :] = 0.
        stencil_c[ -1, :, -1, :] = 0.
        A_c = compute_A_numpy(stencil_c.detach().numpy(),n)


        levels[-1].A = A_c
        levels[-1].stencil = stencil

        # print('non-galerkin',A_c.count_nonzero()/A.count_nonzero())
        # print(np.linalg.norm(I-np.matmul(np.linalg.inv(A_c.toarray()),A.toarray()),ord=2))
    else:
        # T = prolongation_fn(n).T
        # P = jacobi_prolongation_smoother(A, T, A, np.ones(T.shape[1]))
        # R = P.T
        # A_c = R*A*P
        # levels[-1].P = P  # prolongation operator
        # levels[-1].R = R  # restriction operator
        # levels.append(multilevel_solver.level())
        R = prolongation_fn(n)
        stencil = levels[-1].stencil

        P = R.T
        levels[-1].P = P  # prolongation operator
        levels[-1].R = R  # restriction operator
        n=n//2
        levels.append(multilevel_solver.level())
        # stencil = project_stencil(stencil).weight

        # s = np.array([[0,-1,0],[-1,4,-1],[0,-1,0]])/16**(len(levels)-1)
        cc = stencil[0,0,0,0].item()
        bb = stencil[0,0,0,1].item()
        aa = stencil[0,0,1,0].item()
        s = np.array([[0,bb+2*cc,0],[aa+2*cc,-2*(aa+bb)-8*cc,aa+2*cc],[0,bb+2*cc,0]])/16**(len(levels)-1)

        print('optimal:\n',s)
        A_c = pyamg.gallery.stencil_grid(s,(n,n))
        levels[-1].A = A_c
#         print(A_c.shape)
        # print('naive',A_c.count_nonzero()/A.count_nonzero())
    levels[-1].n=n


In [None]:
torch.set_printoptions(precision=10)
num_test = 1
t1=time.time()
num_iter = []
num_iter2 = []
num_iter3 = []
num_iter4 = []
t_iter = []
t_iter2 = []
t_iter3 = []
t_iter4 = []
res_s = []
res2_s = []
res3_s = []
res4_s = []
test_grid_size = 15
model1=0
model2=0
model3=0
for i in range(num_test):
    s = construct_problem(test_grid_size,g)

    A_orig = compute_A_numpy(s,test_grid_size)

    solver_standard = geometric_solver(A_orig,'standard',0,[model1,model2,model3],test_grid_size,max_levels=2,coarse_solver='splu')
    solver_non_galerkin = geometric_solver(A_orig,'non-galerkin','GNN',[model1,model2,model3],test_grid_size,max_levels=2,coarse_solver='splu', stencil = torch.from_numpy(s))
    # solver_non_galerkin_NONGNN = geometric_solver(A_orig,'234','123',[model1,model2,model3],test_grid_size,max_levels=2,coarse_solver='splu',stencil = torch.from_numpy(s))

    x0 = np.ones((A_orig.shape[0],1))
    b = np.random.rand(A_orig.shape[0],1)

    res=[]
    res2= []
    res3= []
    res4 = []
    t1 =time.time()
    x = solver_standard.solve(b,x0=x0,maxiter=1000, tol=1e-6,residuals=res)
    t2=time.time()
    x = solver_non_galerkin.solve(b,x0=x0,maxiter=1000, tol=1e-6,residuals=res2)
    t3=time.time()
    # x = solver_non_galerkin_NONGNN.solve(b,x0=x0,maxiter=1000, tol=1e-6,residuals=res3)
    # t4=time.time()
    # x = solver_naive.solve(b,x0=x0,maxiter=1000, tol=1e-6,residuals=res4)
    # t5=time.time()

    res_s.append(res)
    res2_s.append(res2)
    # res3_s.append(res3)
    # res4_s.append(res4)

    num_iter.append(len(res))
    num_iter2.append(len(res2))
    # num_iter3.append(len(res3))
    # num_iter4.append(len(res4))

    t_iter.append(t2-t1)
    t_iter2.append(t3-t2)
    # t_iter3.append(t4-t3)
    # t_iter4.append(t5-t4)


print('standard iter:   ',np.mean(num_iter),'  standard time:    ',np.mean(t_iter))
print('non galerkin iter:   ',np.mean(num_iter2),'  non galerkin time:    ',np.mean(t_iter2))
# print('non galerkin nongnn iter:   ',np.mean(num_iter3),'  non galerkin time:    ',np.mean(t_iter3))
# print('naive iter:   ',np.mean(num_iter4),'  non galerkin time:    ',np.mean(t_iter4))


In [None]:
+print(res2[-1])

In [None]:
print(res2[-1])

In [None]:
print(XX)

In [None]:
A = A_orig.copy()
C = pyamg.strength.classical_strength_of_connection(A)
splitting = pyamg.classical.split.RS(A)
P = pyamg.classical.direct_interpolation(A, C, splitting)
R = P.T.tocsr()
#     R = prolongation_fn(size)
#     P = R.T.tocsr()*4
# Form next level through Galerkin product
A_g = R * A * P
A = A.astype(np.float64)  # convert from complex numbers, should have A.imag==0
A_c = ru(A_g,splitting,model,'GNN',)
def heatmap2d(arr):
    plt.imshow(arr)
    #plt.colorbar()
    plt.axis('off')
    plt.show()
print(np.count_nonzero(A_g.toarray()))
print(np.count_nonzero(A_c.toarray()))
heatmap2d(A_g.toarray())
heatmap2d(A_c.toarray())