In [270]:
%%capture
import numpy as np
from scipy.linalg import qr as scipy_qr
from scipy.linalg import solve_triangular as scipy_solve_triangular
# from scipy.sparse.linalg._interface import MatrixLinearOperator, _CustomLinearOperator
import scipy.sparse as sps
import scipy.sparse.linalg as splinalg
import math

from scipy.linalg import null_space

# from .matrix import MatrixOperator, SparseMatrixOperator
# from .util import banded_cholesky_factorization
# from .diagonal import DiagonalOperator
# from .derivatives import DiscreteGradientNeumann2D
# from .dct import build_dct_Lpinv

import cupy as cp

from jlinops import MatrixLinearOperator, MatrixLinearOperator, cg, _CustomLinearOperator
from jlinops import banded_cholesky
from jlinops import issparse, tosparse
from jlinops import scipy_superlu_to_cupy_superlu
from jlinops import check_adjoint
from jlinops import dct_sqrt_pinv

In [2]:
A = np.random.normal(size=(40,40))

In [3]:
tosparse(A)

<40x40 sparse matrix of type '<class 'numpy.float64'>'
	with 1600 stored elements in Compressed Sparse Column format>

In [4]:
A = cp.random.normal(size=(40,40))

In [5]:
tosparse(A)

<cupyx.scipy.sparse._csc.csc_matrix at 0x7f18be998910>

# Figure out superlu

In [13]:
from cupyx.scipy.sparse.linalg import SuperLU as cp_SuperLU

In [14]:
A = np.random.normal(size=(40,40))
A = tosparse(A)

In [15]:
# Sparse LU
LU = splinalg.splu(A, diag_pivot_thresh=0, permc_spec="NATURAL") 

In [16]:
cp_SuperLU(LU)

<cupyx.scipy.sparse.linalg._solve.SuperLU at 0x7f18b1b875d0>

In [80]:
class BandedCholeskyPinvOperator(_CustomLinearOperator):
    """Takes a (non-square) MatrixLinearOperator A and builds a linear operator representing an approximation to the 
    pseudo-inverse of A. This is most efficient if A^T A is sparse and (already) banded.
    """

    def __init__(self, A, delta=1e-3, _superlu=None):

        assert isinstance(A, MatrixLinearOperator), "Must give MatrixOperator as an input."
        
        # Bind
        self.original_op = A
        self.original_shape = A.shape
        self.delta = delta
        
        # Device
        device = A.device
        
        # Original shape
        k, n = A.shape
        
        # Enforce that underling A.T A is a sparse type
        if not issparse(A.A):
            AtA = MatrixLinearOperator( tosparse(A.A.T @ A.A).tocsc() )
        else:
            AtA = MatrixLinearOperator(A.A.T @ A.A)
        
        # Even if on GPU, factorize and make superlu object on CPU
        if device == "cpu":
            AtA_cpu = AtA
        else:
            AtA_cpu = AtA.to_cpu()
            
        # Matrix we will factorize
        mat = AtA_cpu.A + self.delta*sps.eye(n)
        
        # Perform factorization
        chol_fac, superlu = banded_cholesky( mat )
        
        # Make GPU superlu object if applicable
        if device == "gpu":
            superlu = cp_SuperLU(superlu)
            
        # Bind superlu and A
        self.superlu = superlu
        self.A = A
            
        # Build matvec and rmatvec
        def _matvec(x):
            tmp = self.A.rmatvec(x)
            tmp = self.superlu.solve(tmp, trans="N")
            return tmp
        
        def _rmatvec(x):
            tmp = self.superlu.solve(x, trans="T")
            tmp = self.A.matvec(tmp)
            return tmp
        
        super().__init__( (n,k), _matvec, _rmatvec, device=device, dtype=A.dtype)
        
        
    def to_gpu(self):
        
        # Switch to CPU superlu
        superlu = cp_SuperLU(self.superlu)
        
        return BandedCholeskyPinvOperator(A.to_gpu(), delta=self.delta, _superlu=superlu)
    
    
    def to_cpu(self):
        
        raise NotImplementedError


# Test it out

