Skip to content

Commit

Permalink
Revert: Sparse direct solver using QR factorization from cuSOLVER. Th…
Browse files Browse the repository at this point in the history
…is is the jaxlib implementation. We will want to combine this with the sparse libraries already existing in JAX.

Reason: Breaks JAX tests.
PiperOrigin-RevId: 468346430
  • Loading branch information
hawkinsp authored and jax authors committed Aug 18, 2022
1 parent 2bc3e39 commit 3bb0030
Show file tree
Hide file tree
Showing 7 changed files with 0 additions and 257 deletions.
52 changes: 0 additions & 52 deletions jax/experimental/sparse/linalg.py
Expand Up @@ -20,12 +20,6 @@
import jax
import jax.numpy as jnp

from jax import core
from jax.interpreters import mlir
from jax.interpreters import xla

from jax._src.lib import gpu_solver

import numpy as np

def lobpcg_standard(
Expand Down Expand Up @@ -507,49 +501,3 @@ def _extend_basis(X, m):
h = -2 * jnp.linalg.multi_dot(
[w, w[k:, :].T, other], precision=jax.lax.Precision.HIGHEST)
return h.at[k:].add(other)


# Sparse direct solve via QR factorization


def _spsolve_abstract_eval(data, indices, indptr, b, tol, reorder):
del data, indices, indptr, tol, reorder
return core.raise_to_shaped(b)


def _spsolve_gpu_lowering(ctx, data, indices, indptr, b, tol, reorder):
data_aval, _, _, _, = ctx.avals_in

return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices,
indptr, b, tol, reorder)


spsolve_p = core.Primitive('spsolve')
spsolve_p.def_impl(functools.partial(xla.apply_primitive, spsolve_p))
spsolve_p.def_abstract_eval(_spsolve_abstract_eval)
mlir.register_lowering(spsolve_p, _spsolve_gpu_lowering, platform='cuda')


def spsolve(data, indices, indptr, b, tol=1e-6, reorder=1):
"""A sparse direct solver using QR factorization.
Accepts a sparse matrix in CSR format `data, indices, indptr` arrays.
Currently only the CUDA GPU backend is implemented.
Args:
data : An array containing the non-zero entries of the CSR matrix.
indices : The column indices of the CSR matrix.
indptr : The row pointer array of the CSR matrix.
b : The right hand side of the linear system.
tol : Tolerance to decide if singular or not. Defaults to 1e-6.
reorder : The reordering scheme to use to reduce fill-in. No reordering if
`reorder=0'. Otherwise, symrcm, symamd, or csrmetisnd (`reorder=1,2,3'),
respectively. Defaults to symrcm.
Returns:
An array with the same dtype and size as b representing the solution to
the sparse linear system.
"""
if jax._src.lib.xla_extension_version < 86:
raise ValueError('spsolve requires jaxlib version 86 or above.')
return spsolve_p.bind(data, indices, indptr, b, tol=tol, reorder=reorder)
1 change: 0 additions & 1 deletion jaxlib/cuda/cuda_gpu_kernels.cc
Expand Up @@ -37,7 +37,6 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_threefry2x32", CudaThreeFry2x32,
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_potrf", Potrf, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA");
Expand Down
15 changes: 0 additions & 15 deletions jaxlib/cuda/cusolver.cc
Expand Up @@ -182,19 +182,6 @@ std::pair<int, py::bytes> BuildGeqrfDescriptor(const py::dtype& dtype, int b,
return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})};
}

// csrlsvqr: Linear system solve via Sparse QR

// Returns a descriptor for a csrlsvqr operation.
py::bytes BuildCsrlsvqrDescriptor(const py::dtype& dtype, int n, int nnzA,
int reorder, double tol) {
CusolverType type = DtypeToCusolverType(dtype);
auto h = SpSolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;

return PackDescriptor(CsrlsvqrDescriptor{type, n, nnzA, reorder, tol});
}

// orgqr/ungqr: apply elementary Householder transformations

