-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add JAX API that provides sparse matmul support (2:4 structured spars…
…ity) Usage: from jax.experimental.sparse import nm res = nm.nm_spmm(lhs, rhs, nm.nm_pack(mask)) where: lhs.shape = [M, K/2] rhs.shape = [K, N] `mask` has the same shape as `lhs` with boolean type If batch dimensions are present, the `dimension_numbers` argument has to be set to: ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) The lowering only works on nVidia GPUs, that provide hardware support for sparse dots. PiperOrigin-RevId: 627640553
- Loading branch information
1 parent
b5fdc0d
commit aebe82a
Showing
3 changed files
with
458 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""N:M-sparsity associated primitives.""" | ||
|
||
from jax import core | ||
from jax._src import dispatch | ||
from jax._src.lax.lax import DotDimensionNumbers | ||
from jax._src.lib import gpu_sparse | ||
from jax._src.lib.mlir.dialects import mhlo | ||
from jax._src.typing import Array, DTypeLike | ||
from jax.interpreters import mlir | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
# -------------------------------------------------------------------- | ||
# nm_spmm | ||
|
||
nm_spmm_p = core.Primitive("sparse_dense_matmul") | ||
|
||
_supported_input_types = (jnp.int8, jnp.int16, jnp.float16, jnp.bfloat16) | ||
_supported_output_types = (jnp.bfloat16, jnp.float32) | ||
|
||
|
||
def nm_spmm( | ||
lhs: Array, | ||
rhs: Array, | ||
metadata: Array, | ||
dimension_numbers: DotDimensionNumbers = (((1,), (0,)), (tuple(), tuple())), | ||
sparse_operand_idx: int = 0, | ||
output_dtype: DTypeLike = jnp.bfloat16, | ||
) -> Array: | ||
"""Dot operation where one of the operands has N:M sparsity. | ||
Args: | ||
lhs: An ndarray (first dot operand). | ||
rhs: An ndarray (second dot operand). | ||
metadata: An ndarray with structured sparsity metadata for the contracting | ||
dimension. For 2:4 sparsity it should contain (N=2) two-bit index values | ||
for each (M=4) element group. | ||
dimension_numbers: a tuple of tuples of the form `((lhs_contracting_dims, | ||
rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. | ||
sparse_operand_idx: index of the sparse operand (0 or 1). | ||
output_dtype: result type. | ||
Returns: | ||
An ndarray dense array containing the result. | ||
""" | ||
return nm_spmm_p.bind( | ||
lhs, | ||
rhs, | ||
metadata, | ||
dimension_numbers=dimension_numbers, | ||
sparse_operand_idx=sparse_operand_idx, | ||
output_dtype=output_dtype, | ||
) | ||
|
||
|
||
def _calc_groups_per_element(n, m): | ||
group_bits = n * (m.bit_length() - 1) # 4 bits per group for 2:4 | ||
return 16 // group_bits | ||
|
||
|
||
def _validate_dnums(rank, contract, batch, name): | ||
non_contract = tuple(sorted(set(range(rank)) - set(contract + batch))) | ||
if sorted(non_contract + contract + batch) != list(range(rank)): | ||
raise TypeError(f"Incorrect dimension numbers for {name}") | ||
return non_contract | ||
|
||
|
||
def _validate_metadata(lhs, rhs, metadata, dimension_numbers, index, n=2, m=4): | ||
assert index in (0, 1) | ||
size_factor = n * _calc_groups_per_element(n, m) | ||
|
||
sparse = [lhs, rhs][index] | ||
sparse_contract = dimension_numbers[0][index] | ||
if metadata.dtype != np.uint16: | ||
raise TypeError(f"Metadata must be uint16, got {metadata.dtype}") | ||
if sparse_contract[0] != sparse.ndim - 1: | ||
raise TypeError("Contracting dimension must be the minor one") | ||
if metadata.shape[:-1] != sparse.shape[:-1]: | ||
raise TypeError( | ||
"Metadata shape must match the operand shape (except for the" | ||
" contracting dimension)" | ||
) | ||
if metadata.shape[-1] * size_factor != sparse.shape[-1]: | ||
raise TypeError( | ||
f"Metadata must be exactly {size_factor} times less than the" | ||
f" contracting dimension for {n}:{m} structured sparsity (expected" | ||
f" {sparse.shape[-1] // size_factor}, got {metadata.shape[-1]})" | ||
) | ||
if sparse.shape[-1] % size_factor != 0: | ||
raise NotImplementedError("Metadata with padding is not supported") | ||
|
||
dense = [lhs, rhs][1 - index] | ||
dense_contract = dimension_numbers[0][1 - index] | ||
a, b = sparse.shape[sparse_contract[0]], dense.shape[dense_contract[0]] | ||
if n * b != m * a: | ||
raise TypeError( | ||
f"Contracting dimension sizes should have {n}:{m} ratio, got {a}:{b}" | ||
) | ||
|
||
|
||
def _infer_result_shape(lhs, rhs, dimension_numbers): | ||
((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers | ||
if len(lhs_contract) != 1 or len(rhs_contract) != 1: | ||
raise TypeError("Only single contracting dimension is supported") | ||
lhs_dims = _validate_dnums(lhs.ndim, lhs_contract, lhs_batch, "lhs") | ||
rhs_dims = _validate_dnums(rhs.ndim, rhs_contract, rhs_batch, "rhs") | ||
if len(lhs_dims) != 1 or len(rhs_dims) != 1: | ||
raise TypeError("Only single non-contracting dimension is supported") | ||
batch = [lhs.shape[i] for i in lhs_batch] | ||
if batch != [rhs.shape[i] for i in rhs_batch]: | ||
raise TypeError("Batch dimension sizes do not match") | ||
return tuple(batch + [lhs.shape[lhs_dims[0]], rhs.shape[rhs_dims[0]]]) | ||
|
||
|
||
def _nm_spmm_default_lowering(*_args, **_kwargs): | ||
raise NotImplementedError("Sparse N:M matmul is only implemented on GPU") | ||
|
||
|
||
def _nm_spmm_gpu_lowering( | ||
ctx, | ||
lhs, | ||
rhs, | ||
metadata, | ||
*, | ||
dimension_numbers, | ||
sparse_operand_idx, | ||
output_dtype, | ||
): | ||
assert sparse_operand_idx in (0, 1) | ||
sparsity_descriptor = mhlo.SparsityDescriptor.get( | ||
dimension=dimension_numbers[0][sparse_operand_idx][0], n=2, m=4 | ||
) | ||
dot_dnums = mhlo.DotDimensionNumbers.get( | ||
lhs_batching_dimensions=dimension_numbers[1][sparse_operand_idx], | ||
rhs_batching_dimensions=dimension_numbers[1][1 - sparse_operand_idx], | ||
lhs_contracting_dimensions=dimension_numbers[0][sparse_operand_idx], | ||
rhs_contracting_dimensions=dimension_numbers[0][1 - sparse_operand_idx], | ||
) | ||
dot_type = ctx.avals_out[0] | ||
key = ["lhs_sparsity", "rhs_sparsity"][sparse_operand_idx] | ||
kwargs = {key: sparsity_descriptor} | ||
op = mhlo.SparseDotOp( | ||
mlir.aval_to_ir_type(dot_type), lhs, rhs, [metadata], dot_dnums, **kwargs | ||
) | ||
return op.results | ||
|
||
|
||
@nm_spmm_p.def_abstract_eval | ||
def _nm_spmm_abstract_eval( | ||
lhs, rhs, metadata, *, dimension_numbers, sparse_operand_idx, output_dtype | ||
): | ||
if lhs.dtype not in _supported_input_types: | ||
raise TypeError(f"Unsupported lhs input type: {lhs.dtype}") | ||
if rhs.dtype not in _supported_input_types: | ||
raise TypeError(f"Unsupported rhs input type: {rhs.dtype}") | ||
if output_dtype not in _supported_output_types: | ||
raise TypeError(f"Unsupported output type: {output_dtype}") | ||
|
||
res_shape = _infer_result_shape(lhs, rhs, dimension_numbers) | ||
_validate_metadata(lhs, rhs, metadata, dimension_numbers, sparse_operand_idx) | ||
return core.ShapedArray(res_shape, output_dtype) | ||
|
||
|
||
mlir.register_lowering(nm_spmm_p, _nm_spmm_default_lowering) | ||
dispatch.simple_impl(nm_spmm_p) | ||
|
||
if gpu_sparse.cuda_is_supported: | ||
mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="cuda") | ||
|
||
# -------------------------------------------------------------------- | ||
# nm_pack | ||
|
||
nm_pack_p = core.Primitive("sparse_pack_nm") | ||
|
||
|
||
def nm_pack(mask: Array, n=2, m=4) -> Array: | ||
"""Generate metadata tensor for an N:M mask. | ||
Args: | ||
mask: Predicates for the input tensor, where the elements are grouped in the | ||
minor dimension. In each group of size M there should be exactly N true | ||
values, which mark the data elements to keep. | ||
n: Number of non-zero elements in a group. | ||
m: Group size. | ||
Returns: | ||
An ndarray containing only the masked input elements. | ||
""" | ||
return nm_pack_p.bind(mask, n=n, m=m) | ||
|
||
|
||
def _compress(data, n, m, k): | ||
result = [] | ||
expected = n * (k // m) | ||
for i in range(0, len(data), k): | ||
index = tuple(jnp.nonzero(data[i : i + k], size=expected)[0] % m) | ||
value = sum(j * pow(m, i) for i, j in enumerate(index)) | ||
result.append(value) | ||
return jnp.array(result, dtype=np.uint16) | ||
|
||
|
||
@nm_pack_p.def_impl | ||
def _nm_pack_impl(mask, *, n, m): | ||
batch_size = m * _calc_groups_per_element(n, m) | ||
return jnp.apply_along_axis( | ||
lambda x: _compress(x, n, m, batch_size), -1, mask | ||
) | ||
|
||
|
||
@nm_pack_p.def_abstract_eval | ||
def _nm_pack_abstract_eval(mask, *, n, m): | ||
size_factor = m * _calc_groups_per_element(n, m) | ||
if mask.dtype != bool: | ||
raise TypeError(f"Mask should be bool, got {mask.dtype}") | ||
if mask.shape[-1] % size_factor != 0: | ||
raise TypeError( | ||
f"Inner dimension size should be divisible by {size_factor}, got" | ||
f" {mask.shape}" | ||
) | ||
res_shape = list(mask.shape) | ||
res_shape[-1] //= size_factor | ||
return core.ShapedArray(res_shape, np.uint16) | ||
|
||
|
||
_nm_pack_lowering = mlir.lower_fun(_nm_pack_impl, multiple_results=False) | ||
mlir.register_lowering(nm_pack_p, _nm_pack_lowering) | ||
dispatch.simple_impl(nm_pack_p) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.