In [81]:
Amat = np.random.normal(size=(40,40))
A = Amat.copy()
A = tosparse(A)
A = MatrixLinearOperator(A)

In [86]:
Apinv = BandedCholeskyPinvOperator(A, delta=1e-13)

In [87]:
u = np.random.normal(size=Apinv.shape[1])
np.linalg.norm( (Apinv @ u) - (np.linalg.pinv(Amat) @ u) )

8.030267318020561e-10

In [88]:
check_adjoint(Apinv)

True

In [89]:
Apinv = Apinv.to_gpu()

In [90]:
u = cp.random.normal(size=Apinv.shape[1])
cp.linalg.norm( (Apinv @ u) - (cp.asarray(np.linalg.pinv(Amat)) @ u) )

array(1.46050069e-09)

In [91]:
check_adjoint(Apinv)

True

In [92]:
Amat = cp.random.normal(size=(40,40))
A = Amat.copy()
A = tosparse(A)
A = MatrixLinearOperator(A)

In [93]:
Apinv = BandedCholeskyPinvOperator(A, delta=1e-13)

In [94]:
u = cp.random.normal(size=Apinv.shape[1])
cp.linalg.norm( (Apinv @ u) - (cp.asarray(np.linalg.pinv(Amat)) @ u) )

array(1.42238999e-10)

# QR Pseudoinverse

In [100]:
from cupy.linalg import qr as cp_qr

In [139]:
from scipy.linalg import qr as sp_qr
from cupy.linalg import qr as cp_qr
from scipy.linalg import solve_triangular as sp_solve_triangular
from cupyx.scipy.linalg import solve_triangular as cp_solve_triangular


class QRPinvOperator(_CustomLinearOperator):
    """Takes a dense matrix A with full column rank, builds a linear operator representing the pseudo-inverse of A
    using the QR method.
    """

    def __init__(self, A):

        assert isinstance(A, MatrixLinearOperator), "must give MatrixOperator as an input."

        # Store original operator
        self.original_op = A
        k, n = A.shape
        
        # Device
        device = A.device
        
        if device == "cpu":
            
            Q_fac, R_fac = sp_qr(A.A, mode="economic")

            # Build matvec and rmatvec
            def _matvec(vec):
                tmp = Q_fac.T @ vec
                tmp = sp_solve_triangular(R_fac, tmp, lower=False)
                return tmp

            def _rmatvec(vec):
                tmp = scipy_solve_triangular(R_fac.T, vec, lower=True)
                tmp = Q_fac @ tmp
                return tmp
            
        else:
            
            # economic is deprecated
            Q_fac, R_fac = cp_qr(A.A, mode="reduced")

            # Build matvec and rmatvec
            def _matvec(vec):
                tmp = Q_fac.T @ vec
                tmp = cp_solve_triangular(R_fac, tmp, lower=False)
                return tmp

            def _rmatvec(vec):
                tmp = cp_solve_triangular(R_fac.T, vec, lower=True)
                tmp = Q_fac @ tmp
                return tmp

        super().__init__( (n, k), _matvec, _rmatvec , device=device)
        
        
    def to_gpu(self):
        return QRPseudoInverseOperator(self.original_op.to_gpu())
    
    def to_cpu(self):
        return QRPseudoInverseOperator(self.original_op.to_cpu())


In [140]:
Wmat = np.random.normal(size=(30,3))
W = MatrixLinearOperator(Wmat)
Wpinv = QRPinvOperator(W)

u = np.random.normal(size=Wpinv.shape[1])
np.linalg.norm( (Wpinv @ u) - (np.linalg.pinv(Wmat) @ u) )

1.2719202621569003e-16

In [141]:
Wpinv = Wpinv.to_gpu()

In [142]:
u = cp.random.normal(size=Wpinv.shape[1])
cp.linalg.norm( (Wpinv @ u) - (cp.asarray(np.linalg.pinv(Wmat)) @ u) )

array(6.35960131e-17)

# CGPseudoinverse

In [149]:
from jlinops import cg as jlinops_cg
from scipy.sparse.linalg import cg as scipy_cg
from cupyx.scipy.sparse.linalg import cg as cupy_cg

