<a href="https://colab.research.google.com/github/dschlesinger/experiments/blob/main/LU_factorization/LU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np

In [96]:
genA = lambda: np.round(np.random.normal(loc=0, scale=2, size=(5,5))).astype(np.float32)
genb = lambda: np.round(np.random.normal(loc=0, scale=2, size=(5,))).astype(np.float32)

A = genA()

print(
  A,
  # Check that matrix is non-singular
  np.linalg.det(A),
  sep='\n'
)

[[ 1.  0.  4. -0. -3.]
 [-2.  5.  1. -4.  0.]
 [ 1.  1. -0. -2. -4.]
 [-1. -1. -1.  2.  2.]
 [-2.  3.  3.  0.  4.]]
-84.0


In [54]:
from typing import Tuple, List

class solver:

  def __init__(self, A: np.ndarray) -> None:

    if np.linalg.det(A) == 0:

      raise ValueError("Matrix is singular")

    self.PLU(A)

  def reconstruct(self) -> np.ndarray:
    """f(P, L, U) -> A"""
    return self.P.T @ self.L @ self.U

  def solve(self, b: np.ndarray) -> np.ndarray:
    """Find x in Ax=b"""
    pb = self.P @ b

    y = self.forward(pb)

    x = self.backward(y)

    return x

  def forward(self, b: np.ndarray) -> np.ndarray:
    """f(L, b) -> y
    Ly = b
    """

    y = np.zeros(self.n, dtype=np.float64)
    y[0] = b[0]

    for i in range(1, self.n):

      y[i] = b[i] - np.dot(y[:i], self.L[i, :i])

    return y


  def backward(self, y: np.ndarray) -> np.ndarray:
    """f(U, y: from forward) -> x"""

    x = np.zeros(self.n, dtype=np.float64)
    x[-1] = y[-1] / self.U[-1, -1]

    for i in range(-2, -(self.n + 1), -1):

      x[i] = (y[i] - np.dot(x[i:], self.U[i, i:])) / self.U[i, i]

    return x

  def PLU(self, A: np.ndarray) -> Tuple[np.ndarray]:
        """
        Returns and sets as instance attributes:
          P: Permutation Matrix
          L: Lower Triangle Matrix, with diagonal elements = 1
          U: Upper Triangle Matrix
        """
        n = A.shape[0]
        self.n = n

        P = np.identity(n)
        L = np.identity(n, dtype=np.float64)
        U = A.copy().astype(np.float64)

        for head in range(n):
            # Find the pivot element (row with max absolute value in current column)
            pivot_row = head + np.argmax(np.abs(U[head:, head]))

            # If the pivot is zero, matrix is singular
            if np.abs(U[pivot_row, head]) < 1e-10:
                continue

            # Swap rows if needed
            if pivot_row != head:
                # Swap rows in U
                U[[head, pivot_row]] = U[[pivot_row, head]]

                # Swap rows in P
                P[[head, pivot_row]] = P[[pivot_row, head]]

                # Only swap the ROWS of L for the COLUMNS that have already been computed
                #if head > 0:
                L[[head, pivot_row], :head] = L[[pivot_row, head], :head]

            # Perform elimination
            for i in range(head + 1, n):
                if U[head, head] == 0:
                    continue

                factor = U[i, head] / U[head, head]
                L[i, head] = factor
                U[i, head:] -= factor * U[head, head:]

        self.P = P
        self.L = L
        self.U = U

        return P, L, U

In [28]:
from itertools import repeat
from IPython.display import clear_output

# Test accurate reconstruction

all_cor = []
for _ in repeat(None, 1000):

  clear_output()

  A = genA()

  try:
    s = solver(A)
  except ValueError:
    continue

  correct_grid = np.isclose(A, s.reconstruct() ,rtol=1e-4, atol=1e-4)

  correct = np.all(correct_grid)

  all_cor.append(correct)

  if not correct:

    print("----------")

    print(A, np.round(s.reconstruct()), correct_grid, s.P, s.L, s.U, sep='\n')

    break

print(np.unique(all_cor, return_counts=True))

(array([ True]), array([991]))


In [95]:
# Test foward and backward

all_cor = []
for _ in repeat(None, 1000):

  clear_output()

  A = genA()
  b = genb()

  try:
    s = solver(A)
  except ValueError:
    continue

  x = s.solve(b)

  correct = np.allclose(A @ x, b, rtol=1e-4, atol=1e-4)

  all_cor.append(correct)

  if not correct:

    print("----------")

    print(A @ x, b, A, x, sep='\n')

    break

# Not all test cases pass likely due to numerical instability
print(np.unique(all_cor, return_counts=True))

----------
[0. 0. 2. 0. 2.]
[ 1. -1.  4. -2. -0.]
[[-1. -0. -1. -1.  0.]
 [ 0. -2.  0. -0. -2.]
 [-1.  2.  0.  2. -3.]
 [-4.  2.  0.  0. -2.]
 [-1. -1. -3. -1. -3.]]
[ 1.16342990e+16  1.16342990e+16  1.16342990e+16 -2.32685981e+16
 -1.16342990e+16]
(array([False,  True]), array([  1, 495]))