// Returns the workspace size and a descriptor for a geqrf operation.
Expand Down Expand Up @@ -476,7 +463,6 @@ py::dict Registrations() {
dict["cusolver_potrf"] = EncapsulateFunction(Potrf);
dict["cusolver_getrf"] = EncapsulateFunction(Getrf);
dict["cusolver_geqrf"] = EncapsulateFunction(Geqrf);
dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr);
dict["cusolver_orgqr"] = EncapsulateFunction(Orgqr);
dict["cusolver_syevd"] = EncapsulateFunction(Syevd);
dict["cusolver_syevj"] = EncapsulateFunction(Syevj);
Expand All @@ -490,7 +476,6 @@ PYBIND11_MODULE(_cusolver, m) {
m.def("build_potrf_descriptor", &BuildPotrfDescriptor);
m.def("build_getrf_descriptor", &BuildGetrfDescriptor);
m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor);
m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor);
m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor);
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
m.def("build_syevj_descriptor", &BuildSyevjDescriptor);
Expand Down
116 changes: 0 additions & 116 deletions jaxlib/cuda/cusolver_kernels.cc
Expand Up @@ -27,7 +27,6 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#include "third_party/gpus/cuda/include/cusolverSp.h"
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_helpers.h"
Expand All @@ -53,24 +52,6 @@ template <>
return Handle(pool, handle, stream);
}

template <>
/*static*/ absl::StatusOr<SpSolverHandlePool::Handle>
SpSolverHandlePool::Borrow(cudaStream_t stream) {
SpSolverHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
cusolverSpHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}

static int SizeOfCusolverType(CusolverType type) {
switch (type) {
case CusolverType::F32:
Expand Down Expand Up @@ -351,103 +332,6 @@ void Geqrf(cudaStream_t stream, void** buffers, const char* opaque,
}
}

// csrlsvqr: Linear system solve via Sparse QR

static absl::Status Csrlsvqr_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len,
int& singularity) {
auto s = UnpackDescriptor<CsrlsvqrDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CsrlsvqrDescriptor& d = **s;

// This is the handle to the CUDA session. Gets a cusolverSp handle.
auto h = SpSolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;

cusparseMatDescr_t matdesc = nullptr;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateMatDescr(&matdesc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseSetMatType(matdesc, CUSPARSE_MATRIX_TYPE_GENERAL)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseSetMatIndexBase(matdesc, CUSPARSE_INDEX_BASE_ZERO)));

switch (d.type) {
case CusolverType::F32: {
float* csrValA = static_cast<float*>(buffers[0]);
int* csrRowPtrA = static_cast<int*>(buffers[1]);
int* csrColIndA = static_cast<int*>(buffers[2]);
float* b = static_cast<float*>(buffers[3]);
float* x = static_cast<float*>(buffers[4]);

JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpScsrlsvqr(
handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA,
b, (float)d.tol, d.reorder, x, &singularity)));

break;
}
case CusolverType::F64: {
double* csrValA = static_cast<double*>(buffers[0]);
int* csrRowPtrA = static_cast<int*>(buffers[1]);
int* csrColIndA = static_cast<int*>(buffers[2]);
double* b = static_cast<double*>(buffers[3]);
double* x = static_cast<double*>(buffers[4]);

JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpDcsrlsvqr(
handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA,
b, d.tol, d.reorder, x, &singularity)));

break;
}
case CusolverType::C64: {
cuComplex* csrValA = static_cast<cuComplex*>(buffers[0]);
int* csrRowPtrA = static_cast<int*>(buffers[1]);
int* csrColIndA = static_cast<int*>(buffers[2]);
cuComplex* b = static_cast<cuComplex*>(buffers[3]);
cuComplex* x = static_cast<cuComplex*>(buffers[4]);

JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpCcsrlsvqr(
handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA,
b, (float)d.tol, d.reorder, x, &singularity)));

break;
}
case CusolverType::C128: {
cuDoubleComplex* csrValA = static_cast<cuDoubleComplex*>(buffers[0]);
int* csrRowPtrA = static_cast<int*>(buffers[1]);
int* csrColIndA = static_cast<int*>(buffers[2]);
cuDoubleComplex* b = static_cast<cuDoubleComplex*>(buffers[3]);
cuDoubleComplex* x = static_cast<cuDoubleComplex*>(buffers[4]);

JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpZcsrlsvqr(
handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA,
b, (float)d.tol, d.reorder, x, &singularity)));

break;
}
}

cusparseDestroyMatDescr(matdesc);
return absl::OkStatus();
}

void Csrlsvqr(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
// Is >= 0 if A is singular.
int singularity = -1;

auto s = Csrlsvqr_(stream, buffers, opaque, opaque_len, singularity);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}