In [252]:
class CGPinvOperator(_CustomLinearOperator):
    """Returns a linear operator that approximately computes the pseudoinverse of a matrix A using 
    a conjugate gradient method.
    """

    def __init__(self, A, warmstart_prev=False, which="jlinops", check=False, *args, **kwargs):

        assert which in ["jlinops", "scipy"], "Invalid choice for which!"

        # Store operator
        self.original_op = A
        self.A = A
        
        # Device
        device = A.device
        
        # Shape
        m, n = A.shape
        shape = (n, m)
    
        # Setup
        self.which = which
        self.warmstart_prev = warmstart_prev
        self.check = check
        self.args = args
        self.kwargs = kwargs
        self.in_shape = A.shape[0]
        self.out_shape = A.shape[1]
        
        if device == "cpu":
            self.prev_eval = np.zeros(self.out_shape)
            self.prev_eval_t = np.zeros(self.in_shape)
        else:
            self.prev_eval = cp.zeros(self.out_shape)
            self.prev_eval_t = cp.zeros(self.in_shape)
            
        self.warmstart_prev = warmstart_prev

        # Build both operators we need
        self.AtA = self.original_op.T @ self.original_op
        self.AAt = self.original_op @ self.original_op.T
        
        
        if device == "cpu":
            
            if self.which == "jlinops":
                
                def _matvec(x):
                    solver_data = jlinops_cg(self.AtA, self.A.rmatvec(x), x0=self.prev_eval, *args, **kwargs)
                    sol = solver_data["x"]
                    if self.check:
                        assert solver_data["converged"], "CG algorithm did not converge"
                    
                    if self.warmstart_prev:
                        self.prev_eval = sol.copy()
                    
                    return sol
                
                def _rmatvec(x):
                    solver_data = jlinops_cg(self.AAt, self.A.matvec(x), x0=self.prev_eval_t, *args, **kwargs)
                    sol = solver_data["x"]
                    if self.check:
                        assert solver_data["converged"], "CG algorithm did not converge"
                        
                    if self.warmstart_prev:
                        self.prev_eval_t = sol.copy()
                        
                    return sol
        
            elif self.which == "scipy":
                
                def _matvec(x):
                    sol, converged = scipy_cg(self.AtA, self.A.rmatvec(x), x0=self.prev_eval, *args, **kwargs) 
                    if self.check:
                        assert converged == 0, "CG algorithm did not converge!"
                        
                    if self.warmstart_prev:
                        self.prev_eval = sol.copy()
                    
                    return sol
                
                def _rmatvec(x):
                    sol, converged = scipy_cg(self.AAt, self.A.matvec(x), x0=self.prev_eval_t, *args, **kwargs) 
                    if self.check:
                        assert converged == 0, "CG algorithm did not converge!"
                    
                    if self.warmstart_prev:
                        self.prev_eval_t = sol.copy()
                    
                    return sol
                
            else:
                raise NotImplementedError
                
        else:
            
            
            if self.which == "jlinops":
                
                def _matvec(x):
                    solver_data = jlinops_cg(self.AtA, self.A.rmatvec(x), x0=self.prev_eval, *args, **kwargs)
                    sol = solver_data["x"]
                    if self.check:
                        assert solver_data["converged"], "CG algorithm did not converge"
                        
                    if self.warmstart_prev:
                        self.prev_eval = sol.copy()
                        
                    return sol
                
                def _rmatvec(x):
                    solver_data = jlinops_cg(self.AAt, self.A.matvec(x), x0=self.prev_eval_t, *args, **kwargs)
                    sol = solver_data["x"]
                    if self.check:
                        assert solver_data["converged"], "CG algorithm did not converge"
                    
                    if self.warmstart_prev:
                        self.prev_eval_t = sol.copy()
                    
                    return sol
        
            elif self.which == "scipy":
                
                def _matvec(x):
                    sol, converged = cupy_cg(self.AtA, self.A.rmatvec(x), x0=self.prev_eval, *args, **kwargs) 
                    if self.check:
                        assert converged == 0, "CG algorithm did not converge!"
                    
                    if self.warmstart_prev:
                        self.prev_eval = sol.copy()
                        
                    return sol
                
                def _rmatvec(x):
                    sol, converged = cupy_cg(self.AAt, self.A.matvec(x), x0=self.prev_eval_t, *args, **kwargs) 
                    if self.check:
                        assert converged == 0, "CG algorithm did not converge!"
                    
                    if self.warmstart_prev:
                        self.prev_eval_t = sol.copy()
                        
                    return sol
                
            else:
                raise NotImplementedError
            
            
        super().__init__( shape, _matvec, _rmatvec, device=device, dtype=self.A.dtype)
        
        
        
    def to_gpu(self):
        return CGPinvOperator(self.A.to_gpu(), warmstart_prev=self.warmstart_prev, which=self.which, check=self.check, *self.args, **self.kwargs)
    
    def to_cpu(self):
        return CGPinvOperator(self.A.to_cpu(), warmstart_prev=self.warmstart_prev, which=self.which, check=self.check, *self.args, **self.kwargs)


