In [1]:
import numpy as np
import matplotlib.pyplot as plt
import jlinops

import numpy as np
import math
from scipy.fft import dstn as sp_dstn
from scipy.fft import idstn as sp_idstn
from scipy.sparse.linalg import LinearOperator
# from scipy.sparse.linalg._interface import _CustomLinearOperator


from jlinops import _CustomLinearOperator, DiagonalOperator, get_device

# from ..base import _CustomLinearOperator
# from ..diagonal import DiagonalOperator
# from ..util import get_device

from jlinops import CUPY_INSTALLED
# from .. import CUPY_INSTALLED
if CUPY_INSTALLED:
    import cupy as cp
    from cupyx.scipy.fft import dctn as cp_dctn
    from cupyx.scipy.fft import idctn as cp_idctn


In [2]:
v1 = np.arange(5)
v2 = np.arange(8)
np.hstack([v1,v2])

array([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7])

In [3]:
b = np.asarray([1,2])


In [4]:
v1[b]

array([1, 2])

In [5]:
np.arange(4)

array([0, 1, 2, 3])

In [6]:
np.arange(3,6)

array([3, 4, 5])

In [18]:
class BlockDiagonalOperator(_CustomLinearOperator):
    """Given square operators A_1, \ldots, A_p of shapes n_i \times n_i, represents the block diagonal operator diag(A_1, \ldots, A_p).
    """
    def __init__(self, As, device="cpu", type=1):

        # Setup
        self.As = As # list of operators
        self.p = len(self.As) # number of operators
        self.ns = [] # list of shapes
        self.idxs = []
        tmp = 0
        for j, op in enumerate(self.As):
            nj = op.shape[0]
            assert op.shape[1] == nj, "Not all A_i are square!"
            self.ns.append(nj)

            if j == 0:
                self.idxs.append( np.arange(nj) )     
            else:
                self.idxs.append( np.arange(tmp, tmp+nj) )
                
            tmp += nj


        # Set shape
        n = sum(self.ns)
        shape = (n, n)

        # Define matvecs
        def _matvec(x):
            pieces = []
            for j, op in enumerate(self.As):
                pieces.append( op.matvec(x[self.idxs[j]]) )
            return np.hstack(pieces)
        
        def _rmatvec(x):
            pieces = []
            for j, op in enumerate(self.As):
                pieces.append( op.rmatvec(x[self.idxs[j]]) )
            return np.hstack(pieces)

            
        super().__init__(shape, _matvec, _rmatvec, device=device)
                
                
        
    def to_gpu(self):

        As_gpu = []
        for op in self.As:
            gpu_op = op.to_gpu()
            As_gpu.append(gpu_op)

        return BlockDiagonalOperator(As_gpu, device="gpu")
    
    def to_cpu(self):

        As_cpu = []
        for op in self.As:
            cpu_op = op.to_cpu()
            As_cpu.append(cpu_op)

        return BlockDiagonalOperator(As_cpu, device="cpu")

In [19]:
n1 = 5
n2 = 9
np.random.seed(0)
A1 = jlinops.MatrixLinearOperator( np.random.normal(size=(n1,n1)) )
A2 = jlinops.MatrixLinearOperator( np.random.normal(size=(n2,n2)) )
A = jlinops.BlockDiagonalOperator([A1, A2])

In [20]:
v1 = np.arange(5)
v2 = np.arange(5,5+9)
v = np.hstack([v1, v2])

In [21]:
A.T @ v

array([ -9.90012702,  10.95534703,   4.21298825,  -1.8893258 ,
         7.81505625,  15.24971629, -16.3875941 ,   9.30814004,
        43.05642643, -20.51545741, -19.86233757,  22.22613981,
       -29.88856451,  -2.04708387])

In [22]:
A1.T @ v1

array([-9.90012702, 10.95534703,  4.21298825, -1.8893258 ,  7.81505625])

In [23]:
A2.T @ v2

array([ 15.24971629, -16.3875941 ,   9.30814004,  43.05642643,
       -20.51545741, -19.86233757,  22.22613981, -29.88856451,
        -2.04708387])