Skip to content

Commit

Permalink
Test and implement ragged slicing.
Browse files Browse the repository at this point in the history
This touches _gather_batching_rule because slicing is implemented as a
gather, but we only test the case exercised by the slice that occurs
in our test transformer model, namely the unstack operation
  q, k, v = qkv
(which turns into three slices on an non-batched and non-ragged axis).

Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
axch and mattjj committed Jul 7, 2023
1 parent 6f09fe8 commit 89dd69e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 18 deletions.
6 changes: 5 additions & 1 deletion jax/_src/lax/lax.py
Expand Up @@ -3265,8 +3265,12 @@ def _squeeze_transpose_rule(t, operand, *, dimensions):
def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
operand, = batched_args
bdim, = batch_dims
operand, bdim_out = batching.move_stacked_axis(operand, bdim, 0)
operand, bdim = batching.move_stacked_axis(operand, bdim, 0)
dimensions = tuple(np.add(1, dimensions))
out_stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim
bdim_out = batching.shape_as_bdim(
out_stack_dim,
_compute_squeeze_shape(batching.bdim_as_shape(bdim, operand.shape), dimensions))
return squeeze(operand, dimensions=dimensions), bdim_out

squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
Expand Down
49 changes: 32 additions & 17 deletions jax/_src/lax/slicing.py
Expand Up @@ -1161,7 +1161,9 @@ def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
mlir.register_lowering(slice_p, _slice_lower)


def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes):
def _dynamic_slice_shape_rule(
operand, *starts_and_dyn_sizes, slice_sizes):
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim])
if operand.ndim != len(start_indices):
msg = ("dynamic_slice start_indices must have length equal to the number "
"of dimensions of the operand, got indices {} for operand shape {}.")
Expand All @@ -1170,20 +1172,21 @@ def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes):
msg = ("dynamic_slice slice_sizes must have the same length as "
"start_indices, got start_indices length {} and slice_sizes {}.")
raise TypeError(msg.format(len(start_indices), slice_sizes))
if not core.greater_equal_shape(operand.shape, slice_sizes):
if not dyn and not core.greater_equal_shape(operand.shape, slice_sizes):
msg = ("slice slice_sizes must be less than or equal to operand shape, "
"got slice_sizes {} for operand shape {}.")
raise TypeError(msg.format(slice_sizes, operand.shape))
if not all(core.greater_equal_dim(ssz, 0) for ssz in slice_sizes):
if not dyn and not all(core.greater_equal_dim(ssz, 0) for ssz in slice_sizes):
msg = ("slice slice_sizes must be greater than or equal to zero, "
"got slice_sizes of {}.")
raise TypeError(msg.format(slice_sizes))
if any(idx.ndim != 0 for idx in start_indices):
raise TypeError("start_indices arguments to dynamic_slice must be scalars, "
f" got indices {start_indices}")
return tuple(slice_sizes)
return tuple(lax._merge_dyn_shape(slice_sizes, dyn))

def _dynamic_slice_dtype_rule(operand, *start_indices, slice_sizes):
def _dynamic_slice_dtype_rule(operand, *starts_and_dyn_sizes, slice_sizes):
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim])
if any(i.dtype != start_indices[0].dtype or
not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices):
msg = ("index arguments to dynamic_slice must be integers of the same "
Expand Down Expand Up @@ -1228,16 +1231,20 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
# batching rule.
# TODO(phawkins): consider removing dynamic_slice entirely and using gather
# always.
operand, *start_indices = batched_args
operand_bd, *start_idx_bds = batch_dims
operand_shape = (operand.shape if operand_bd is batching.not_mapped
else tuple(np.delete(operand.shape, operand_bd)))
dims = tuple(range(len(operand_shape)))
# TODO(mattjj): Alternately, we could add jnp.unstack and an unstack_p,
# since it should have easier rules (especially compared to gather).
operand, *start_indices_and_dyn = batched_args
operand_bd, *start_idx_and_dyn_bds = batch_dims
ndims = operand.ndim - (0 if operand_bd is batching.not_mapped else 1)
dims = tuple(range(ndims))
start_indices, dyn_slice_sizes = util.split_list(start_indices_and_dyn, [ndims])
start_idx_bds, dyn_slice_size_bds = util.split_list(start_idx_and_dyn_bds, [ndims])
dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(),
start_index_map=dims)
index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds)
return _gather_batching_rule(
[operand, index], [operand_bd, index_bdim], dimension_numbers=dnums,
[operand, index, *dyn_slice_sizes],
[operand_bd, index_bdim, *dyn_slice_size_bds], dimension_numbers=dnums,
slice_sizes=slice_sizes, unique_indices=True, indices_are_sorted=True,
mode=GatherScatterMode.PROMISE_IN_BOUNDS, fill_value=None)

Expand Down Expand Up @@ -1607,11 +1614,12 @@ def _gather_transpose_rule(t, operand, indices, *, dimension_numbers,
def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted,
mode, fill_value):
operand, indices = batched_args
operand_bdim, indices_bdim = batch_dims
operand, indices, *dyn_slice_sizes = batched_args
operand_bdim, indices_bdim, *dyn_slice_size_bds = batch_dims
dyn_slice_size_bounds = [b.dtype.bound for b in dyn_slice_sizes]

if operand_bdim is not None and indices_bdim is None:
operand = batching.moveaxis(operand, operand_bdim, 0)
operand, operand_bdim = batching.move_stacked_axis(operand, operand_bdim, 0)
slice_sizes = (operand.shape[0],) + slice_sizes
offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims))
collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims))
Expand All @@ -1620,10 +1628,17 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
offset_dims=offset_dims,
collapsed_slice_dims=collapsed_slice_dims,
start_index_map=start_index_map)
return gather(operand, indices, dimension_numbers=dnums,
slice_sizes=slice_sizes, unique_indices=unique_indices,
# TODO(reviewer): This should be correct for `operand_bdim` being a
# `RaggedAxis` as long as we are not gathering from any ragged axis, and as
# long as we have no collapsed_slice_dims. (The latter could require
# adjusting the `ragged_axes` field of `operand_bdim`). Should we put a
# check to confirm?
ans = gather(operand, indices, dimension_numbers=dnums,
slice_sizes=lax._merge_dyn_shape(slice_sizes, dyn_slice_size_bounds),
unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode,
fill_value=fill_value), 0
fill_value=fill_value), operand_bdim
return ans

elif operand_bdim is None and indices_bdim is not None:
indices = batching.moveaxis(indices, indices_bdim, 0)
Expand Down
12 changes: 12 additions & 0 deletions tests/dynamic_api_test.py
Expand Up @@ -1652,6 +1652,18 @@ def fprop_layer(x_size):
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[3,bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]')
self.assertEqual(p.data.shape, (3, 3, 5, 2, 7))

def test_split_while_ragged(self):
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
def func(size):
one_d = jnp.arange(size, dtype='int32')
two_d = jnp.broadcast_to(one_d, (2, size))
part_1, part_2 = two_d
return part_1
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
self.assertAllClose(p.data, data)

def pile_map(f):
def mapped(*piles):
return jax.vmap(f, in_axes=batching.pile_axis, out_axes=batching.pile_axis,
Expand Down

0 comments on commit 89dd69e

Please sign in to comment.