In [253]:
class CGModPinvOperator(_CustomLinearOperator):
    """Returns a linear operator that approximately computes the pseudoinverse of a matrix A using 
    a conjugate gradient method. Modifed so that it only ever solves systems with A^T A. 
    
    W: a LinearOperator representing a matrix with linearly independent columns that spans null(A).
    Wpinv: a LinearOperator represening the pseudoinverse of W.
    """

    def __init__(self, A, W, Wpinv, warmstart_prev=False, which="jlinops", check=False, *args, **kwargs):

        assert which in ["jlinops", "scipy"], "Invalid choice for which!"

        # Store operator
        self.original_op = A
        self.A = A
        self.W = W
        self.Wpinv = Wpinv
        
        # Device
        device = A.device
    
        # Shape
        m, n = A.shape
        shape = (n, m)
    
        # Setup
        self.which = which
        self.warmstart_prev = warmstart_prev
        self.check = check
        self.args = args
        self.kwargs = kwargs
        self.in_shape = A.shape[0]
        self.out_shape = A.shape[1]
        
        if device == "cpu":
            self.prev_eval = np.zeros(self.out_shape)
            self.prev_eval_t = np.zeros(self.in_shape)
        else:
            self.prev_eval = cp.zeros(self.out_shape)
            self.prev_eval_t = cp.zeros(self.in_shape)
            
        self.warmstart_prev = warmstart_prev

        # Build both operators we need
        self.AtA = self.original_op.T @ self.original_op
        self.AAt = self.original_op @ self.original_op.T
        
        
        if device == "cpu":
            
            if self.which == "jlinops":
                
                def _matvec(x):
                    solver_data = jlinops_cg(self.AtA, self.A.rmatvec(x), x0=self.prev_eval, *args, **kwargs)
                    sol = solver_data["x"]
                    if self.check:
                        assert solver_data["converged"], "CG algorithm did not converge"
                    
                    if self.warmstart_prev:
                        self.prev_eval = sol.copy()
                    
                    return sol
                
                def _rmatvec(x):
                    
                    # Project x onto range(A^T A) = range(A^T).
                    x = x - (W @ (Wpinv @ x))
                    
                    solver_data = jlinops_cg(self.AtA, x, x0=self.prev_eval_t, *args, **kwargs)
                    sol = solver_data["x"]
                    if self.check:
                        assert solver_data["converged"], "CG algorithm did not converge"
                        
                    if self.warmstart_prev:
                        self.prev_eval_t = sol.copy()
                        
                    return self.A @ sol
        
            elif self.which == "scipy":
                
                def _matvec(x):
                    sol, converged = scipy_cg(self.AtA, self.A.rmatvec(x), x0=self.prev_eval, *args, **kwargs) 
                    if self.check:
                        assert converged == 0, "CG algorithm did not converge!"
                        
                    if self.warmstart_prev:
                        self.prev_eval = sol.copy()
                    
                    return sol
                
                def _rmatvec(x):
                    
                    # Project x onto range(A^T A) = range(A^T).
                    x = x - (W @ (Wpinv @ x))
                    
                    sol, converged = scipy_cg(self.AtA, x, x0=self.prev_eval_t, *args, **kwargs) 
                    if self.check:
                        assert converged == 0, "CG algorithm did not converge!"
                    
                    if self.warmstart_prev:
                        self.prev_eval_t = sol.copy()
                    
                    return self.A @ sol
                
            else:
                raise NotImplementedError
                
        else:
            
            
            if self.which == "jlinops":
                
                def _matvec(x):
                    solver_data = jlinops_cg(self.AtA, self.A.rmatvec(x), x0=self.prev_eval, *args, **kwargs)
                    sol = solver_data["x"]
                    if self.check:
                        assert solver_data["converged"], "CG algorithm did not converge"
                        
                    if self.warmstart_prev:
                        self.prev_eval = sol.copy()
                        
                    return sol
                
                def _rmatvec(x):
                    
                    # Project x onto range(A^T A) = range(A^T).
                    x = x - (W @ (Wpinv @ x))
                    
                    solver_data = jlinops_cg(self.AtA, x, x0=self.prev_eval_t, *args, **kwargs)
                    sol = solver_data["x"]
                    if self.check:
                        assert solver_data["converged"], "CG algorithm did not converge"
                    
                    if self.warmstart_prev:
                        self.prev_eval_t = sol.copy()
                    
                    return self.A @ sol
        
            elif self.which == "scipy":
                
                def _matvec(x):
                    sol, converged = cupy_cg(self.AtA, self.A.rmatvec(x), x0=self.prev_eval, *args, **kwargs) 
                    if self.check:
                        assert converged == 0, "CG algorithm did not converge!"
                    
                    if self.warmstart_prev:
                        self.prev_eval = sol.copy()
                        
                    return sol
                
                def _rmatvec(x):
                    
                    # Project x onto range(A^T A) = range(A^T).
                    x = x - (W @ (Wpinv @ x))
                    
                    sol, converged = cupy_cg(self.AtA, x, x0=self.prev_eval_t, *args, **kwargs) 
                    if self.check:
                        assert converged == 0, "CG algorithm did not converge!"
                    
                    if self.warmstart_prev:
                        self.prev_eval_t = sol.copy()
                        
                    return self.A @ sol
                
            else:
                raise NotImplementedError
            
            
        super().__init__( shape, _matvec, _rmatvec, device=device, dtype=self.A.dtype)
        
        
        
    def to_gpu(self):
        return CGModPinvOperator(self.A.to_gpu(), self.W.to_gpu(), self.Wpinv.to_gpu(), warmstart_prev=self.warmstart_prev, which=self.which, check=self.check, *self.args, **self.kwargs)
    
    def to_cpu(self):
        return CGPModinvOperator(self.A.to_cpu(), self.W.to_cpu(), self.Wpinv.to_cpu(), warmstart_prev=self.warmstart_prev, which=self.which, check=self.check, *self.args, **self.kwargs)


