In [16]:
import numpy as np
import matplotlib.pyplot as plt

from jlinops import _CustomLinearOperator
import jlinops

In [17]:
class Dirichlet2D(_CustomLinearOperator):
    """Implements a matrix-free operator R representing the anisotropic discrete gradient of an input vector x
    equipped with Neumann boundary conditions. The null space of this operator is spanned by the constant vector,
    and is such that R^T R can be diagonalized using a 2-dimensional DCT.
    """
    def __init__(self, grid_shape, device="cpu"):
        
        # Handle grid shape
        self.grid_shape = grid_shape
        m, n = grid_shape
        shape = (2*m*n, m*n)
        self.M, self.N = self.grid_shape

        if device == "cpu":

            def matvec(x):
                # Reshape the vector into a 2D grid
                grid = x.reshape(self.M, self.N)

                # Compute the x-derivative
                dx = np.zeros_like(grid)
                dx[:-1, :] = grid[1:, :] - grid[:-1, :]
                dx[-1, :] = -grid[-1,:]

                # Compute the y-derivative
                dy = np.zeros_like(grid)
                dy[:, :-1] = grid[:, 1:] - grid[:, :-1]
                dy[:, -1] = -grid[:, -1]

                # Flatten and combine the derivatives
                return np.hstack((dx.ravel(), dy.ravel()))

            def rmatvec(y):
                # Reshape the vector into two 2D grids for dx and dy
                dx, dy = np.split(y, 2)
                dx = dx.reshape(self.M, self.N)
                dy = dy.reshape(self.M, self.N)

                # Compute the transpose operations with sign correction
                dxt = np.zeros_like(dx)
                dyt = np.zeros_like(dy)

                # Transpose operation for x-derivative
                dxt[1:, :] -= dx[:-1, :]
                dxt[:, :] += dx[:, :]

                # Transpose operation for y-derivative
                dyt[:, 1:] -= dy[:, :-1]
                dyt[:, :] += dy[:, :]


                # Combine, flatten, and apply sign correction
                return -(dxt + dyt).ravel()

        else:
    
            def matvec(x):
                # Reshape the vector into a 2D grid
                grid = x.reshape(self.M, self.N)

                # Compute the x-derivative
                dx = cp.zeros_like(grid)
                dx[:-1, :] = grid[1:, :] - grid[:-1, :]
                dx[-1, :] = -grid[-1,:]

                # Compute the y-derivative
                dy = np.zeros_like(grid)
                dy[:, :-1] = grid[:, 1:] - grid[:, :-1]
                dy[:, -1] = -grid[:, -1]

                # Flatten and combine the derivatives
                return cp.hstack((dx.ravel(), dy.ravel()))

            def rmatvec(y):
                # Reshape the vector into two 2D grids for dx and dy
                dx, dy = cp.split(y, 2)
                dx = dx.reshape(self.M, self.N)
                dy = dy.reshape(self.M, self.N)

                # Compute the transpose operations with sign correction
                dxt = cp.zeros_like(dx)
                dyt = cp.zeros_like(dy)

                # Transpose operation for x-derivative
                dxt[1:, :] -= dx[:-1, :]
                dxt[:, :] += dx[:, :]

                # Transpose operation for y-derivative
                dyt[:, 1:] -= dy[:, :-1]
                dyt[:, :] += dy[:, :]

                # Combine, flatten, and apply sign correction
                return -(dxt + dyt).ravel()


        super().__init__(shape, matvec, rmatvec, device=device)


    def to_gpu(self):
        return Dirichlet2D(self.grid_shape, device="gpu")

    def to_cpu(self):
        return Dirichlet2D(self.grid_shape, device="cpu")



In [18]:
R = Dirichlet2D((4,5))

In [19]:
jlinops.check_adjoint(R)

True

In [1]:
from jlinops import Dirichlet2D