From d7b25908057a04587d6407baa2a5ba78357dc290 Mon Sep 17 00:00:00 2001 From: Rahul Batra Date: Wed, 13 Dec 2023 22:08:36 +0000 Subject: [PATCH] [ROCm]: Lower sparse(some) ops correctly for ROCm -Lower coo_spmv, coo_spmm, csr_spmv and csr_spmm correctly for ROCm --- jax/experimental/sparse/_lowerings.py | 61 +++++++++++++++++++-------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/jax/experimental/sparse/_lowerings.py b/jax/experimental/sparse/_lowerings.py index 832cdc44407b..f4fe0b9040e6 100644 --- a/jax/experimental/sparse/_lowerings.py +++ b/jax/experimental/sparse/_lowerings.py @@ -17,6 +17,8 @@ are used internally in GPU translation rules of higher-level primitives. """ +from functools import partial + from jax import core from jax._src import dispatch from jax._src.interpreters import mlir @@ -52,9 +54,9 @@ def _coo_spmv_abstract_eval(data, row, col, x, *, transpose, shape): shape=shape[1:] if transpose else shape[:1], dtype=x.dtype) -def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape): +def _coo_spmv_gpu_lowering(coo_spmv_hlo, ctx, data, row, col, x, *, transpose, shape): data_aval, row_aval, _, x_aval = ctx.avals_in - return [gpu_sparse.cuda_coo_matvec( + return [coo_spmv_hlo( data, row, col, x, shape=shape, transpose=transpose, @@ -65,9 +67,15 @@ def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape): coo_spmv_p.def_abstract_eval(_coo_spmv_abstract_eval) dispatch.simple_impl(coo_spmv_p) if gpu_sparse.cuda_is_supported: - mlir.register_lowering(coo_spmv_p, _coo_spmv_gpu_lowering, platform='cuda') + mlir.register_lowering( + coo_spmv_p, + partial(_coo_spmv_gpu_lowering, gpu_sparse.cuda_coo_matvec), + platform='cuda') if gpu_sparse.rocm_is_supported: - mlir.register_lowering(coo_spmv_p, _coo_spmv_gpu_lowering, platform='rocm') + mlir.register_lowering( + coo_spmv_p, + partial(_coo_spmv_gpu_lowering, gpu_sparse.rocm_coo_matvec), + platform='rocm') # coo_spmm_p @@ -95,9 +103,9 @@ def _coo_spmm_abstract_eval(data, row, col, x, *, transpose, shape): shape=(shape[1] if transpose else shape[0], x.shape[1]), dtype=x.dtype) -def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape): +def _coo_spmm_gpu_lowering(coo_spmm_hlo, ctx, data, row, col, x, *, transpose, shape): data_aval, row_aval, _, x_aval = ctx.avals_in - return [gpu_sparse.cuda_coo_matmat( + return [coo_spmm_hlo( data, row, col, x, shape=shape, transpose=transpose, @@ -108,9 +116,15 @@ def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape): coo_spmm_p.def_abstract_eval(_coo_spmm_abstract_eval) dispatch.simple_impl(coo_spmm_p) if gpu_sparse.cuda_is_supported: - mlir.register_lowering(coo_spmm_p, _coo_spmm_gpu_lowering, platform='cuda') + mlir.register_lowering( + coo_spmm_p, + partial(_coo_spmm_gpu_lowering, gpu_sparse.cuda_coo_matmat), + platform='cuda') if gpu_sparse.rocm_is_supported: - mlir.register_lowering(coo_spmm_p, _coo_spmm_gpu_lowering, platform='rocm') + mlir.register_lowering( + coo_spmm_p, + partial(_coo_spmm_gpu_lowering, gpu_sparse.rocm_coo_matmat), + platform='rocm') # csr_spmv_p # This is an internal-only primitive that calls into cusparse csr SpMV. @@ -137,9 +151,9 @@ def _csr_spmv_abstract_eval(data, indices, indptr, x, *, transpose, shape): shape=shape[1:] if transpose else shape[:1], dtype=x.dtype) -def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape): +def _csr_spmv_gpu_lowering(csr_spmv_hlo, ctx, data, indices, indptr, x, *, transpose, shape): data_aval, indices_aval, _, x_aval = ctx.avals_in - return [gpu_sparse.cuda_csr_matvec( + return [csr_spmv_hlo( data, indices, indptr, x, shape=shape, transpose=transpose, @@ -150,12 +164,17 @@ def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape): csr_spmv_p.def_abstract_eval(_csr_spmv_abstract_eval) dispatch.simple_impl(csr_spmv_p) if gpu_sparse.cuda_is_supported: - mlir.register_lowering(csr_spmv_p, _csr_spmv_gpu_lowering, platform='cuda') + mlir.register_lowering( + csr_spmv_p, + partial(_csr_spmv_gpu_lowering, gpu_sparse.cuda_csr_matvec), + platform='cuda') if gpu_sparse.rocm_is_supported: - mlir.register_lowering(csr_spmv_p, _csr_spmv_gpu_lowering, platform='rocm') - + mlir.register_lowering( + csr_spmv_p, + partial(_csr_spmv_gpu_lowering, gpu_sparse.rocm_csr_matvec), + platform='rocm') -# csr_spmm_p + # csr_spmm_p # This is an internal-only primitive that calls into cusparse CSR SpMM. # This is a raw lowering that does no validation of inputs; the indices are # assumed to be lexicographically sorted, deduplicated, and in-bounds. @@ -180,9 +199,9 @@ def _csr_spmm_abstract_eval(data, indices, indptr, x, *, transpose, shape): shape=(shape[1] if transpose else shape[0], x.shape[1]), dtype=x.dtype) -def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape): +def _csr_spmm_gpu_lowering(csr_spmm_hlo, ctx, data, indices, indptr, x, *, transpose, shape): data_aval, indices_aval, _, x_aval = ctx.avals_in - return [gpu_sparse.cuda_csr_matmat( + return [csr_spmm_hlo( data, indices, indptr, x, shape=shape, transpose=transpose, @@ -193,6 +212,12 @@ def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape): csr_spmm_p.def_abstract_eval(_csr_spmm_abstract_eval) dispatch.simple_impl(csr_spmm_p) if gpu_sparse.cuda_is_supported: - mlir.register_lowering(csr_spmm_p, _csr_spmm_gpu_lowering, platform='cuda') + mlir.register_lowering( + csr_spmm_p, + partial(_csr_spmm_gpu_lowering, gpu_sparse.cuda_csr_matmat), + platform='cuda') if gpu_sparse.rocm_is_supported: - mlir.register_lowering(csr_spmm_p, _csr_spmm_gpu_lowering, platform='rocm') + mlir.register_lowering( + csr_spmm_p, + partial(_csr_spmm_gpu_lowering, gpu_sparse.rocm_csr_matmat), + platform='rocm')