In [254]:
Amat = np.random.normal(size=(60,40))
Amat[:,-1] = Amat[:,-2]
Amat = Amat.T @ Amat
A = Amat.copy()
A = MatrixLinearOperator(A)

In [255]:
Apinv = CGPinvOperator(A, tol=1e-9, which="scipy")
check_adjoint(Apinv)

True

In [256]:
u = np.random.normal(size=Apinv.shape[1])
np.linalg.norm( (Apinv @ u) - (np.linalg.pinv(Amat) @ u) )

1.0319538686917293e-10

In [257]:
u = np.random.normal(size=Apinv.shape[0])
np.linalg.norm( (Apinv.T @ u) - (np.linalg.pinv(Amat).T @ u) )

7.396343545536628e-11

In [258]:
Apinv = Apinv.to_gpu()
check_adjoint(Apinv)

True

In [259]:
u = cp.random.normal(size=Apinv.shape[1])
cp.linalg.norm( (Apinv @ u) - (cp.asarray(np.linalg.pinv(Amat)) @ u) )

array(9.11456091e-11)

In [260]:
u = cp.random.normal(size=Apinv.shape[0])
cp.linalg.norm( (Apinv.T @ u) - (cp.asarray(np.linalg.pinv(Amat).T) @ u) )

array(6.31010465e-11)

