Code for the paper Learning to optimize multigrid PDE solvers, which can be found on https://arxiv.org/abs/1902.10248. 

The notation of the grid points is as follows: in any grid-related tensor, the (0,0) cell corresponds to the leftmost bottommost grid cell. The (I,J) cell then corresponds to the grid cell located in the I'th column and J'th row of the grid.  

In [None]:
import numpy as np
from scipy.sparse import csr_matrix

In [None]:
# internal function
def extend_hierarchy(levels, prolongation_fn, prolongation_args):
    """Extend the multigrid hierarchy."""

    A = levels[-1].A

    # Generate the interpolation matrix that maps from the coarse-grid to the
    # fine-grid
    P = prolongation_fn(A, prolongation_args)

    # Generate the restriction matrix that maps from the fine-grid to the
    # coarse-grid
    R = P.T.tocsr()

    levels[-1].P = P  # prolongation operator
    levels[-1].R = R  # restriction operator

    levels.append(multilevel_solver.level())

    # Form next level through Galerkin product
    A = R * A * P
    A = A.astype(np.float64)  # convert from complex numbers, should have A.imag==0
    levels[-1].A = A

In [None]:
# similar to "ruge_stuben_solver" in pyamg
def geometric_solver(A, prolongation_function, prolongation_args,
                     presmoother=('gauss_seidel', {'sweep': 'forward'}),
                     postsmoother=('gauss_seidel', {'sweep': 'forward'}),
                     max_levels=10, max_coarse=10, **kwargs):
    """Create a multilevel solver using geometric AMG.

    Parameters
    ----------
    A : csr_matrix
        Square matrix in CSR format
    presmoother : string or dict
        Method used for presmoothing at each level.  Method-specific parameters
        may be passed in using a tuple, e.g.
        presmoother=('gauss_seidel',{'sweep':'symmetric}), the default.
    postsmoother : string or dict
        Postsmoothing method with the same usage as presmoother
    max_levels: integer
        Maximum number of levels to be used in the multilevel solver.
    max_coarse: integer
        Maximum number of variables permitted on the coarse grid.

    Returns
    -------
    ml : multilevel_solver
        Multigrid hierarchy of matrices and prolongation operators

    Notes
    -----
    "coarse_solver" is an optional argument and is the solver used at the
    coarsest grid.  The default is a pseudo-inverse.  Most simply,
    coarse_solver can be one of ['splu', 'lu', 'cholesky, 'pinv',
    'gauss_seidel', ... ].  Additionally, coarse_solver may be a tuple
    (fn, args), where fn is a string such as ['splu', 'lu', ...] or a callable
    function, and args is a dictionary of arguments to be passed to fn.
    See [2001TrOoSc]_ for additional details.


    References
    ----------
    .. [2001TrOoSc] Trottenberg, U., Oosterlee, C. W., and Schuller, A.,
       "Multigrid" San Diego: Academic Press, 2001.  Appendix A

    See Also
    --------
    aggregation.smoothed_aggregation_solver, multilevel_solver,
    aggregation.rootnode_solver

    """
    levels = [multilevel_solver.level()]

    # convert A to csr
    if not isspmatrix_csr(A):
        try:
            A = csr_matrix(A)
            warn("Implicit conversion of A to CSR",
                 SparseEfficiencyWarning)
        except BaseException:
            raise TypeError('Argument A must have type csr_matrix, \
                             or be convertible to csr_matrix')
    # preprocess A
    A = A.asfptype()
    if A.shape[0] != A.shape[1]:
        raise ValueError('expected square matrix')

    levels[-1].A = A

    while len(levels) < max_levels and levels[-1].A.shape[0] > max_coarse:
        extend_hierarchy(levels, prolongation_function, prolongation_args)

    ml = multilevel_solver(levels, **kwargs)
    change_smoothers(ml, presmoother, postsmoother)
    return ml

