I'm interested in understanding [`torch.linalg.lu_solve`](https://docs.pytorch.org/docs/stable/generated/torch.linalg.lu_solve.html) with the aim of adding support for MPS.

The purpose of `lu_solve` is to solve a matrix equation $A X = B$ for $X$, given $A$ and $B$.

`lu_solve` relies on a strategy called LU decomposition. To paraphrase the [Wikipedia article](https://en.wikipedia.org/wiki/LU_decomposition) on LU decomposition, $A$ is separated into two matrices $L$, a lower triangular matrix, and $U$, an upper triangular matrix. Often, $A = LU$, though sometimes $A$ must be reordered to prevent division by zero or to control errors, in which case $P A Q = L U$, where $P$ and $Q$ are permutation matrices that reorder the rows and columns of $A$, respectively. However, it is usually sufficient to just use row permutations, $P A = L U$. This strategy is called "LU factorization with partial pivoting".

In PyTorch, however, $A = P L U$, by defintion. So $P$ is actually the inverse of the $P$ matrix that the Wikipedia article mentions.

[`torch.linalg.lu`](https://docs.pytorch.org/docs/stable/generated/torch.linalg.lu.html) can be used to compute the matrices $P$, $L$, and $U$ given $A$.

However it is not actually necessary to compute $L$ and $U$ separately in order to solve $A X = B$. So [`torch.linalg.lu_factor`](https://docs.pytorch.org/docs/stable/generated/torch.linalg.lu_factor.html) computes $LU$ and a `pivots` tensor. `pivots` is just a compactified representation of the $P$ matrix. $P$ is just the identity matrix with reordered rows, so it is mostly filled with zeros, wasting a lot of space. `pivots` is a vector where each element specifies a row index of the identity matrix. It's very easy to generate $P$ from `pivots`--you just take the identity matrix and swap the rows to the order that `pivots` indicates.

Since $A = P L U$, then $A X = B$ becomes $P L U X = B$, and solving for $X$, we get $X = U^{-1} L^{-1} P^{-1} B$. Since $P$ is a permutation matrix, its inverse is just its transpose, $P^{-1} = P^T$, so:

$$
X = U^{-1} L^{-1} P^T B
= (LU)^{-1} P^T B
$$

`lu_solve` takes $LU$, `pivots`, and $B$ as arguments and calculates the solution $X$ to the equation $X = (LU)^{-1} P^T B$.

In PyTorch, the CPU impl of `lu_solve` just calls into the LAPACK function [`[s|d|z|c]getrs`](https://www.netlib.org/lapack/explore-html/df/d36/group__getrs.html). LAPACK is very difficult to understand in my opinion, and I won't attempt it. SciPy also just calls the LAPACK function. The CUDA impl in PyTorch uses either CuBLAS, MAGMA, or a custom implementation, depending on the situation. CuBLAS is closed source, so I cannot learn much about it. The MAGMA function is named [`magma_[s|d|z|c]getrs_batched`](https://icl.utk.edu/projectsfiles/magma/doxygen/group__magma__getrs__batched.html).

The MAGMA implementation of `lu_solve` for complex numbers is essentially [this](https://github.com/icl-utk-edu/magma/blob/f360e055c1f54322807f5dcda75d41b2d6c1410a/src/zgetrs_gpu.cpp#L186-L216). The meat of it is a pair of calls to `magma_ztrs[v|m]`, which are triangular solve functions.

The custom CUDA impl of `lu_solve` in PyTorch is [this](https://github.com/pytorch/pytorch/blob/8e8cbb85ee927776210f7872e3d0286d5d40dc14/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp#L2473-L2494), which is also a pair of triangular solve function calls.

MLX also has a solver [here](https://github.com/ml-explore/mlx/blob/27778156dcbabbd7077985e8ea0683cf3ce04cfb/mlx/linalg.cpp#L680-L694), which accepts $A$ rather than $LU$ and `pivots`, but it decomposes $A$ into $LU$ and runs a pair of triangular solves.

JAX also has an `lu_solve` function which is implemented [here](https://github.com/jax-ml/jax/blob/30582db24e8794abd09df2b3120aa5b58af8e9fe/jax/_src/lax/linalg.py#L1604-L1621) with a pair of triangular solve function calls.

So everywhere I look for open source implementations of `lu_solve`, I see a pair of triangular solves. PyTorch's `triangular_solve` function currently supports MPS inputs, so I should be able to mimic what all these these other `lu_solve` impls are doing.

In [1]:
import torch

def my_pivots_to_permutation(pivots, size, *, inverse=False):
    perm = torch.arange(size, dtype=torch.int32)
    indices = range(size)
    if inverse:
        indices = reversed(indices)

    for i in indices:
        j = pivots[i] - 1
        perm_i = perm[i].item()
        perm_j = perm[j].item()
        perm[i] = perm_j
        perm[j] = perm_i

    return perm

def my_lu_solve(lu, pivots, b, trans):
    m = lu.shape[0]
    x = b

    if trans == 0:
        perm = my_pivots_to_permutation(pivots, m)
        x = x[perm, :]
        x = torch.linalg.solve_triangular(lu, x, left=True, upper=False, unitriangular=True)
        x = torch.linalg.solve_triangular(lu, x, left=True, upper=True)

    elif trans == 1 or trans == 2:
        lu_ = lu.T
        if trans == 2:
            lu_ = lu_.conj()

        x = torch.linalg.solve_triangular(lu_, x, left=True, upper=False)
        x = torch.linalg.solve_triangular(lu_, x, left=True, upper=True, unitriangular=True)
        inv_perm = my_pivots_to_permutation(pivots, m, inverse=True)
        x = x[inv_perm, :]

    return x

In [2]:
import itertools

for adjoint, m in itertools.product([False, True], [4, 5]):
    print(f'{m=} {adjoint=}')
    for _ in range(100):
        a = torch.randn(m, m, dtype=torch.cfloat)
        b = torch.randn(m, m, dtype=torch.cfloat)
        lu, pivots = torch.linalg.lu_factor(a)

        r = my_lu_solve(lu, pivots, b, 2 if adjoint else 0)
        r_check = torch.linalg.lu_solve(lu, pivots, b, adjoint=adjoint)
        is_match = torch.allclose(r, r_check, atol=1e-4, rtol=1e-4)

        if not is_match:
            break

    if not is_match:
        print("  FAIL")
        print(f'{r=}')
        print(f'{r_check=}')
        raise RuntimeError("not a match")
    else:
        print("  OK")


m=4 adjoint=False
  OK
m=5 adjoint=False
  OK
m=4 adjoint=True
  OK
m=5 adjoint=True
  OK