In [261]:
Amat = np.random.normal(size=(60,40))
Amat[:,-1] = Amat[:,-2]
Amat = Amat.T @ Amat
A = Amat.copy()
A = MatrixLinearOperator(A)

In [262]:
W = null_space(Amat)
W = MatrixLinearOperator(W)
Wpinv = QRPinvOperator(W)

In [263]:
Apinv = CGModPinvOperator(A, W, Wpinv, tol=1e-9, which="scipy")
check_adjoint(Apinv)

True

In [264]:
u = np.random.normal(size=Apinv.shape[1])
np.linalg.norm( (Apinv @ u) - (np.linalg.pinv(Amat) @ u) )

1.2282461368355753e-11

In [265]:
u = np.random.normal(size=Apinv.shape[0])
np.linalg.norm( (Apinv.T @ u) - (np.linalg.pinv(Amat).T @ u) )

2.2379468240116435e-11

In [266]:
Apinv = Apinv.to_gpu()
check_adjoint(Apinv)

True

In [267]:
u = cp.random.normal(size=Apinv.shape[1])
cp.linalg.norm( (Apinv @ u) - (cp.asarray(np.linalg.pinv(Amat)) @ u) )

array(6.81249161e-12)

In [268]:
u = cp.random.normal(size=Apinv.shape[0])
cp.linalg.norm( (Apinv.T @ u) - (cp.asarray(np.linalg.pinv(Amat).T) @ u) )

array(2.82265432e-11)

# Preconditioned CGpinv

