# Factoring block tridiagonal symmetric positive definite matrices.

We implement Intel's algorithm in function `dpbltrf()` below, at https://www.intel.com/content/www/us/en/docs/onemkl/cookbook/2023-2/factor-block-tridiag-symm-pos-def-matrices.html.

The block tridiagonal matrix is parametrized as a list of ND blocks on the diagonal and a list of (ND-1) blocks on the subdiagonal (and their transpose on the supdiagonal). The blocks are square and of equal size.

If the blocks are of size NB x NB, this reduces complexity to O([ND NB]^3) to O(ND [NB]^3).

**Note**: The Intel algorithm (Fortran) works in-place -- Choleskys and linear solves can do that. However, Jax can't, so we didn't bother. You can do it in Python though when using scipy's Cholesky and linear solver rather than the Numpy versions.

In [2]:
import numpy as np

def make_spd_matrix(n_dim, *, random_state=None):
    """Generate a random symmetric, positive-definite matrix.

    Read more in the :ref:`User Guide <sample_generators>`.

    Parameters
    ----------
    n_dim : int
        The matrix dimension.

    random_state : int, RandomState instance or None, default=None
        Determines random number generation for dataset creation. Pass an int
        for reproducible output across multiple function calls.
        See :term:`Glossary <random_state>`.

    Returns
    -------
    X : ndarray of shape (n_dim, n_dim)
        The random symmetric, positive-definite matrix.

    See Also
    --------
    make_sparse_spd_matrix: Generate a sparse symmetric definite positive matrix.
    """
    A = np.random.random(size=(n_dim, n_dim))
    U, _, Vt = np.linalg.svd(np.dot(A.T, A))
    X = np.dot(np.dot(U, 1.0 + np.diag(np.random.random(size=n_dim))), Vt)

    return X

make_spd_matrix(3)

array([[0.83016087, 0.31760871, 0.39762778],
       [0.31760871, 1.44804131, 1.55298979],
       [0.39762778, 1.55298979, 2.24556612]])

In [5]:
def reconstruct_matrix_A(D, B):
    """
    Reconstructs the original matrix A from factorized diagonal and sub-diagonal blocks.

    Parameters:
    - D: Factorized diagonal blocks.
      Shape: (NB, N * NB)
    - B: Factorized sub-diagonal blocks.
      Shape: (NB, (N-1) * NB)

    Returns:
    - A: Reconstructed matrix A.
      Shape: (N * NB, N * NB)
    """
    N, NB = D.shape[1] // D.shape[0], D.shape[0]

    A = np.zeros((N * NB, N * NB))

    # Fill diagonal and sub-diagonal blocks
    for k in range(N):
        A[k * NB:(k + 1) * NB, k * NB:(k + 1) * NB] = D[:, k * NB:(k + 1) * NB]
        if k < N - 1:
            A[(k + 1) * NB:(k + 2) * NB, k * NB:(k + 1) * NB] = B[:, k * NB:(k + 1) * NB]

    # Fill upper-triangular part by transposing lower-triangular part
    A = A + A.T - np.diag(np.diag(A))

    return A

def extract_D_B_from_A(A, NB):
    """
    Extracts factorized diagonal and sub-diagonal blocks from the original matrix A.

    Parameters:
    - A: Original matrix A.
      Shape: (N * NB, N * NB)
    - NB: Size of blocks.

    Returns:
    - D: Extracted diagonal blocks.
      Shape: (NB, N * NB)
    - B: Extracted sub-diagonal blocks.
      Shape: (NB, (N-1) * NB)
    """
    N = A.shape[0] // NB

    D = np.zeros((NB, N * NB))
    B = np.zeros((NB, (N - 1) * NB))

    # Extract diagonal and sub-diagonal blocks
    for k in range(N):
        D[:, k * NB:(k + 1) * NB] = A[k * NB:(k + 1) * NB, k * NB:(k + 1) * NB]
        if k < N - 1:
            B[:, k * NB:(k + 1) * NB] = A[(k + 1) * NB:(k + 2) * NB, k * NB:(k + 1) * NB]

    return D, B

# Example usage:
N = 3
NB = 2

# Create random positive definite input matrix D
D = np.random.randn(NB, N * NB)

# Create random input matrix B
B = np.random.randn(NB, (N - 1) * NB)

