Skip to content

Commit

Permalink
inline and remove dynamic_slice_mlir rules
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed May 18, 2023
1 parent aed77c5 commit 2dbdf1a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 32 deletions.
37 changes: 23 additions & 14 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -1283,21 +1283,30 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
start_indices) -> ir.Value:
if dtypes.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.dynamic_slice_mlir(ctx, aval_out, x,
start_indices)
slice_sizes = aval_out.shape
if not core.is_constant_shape(slice_sizes):
slice_sizes = eval_dynamic_shape(ctx, slice_sizes)
return hlo.RealDynamicSliceOp(
aval_to_ir_type(aval_out), x,
shape_tensor(start_indices),
hlo.AddOp(shape_tensor(start_indices),
shape_tensor(slice_sizes)).result,
shape_tensor([1] * len(slice_sizes))
).result
elt_shape = aval_out.dtype._rules.physical_element_aval(
aval_out.dtype).shape
index_avals = ctx.avals_in[1:]
dtype = dtypes.canonicalize_dtype(
index_avals[0].dtype if index_avals else 'int64') # type: ignore
trailing_zeros = [ir_constant(np.array(0, dtype))] * len(elt_shape)
start_indices = (*start_indices, *trailing_zeros)
physical_aval_out = core.physical_aval(aval_out)
return dynamic_slice(ctx, physical_aval_out, x,
start_indices=start_indices)
else:
return hlo.DynamicSliceOp(x, start_indices,
dense_int_elements(slice_sizes)).result
slice_sizes = aval_out.shape
if not core.is_constant_shape(slice_sizes):
slice_sizes = eval_dynamic_shape(ctx, slice_sizes)
return hlo.RealDynamicSliceOp(
aval_to_ir_type(aval_out), x,
shape_tensor(start_indices),
hlo.AddOp(shape_tensor(start_indices),
shape_tensor(slice_sizes)).result,
shape_tensor([1] * len(slice_sizes))
).result
else:
return hlo.DynamicSliceOp(x, start_indices,
dense_int_elements(slice_sizes)).result

def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
start_indices) -> ir.Value:
Expand Down
11 changes: 0 additions & 11 deletions jax/_src/prng.py
Expand Up @@ -521,17 +521,6 @@ def make_sharded_array(aval, sharding, arrays, committed):

# element-type-polymorphic primitive lowering rules

@staticmethod
def dynamic_slice_mlir(ctx, aval_out, x, start_indices) -> ir.Value:
index_avals = ctx.avals_in[1:]
dtype = dtypes.canonicalize_dtype(index_avals[0].dtype if index_avals else 'int64')
key_shape = aval_out.dtype.impl.key_shape
trailing_zeros = [mlir.ir_constant(np.array(0, dtype))] * len(key_shape)
start_indices = (*start_indices, *trailing_zeros)
physical_aval_out = core.physical_aval(aval_out)
return mlir.dynamic_slice(ctx, physical_aval_out, x,
start_indices=start_indices)

@staticmethod
def dynamic_update_slice_mlir(ctx, aval_out, x, update, *start_indices) -> ir.Value:
index_avals = ctx.avals_in[2:]
Expand Down
7 changes: 0 additions & 7 deletions tests/lax_test.py
Expand Up @@ -2869,13 +2869,6 @@ def handler(arr):

# element-type-polymorphic primitive lowering rules

@staticmethod
def dynamic_slice_mlir(ctx, aval_out, x, start_indices):
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
slice_sizes_ = mlir.dense_int_elements((*aval_out.shape, 2))
return hlo.DynamicSliceOp(x, start_indices, slice_sizes_).result

@staticmethod
def dynamic_update_slice_mlir(ctx, aval_out, x, update, *start_indices):
aval_out, = ctx.avals_out
Expand Down

0 comments on commit 2dbdf1a

Please sign in to comment.