Skip to content

Commit

Permalink
[ROCm]: Lower sparse(some) ops correctly for ROCm
Browse files Browse the repository at this point in the history
	-Lower coo_spmv, coo_spmm, csr_spmv and csr_spmm
	correctly for ROCm
  • Loading branch information
Rahul Batra committed Dec 18, 2023
1 parent 29ed3cd commit d7b2590
Showing 1 changed file with 43 additions and 18 deletions.
61 changes: 43 additions & 18 deletions jax/experimental/sparse/_lowerings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
are used internally in GPU translation rules of higher-level primitives.
"""

from functools import partial

from jax import core
from jax._src import dispatch
from jax._src.interpreters import mlir
Expand Down Expand Up @@ -52,9 +54,9 @@ def _coo_spmv_abstract_eval(data, row, col, x, *, transpose, shape):
shape=shape[1:] if transpose else shape[:1],
dtype=x.dtype)

def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape):
def _coo_spmv_gpu_lowering(coo_spmv_hlo, ctx, data, row, col, x, *, transpose, shape):
data_aval, row_aval, _, x_aval = ctx.avals_in
return [gpu_sparse.cuda_coo_matvec(
return [coo_spmv_hlo(
data, row, col, x,
shape=shape,
transpose=transpose,
Expand All @@ -65,9 +67,15 @@ def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape):
coo_spmv_p.def_abstract_eval(_coo_spmv_abstract_eval)
dispatch.simple_impl(coo_spmv_p)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(coo_spmv_p, _coo_spmv_gpu_lowering, platform='cuda')
mlir.register_lowering(
coo_spmv_p,
partial(_coo_spmv_gpu_lowering, gpu_sparse.cuda_coo_matvec),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(coo_spmv_p, _coo_spmv_gpu_lowering, platform='rocm')
mlir.register_lowering(
coo_spmv_p,
partial(_coo_spmv_gpu_lowering, gpu_sparse.rocm_coo_matvec),
platform='rocm')


# coo_spmm_p
Expand Down Expand Up @@ -95,9 +103,9 @@ def _coo_spmm_abstract_eval(data, row, col, x, *, transpose, shape):
shape=(shape[1] if transpose else shape[0], x.shape[1]),
dtype=x.dtype)

def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape):
def _coo_spmm_gpu_lowering(coo_spmm_hlo, ctx, data, row, col, x, *, transpose, shape):
data_aval, row_aval, _, x_aval = ctx.avals_in
return [gpu_sparse.cuda_coo_matmat(
return [coo_spmm_hlo(
data, row, col, x,
shape=shape,
transpose=transpose,
Expand All @@ -108,9 +116,15 @@ def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape):
coo_spmm_p.def_abstract_eval(_coo_spmm_abstract_eval)
dispatch.simple_impl(coo_spmm_p)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(coo_spmm_p, _coo_spmm_gpu_lowering, platform='cuda')
mlir.register_lowering(
coo_spmm_p,
partial(_coo_spmm_gpu_lowering, gpu_sparse.cuda_coo_matmat),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(coo_spmm_p, _coo_spmm_gpu_lowering, platform='rocm')
mlir.register_lowering(
coo_spmm_p,
partial(_coo_spmm_gpu_lowering, gpu_sparse.rocm_coo_matmat),
platform='rocm')

# csr_spmv_p
# This is an internal-only primitive that calls into cusparse csr SpMV.
Expand All @@ -137,9 +151,9 @@ def _csr_spmv_abstract_eval(data, indices, indptr, x, *, transpose, shape):
shape=shape[1:] if transpose else shape[:1],
dtype=x.dtype)

def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape):
def _csr_spmv_gpu_lowering(csr_spmv_hlo, ctx, data, indices, indptr, x, *, transpose, shape):
data_aval, indices_aval, _, x_aval = ctx.avals_in
return [gpu_sparse.cuda_csr_matvec(
return [csr_spmv_hlo(
data, indices, indptr, x,
shape=shape,
transpose=transpose,
Expand All @@ -150,12 +164,17 @@ def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape):
csr_spmv_p.def_abstract_eval(_csr_spmv_abstract_eval)
dispatch.simple_impl(csr_spmv_p)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(csr_spmv_p, _csr_spmv_gpu_lowering, platform='cuda')
mlir.register_lowering(
csr_spmv_p,
partial(_csr_spmv_gpu_lowering, gpu_sparse.cuda_csr_matvec),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(csr_spmv_p, _csr_spmv_gpu_lowering, platform='rocm')

mlir.register_lowering(
csr_spmv_p,
partial(_csr_spmv_gpu_lowering, gpu_sparse.rocm_csr_matvec),
platform='rocm')

# csr_spmm_p
# csr_spmm_p
# This is an internal-only primitive that calls into cusparse CSR SpMM.
# This is a raw lowering that does no validation of inputs; the indices are
# assumed to be lexicographically sorted, deduplicated, and in-bounds.
Expand All @@ -180,9 +199,9 @@ def _csr_spmm_abstract_eval(data, indices, indptr, x, *, transpose, shape):
shape=(shape[1] if transpose else shape[0], x.shape[1]),
dtype=x.dtype)

def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape):
def _csr_spmm_gpu_lowering(csr_spmm_hlo, ctx, data, indices, indptr, x, *, transpose, shape):
data_aval, indices_aval, _, x_aval = ctx.avals_in
return [gpu_sparse.cuda_csr_matmat(
return [csr_spmm_hlo(
data, indices, indptr, x,
shape=shape,
transpose=transpose,
Expand All @@ -193,6 +212,12 @@ def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape):
csr_spmm_p.def_abstract_eval(_csr_spmm_abstract_eval)
dispatch.simple_impl(csr_spmm_p)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(csr_spmm_p, _csr_spmm_gpu_lowering, platform='cuda')
mlir.register_lowering(
csr_spmm_p,
partial(_csr_spmm_gpu_lowering, gpu_sparse.cuda_csr_matmat),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(csr_spmm_p, _csr_spmm_gpu_lowering, platform='rocm')
mlir.register_lowering(
csr_spmm_p,
partial(_csr_spmm_gpu_lowering, gpu_sparse.rocm_csr_matmat),
platform='rocm')

0 comments on commit d7b2590

Please sign in to comment.