In [None]:
class Utils:
    def compute_p2(self, P_stencil, grid_size):
        indexes = self.get_p_matrix_indices_one(grid_size)
        P = csr_matrix(arg1=(P_stencil.numpy().reshape(-1), (indexes[:, 0], indexes[:, 1])),
                       shape=(grid_size ** 2, (grid_size // 2) ** 2))

        return P
        
    @memoize
    def compute_A_indices(self, grid_size):
        K = self.map_2_to_1(grid_size=grid_size)
        A_idx = []
        stencil_idx = []
        for i in range(grid_size):
            for j in range(grid_size):
                I = int(K[i, j, 1, 1])
                for k in range(3):
                    for m in range(3):
                        J = int(K[i, j, k, m])
                        A_idx.append([I, J])
                        stencil_idx.append([i, j, k, m])
        return np.array(A_idx), stencil_idx

    def compute_csr_matrices(self, stencils, grid_size=8):
        A_idx, stencil_idx = self.compute_A_indices(grid_size)
        if len(stencils.shape) == 5:
            matrices = []
            for stencil in stencils:
                matrices.append(csr_matrix(arg1=(stencil.reshape((-1)), (A_idx[:, 0], A_idx[:, 1])),
                                           shape=(grid_size ** 2, grid_size ** 2)))
            return np.asarray(matrices)
        else:
            return csr_matrix(arg1=(stencils.reshape((-1)), (A_idx[:, 0], A_idx[:, 1])),
                              shape=(grid_size ** 2, grid_size ** 2))
                              
    def solve_with_model(self, model, A_matrices, b, initial_guess, max_iterations, max_depth=3, blackbox=False,
                         w_cycle=False):
        def prolongation_fn(A, args):
            is_blackbox = args["is_blackbox"]
            grid_size = int(math.sqrt(A.shape[0]))
            indices = self.get_indices_compute_A_one(grid_size)
            A_stencil = np.array(A[indices[:, 0], indices[:, 1]]).reshape((grid_size, grid_size, 3, 3))
            model.grid_size = grid_size  # TODO: infer grid_size automatically

            tf_A_stencil = tf.convert_to_tensor([A_stencil])
            with tf.device(self.device):
                if is_blackbox:
                    P_stencil = model(inputs=tf_A_stencil, black_box=True)
                else:
                    P_stencil = model(inputs=tf_A_stencil, black_box=False, phase="Test")
            return self.compute_p2(P_stencil, grid_size).astype(np.double)  # imaginary part should be zero

        prolongation_args = {"is_blackbox": blackbox}

        error_norms = []

        #  solver calls this function after each iteration
        def error_callback(x_k):
            error_norms.append(pyamg.util.linalg.norm(x_k))

        solver = geometric_solver(A_matrices, prolongation_fn, prolongation_args,
                                  max_levels=max_depth)

        if w_cycle:
            cycle = 'W'
        else:
            cycle = 'V'
        residual_norms = []
        x = solver.solve(b, x0=initial_guess, maxiter=max_iterations, cycle=cycle, residuals=residual_norms, tol=0,
                         callback=error_callback)
        return x, residual_norms, error_norms, solver


In [None]:
class Pnetwork(tf.keras.Model):
    def __init__(self, grid_size=8, device="/cpu:0"):
        super(Pnetwork, self).__init__()
        self.grid_size = grid_size
        self.device = device
    
    def call(self, inputs, black_box=False, index=None, pos=-1., phase='Training'):
        # inputs are stencils
        with tf.device(self.device):
            if not black_box:
                x = self.linear0(flattended)
                x = tf.nn.relu(x)
                for i in range(1, self.num_layers, 2):
                    x1 = getattr(self, "bias_1%i" % i) + x
                    x1 = getattr(self, "linear%i" % i)(x1)
                    x1 = x1 + getattr(self, "bias_2%i" % i) + x1
                    x1 = tf.nn.relu(x1)
                    x1 = x1 + getattr(self, "bias_3%i" % i) + x1
                    x1 = getattr(self, "linear%i" % (i + 1))(x1)
                    x1 = tf.multiply(x1, getattr(self, "multiplier_%i" % i))
                    x = x + x1
                    x = x + getattr(self, "bias_4%i" % i)
                    x = tf.nn.relu(x)

                x = self.output_layer(x)
                
            if black_box:
                up_contributions_output = tf.gather(inputs,[i for i in range(0,self.grid_size,2)],axis=1)
                up_contributions_output = tf.gather(up_contributions_output,
                                                    [i for i in range(1,self.grid_size,2)], axis=2)
                up_contributions_output = -tf.reduce_sum(up_contributions_output[:,:,:,:,0],axis=-1)/tf.reduce_sum(up_contributions_output[:,:,:,:,1],axis=-1)

                left_contributions_output = tf.gather(inputs, idx, axis=1)
                left_contributions_output = tf.gather(left_contributions_output,
                                                      [i for i in range(0,self.grid_size,2)], axis=2)
                left_contributions_output = -tf.reduce_sum(left_contributions_output[:, :, :, 2, :],
                                                            axis=-1) / tf.reduce_sum(
                    left_contributions_output[:, :, :, 1, :], axis=-1)

                right_contributions_output = tf.gather(inputs, [i for i in range(1,self.grid_size,2)], axis=1)
                right_contributions_output = tf.gather(right_contributions_output, [i for i in range(0,self.grid_size,2)], axis=2)
                right_contributions_output = -tf.reduce_sum(right_contributions_output[:, :, :, 0, :],
                                                           axis=-1) / tf.reduce_sum(
                    right_contributions_output[:, :, :, 1, :], axis=-1)
                down_contributions_output = tf.gather(inputs, [i for i in range(0,self.grid_size,2)], axis=1)
                down_contributions_output = tf.gather(down_contributions_output, idx, axis=2)
                down_contributions_output = -tf.reduce_sum(down_contributions_output[:, :, :, :, 2],
                                                            axis=-1) / tf.reduce_sum(
                    down_contributions_output[:, :, :, :, 1], axis=-1)
            else:
                jm1 = [(i - 1) % (self.grid_size // 2) for i in range(self.grid_size // 2)]
                jp1 = [(i + 1) % (self.grid_size // 2) for i in range(self.grid_size // 2)]
                right_contributions_output = x[:,:,:,0]/(tf.gather(x[:, :, :, 1],jp1,axis=1)+x[:,:,:,0])
                left_contributions_output = x[:,:,:,1]/(x[:,:,:,1]+tf.gather(x[:, :, :, 0],jm1,axis=1))
                up_contributions_output = x[:,:,:,2]/(x[:,:,:,2]+tf.gather(x[:, :, :, 3],jp1,axis=2))
                down_contributions_output = x[:,:,:,3]/(tf.gather(x[:, :, :, 2],jm1,axis=2)+x[:,:,:,3])
            ones = tf.ones_like(down_contributions_output)

            #based on rule 2 given rule 1:
            up_right_contribution = tf.gather(inputs,[i for i in range(1,self.grid_size,2)],axis=1)
            up_right_contribution = tf.gather(up_right_contribution, [i for i in range(1,self.grid_size,2)], axis=2)
            up_right_contribution = up_right_contribution [:,:,:,0,1]
            right_up_contirbution = tf.gather(inputs, [i for i in range(1,self.grid_size,2)], axis=1)
            right_up_contirbution = tf.gather(right_up_contirbution, [i for i in range(1,self.grid_size,2)], axis=2)
            right_up_contirbution_additional_term = right_up_contirbution[:, :, :, 0, 0]
            right_up_contirbution = right_up_contirbution[:,:,:,1,0]
            ru_center_ = tf.gather(inputs, [i for i in range(1,self.grid_size,2)], axis=1)
            ru_center_ = tf.gather(ru_center_, [i for i in range(1,self.grid_size,2)], axis=2)
            ru_center_ = ru_center_[:,:,:,1,1]
            ru_contribution = -tf.expand_dims((right_up_contirbution_additional_term+
                                               tf.multiply(right_up_contirbution,right_contributions_output) +\
                              tf.multiply(up_right_contribution,up_contributions_output))/ru_center_, -1)

            up_left_contribution = tf.gather(inputs, idx, axis=1)
            up_left_contribution = tf.gather(up_left_contribution, [i for i in range(1,self.grid_size,2)], axis=2)
            up_left_contribution = up_left_contribution[:, :, :, 2, 1]
            left_up_contirbution = tf.gather(inputs, idx, axis=1)
            left_up_contirbution = tf.gather(left_up_contirbution, [i for i in range(1,self.grid_size,2)], axis=2)
            left_up_contirbution_addtional_term = left_up_contirbution[:, :, :, 2, 0]
            left_up_contirbution = left_up_contirbution[:, :, :, 1, 0]
            lu_center_ = tf.gather(inputs, idx, axis=1)
            lu_center_ = tf.gather(lu_center_, [i for i in range(1,self.grid_size,2)], axis=2)
            lu_center_ = lu_center_[:, :, :, 1, 1]
            lu_contribution = -tf.expand_dims((left_up_contirbution_addtional_term+
                                               tf.multiply(up_left_contribution , up_contributions_output) + \
                               tf.multiply(left_up_contirbution , left_contributions_output)) / lu_center_, -1)

            down_left_contribution = tf.gather(inputs, idx, axis=1)
            down_left_contribution = tf.gather(down_left_contribution, idx, axis=2)
            down_left_contribution = down_left_contribution[:, :, :, 2, 1]
            left_down_contirbution = tf.gather(inputs, idx, axis=1)
            left_down_contirbution = tf.gather(left_down_contirbution, idx, axis=2)
            left_down_contirbution_additional_term = left_down_contirbution[:, :, :, 2, 2]
            left_down_contirbution = left_down_contirbution[:, :, :, 1, 2]
            ld_center_ = tf.gather(inputs, idx, axis=1)
            ld_center_ = tf.gather(ld_center_, idx, axis=2)
            ld_center_ = ld_center_[:, :, :, 1, 1]
            ld_contribution = -tf.expand_dims((left_down_contirbution_additional_term+
                                               tf.multiply(down_left_contribution , down_contributions_output) + \
                               tf.multiply(left_down_contirbution , left_contributions_output)) / ld_center_,-1)

            down_right_contribution = tf.gather(inputs, [i for i in range(1,self.grid_size,2)], axis=1)
            down_right_contribution = tf.gather(down_right_contribution, idx, axis=2)
            down_right_contribution = down_right_contribution[:, :, :, 0, 1]
            right_down_contirbution = tf.gather(inputs, [i for i in range(1,self.grid_size,2)], axis=1)
            right_down_contirbution = tf.gather(right_down_contirbution, idx, axis=2)
            right_down_contirbution_addtional_term = right_down_contirbution[:, :, :, 0, 2]
            right_down_contirbution = right_down_contirbution[:, :, :, 1, 2]
            rd_center_ = tf.gather(inputs, [i for i in range(1,self.grid_size,2)], axis=1)
            rd_center_ = tf.gather(rd_center_, idx, axis=2)
            rd_center_ = rd_center_[:, :, :, 1, 1]
            rd_contribution = -tf.expand_dims((right_down_contirbution_addtional_term+tf.multiply(down_right_contribution , down_contributions_output) + \
                               tf.multiply(right_down_contirbution , right_contributions_output)) / rd_center_,-1)

            first_row = tf.concat([ld_contribution, tf.expand_dims(left_contributions_output,-1),
                                   lu_contribution], -1)
            second_row = tf.concat([tf.expand_dims(down_contributions_output,-1),
                                    tf.expand_dims(ones, -1), tf.expand_dims(up_contributions_output, -1)], -1)
            third_row = tf.concat([rd_contribution, tf.expand_dims(right_contributions_output, -1),
                                   ru_contribution], -1)

            output = tf.stack([first_row, second_row, third_row], 0)
            output = tf.transpose(output, (1, 2, 3, 0, 4))

            return tf.to_complex128(output)