Skip to content

Commit

Permalink
Add a batched QR decomposition implementation on GPU.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 449583027
  • Loading branch information
hawkinsp authored and jax authors committed May 18, 2022
1 parent 6110be4 commit bb08162
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 25 deletions.
47 changes: 28 additions & 19 deletions jax/_src/lax/linalg.py
Expand Up @@ -1300,26 +1300,31 @@ def _geqrf_batching_rule(batched_args, batch_dims):
def _geqrf_translation_rule(ctx, avals_in, avals_out, operand):
return xops.QrDecomposition(operand)

def _geqrf_cpu_gpu_lowering(geqrf_impl, ctx, a):
def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
a_aval, taus_aval = ctx.avals_out
*batch_dims, m, n = a_aval.shape
batch = prod(batch_dims)

if m == 0 or n == 0:
if batch == 0 or m == 0 or n == 0:
return mlir.full_like_aval(0, a_aval), mlir.full_like_aval(0, taus_aval)

a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
zeros = mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_mhlo(info_geqrf, zeros, "EQ", "SIGNED")
ok_a = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get((*batch_dims, 1, 1),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
a_out = _broadcasting_select_mhlo(ok_a, a_out, _nan_like_mhlo(a_aval))
ok_taus = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get((*batch_dims, 1,),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
taus = _broadcasting_select_mhlo(ok_taus, taus, _nan_like_mhlo(taus_aval))
if (batched_geqrf_impl is not None and batch > 1 and m // batch <= 128 and
n // batch <= 128):
a_out, taus = batched_geqrf_impl(a_aval.dtype, a)
else:
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
zeros = mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_mhlo(info_geqrf, zeros, "EQ", "SIGNED")
ok_a = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get((*batch_dims, 1, 1),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
a_out = _broadcasting_select_mhlo(ok_a, a_out, _nan_like_mhlo(a_aval))
ok_taus = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get((*batch_dims, 1,),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
taus = _broadcasting_select_mhlo(ok_taus, taus, _nan_like_mhlo(taus_aval))
return a_out, taus

geqrf_p = Primitive('geqrf')
Expand All @@ -1330,22 +1335,26 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, ctx, a):
xla.register_translation(geqrf_p, _geqrf_translation_rule)

mlir.register_lowering(
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_mhlo),
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_mhlo, None),
platform='cpu')
if gpu_solver is not None:
# TODO(phawkins): make cuda_geqrf_batched and rocm_geqrf_unbatched
# unconditional when jaxlib 0.3.11 is the minimum.
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf),
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf,
getattr(gpu_solver, 'cuda_geqrf_batched', None)),
platform='cuda')
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf),
partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf,
getattr(gpu_solver, 'rocm_geqrf_batched', None)),
platform='rocm')

if solver_apis is not None:
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, solver_apis.geqrf_mhlo),
partial(_geqrf_cpu_gpu_lowering, solver_apis.geqrf_mhlo, None),
platform='gpu')


Expand Down
2 changes: 2 additions & 0 deletions jaxlib/cuda/BUILD
Expand Up @@ -61,7 +61,9 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@local_config_cuda//cuda:cublas_headers",
"@local_config_cuda//cuda:cuda_headers",
Expand Down
10 changes: 10 additions & 0 deletions jaxlib/cuda/cublas.cc
Expand Up @@ -76,17 +76,27 @@ std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})};
}

// Returns the descriptor for a GetrfBatched operation.
std::pair<size_t, py::bytes> BuildGeqrfBatchedDescriptor(const py::dtype& dtype,
int b, int m, int n) {
CublasType type = DtypeToCublasType(dtype);
size_t size = b * sizeof(void*);
return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})};
}

py::dict Registrations() {
py::dict dict;
dict["cublas_trsm_batched"] = EncapsulateFunction(TrsmBatched);
dict["cublas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
dict["cublas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
return dict;
}

PYBIND11_MODULE(_cublas, m) {
m.def("registrations", &Registrations);
m.def("build_trsm_batched_descriptor", &BuildTrsmBatchedDescriptor);
m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor);
m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor);
}

} // namespace
Expand Down
88 changes: 88 additions & 0 deletions jaxlib/cuda/cublas_kernels.cc
Expand Up @@ -22,6 +22,8 @@ limitations under the License.

#include "absl/base/casts.h"
#include "absl/base/thread_annotations.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cuda.h"
Expand Down Expand Up @@ -218,4 +220,90 @@ void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
}
}

// Batched QR decomposition: geqrfbatched

