Skip to content

Commit

Permalink
[dynamic-shapes] Add basic slicing support
Browse files Browse the repository at this point in the history
If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`.
To do that, we can use a generalized version of `dynamic_slice` which allows
dynamic slice sizes (where the result shape depends on those slice sizes).

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
  • Loading branch information
mattjj and sharadmv committed Sep 28, 2022
1 parent 33dbf0e commit a8826e6
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 34 deletions.
7 changes: 4 additions & 3 deletions jax/_src/lax/lax.py
Expand Up @@ -1456,10 +1456,9 @@ def _iter(tracer):
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
n = int(tracer.shape[0])
# return (index_in_dim(tracer, i, keepdims=False) for i in range(n))
return iter([slicing.index_in_dim(tracer, i, keepdims=False)
for i in range(n)])
return (slicing.index_in_dim(tracer, i, keepdims=False) for i in range(n))
ShapedArray._iter = staticmethod(_iter)
core.DShapedArray._iter = staticmethod(_iter)

# Add some ad handlers that use (or could use) lax primitives

Expand Down Expand Up @@ -2884,6 +2883,8 @@ def _broadcast_in_dim_pp_rule(eqn, context, settings):
return [lhs, pp.text(" = ", annotation=annotation), *rhs]

def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
if dyn_shape: raise NotImplementedError
del dyn_shape
if not any(isinstance(d, core.BInt) for d in shape):
shape = _broadcast_in_dim_shape_rule( # error checking
x, shape=shape, broadcast_dimensions=broadcast_dimensions)
Expand Down
115 changes: 86 additions & 29 deletions jax/_src/lax/slicing.py
Expand Up @@ -104,8 +104,13 @@ def dynamic_slice(operand: Array, start_indices: Sequence[Array],
[ 8, 9, 10, 11]], dtype=int32)
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
return dynamic_slice_p.bind(operand, *start_indices,
slice_sizes=core.canonicalize_shape(slice_sizes))
if jax.config.jax_dynamic_shapes:
dynamic_sizes, static_sizes = lax._extract_tracers_dyn_shape(slice_sizes)
else:
dynamic_sizes = []
static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore
return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes,
slice_sizes=tuple(static_sizes))

def dynamic_update_slice(operand: Array, update: Array,
start_indices: Array) -> Array:
Expand Down Expand Up @@ -684,7 +689,7 @@ def index_in_dim(operand: Array, index: int, axis: int = 0,
def dynamic_slice_in_dim(operand: Array, start_index: Array,
slice_size: int, axis: int = 0) -> Array:
"""Convenience wrapper around dynamic_slice applying to one dimension."""
start_indices = [lax._zero(start_index)] * operand.ndim
start_indices = [np.zeros((), dtype=dtypes.dtype(start_index))] * operand.ndim
slice_sizes = list(operand.shape)

axis = int(axis)
Expand Down Expand Up @@ -746,22 +751,24 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
msg = ("slice start_indices must be greater than or equal to zero, "
"got start_indices of {}.")
raise TypeError(msg.format(start_indices))
if not core.greater_equal_shape(limit_indices, start_indices):
msg = ("slice limit_indices must be greater than or equal to start_indices,"
" got start_indices {} and limit_indices {}.")
raise TypeError(msg.format(start_indices, limit_indices))
if strides is None:
strides = np.ones(operand.ndim, np.int32)
else:
lax._check_shapelike("slice", "strides", strides)
if len(strides) != operand.ndim:
msg = ("slice strides must have length equal to the number of dimensions "
"of the operand, got strides {} for operand shape {}.")
raise TypeError(msg.format(strides, operand.shape))
if not core.greater_equal_shape(strides, (0,) * len(strides)):
msg = "slice strides must be positive, got {}"
raise TypeError(msg.format(strides))

if not jax.config.jax_dynamic_shapes:
if not core.greater_equal_shape(limit_indices, start_indices):
msg = ("slice limit_indices must be greater than or equal to start_indices,"
" got start_indices {} and limit_indices {}.")
raise TypeError(msg.format(start_indices, limit_indices))
if strides is None or tuple(strides) == (1,) * len(operand.shape):
shape = [limit if type(start) is int and start == 0 else limit - start
for start, limit in zip(start_indices, limit_indices)]
return tuple(shape)

lax._check_shapelike("slice", "strides", strides)
if len(strides) != operand.ndim:
msg = ("slice strides must have length equal to the number of dimensions "
"of the operand, got strides {} for operand shape {}.")
raise TypeError(msg.format(strides, operand.shape))
if not core.greater_equal_shape(strides, (0,) * len(strides)):
msg = "slice strides must be positive, got {}"
raise TypeError(msg.format(strides))
diff = core.diff_shape(limit_indices, start_indices)
return core.stride_shape(diff, (1,) * len(diff), strides)

Expand Down Expand Up @@ -902,24 +909,69 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
slice_sizes=slice_sizes, unique_indices=True, indices_are_sorted=True,
mode=GatherScatterMode.PROMISE_IN_BOUNDS, fill_value=None)

def _dynamic_slice_staging_rule(trace, x, *starts_and_dyn_sizes, slice_sizes):
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.ndim])
if not dyn:
return trace.default_process_primitive(dynamic_slice_p, (x, *start_indices),
dict(slice_sizes=slice_sizes))
shape = lax._merge_dyn_shape(slice_sizes, dyn)
aval = core.DShapedArray(shape, x.dtype, False)
return lax._dyn_shape_staging_rule(trace, dynamic_slice_p, aval, x,
*starts_and_dyn_sizes,
slice_sizes=slice_sizes)

def _dynamic_slice_typecheck_rule(x, *starts_and_dyn_sizes, slice_sizes):
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.aval.ndim])
if not dyn:
out_aval, effects = dynamic_slice_p.abstract_eval(
x.aval, *(d.aval for d in start_indices), slice_sizes=slice_sizes)
return [out_aval], effects
else:
# TODO(mattjj): perform more checks
out_shape = lax._merge_dyn_shape(slice_sizes, dyn)
out_shape = [d.val if type(d) is core.Literal else d for d in out_shape]
out_aval = core.DShapedArray(tuple(out_shape), x.aval.dtype,
x.aval.weak_type)
return [out_aval], core.no_effects


dynamic_slice_p = standard_primitive(
_dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice',
weak_type_rule=_argnum_weak_type(0))
ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp
ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
pe.custom_staging_rules[dynamic_slice_p] = _dynamic_slice_staging_rule
core.custom_typechecks[dynamic_slice_p] = _dynamic_slice_typecheck_rule

def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes):
def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes):
x_aval, *_ = ctx.avals_in
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x_aval.ndim])
aval_out, = ctx.avals_out
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.dynamic_slice_mlir(
ctx, x, start_indices, slice_sizes)
return mhlo.DynamicSliceOp(x, start_indices,
mlir.dense_int_elements(slice_sizes)).results

if core.is_opaque_dtype(aval_out.dtype) and dyn: raise NotImplementedError
if not dyn:
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.dynamic_slice_mlir(ctx, x, start_indices,
slice_sizes)
return mhlo.DynamicSliceOp(x, start_indices,
mlir.dense_int_elements(slice_sizes)).results
slice_sizes = lax._merge_dyn_shape(slice_sizes, dyn)
return mhlo.RealDynamicSliceOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.shape_tensor(start_indices),
mlir.shape_tensor(slice_sizes),
mlir.shape_tensor([1] * len(slice_sizes))
).results
mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower)

# def _getslice_lower(ctx, x, lo, hi):
# aval_out, = ctx.avals_out
# return mhlo.RealDynamicSliceOp(
# mlir.aval_to_ir_type(aval_out), x,
# mlir.shape_tensor([lo]), mlir.shape_tensor([hi]), mlir.shape_tensor([1])
# ).results
# mlir.register_lowering(getslice_p, _getslice_lower)


def _dynamic_update_slice_shape_rule(operand, update, *start_indices):
if operand.ndim != update.ndim:
Expand Down Expand Up @@ -2055,9 +2107,14 @@ def _dynamic_slice_indices(operand, start_indices: Any):
start_indices = list(start_indices)
result = []
for i, d in zip(start_indices, operand.shape):
# We test whether i and d are static to avoid unnecessary staging.
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):
result.append(lax.convert_element_type(i + d, _dtype(i)) if i < 0 else i)
else:
d = lax.convert_element_type(core.dimension_as_value(d), _dtype(i))
result.append(lax.select(i < 0, i + d, i))
continue
d = core.dimension_as_value(d)
if isinstance(i, (int, np.integer)):
result.append(i + lax.convert_element_type(d, _dtype(i)) if i < 0 else i)
continue
d = lax.convert_element_type(d, _dtype(i))
result.append(lax.select(i < 0, i + d, i))
return result
5 changes: 4 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Expand Up @@ -3621,7 +3621,10 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
if 0 <= idx < arr.shape[0]:
return lax.index_in_dim(arr, idx, keepdims=False)
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
else:
return lax.index_in_dim(arr, idx, keepdims=False)
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
isinstance(idx, slice) and
(type(idx.start) is int or idx.start is None) and
Expand Down
2 changes: 1 addition & 1 deletion jax/core.py
Expand Up @@ -1747,7 +1747,7 @@ def symbolic_equal_shape(s1: Shape, s2: Shape) -> bool:

def greater_equal_dim(d1: DimSize, d2: DimSize) -> bool:
handler, ds = _dim_handler_and_canonical(d1, d2)
return handler.greater_equal(*ds)
return handler.symbolic_equal(*ds) or handler.greater_equal(*ds)

def greater_equal_shape(s1: Shape, s2: Shape) -> bool:
return all(map(greater_equal_dim, s1, s2))
Expand Down
31 changes: 31 additions & 0 deletions tests/dynamic_api_test.py
Expand Up @@ -1360,6 +1360,37 @@ def f(x):

jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3)) # doesn't crash

def test_slicing_basic_jaxpr(self):
def f(x):
return x[0]

jaxpr = jax.make_jaxpr(f, abstracted_axes=(None, 'n'))(jnp.zeros((3, 4)))
# { lambda ; a:i32[] b:f32[3,a]. let
# c:f32[1,a] = dynamic_slice[slice_sizes=(1, None)] b 0 0 a
# d:f32[a] = squeeze[dimensions=(0,)] c
# in (d,) }
self.assertLen(jaxpr.jaxpr.invars, 2)
a, _ = jaxpr.jaxpr.invars
self.assertLen(jaxpr.jaxpr.outvars, 1)
d, = jaxpr.jaxpr.outvars
self.assertLen(d.aval.shape, 1)
self.assertEqual(d.aval.shape, (a,))

def test_slicing_basic_lower(self):
@partial(jax.jit, abstracted_axes=(None, 'n'))
def f(x):
return x[0]
f.lower(jnp.zeros((3, 4))).compiler_ir() # doesn't crash

@unittest.skipIf(jtu.device_under_test() != 'iree', "iree test")
def test_slicing_basic_execute(self):
@partial(jax.jit, abstracted_axes=(None, 'n'))
def f(x):
return x[0]

y = f(jnp.arange(3 * 4).reshape(3, 4))
self.assertAllClose(y, jnp.array([0, 1, 2, 3]))


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit a8826e6

Please sign in to comment.