<a href="https://colab.research.google.com/github/dschlesinger/experiments/blob/main/QR_decomposition/QR_Decomposition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import numpy as np

from collections.abc import Callable
from typing import Tuple
from functools import partial

In [11]:
def generate_full_rank(rank: int = 5, num_func: Callable = np.random.randint, **kwargs) -> np.ndarray:
  """Generates guarrenteed full rank matrix

  Args:
    rank: int -> Rank of generated matrix
    num_func: Callable -> the function to generation numbers
    **kwargs: dict -> passed into num_func

  Returns:
    A: np.ndarray -> Always full rank matrix
  """

  A = num_func(size=(rank, rank), **kwargs)

  while np.linalg.det(A) == 0:

    A = num_func(size=(rank, rank), **kwargs)

  return A.astype(np.float32)

matrix_generator = lambda: generate_full_rank(low=-5, high=6)

In [12]:
def qr(A: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  """Returns orthogonal matrix Q and R which tracks operations
  Q @ R = A

  Args:
    A: np.ndarray (m, n) -> Matrix to decompose

  Retruns:
    Q: np.ndarray -> column orthogonal matrix
    R: np.ndarray -> upper triangular matrix

  Raises:
    ValueError: if A is not full rank
  """

  m, n, *_ = A.shape

  if np.linalg.matrix_rank(A) < m:
    raise ValueError("A is not full rank")

  Q = np.zeros((m, n), dtype=np.float32)
  R = np.zeros((n, n))

  # Construct Q and R
  q1 = A[:, 0]
  R[0, 0] = np.linalg.norm(q1)
  Q[:, 0] = q1 / R[0, 0]

  for column in range(1, n):

    q = A[:, column]

    proj = Q[:, :column].T @ q

    # Subtract projections
    q = q - Q[:, :column] @ proj

    # Normalize and update Q and R
    R[:column, column] = proj

    R[column, column] = np.linalg.norm(q)
    Q[:, column] = q / R[column, column]

  return Q, R


In [13]:
# Testing
from itertools import repeat

for _ in repeat(None, 1000):

  A = matrix_generator()

  Q, R = qr(A)

  if not np.allclose(A, Q @ R, atol=1e-4, rtol=1e-4):

    print(f"Fails")

    print(A, np.round(Q @ R), Q, R, sep='\n')

    break

else:

  print("Correct")

Correct
