<a href="https://colab.research.google.com/github/dschlesinger/experiments/blob/main/LU_factorization/LU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import numpy as np

In [5]:
A = np.round(np.random.normal(loc=0, scale=2, size=(5,5))).astype(np.float32)

print(
  A,
  # Check that matrix is non-singular
  np.linalg.det(A),
  sep='\n'
)

[[-0.  3.  1. -0.  0.]
 [-1. -2. -2.  0. -2.]
 [-3. -0. -1.  3.  5.]
 [ 2. -1.  1. -0. -3.]
 [-3.  1.  1. -0.  1.]]
-210.0


In [None]:
import numpy as np
from typing import Tuple
from itertools import repeat

def find_lead_zeros(r: np.ndarray) -> int:
    """Takes row and finds how many leading zeros, returns 0 if none."""
    i: int = 0
    while i < r.shape[0] and r[i] == 0:
        i += 1
    return i

def LU(A: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Returns (L, U). Raises an error if A is not computable."""

    n = A.shape[0]
    L = np.identity(n)
    U = A.copy().astype(np.float32)  # Work with a copy of A to avoid modifying the original

    if np.linalg.det(A) == 0:

      raise ZeroDivisionError("Matrix is singular")

    # Sort Matrix to Triangle
    leading_zeros = np.apply_along_axis(find_lead_zeros, axis=1, arr=U)
    rows = sorted(range(n), key=lambda i: leading_zeros[i])

    # Keep L matrix in sync with
    U = U[rows]
    L = L[rows]

    # Pivot and elimination
    for head_index in range(n - 1):
        for action_index in range(head_index + 1, n):

            if U[head_index, head_index] == 0 or U[action_index, head_index] == 0:
              continue

            multiplier = U[action_index, head_index] / U[head_index, head_index]
            U[action_index] = U[action_index] - (multiplier * U[head_index])
            L[action_index, head_index] = multiplier


    return L, U

# A = np.array([
#     [4, 3, 2],
#     [1, 3, 2],
#     [0, 1, 5]
# ])
for _ in repeat(None, 100):

  A = np.round(np.random.normal(loc=0, scale=2, size=(5,5))).astype(np.float32)

  # print(
  #   A,
  #   np.linalg.det(A),
  #   np.linalg.inv(A),
  #   sep='\n'
  # )

  try:
    L, U = LU(A)
  except ZeroDivisionError:
    continue
  # print("L:\n", L)
  # print("U:\n", U)

  # Verify that A = LU
  comp = L @ U

  # print("A:\n", A)
  # print("Product L * U:\n", comp)

  correct = np.all(np.isclose(A, comp, atol=1e-2, rtol=1e-2))

  print(correct)

  if not correct:

    print(L, U, comp, sep='\n')

True
True
False
[[ 1.          0.          0.          0.          0.        ]
 [ 1.          0.          1.          0.          0.        ]
 [-2.         -8.          0.          1.          0.        ]
 [-3.         -6.          0.7647059   0.          1.        ]
 [ 0.         -1.          0.         -0.71830964  0.        ]]
[[  1.         -2.          0.         -1.         -1.       ]
 [  0.          1.         -2.          4.         -2.       ]
 [  0.          0.        -17.         29.        -18.       ]
 [  0.          0.          0.         -4.1764717  -3.2352943]
 [  0.          0.          0.          0.         -5.323943 ]]
[[  1.          -2.           0.          -1.          -1.        ]
 [  1.          -2.         -17.          28.         -19.        ]
 [ -2.          -4.          16.         -34.17647171  14.76470566]
 [ -3.           0.          -1.00000024   1.17647099  -4.08864927]
 [  0.          -1.           2.          -1.00000011   4.32394312]]
False
[[ 1.

In [None]:
print(*A)

[-0. -2. -1.] [-3.  2.  2.] [-4. -2. -2.]