if (singularity >= 0) {
auto s = std::string("Singular matrix in linear solve.");
XlaCustomCallStatusSetFailure(status, s.c_str(), s.length());
}
}

// orgqr/ungqr: apply elementary Householder transformations

static absl::Status Orgqr_(cudaStream_t stream, void** buffers,
Expand Down
17 changes: 0 additions & 17 deletions jaxlib/cuda/cusolver_kernels.h
Expand Up @@ -20,23 +20,17 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#include "third_party/gpus/cuda/include/cusolverSp.h"
#include "jaxlib/handle_pool.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"

namespace jax {

using SolverHandlePool = HandlePool<cusolverDnHandle_t, cudaStream_t>;
using SpSolverHandlePool = HandlePool<cusolverSpHandle_t, cudaStream_t>;

template <>
absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(
cudaStream_t stream);

template <>
absl::StatusOr<SpSolverHandlePool::Handle> SpSolverHandlePool::Borrow(
cudaStream_t stream);

// Set of types known to Cusolver.
enum class CusolverType {
F32,
Expand Down Expand Up @@ -76,17 +70,6 @@ struct GeqrfDescriptor {
void Geqrf(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);

// csrlsvpr: Linear system solve via Sparse QR

struct CsrlsvqrDescriptor {
CusolverType type;
int n, nnz, reorder;
double tol;
};

void Csrlsvqr(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);

// orgqr/ungqr: apply elementary Householder transformations

struct OrgqrDescriptor {
Expand Down
23 changes: 0 additions & 23 deletions jaxlib/gpu_solver.py
Expand Up @@ -259,29 +259,6 @@ def _geqrf_batched_mhlo(platform, gpu_blas, dtype, a):
cuda_geqrf_batched = partial(_geqrf_batched_mhlo, "cu", _cublas)
rocm_geqrf_batched = partial(_geqrf_batched_mhlo, "hip", _hipblas)

def _csrlsvqr_mhlo(platform, gpu_solver, dtype, data,
indices, indptr, b, tol, reorder):
"""Sparse solver via QR decomposition. CUDA only."""
b_type = ir.RankedTensorType(b.type)
data_type = ir.RankedTensorType(data.type)

n = b_type.shape[0]
nnz = data_type.shape[0]
opaque = gpu_solver.build_csrlsvqr_descriptor(
np.dtype(dtype), n, nnz, reorder, tol
)

out = custom_call(
f"{platform}solver_csrlsvqr", # call_target_name
[b.type], # out_types
[data, indptr, indices, b], # operands
backend_config=opaque, # backend_config
operand_layouts=[(0,), (0,), (0,), (0,)], # operand_layouts
result_layouts=[(0,)] # result_layouts
)
return [out]

cuda_csrlsvqr = partial(_csrlsvqr_mhlo, "cu", _cusolver)

def _orgqr_mhlo(platform, gpu_solver, dtype, a, tau):
"""Product of elementary Householder reflections."""
Expand Down
33 changes: 0 additions & 33 deletions tests/sparse_test.py
Expand Up @@ -2319,38 +2319,5 @@ def test_random_bcoo(self, shape, dtype, indices_dtype, n_batch, n_dense):
self.assertAlmostEqual(int(num_nonzero), approx_expected_num_nonzero, delta=2)


class SparseSolverTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_re{}_({})".format(reorder,
jtu.format_shape_dtype_string((size, size), dtype)),
"size": size, "reorder": reorder, "dtype": dtype}
for size in [20, 50, 100]
for reorder in [0, 1, 2, 3]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
@unittest.skipIf(jax._src.lib.xla_extension_version < 86, "test requires jaxlib version 86")
@jtu.skip_on_devices("rocm")
def test_sparse_qr_linear_solver(self, size, reorder, dtype):
rng = rand_sparse(self.rng())
a = rng((size, size), dtype)
nse = (a != 0).sum()
data, indices, indptr = sparse.csr_fromdense(a, nse=nse)

rng_k = jtu.rand_default(self.rng())
b = rng_k([size], dtype)

def args_maker():
return data, indices, indptr, b

tol = 1e-8
def sparse_solve(data, indices, indptr, b):
return sparse.linalg.spsolve(data, indices, indptr, b, tol, reorder)
x = sparse_solve(data, indices, indptr, b)

self.assertAllClose(a @ x, b, rtol=1e-2)
self._CompileAndCheck(sparse_solve, args_maker)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 3bb0030

Please sign in to comment.