In [286]:
class CGPreconditionedPinvOperator(_CustomLinearOperator):
    """Returns a linear operator that approximately computes the pseudoinverse of a matrix A using 
    a conjugate gradient method. Modifed so that it only ever solves systems with A^T A. 
    
    W: a LinearOperator representing a matrix with linearly independent columns that spans null(A).
    Wpinv: a LinearOperator represening the pseudoinverse of W.
    Lpinv: 
    """

    def __init__(self, A, W, Wpinv, Lpinv, warmstart_prev=True, check=False, which="jlinops", *args, **kwargs):

        assert which in ["jlinops", "scipy"], "Invalid choice for which!"

        # Device
        device = A.device
        
        # Store operator
        self.A = A
        self.W = W
        self.Wpinv = Wpinv
        self.Lpinv = Lpinv
        self.Ltpinv = Lpinv.T
        
        # Shape
        m, n = A.shape
        shape = (n, m)

        # Setup
        self.which = which
        self.check = check
        self.warmstart_prev = warmstart_prev
        self.in_shape = self.A.shape[0]
        self.out_shape = self.A.shape[1]
        
        if device == "cpu":
            self.prev_eval = np.zeros(self.out_shape)
            self.prev_eval_t = np.zeros(self.out_shape)
        else:
            self.prev_eval = cp.zeros(self.out_shape)
            self.prev_eval_t = cp.zeros(self.out_shape)

        # Build both operators we need
        self.AtA = self.A.T @ self.A
        self.Q = self.Lpinv @ self.AtA @ self.Ltpinv

        
        if device == "cpu":
            
            if self.which == "jlinops":
                
                def _matvec(x):
                    solver_data = jlinops_cg(self.Q, self.Lpinv @ (self.A.rmatvec(x)), x0=self.prev_eval, *args, **kwargs)
                    sol = solver_data["x"]
                    if self.check:
                        assert solver_data["converged"], "CG algorithm did not converge"
                    
                    if self.warmstart_prev:
                        self.prev_eval = sol.copy()
                    
                    return self.Ltpinv @ sol
                
                def _rmatvec(x):
                    
                    # Project x onto range(A^T A) = range(A^T).
                    x = x - (W @ (Wpinv @ x))
                    
                    solver_data = jlinops_cg(self.Q, self.Lpinv @ x, x0=self.prev_eval_t, *args, **kwargs)
                    sol = solver_data["x"]
                    if self.check:
                        assert solver_data["converged"], "CG algorithm did not converge"
                        
                    if self.warmstart_prev:
                        self.prev_eval_t = sol.copy()
                        
                    return self.A @ (self.Ltpinv @ sol)
        
            elif self.which == "scipy":
                
                def _matvec(x):
                    sol, converged = scipy_cg(self.Q, self.Lpinv @ (self.A.rmatvec(x)), x0=self.prev_eval, *args, **kwargs) 
                    if self.check:
                        assert converged == 0, "CG algorithm did not converge!"
                    
                    if self.warmstart_prev:
                        self.prev_eval = sol.copy()
                    
                    return self.Ltpinv @ sol
                
                def _rmatvec(x):
                    
                    # Project x onto range(A^T A) = range(A^T).
                    x = x - (W @ (Wpinv @ x))
                    
                    sol, converged = scipy_cg(self.Q, self.Lpinv @ x, x0=self.prev_eval_t, *args, **kwargs) 
                    if self.check:
                        assert converged == 0, "CG algorithm did not converge!"
                    
                    if self.warmstart_prev:
                        self.prev_eval_t = sol.copy()
                    
                    return self.A @ (self.Ltpinv @ sol)
                
            else:
                raise NotImplementedError
                
        else:
            
            
            if self.which == "jlinops":
                
                def _matvec(x):
                    solver_data = jlinops_cg(self.Q, self.Lpinv @ (self.A.rmatvec(x)), x0=self.prev_eval, *args, **kwargs)
                    sol = solver_data["x"]
                    if self.check:
                        assert solver_data["converged"], "CG algorithm did not converge"
                        
                    if self.warmstart_prev:
                        self.prev_eval = sol.copy()
                        
                    return self.Ltpinv @ sol
                
                def _rmatvec(x):
                    
                    # Project x onto range(A^T A) = range(A^T).
                    x = x - (W @ (Wpinv @ x))
                    
                    solver_data = jlinops_cg(self.Q, self.Lpinv @ x, x0=self.prev_eval_t, *args, **kwargs)
                    sol = solver_data["x"]
                    if self.check:
                        assert solver_data["converged"], "CG algorithm did not converge"
                    
                    if self.warmstart_prev:
                        self.prev_eval_t = sol.copy()
                    
                    return self.A @ (self.Ltpinv @ sol)
        
            elif self.which == "scipy":
                
                def _matvec(x):
                    sol, converged = cupy_cg(self.Q, self.Lpinv @ (self.A.rmatvec(x)), x0=self.prev_eval, *args, **kwargs)
                    if self.check:
                        assert converged == 0, "CG algorithm did not converge!"
                    
                    if self.warmstart_prev:
                        self.prev_eval = sol.copy()
                        
                    return self.Ltpinv @ sol
                
                def _rmatvec(x):
                    
                    # Project x onto range(A^T A) = range(A^T).
                    x = x - (W @ (Wpinv @ x))
                    
                    sol, converged = cupy_cg(self.Q, self.Lpinv @ x, x0=self.prev_eval_t, *args, **kwargs) 
                    if self.check:
                        assert converged == 0, "CG algorithm did not converge!"
                    
                    if self.warmstart_prev:
                        self.prev_eval_t = sol.copy()
                        
                    return self.A @ (self.Ltpinv @ sol)
                
            else:
                raise NotImplementedError
        
        
        super().__init__( shape, _matvec, _rmatvec, dtype=np.float64, device=device)
        
        
    def to_gpu(self):
        return CGPreconditionedPinvOperator(self.A.to_gpu(), self.W.to_gpu(), self.Wpinv.to_gpu(), self.Lpinv.to_gpu(), warmstart_prev=self.warmstart_prev, which=self.which, check=self.check, *self.args, **self.kwargs)
    
        
    def to_cpu(self):
        return CGPreconditionedPinvOperator(self.A.to_cpu(), self.W.to_cpu(), self.Wpinv.to_cpu(), self.Lpinv.to_cpu(), warmstart_prev=self.warmstart_prev, which=self.which, check=self.check, *self.args, **self.kwargs)
    

