Skip to content

Commit

Permalink
Merge pull request #20940 from piotrfilipiuk:changelist/623910451
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633170419
  • Loading branch information
jax authors committed May 13, 2024
2 parents 1c6855a + 93dfe05 commit 1fed784
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 0 deletions.
83 changes: 83 additions & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,33 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type)


def ragged_dot(
lhs: Array,
rhs: Array,
group_sizes: Array,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
group_offset: Array | None = None,
) -> Array:
"""Ragged matrix multiplication.
Args:
lhs: (m, k) shaped array.
rhs: (g, k, n) shaped array.
group_sizes: (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group.
precision: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
preferred_element_type: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
group_offset: Optional. (1,) shaped array that ndicates the group in group_sizes to start computing from. If not specified, defaults to [0].
Results:
(m, n) shaped array with preferred_element_type element type.
"""
return ragged_dot_p.bind(lhs, rhs, group_sizes,
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type, group_offset=group_offset)


def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array:
"""Broadcasts an array, adding new leading dimensions
Expand Down Expand Up @@ -2946,6 +2973,62 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
platform=platform)


def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape:
m, k = lhs.shape
group_count, rk, n = rhs.shape
if k != rk:
raise TypeError("ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {} and {}.".format(k, rk))
num_groups = group_sizes.shape[0]
if group_count != num_groups:
raise TypeError("ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {} and {}.".format(group_count, num_groups))
return (m, n)

# DotDimensionNumbers used in the dot_general call for ragged_dot().
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (([2, 0], [1, 0]), ([], []))

def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type)

ragged_dot_p = standard_primitive(_ragged_dot_shape_rule,
_ragged_dot_dtype_rule, 'ragged_dot')
ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p))

def _ragged_dot_impl(
lhs: Array,
rhs: Array,
group_sizes: Array,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
group_offset: Array | None = None,
) -> Array:
if group_offset is not None:
raise NotImplementedError("Unimplemented group_offset support.")
shape = (rhs.shape[0], lhs.shape[0], lhs.shape[1])
lhs = broadcast_in_dim(lhs, shape, [1, 2])
iota = broadcasted_iota(group_sizes.dtype, shape, 1)
group_ends = jax.lax.cumsum(group_sizes)
group_starts = concatenate(
[_zeros(group_sizes)[:1], group_ends[:-1]], dimension=0,
)
group_ends = broadcast_in_dim(group_ends, shape, (0,))
group_starts = broadcast_in_dim(group_starts, shape, (0,))
mask = bitwise_and(group_starts <= iota, iota < group_ends)
lhs = select(mask, lhs, _zeros(lhs))
return dot_general(
lhs,
rhs,
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
precision=precision,
preferred_element_type=preferred_element_type,
)

mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False))


def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
_check_shapelike('broadcast_in_dim', 'shape', shape)
_check_shapelike('broadcast_in_dim', 'broadcast_dimensions',
Expand Down
29 changes: 29 additions & 0 deletions jax/_src/lax_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,35 @@ def dot_general(lhs, rhs, dimension_numbers):
dtype=dtype)
return out.astype(dtypes.bfloat16) if lhs.dtype == dtypes.bfloat16 else out

def ragged_dot(
lhs,
rhs,
group_sizes,
):
"""Reference ragged dot implementation."""
m, lk = lhs.shape
group_count, rk, n = rhs.shape
assert lk == rk
assert group_count == group_sizes.shape[0]
assert lhs.dtype == rhs.dtype

out = np.zeros((m, n), dtype=lhs.dtype)
result_iota = np.expand_dims(np.arange(out.shape[0]), list(range(1, out.ndim)))
start = 0
for i, size in enumerate(group_sizes):
out += np.where(
np.logical_and(start <= result_iota, result_iota < (start + size)),
np.einsum(
"nk,km->nm",
lhs,
rhs[i, :, :],
dtype=np.float32 if lhs.dtype == dtypes.bfloat16 else None,
),
np.zeros(out.shape, dtype=out.dtype),
)
start += size
return out.astype(dtypes.bfloat16) if lhs.dtype == dtypes.bfloat16 else out

def broadcast(operand, sizes):
return np.broadcast_to(operand, sizes + np.shape(operand))

Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"platform_index",
"assert_consumed_value",
"consume",
"ragged_dot",
]

tf_impl[random_internal.random_clone_p] = lambda x: x
Expand Down
9 changes: 9 additions & 0 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,15 @@ def f_switch(x):
f_switch_tf = jax2tf.convert(f_switch, enable_xla=False)
self.assertIn("switch_case", self.TfToHlo(f_switch_tf, np.pi))

@jtu.skip_on_flag("jax2tf_default_native_serialization", False)
def test_ragged_dot(self):
dtype = np.float32
m, k, n, num_groups = 5, 4, 3, 2
lhs = np.arange(m * k, dtype=dtype).reshape((m, k))
rhs = np.arange(num_groups * k * n, dtype=dtype).reshape((num_groups, k, n))
group_sizes = np.array([3, 2], dtype=np.int32)
self.ConvertAndCompare(jax.lax.ragged_dot, lhs, rhs, group_sizes)


@jtu.with_config(jax_enable_custom_prng=True)
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
Expand Down
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
population_count_p as population_count_p,
pow as pow,
pow_p as pow_p,
ragged_dot as ragged_dot,
real as real,
real_p as real_p,
reciprocal as reciprocal,
Expand Down
23 changes: 23 additions & 0 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,5 +1396,28 @@ def f_jax(x):
with self.assertRaisesRegex(Exception, expect_error):
_ = get_exported(f_jax)(jax.ShapeDtypeStruct((3, 4), x.dtype))

@jtu.parameterized_filterable(
kwargs=[
{"m": 5, "k": 4, "n": 3, "group_sizes": [5]},
{"m": 10, "k": 9, "n": 8, "group_sizes": [3, 7]},
])
def test_ragged_dot(self, m, k, n, group_sizes):
def f_jax(x, y, gs):
return jax.lax.ragged_dot(x, y, gs)
dtype = np.float32
group_sizes = np.array(group_sizes, dtype=np.int32)
lhs = np.arange(m * k, dtype=dtype).reshape((m, k))
num_groups = group_sizes.shape[0]
rhs = np.arange(num_groups * k * n, dtype=dtype).reshape((num_groups, k, n))
res_native = f_jax(lhs, rhs, group_sizes)

exp_f = get_exported(f_jax)(
jax.ShapeDtypeStruct(lhs.shape, dtype=lhs.dtype),
jax.ShapeDtypeStruct(rhs.shape, dtype=rhs.dtype),
jax.ShapeDtypeStruct(group_sizes.shape, dtype=group_sizes.dtype),
)
res_exported = export.call_exported(exp_f)(lhs, rhs, group_sizes)
self.assertAllClose(res_native, res_exported)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
27 changes: 27 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,33 @@ def testDotGeneralAgainstNumpy(self, lhs_shape, rhs_shape, dtype,
numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers)
self._CheckAgainstNumpy(numpy_op, op, args_maker)

@jtu.sample_product(
[
{'m': 5, 'k': 4, 'n': 3, 'num_groups': 1},
{'m': 10, 'k': 9, 'n': 8, 'num_groups': 2},
],
dtype=jtu.dtypes.numeric,
)
def testRaggedDot(self, m, k, n, num_groups, dtype):
"""Tests ragged_dot.
The ragged_dot is tested against numpy reference implementation, and by running JAX compilation.
Raises:
SkipTest: in the case dtype is not supported.
"""
lhs_shape = (m, k)
rhs_shape = (num_groups, k, n)
def group_sizes(m, num_groups):
ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1))
ends = jnp.concatenate([ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)])
starts = jnp.concatenate([jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final])
return ends - starts
rng = jtu.rand_small(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype), group_sizes(m, num_groups)]
self._CompileAndCheck(lax.ragged_dot, args_maker)
self._CheckAgainstNumpy(lax_reference.ragged_dot, lax.ragged_dot, args_maker)

@jtu.sample_product(
shape=[(), (2, 3)],
dtype=lax_test_util.default_dtypes,
Expand Down

0 comments on commit 1fed784

Please sign in to comment.