# Reconstruct the original matrix A
original_A = reconstruct_matrix_A(D, B)

# Extract D and B from the original matrix A
extracted_D, extracted_B = extract_D_B_from_A(original_A, NB)

# Display the results
print("Original Matrix A:")
print(original_A)
print("\nExtracted Diagonal Blocks D:")
print(extracted_D)
print("\nExtracted Sub-diagonal Blocks B:")
print(extracted_B)

Original Matrix A:
[[ 0.8597252   1.33460823 -0.92816239 -0.38895811  0.          0.        ]
 [ 1.33460823  0.44478738 -0.1566678   1.7657079   0.          0.        ]
 [-0.92816239 -0.1566678   0.4067891  -0.63108444  0.76375694  1.90686374]
 [-0.38895811  1.7657079  -0.63108444  0.1027884  -0.34411442  1.2802755 ]
 [ 0.          0.          0.76375694 -0.34411442 -0.35267889 -1.23034714]
 [ 0.          0.          1.90686374  1.2802755  -1.23034714  0.58565449]]

Extracted Diagonal Blocks D:
[[ 0.8597252   1.33460823  0.4067891  -0.63108444 -0.35267889 -1.23034714]
 [ 1.33460823  0.44478738 -0.63108444  0.1027884  -1.23034714  0.58565449]]

Extracted Sub-diagonal Blocks B:
[[-0.92816239 -0.1566678   0.76375694 -0.34411442]
 [-0.38895811  1.7657079   1.90686374  1.2802755 ]]


In [56]:
def make_random_A(N, NB, maxtries=100):
    def tryone():
        A = make_spd_matrix(N*NB) + np.eye(N*NB)*1
        D, B = extract_D_B_from_A(A, NB)
        A = reconstruct_matrix_A(D, B)
        return A

    tries = 0
    while tries < maxtries:
        A = tryone()
        try:
            np.linalg.cholesky(A)
        except np.linalg.LinAlgError: # Not pos def
            tries += 1
            continue
        break

    return A

A = make_random_A(3, 2)
A, np.linalg.cholesky(A)

