# Creating a Pytorch solver for sparse linear systems
About a year ago, I [implemented a sparse solver for PyTorch](solving-sparse-linear-systems-in-pytorch.ipynb). I want to pick up this subject again and attempt to do the same for JAX. However, doing the same in JAX seems to be a bit more painful...

Anyway... let's start the adventure. The goal of the discussion below is to define a custom JAX operation that solves the linear system of the form

$$
\begin{align}
Ax &= b
\end{align}
$$

Where $A$ is a sparse $m \times m$ matrix and where $x$ and $b$ are dense $m \times n $ matrices.

## Imports

In [None]:
import numpy as np
import scipy as sp

import jax
import jax.numpy as jnp
import jax.scipy as jsp

## Solving a system of equations

Let's pick up near the middle of the [PyTorch post](solving-sparse-linear-systems-in-pytorch.ipynb), where we derived the gradient rules for a matrix x matrix system where $x$ and $b$ are $m\times n$ matrices and $A$ is an $m\times m$ matrix:

$$
\begin{align*}
    \frac{\partial L}{\partial b}
        &= \frac{\partial L}{\partial x_{ij}} \frac{\partial x_{ij}}{\partial b_{kl}}  \\
        &= \frac{\partial L}{\partial x_{ij}} A^{-1}_{ik} \frac{\partial b_{kj}}{\partial b_{lm}}\\
        &= \frac{\partial L}{\partial x_{ij}} A^{-1}_{ik} \delta_{kl}\delta_{jm}\\
        &= \frac{\partial L}{\partial x_{im}} A^{-1}_{il}\\
        &= \big(A^{-1}\big)^{T} \frac{\partial L}{\partial x} \\
        &= \mathrm{solve}\big( A^T\,,\,\,  \frac{\partial L}{\partial x} \big) \\\\
\frac{\partial L}{\partial A} 
        &= \frac{\partial L}{\partial x_{ij}} \frac{\partial x_{ij}}{\partial A_{mn}}  \\
        &= \frac{\partial L}{\partial x_{ij}} \frac{\partial}{\partial A_{mn}} ( A^{-1}_{ik} b_{kj} ) \\
        &= -\frac{\partial L}{\partial x_{ij}} A^{-1}_{ik} \frac{\partial A_{kl}}{\partial A_{mn}} A^{-1}_{lp} b_{pj} \\
        &= -\frac{\partial L}{\partial x_{ij}} A^{-1}_{ik} \delta_{km} \delta_{ln} A^{-1}_{lp} b_{pj} \\
        &= -\frac{\partial L}{\partial x_{ij}} A^{-1}_{im} A^{-1}_{np} b_{pj} \\
        &= -\left(\big(A^{-1}\big)^{T} \frac{\partial L}{\partial x}\right)\left( A^{-1} b \right)^T \\
        &= -\frac{\partial L}{\partial b} x^T
\end{align*}
$$

We can define custom functions for this forward and backward operation by following the [JAX tutorial on primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).

In [None]:
sp.linalg.solve

In [None]:
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(42)

In [None]:
jax.random.normal(key, (3,3))

In [None]:
solve_prim = jax.core.Primitive("solve")

def solve(A, b):
    return solve_prim.bind(A, b)

def solve_impl(A, b):
    return sp.linalg.solve(A, b)

solve_prim.def_impl(solve_impl)

def solve_abstract_eval(A, b):
    assert A.ndim == 2
    assert b.ndim in (1, 2)
    assert A.shape[-1] == b.shape[0]
    assert A.dtype == b.dtype
    shape = A.shape[:-1] + b.shape[1:]
    return jax.abstract_arrays.ShapedArray(shape, A.dtype)

solve_prim.def_abstract_eval(solve_abstract_eval)

A = np.random.randn(3,3)
b = np.random.randn(3,2)
solve(A, b)

In [None]:
import jax.scipy as jsp
key_A, key_b = jax.random.split(jax.random.PRNGKey(42))
mask = jnp.array([[1,0,0],[1,1,0],[0,0,1]])
A = (mask * jax.random.normal(key_A, (3, 3)))
b = jax.random.normal(key_b, (3, 2))
x = jsp.linalg.solve(A, b)

In [None]:
sp.linalg.solve(A, b)

In [None]:
asdfasf