In [304]:

class CGWeightedNeumann2DPinvOperator(_CustomLinearOperator):
    """Represents the pseudoinverse (R_w)^\dagger of a linear operator R_w = D_w R, where
    D_w is a diagonal matrix of weights and R is a DiscreteGradientNeumann2D operator.
    Here matvecs/rmatvecs are applied approximately using a preconditioned conjugate
    gradient method, where the preconditioner is based on the operator with identity weights. 
    """

    def __init__(self, grid_shape, weights, warmstart_prev=True, check=False, which="jlinops", *args, **kwargs):

        assert 2*math.prod(grid_shape) == len(weights), "Weights incompatible!"
        self.weights = weights
        self.grid_shape = grid_shape
        self.warmstart_prev = warmstart_prev
        self.check = check
        self.which = which
        self.args = args
        self.kwargs = kwargs
        
        # Figure out device
        device = get_device(weights)

        # Build R_w
        self.R = Neumann2D(grid_shape, device=device)
        self.Dw = DiagonalOperator(weights)
        self.Rw = self.Dw @ self.R

        # Get Rpinv (with identity weights)
        self.Rpinv = dct_sqrt_pinv(self.R.T @ self.R, grid_shape)

        # Take care of W (columns span the kernel of R)
        if device == "cpu":
            W = np.ones((self.R.shape[1],1))
        else:
            W = cp.ones((self.R.shape[1],1))
            
        self.W = MatrixLinearOperator(W)
        self.Wpinv = QRPseudoInverseOperator(self.W)

        # Make Rwpinv
        self.Rwpinv = CGPreconditionedPinvOperator(self.Rw, self.W, self.Wpinv, self.Rpinv, warmstart_prev=warmstart_prev, check=check, which=which, *args, **kwargs)

        def _matvec(x):
            return self.Rwpinv @ x

        def _rmatvec(x):
            return self.Rwpinv.T @ x

        super().__init__( self.Rwpinv.shape, _matvec, _rmatvec, dtype=np.float64, device=device)
        

    def to_gpu(self):
        return CGWeightedNeumann2DPinvOperator(self.grid_shape, cp.asarray(self.weights), warmstart_prev=self.warmstart_prev, check=self.check, which=self.which, *self.args, **self.kwargs)
    
    def to_cpu(self):
        return CGWeightedNeumann2DPinvOperator(self.grid_shape, cp.numpy(self.weights), warmstart_prev=self.warmstart_prev, check=self.check, which=self.which, *self.args, **self.kwargs)


In [305]:
from jlinops import get_device, Neumann2D, DiagonalOperator

In [351]:
M, N = 1000, 1000
weights = np.random.uniform(low=1e-1, high=1e1, size=2*M*N)
grid_shape = (M,N)
A = CGWeightedNeumann2DPinvOperator(grid_shape, weights, which="scipy", tol=1e-7)
#A = CGWeightedNeumann2DPinvOperator(grid_shape, weights, which="jlinops", eps=1e-7)
#check_adjoint(A)

In [342]:
%%timeit
u = np.random.normal(size=A.shape[1])
_ = A @ u

43.2 s ± 823 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [352]:
A = A.to_gpu()

In [353]:
check_adjoint(A)

True

In [350]:
%%timeit
u = cp.random.normal(size=A.shape[1])
_ = A @ u

1.15 s ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [344]:
%%timeit
u = cp.random.normal(size=A.shape[1])
_ = A @ u

425 ms ± 15.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [349]:
43200/425

101.6470588235294

In [340]:
26200/255

102.74509803921569

In [334]:
9190/216

42.5462962962963

In [325]:
u = np.random.normal(size=A.shape[1])
A @ u

array([0.20068115, 0.20499076, 0.18600398, ..., 0.18497829, 0.37659145,
       0.03446162])

In [326]:
A = A.to_gpu()

In [327]:
u = cp.random.normal(size=A.shape[1])
A @ u

array([0.35298313, 0.11578183, 0.15125623, ..., 0.19635416, 0.34345851,
       0.82423415])

In [328]:
check_adjoint(A)

True

# Old version