Skip to content

Commit

Permalink
Merge pull request #16397 from gnecula:poly_dynamic_slice
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 540798903
  • Loading branch information
jax authors committed Jun 16, 2023
2 parents 907782d + 645b3c4 commit 9fdaf5a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
40 changes: 24 additions & 16 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,7 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *,

def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
start_indices) -> ir.Value:
x_aval = ctx.avals_in[0]
if dtypes.is_opaque_dtype(aval_out.dtype):
elt_shape = aval_out.dtype._rules.physical_element_aval(
aval_out.dtype).shape
Expand All @@ -1364,23 +1365,30 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
index_avals[0].dtype if index_avals else 'int64') # type: ignore
trailing_zeros = [ir_constant(np.array(0, dtype))] * len(elt_shape)
start_indices = (*start_indices, *trailing_zeros)
physical_aval_out = core.physical_aval(aval_out)
return dynamic_slice(ctx, physical_aval_out, x,
start_indices=start_indices)
aval_out = core.physical_aval(aval_out)
x_aval = core.physical_aval(x_aval)

slice_sizes = aval_out.shape
if not core.is_constant_shape(slice_sizes):
# lax.dynamic_slice clamps the start indices, but we are going to
# lower to RealDynamicSliceOp, which is a version of SliceOp, and does
# not have the clamping behavior. We clamp start ourselves.
slice_sizes = shape_tensor(eval_dynamic_shape(ctx, slice_sizes))
clamped_start = hlo.ClampOp(
shape_tensor([0] * len(start_indices)),
shape_tensor(start_indices),
hlo.SubtractOp(
shape_tensor(eval_dynamic_shape(ctx, x_aval.shape)), # type: ignore
slice_sizes))
return hlo.RealDynamicSliceOp(
aval_to_ir_type(aval_out), x,
clamped_start,
hlo.AddOp(clamped_start, slice_sizes).result,
shape_tensor([1] * len(start_indices))
).result
else:
slice_sizes = aval_out.shape
if not core.is_constant_shape(slice_sizes):
slice_sizes = eval_dynamic_shape(ctx, slice_sizes)
return hlo.RealDynamicSliceOp(
aval_to_ir_type(aval_out), x,
shape_tensor(start_indices),
hlo.AddOp(shape_tensor(start_indices),
shape_tensor(slice_sizes)).result,
shape_tensor([1] * len(slice_sizes))
).result
else:
return hlo.DynamicSliceOp(x, start_indices,
dense_int_elements(slice_sizes)).result
return hlo.DynamicSliceOp(x, start_indices,
dense_int_elements(slice_sizes)).result

def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
start_indices) -> ir.Value:
Expand Down
10 changes: 10 additions & 0 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,6 +2054,16 @@ def f_jax(operand, start_indices, x):
lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)),
arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
poly_axes=[0, None]).both_enable_and_disable_xla(),
PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_large",
# x:shape: (b, 4)
lambda x: lax.dynamic_slice(x, (1, 1), (x.shape[0], 2)),
arg_descriptors=[RandArg((3, 4), _f32)],
poly_axes=[0]).both_enable_and_disable_xla(),
PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_small",
# x:shape: (b, 4)
lambda x: lax.dynamic_slice(x, (-1, 1), (x.shape[0] - 1, 2)),
arg_descriptors=[RandArg((3, 4), _f32)],
poly_axes=[0]).both_enable_and_disable_xla(),
PolyHarness("dynamic_slice_in_dim", "idx=0",
# x:shape: (b, 4)
lambda x: lax.dynamic_slice_in_dim(x, 0, x.shape[0], axis=0),
Expand Down

0 comments on commit 9fdaf5a

Please sign in to comment.