In [None]:
class Solve(Function):
    @staticmethod
    def forward(ctx, A, b):
        if A.ndim != 3 or (A.shape[1] != A.shape[2]):
            raise ValueError("A should be a batch of square 2D matrices with shape (b, m, m)")
        A_np = A.data.numpy()
        b_np = b.data.numpy()
        x_np = np.stack([scipy_solve(A_np[i], b_np[i]) for i in range(A.shape[0])], 0)
        x = torch.tensor(x_np, requires_grad=True)
        ctx.save_for_backward(A, b, x)
        return x

    @staticmethod
    def backward(ctx, grad):
        A, b, x = ctx.saved_tensors
        gradb = Solve.apply(A.transpose(-1,-2), grad)
        gradA = -torch.bmm(gradb, x.transpose(-1,-2))
        return gradA, gradb

solve = Solve.apply

In [None]:
A = torch.randn(4,3,3, requires_grad=True)
b = torch.randn(4,3,2, requires_grad=True)
gradcheck(solve, [A.double(), b.double()]) # gradcheck requires double precision

## Making a C++ extension

Let's now go on to make the above module in C++. However, we won't do this with the imported scipy function. For now, we'll just use `torch.inverse` when we need it. Just know, that after we get this simple C++ PyTorch extension to work, we'll swap out `torch.inverse` for our own C++ solver.

We'll start by making a file called `solve.cpp`, which includes the following two headers:
```cpp
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
```