static absl::Status GeqrfBatched_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GeqrfBatchedDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GeqrfBatchedDescriptor& d = **s;
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[0] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.m * d.n,
cudaMemcpyDeviceToDevice, stream)));
}

std::vector<int> info(d.batch);
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
SizeOfCublasType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
auto tau_ptrs_host =
MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
SizeOfCublasType(d.type) * std::min(d.m, d.n));
JAX_RETURN_IF_ERROR(tau_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream)));
switch (d.type) {
case CublasType::F32: {
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
float** tau_batch_ptrs = static_cast<float**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cublasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
tau_batch_ptrs, info.data(), d.batch)));
break;
}
case CublasType::F64: {
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
double** tau_batch_ptrs = static_cast<double**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cublasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
tau_batch_ptrs, info.data(), d.batch)));
break;
}
case CublasType::C64: {
cuComplex** a_batch_ptrs = static_cast<cuComplex**>(buffers[3]);
cuComplex** tau_batch_ptrs = static_cast<cuComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cublasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
tau_batch_ptrs, info.data(), d.batch)));
break;
}
case CublasType::C128: {
cuDoubleComplex** a_batch_ptrs =
static_cast<cuDoubleComplex**>(buffers[3]);
cuDoubleComplex** tau_batch_ptrs =
static_cast<cuDoubleComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cublasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
tau_batch_ptrs, info.data(), d.batch)));
break;
}
}
auto it =
std::find_if(info.begin(), info.end(), [](int i) { return i != 0; });

if (it != info.end()) {
return absl::InvalidArgumentError(
absl::StrFormat("QR decomposition failed with status %d for batch "
"element %d",
*it, std::distance(info.begin(), it)));
}

return absl::OkStatus();
}

void GeqrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = GeqrfBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}

} // namespace jax
12 changes: 12 additions & 0 deletions jaxlib/cuda/cublas_kernels.h
Expand Up @@ -58,6 +58,18 @@ struct GetrfBatchedDescriptor {
void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);


// Batched QR decomposition: geqrfbatched

struct GeqrfBatchedDescriptor {
CublasType type;
int batch, m, n;
};

void GeqrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);


} // namespace jax

#endif // JAXLIB_CUBLAS_KERNELS_H_
36 changes: 36 additions & 0 deletions jaxlib/gpu_solver.py
Expand Up @@ -223,6 +223,42 @@ def _geqrf_mhlo(platform, gpu_solver, dtype, a):
cuda_geqrf = partial(_geqrf_mhlo, "cu", _cusolver)
rocm_geqrf = partial(_geqrf_mhlo, "hip", _hipsolver)

def _geqrf_batched_mhlo(platform, gpu_blas, dtype, a):
"""Batched QR decomposition."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)

lwork, opaque = gpu_blas.build_geqrf_batched_descriptor(
np.dtype(dtype), batch, m, n)

layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
out = custom_call(
f"{platform}blas_geqrf_batched",
[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), a_type.element_type),
ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)),
ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)),
],
[a],
backend_config=opaque,
operand_layouts=[layout],
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
[0],
[0],
])
return out[:2]

cuda_geqrf_batched = partial(_geqrf_batched_mhlo, "cu", _cublas)
rocm_geqrf_batched = partial(_geqrf_batched_mhlo, "hip", _hipblas)


def _orgqr_mhlo(platform, gpu_solver, dtype, a, tau):
"""Product of elementary Householder reflections."""
Expand Down
2 changes: 2 additions & 0 deletions jaxlib/rocm/BUILD.bazel
Expand Up @@ -57,7 +57,9 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:hipblas",
"@local_config_rocm//rocm:rocm_headers",
Expand Down
10 changes: 10 additions & 0 deletions jaxlib/rocm/hipblas.cc
Expand Up @@ -76,17 +76,27 @@ std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})};
}

// Returns the descriptor for a GetrfBatched operation.
std::pair<size_t, py::bytes> BuildGeqrfBatchedDescriptor(const py::dtype& dtype,
int b, int m, int n) {
HipblasType type = DtypeToHipblasType(dtype);
size_t size = b * sizeof(void*);
return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})};
}

py::dict Registrations() {
py::dict dict;
dict["hipblas_trsm_batched"] = EncapsulateFunction(TrsmBatched);
dict["hipblas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
dict["hipblas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
return dict;
}

PYBIND11_MODULE(_hipblas, m) {
m.def("registrations", &Registrations);
m.def("build_trsm_batched_descriptor", &BuildTrsmBatchedDescriptor);
m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor);
m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor);
}

} // namespace
Expand Down

0 comments on commit bb08162

Please sign in to comment.