From 89dd69ea2dc65d39e5167eff850c32cfbe81c639 Mon Sep 17 00:00:00 2001 From: Alexey Radul Date: Tue, 13 Jun 2023 16:59:24 -0400 Subject: [PATCH] Test and implement ragged slicing. This touches _gather_batching_rule because slicing is implemented as a gather, but we only test the case exercised by the slice that occurs in our test transformer model, namely the unstack operation q, k, v = qkv (which turns into three slices on an non-batched and non-ragged axis). Co-authored-by: Matthew Johnson --- jax/_src/lax/lax.py | 6 ++++- jax/_src/lax/slicing.py | 49 +++++++++++++++++++++++++-------------- tests/dynamic_api_test.py | 12 ++++++++++ 3 files changed, 49 insertions(+), 18 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 062da3c69a03..2133585d9176 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3265,8 +3265,12 @@ def _squeeze_transpose_rule(t, operand, *, dimensions): def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): operand, = batched_args bdim, = batch_dims - operand, bdim_out = batching.move_stacked_axis(operand, bdim, 0) + operand, bdim = batching.move_stacked_axis(operand, bdim, 0) dimensions = tuple(np.add(1, dimensions)) + out_stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim + bdim_out = batching.shape_as_bdim( + out_stack_dim, + _compute_squeeze_shape(batching.bdim_as_shape(bdim, operand.shape), dimensions)) return squeeze(operand, dimensions=dimensions), bdim_out squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 4032e4ffb408..826b033d6d58 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1161,7 +1161,9 @@ def _slice_lower(ctx, x, *, start_indices, limit_indices, strides): mlir.register_lowering(slice_p, _slice_lower) -def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes): +def _dynamic_slice_shape_rule( + operand, *starts_and_dyn_sizes, slice_sizes): + start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim]) if operand.ndim != len(start_indices): msg = ("dynamic_slice start_indices must have length equal to the number " "of dimensions of the operand, got indices {} for operand shape {}.") @@ -1170,20 +1172,21 @@ def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes): msg = ("dynamic_slice slice_sizes must have the same length as " "start_indices, got start_indices length {} and slice_sizes {}.") raise TypeError(msg.format(len(start_indices), slice_sizes)) - if not core.greater_equal_shape(operand.shape, slice_sizes): + if not dyn and not core.greater_equal_shape(operand.shape, slice_sizes): msg = ("slice slice_sizes must be less than or equal to operand shape, " "got slice_sizes {} for operand shape {}.") raise TypeError(msg.format(slice_sizes, operand.shape)) - if not all(core.greater_equal_dim(ssz, 0) for ssz in slice_sizes): + if not dyn and not all(core.greater_equal_dim(ssz, 0) for ssz in slice_sizes): msg = ("slice slice_sizes must be greater than or equal to zero, " "got slice_sizes of {}.") raise TypeError(msg.format(slice_sizes)) if any(idx.ndim != 0 for idx in start_indices): raise TypeError("start_indices arguments to dynamic_slice must be scalars, " f" got indices {start_indices}") - return tuple(slice_sizes) + return tuple(lax._merge_dyn_shape(slice_sizes, dyn)) -def _dynamic_slice_dtype_rule(operand, *start_indices, slice_sizes): +def _dynamic_slice_dtype_rule(operand, *starts_and_dyn_sizes, slice_sizes): + start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim]) if any(i.dtype != start_indices[0].dtype or not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices): msg = ("index arguments to dynamic_slice must be integers of the same " @@ -1228,16 +1231,20 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): # batching rule. # TODO(phawkins): consider removing dynamic_slice entirely and using gather # always. - operand, *start_indices = batched_args - operand_bd, *start_idx_bds = batch_dims - operand_shape = (operand.shape if operand_bd is batching.not_mapped - else tuple(np.delete(operand.shape, operand_bd))) - dims = tuple(range(len(operand_shape))) + # TODO(mattjj): Alternately, we could add jnp.unstack and an unstack_p, + # since it should have easier rules (especially compared to gather). + operand, *start_indices_and_dyn = batched_args + operand_bd, *start_idx_and_dyn_bds = batch_dims + ndims = operand.ndim - (0 if operand_bd is batching.not_mapped else 1) + dims = tuple(range(ndims)) + start_indices, dyn_slice_sizes = util.split_list(start_indices_and_dyn, [ndims]) + start_idx_bds, dyn_slice_size_bds = util.split_list(start_idx_and_dyn_bds, [ndims]) dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(), start_index_map=dims) index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds) return _gather_batching_rule( - [operand, index], [operand_bd, index_bdim], dimension_numbers=dnums, + [operand, index, *dyn_slice_sizes], + [operand_bd, index_bdim, *dyn_slice_size_bds], dimension_numbers=dnums, slice_sizes=slice_sizes, unique_indices=True, indices_are_sorted=True, mode=GatherScatterMode.PROMISE_IN_BOUNDS, fill_value=None) @@ -1607,11 +1614,12 @@ def _gather_transpose_rule(t, operand, indices, *, dimension_numbers, def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): - operand, indices = batched_args - operand_bdim, indices_bdim = batch_dims + operand, indices, *dyn_slice_sizes = batched_args + operand_bdim, indices_bdim, *dyn_slice_size_bds = batch_dims + dyn_slice_size_bounds = [b.dtype.bound for b in dyn_slice_sizes] if operand_bdim is not None and indices_bdim is None: - operand = batching.moveaxis(operand, operand_bdim, 0) + operand, operand_bdim = batching.move_stacked_axis(operand, operand_bdim, 0) slice_sizes = (operand.shape[0],) + slice_sizes offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims)) collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) @@ -1620,10 +1628,17 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, offset_dims=offset_dims, collapsed_slice_dims=collapsed_slice_dims, start_index_map=start_index_map) - return gather(operand, indices, dimension_numbers=dnums, - slice_sizes=slice_sizes, unique_indices=unique_indices, + # TODO(reviewer): This should be correct for `operand_bdim` being a + # `RaggedAxis` as long as we are not gathering from any ragged axis, and as + # long as we have no collapsed_slice_dims. (The latter could require + # adjusting the `ragged_axes` field of `operand_bdim`). Should we put a + # check to confirm? + ans = gather(operand, indices, dimension_numbers=dnums, + slice_sizes=lax._merge_dyn_shape(slice_sizes, dyn_slice_size_bounds), + unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, - fill_value=fill_value), 0 + fill_value=fill_value), operand_bdim + return ans elif operand_bdim is None and indices_bdim is not None: indices = batching.moveaxis(indices, indices_bdim, 0) diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 75b797ba1560..24ff92878160 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -1652,6 +1652,18 @@ def fprop_layer(x_size): self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[3,bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]') self.assertEqual(p.data.shape, (3, 3, 5, 2, 7)) + def test_split_while_ragged(self): + ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) + def func(size): + one_d = jnp.arange(size, dtype='int32') + two_d = jnp.broadcast_to(one_d, (2, size)) + part_1, part_2 = two_d + return part_1 + p = jax.vmap(func, out_axes=batching.pile_axis)(ins) + self.assertIsInstance(p, batching.Pile) + data = jax.lax.broadcasted_iota('int32', (3, 5), 1) + self.assertAllClose(p.data, data) + def pile_map(f): def mapped(*piles): return jax.vmap(f, in_axes=batching.pile_axis, out_axes=batching.pile_axis,