Generally speaking, `torch/extension.h` implements equivalent C++ functions to what `torch` offers in python, while `ATen/ATen.h` offers Python Tensor *methods* as C++ *functions*. The `vector` header is needed when returning more than one tensor (as we'll do in the backward).

Hence the batched matrix x batched matrix forward can be implemented in C++ as follows (remember we'll swap out `torch::inverse` for an actual solver later):
```cpp
torch::Tensor solve_forward(torch::Tensor A, torch::Tensor b){
  auto result = torch::zeros_like(b);
  for (int i = 0; i < at::size(b, 0); i++){
      result[i] = torch::mm(torch::inverse(A[i]), b[i]); // we'll use an actual solver later.
  }
  return result;
}
```
 
Implementing the backward pass can also be done relatively easily:
```cpp
std::vector<torch::Tensor> solve_backward(torch::Tensor grad, torch::Tensor A, torch::Tensor b, torch::Tensor x){
    auto gradb = at::transpose(solve_forward(at::transpose(A, -1, -2), grad), -1, -2);
    auto gradA = -torch::bmm(at::transpose(gradb, -1, -2), at::transpose(x, -1, -2));
    return {gradA, gradb};
}
```
Notice that we're returning a *vector* of tensors here, hence why we needed to include `vector` earlier.

Finally, we need to register the C++ functions defined as python functions. This is done with pybind:
```cpp
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &solve_forward, "solve forward");
  m.def("backward", &solve_backward, "solve backward");
}
```

Here the macro `TORCH_EXTENSION_NAME` will be replaced during compilation to the name of the torch extension defined in the `setup.py` file. The `setup.py` file looks as follows:
```python
from setuptools import setup, Extension
from torch.utils import cpp_extension

solve_cpp = Extension(
    name="solve_cpp",
    sources=["solve.cpp"],
    include_dirs=cpp_extension.include_paths(),
    library_dirs=cpp_extension.library_paths(),
    extra_compile_args=[],
    libraries=[
        "c10",
        "torch",
        "torch_cpu",
        "torch_python",
    ],
    language="c++",
)

setup(
    name="solve",
    ext_modules=[solve_cpp],
    cmdclass={"build_ext": cpp_extension.BuildExtension},
)
```

The C++ extension can now be compiled as follows:
```bash
python setup.py install
```
Which will create a python executable (`.so` file on linux, `.pyd` file on windows) in your python's `site-packages` folder (i.e. it will be in your python path).

The only thing that's left is creating a thin wrapper in python:

In [None]:
import torch # always import torch BEFORE your custom torch extension
import solve_cpp # the custom torch extension we just created

class Solve(Function):
    @staticmethod
    def forward(ctx, A, b):
        if A.ndim != 3 or (A.shape[1] != A.shape[2]):
            raise ValueError("A should be a batch of square 2D matrices with shape (b, m, m)")
        if b.ndim != 3:
            raise ValueError("b should be a batch of matrices with shape (b, m, n)")
        x = solve_cpp.forward(A, b)
        ctx.save_for_backward(A, b, x)
        return x

    @staticmethod
    def backward(ctx, grad):
        A, b, x = ctx.saved_tensors
        gradA, gradb = solve_cpp.backward(grad, A, b, x)
        return gradA, gradb

solve = Solve.apply

In [None]:
A = torch.randn(4,3,3, requires_grad=True)
b = torch.randn(4,3,2)
gradcheck(solve, [A.double(), b.double()]) # gradcheck requires double precision

## Towards a sparse solver for CPU tensors

OK, so now we have a representation for the backward pass and we know how to make a Pytorch C++ extension. Let's go on to make an actual sparse solver.

To create such a sparse solver, we'll wrap the [SuiteSparse](https://github.com/DrTimothyAldenDavis/SuiteSparse) routines in our C++ extension. In particular, we'll wrap the [KLU algorithm](https://ufdc.ufl.edu/UFE0011721/00001) provided by this library. The KLU sparse linear system solver is a very efficient solver for sparse matrices that arise from circuit simulation netlists. This means it will be most efficient for very sparse systems with often only one element per column of the (connection-)matrix. Obviously, it can tackle different linear systems of equations as well, but that's what it was originally intended for.

To use the KLU algorithm, we'll have to add the `klu` header from the `SuiteSparse` library to the three headers we had in the C++ extension before.
```c++
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
#include <klu.h>
```

Note that with the [Anaconda distribution](https://www.anaconda.com/) you can simply do a
```sh
conda install suitesparse
```
to pull in all the SuiteSparse C++ libraries.

However, before we can go on to solving the sparse system with the KLU algorithm, there's one more hurdle to overcome:

### Sparse COO -> Sparse CSC

The KLU algorithm expects the sparse matrix to be in [CSC-format](https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_(CSC_or_CCS)) rather than [COO-format](https://en.wikipedia.org/wiki/Sparse_matrix#Coordinate_list_(COO)), the standard representation used in PyTorch. To do this conversion, we can have a look at how scipy does this conversion, and do something similar for our COO-pytorch tensors:

```c++

std::vector<at::Tensor> _coo_to_csc(int ncol, at::Tensor Ai, at::Tensor Aj, at::Tensor Ax) {
    int nnz = at::size(Ax, 0);
    at::TensorOptions options = at::TensorOptions().dtype(torch::kInt32).device(at::device_of(Ai));
    at::Tensor Bp = at::zeros(ncol+1, options);
    at::Tensor Bi = at::zeros_like(Ai);
    at::Tensor Bx = at::zeros_like(Ax);

    int* ai = Ai.data_ptr<int>();
    int* aj = Aj.data_ptr<int>();
    double* ax = Ax.data_ptr<double>();

    int* bp = Bp.data_ptr<int>();
    int* bi = Bi.data_ptr<int>();
    double* bx = Bx.data_ptr<double>();

    //compute number of non-zero entries per row of A
    for (int n = 0; n < nnz; n++) {
        bp[aj[n]] += 1;
    }

    //cumsum the nnz per row to get Bp
    int cumsum = 0;
    int temp = 0;
    for(int j = 0; j < ncol; j++) {
        temp = bp[j];
        bp[j] = cumsum;
        cumsum += temp;
    }
    bp[ncol] = nnz;

    //write Ai, Ax into Bi, Bx
    int col = 0;
    int dest = 0;
    for(int n = 0; n < nnz; n++) {
        col = aj[n];
        dest = bp[col];
        bi[dest] = ai[n];
        bx[dest] = ax[n];
        bp[col] += 1;
    }

    int last = 0;
    for(int i = 0; i <= ncol; i++) {
        temp = bp[i];
        bp[i] = last;
        last = temp;
    }

    return {Bp, Bi, Bx};
}

```

Note that we're converting the PyTorch CPU tensor to a C-array (by asking for its pointer). This means that this conversion will be CPU-only. However, performing this conversion on native pytorch tensors would be **a lot** slower.

This function returns three pytorch tensors: `Bp`: the column pointers, `Bi`: the indices in each column and `Bx`: the values of the sparse tensor.

### KLU Solver

Using these three vectors, one can define a KLU solver, by wrapping the KLU routines as follows:

```c++
void _klu_solve(at::Tensor Ap, at::Tensor Ai, at::Tensor Ax, at::Tensor b) {
    int ncol = at::size(Ap, 0) - 1;
    int nb = at::size(b, 0);
    int* ap = Ap.data_ptr<int>();
    int* ai = Ai.data_ptr<int>();
    double* ax = Ax.data_ptr<double>();
    double* bb = b.data_ptr<double>();
    klu_symbolic* Symbolic;
    klu_numeric* Numeric;
    klu_common Common;
    klu_defaults(&Common);
    Symbolic = klu_analyze(ncol, ap, ai, &Common);
    Numeric = klu_factor(ap, ai, ax, Symbolic, &Common);
    klu_solve(Symbolic, Numeric, ncol, nb/ncol, bb, &Common);
    klu_free_symbolic(&Symbolic, &Common);
    klu_free_numeric(&Numeric, &Common);
}
```

Using the KLU algorithms comes down to first doing a symbolic analyzation and factorization of the sparse matrix `A` (i.e. `Ap`, `Ai` and `Ax`), probably to determine the sparsity pattern after which the system is solved with `klu_solve`. Note that this is an inplace operation on b, i.e. after solving, the solution `x` will be in the `b` tensor.

### Updated Forward

Finally, we can update the forward method by using our `_klu_solve` wrapper in stead of `torch::inverse`:

```c++
at::Tensor solve_forward(at::Tensor A, at::Tensor b) {
    int p = at::size(b, 0);
    int m = at::size(b, 1);
    int n = at::size(b, 2);
    at::Tensor bflat = at::clone(at::reshape(at::transpose(b, 1, 2), {p, m*n}));
    at::Tensor Ax = at::reshape(A._values(), {p, -1});
    at::Tensor Ai = at::reshape(at::_cast_Int(A._indices()[1]), {p, -1});
    at::Tensor Aj = at::reshape(at::_cast_Int(A._indices()[2]), {p, -1});
    for (int i = 0; i < p; i++) {
        std::vector<at::Tensor> Ap_Ai_Ax = _coo_to_csc(m, Ai[i], Aj[i], Ax[i]);
        _klu_solve(Ap_Ai_Ax[0], Ap_Ai_Ax[1], Ap_Ai_Ax[2], bflat[i]); // result will be in bflat
    }
    return at::transpose(bflat.view({p,n,m}), 1, 2);
}
```

### Updated Setup

The `setup.py` for this extension also becomes a bit more complex, as the SuiteSparse libraries need to be included:

```python
import os
import glob
from setuptools import setup, Extension
from torch.utils import cpp_extension

libroot = os.path.dirname(os.path.dirname(os.__file__))
if os.name == "nt":  # Windows
    suitesparse_lib = os.path.join(libroot, "Library", "lib")
    suitesparse_include = os.path.join(libroot, "Library", "include", "suitesparse")
else:  # Linux / Mac OS
    suitesparse_lib = os.path.join(os.path.dirname(libroot), "lib")
    suitesparse_include = os.path.join(os.path.dirname(libroot), "include")

torch_sparse_solve_cpp = Extension(
    name="torch_sparse_solve_cpp",
    sources=["torch_sparse_solve.cpp"],
    include_dirs=[*cpp_extension.include_paths(), suitesparse_include],
    library_dirs=[*cpp_extension.library_paths(), suitesparse_lib],
    extra_compile_args=[],
    libraries=[
        "c10",
        "torch",
        "torch_cpu",
        "torch_python",
        "klu",
        "btf",
        "amd",
        "colamd",
        "suitesparseconfig",
    ],
    language="c++",
)

setup(
    name="torch_sparse_solve",
    ext_modules=[torch_sparse_solve_cpp],
    cmdclass={"build_ext": cpp_extension.BuildExtension},
)
```

### Updated python wrapper

We can update the python wrapper to include the newly built C++ extension:

In [None]:
from torch_sparse_solve_cpp import solve_forward, solve_backward

class Solve(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, b):
        if A.ndim != 3 or (A.shape[1] != A.shape[2]) or not A.is_sparse:
            raise ValueError(
                "'A' should be a batch of square 2D sparse matrices with shape (b, m, m)."
            )
        if b.ndim != 3:
            raise ValueError("'b' should be a batch of matrices with shape (b, m, n).")
        if not A.dtype == torch.float64:
            raise ValueError("'A' should be a sparse float64 tensor (for now). Please first convert to float64.")
        if not b.dtype == torch.float64:
            raise ValueError("'b' should be a float64 tensor (for now). Please first convert to float64")

        x = solve_forward(A, b)
        ctx.save_for_backward(A, b, x)
        return x

    @staticmethod
    def backward(ctx, grad):
        A, b, x = ctx.saved_tensors
        gradA, gradb = solve_backward(grad, A, b, x)
        return gradA, gradb
    
solve = Solve.apply

In [None]:
mask = torch.tensor([[[1, 0, 0], [1, 1, 0], [0, 0, 1]]], dtype=torch.float64)
A = mask * torch.randn(4, 3, 3, dtype=torch.float64)
Asp = A.to_sparse()
Asp.requires_grad_()
b = torch.randn(4, 3, 2, dtype=torch.float64, requires_grad=True)
gradcheck(solve, [Asp, b], check_sparse_nnz=True) 

## That's it!

Those were the steps I went through creating my first PyTorch C++ extension. Please check it out on GitHub: https://github.com/flaport/torch_sparse_solve and consider giving it a star 😉