(array([[ 1.80932558,  0.94310436,  0.80620673,  0.24765308,  0.        ,
          0.        ],
        [ 0.94310436,  2.35000467,  1.07480342,  0.67719845,  0.        ,
          0.        ],
        [ 0.80620673,  1.07480342,  2.74564025,  2.17826138, -1.70562001,
          1.0734039 ],
        [ 0.24765308,  0.67719845,  2.17826138,  2.28508432, -1.13244158,
          0.57563354],
        [ 0.        ,  0.        , -1.70562001, -1.13244158,  3.50854882,
         -2.46820486],
        [ 0.        ,  0.        ,  1.0734039 ,  0.57563354, -2.46820486,
          2.26811303]]),
 array([[ 1.34511174,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.70113459,  1.36323694,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.59936041,  0.48015945,  1.46828275,  0.        ,  0.        ,
          0.        ],
        [ 0.18411339,  0.40206524,  1.2769037 ,  0.67753012,  0.        ,
          0.        ],
        [ 0.        ,  0.   

In [7]:
def reconstruct_matrix_L(result_D, result_B):
    L = np.tril(reconstruct_matrix_A(result_D, result_B))
    return L

In [8]:
import numpy as np
import scipy.linalg

a1 = np.array([[1,1,1],[1,1,1],[1,1,1]])
a2 = np.array([[2,2,2],[2,2,2],[2,2,2]])
a3 = np.array([[3,3,3],[3,3,3],[3,3,3]])

b = scipy.linalg.block_diag(a1, a2, a3)
b[1,4] = 4
b[1,7] = 5
b[4,1] = 6
b[4,7] = 7
b[7,1] = 8
b[7,4] = 9
print(b)

def extract_block_diag(a, n, k=0):
    a = np.asarray(a)
    if a.ndim != 2:
        raise ValueError("Only 2-D arrays handled")
    if not (n > 0):
        raise ValueError("Must have n >= 0")

    if k > 0:
        a = a[:,n*k:] 
    else:
        a = a[-n*k:]

    n_blocks = min(a.shape[0]//n, a.shape[1]//n)

    new_shape = (n_blocks, n, n)
    new_strides = (n*a.strides[0] + n*a.strides[1],
                   a.strides[0], a.strides[1])

    return np.lib.stride_tricks.as_strided(a, new_shape, new_strides)

extract_block_diag(b, 3, -1)

[[1 1 1 0 0 0 0 0 0]
 [1 1 1 0 4 0 0 5 0]
 [1 1 1 0 0 0 0 0 0]
 [0 0 0 2 2 2 0 0 0]
 [0 6 0 2 2 2 0 7 0]
 [0 0 0 2 2 2 0 0 0]
 [0 0 0 0 0 0 3 3 3]
 [0 8 0 0 9 0 3 3 3]
 [0 0 0 0 0 0 3 3 3]]


array([[[0, 0, 0],
        [0, 6, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 9, 0],
        [0, 0, 0]]])

In [15]:
import jax
jax.config.update("jax_enable_x64", True)
from jax import scipy as jscipy

# Factoring block tridiagonal symmetric positive definite matrices.
# Inputs are lists of the diagonal and subdiagonal blocks D, B.
# Outputs are lists of factorized diagonal and subdiagonal blocks L, C.
# We use the following pseudocode:
# L1=chol(D1) 
# do i=1,N-1
#      Ci=Bi∙Li-T //trsm()
#      Di + 1:=Di + 1 - Ci∙CiT //syrk()
#      Li + 1=chol(Di + 1) 
# end do
def dpbltrf(D, B):
     ND = len(D)
     NB = len(B)
     assert ND-1 == NB

     L = [None]*ND
     C = [None]*NB

     L[0] = jscipy.linalg.cholesky(D[0], lower=True)
     for i in range(ND-1):
          C[i] = jscipy.linalg.solve_triangular(L[i], B[i].T, lower=True).T

          # The next line is equivalent to syrk(), but this is not implemented in Numpy or Jax
          U = D[i+1] - C[i] @ C[i].T

          L[i+1] = jscipy.linalg.cholesky(U, lower=True)
     
     return L, C

**IMPORTANT**: The `make_random_A()` functions are really bad and produce about 1 in 5 a valid matrix for NB = ND = 3, more often for smaller numbers and practically never for larger. If nans appear because the generated matrix are not posdef, keep trying. 

In [63]:
num_blocks = 3
block_size = 4

A = make_random_A(num_blocks, block_size)

D = extract_block_diag(A, block_size, 0)
B = extract_block_diag(A, block_size, -1)

L, C = dpbltrf(D, B)

L, C

([Array([[ 1.23698005,  0.        ,  0.        ,  0.        ],
         [ 0.0072361 ,  1.23375071,  0.        ,  0.        ],
         [ 0.11983359, -0.24278432,  1.13423654,  0.        ],
         [-0.47298397,  0.47540178,  0.48488581,  1.2906931 ]],      dtype=float64),
  Array([[ 1.34840916,  0.        ,  0.        ,  0.        ],
         [ 1.22147935,  1.19665677,  0.        ,  0.        ],
         [-0.78516359, -1.04880962,  0.76457072,  0.        ],
         [-0.41026285, -0.11888511, -0.1681108 ,  1.19610842]],      dtype=float64),
  Array([[ 1.1963427 ,  0.        ,  0.        ,  0.        ],
         [ 0.1839597 ,  1.13304089,  0.        ,  0.        ],
         [ 0.16142017, -0.22434554,  1.11761474,  0.        ],
         [ 0.9593464 ,  0.51350413,  0.54431404,  0.39261148]],      dtype=float64)],
 [Array([[ 0.086724  , -0.07823963, -0.10514684, -0.37162006],
         [ 0.23584591, -0.32351296, -0.30338492, -0.70477748],
         [-0.18512158,  0.11550208,  0.1209523 ,  0

In [64]:
Lnew = reconstruct_matrix_L(np.hstack(L), np.hstack(C))
Lnaieve = np.linalg.cholesky(A)

Anew = Lnew @ Lnew.T
Anaieve = Lnaieve @ Lnaieve.T

# Print the norm differences
print("Norm difference between A and reconstructed A:")
print(np.linalg.norm(A - Anew))
print("Norm difference between A and naieve reconstructed A:")
print(np.linalg.norm(A - Anaieve))

Norm difference between A and reconstructed A:
9.945640348601413e-16
Norm difference between A and naieve reconstructed A:
2.118625103313206e-15
