Skip to content

Commit

Permalink
Improves handling of opaque types for dynamic shapes
Browse files Browse the repository at this point in the history
The immediate motivation for this is to support the lowering
to StableHLO for programs with polymorphic shapes. This requires
mixing of dynamic shapes with opaque types.

The general strategy is to push the actual selection of the MHLO ops
down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim)
so that we have one place where we pick whether we use the Dynamic
or static ops. These routines can also handle the opaque type.
This will result in a recursive
call to, e.g., mlir.slice_op, but the inner call will be using
the physical avals, which should not be opaque anymore.

While making this change I was confused by the fact that the
custom KeyTyRules in prng.py have lowerings that return multiple
MHLO ops. See #11768 (comment)
and I changed the rules to return a single op.

.
  • Loading branch information
gnecula committed Dec 12, 2022
1 parent 2f1354e commit 27f5bd0
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 109 deletions.
11 changes: 4 additions & 7 deletions jax/_src/lax/lax.py
Expand Up @@ -2893,15 +2893,12 @@ def _broadcast_in_dim_partial_eval(
out_tracer.recipe = eqn
return out_tracer

def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions):
def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions) -> Sequence[ir.Value]:
aval_out, = ctx.avals_out
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.broadcast_in_dim_mlir(
ctx, x, *dyn_shape, shape=shape,
broadcast_dimensions=broadcast_dimensions)
if dyn_shape:
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))


return [mlir.broadcast_in_dim(ctx, x, aval_out,
broadcast_dimensions=broadcast_dimensions)]

Expand Down Expand Up @@ -3372,7 +3369,7 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
def _transpose_lower(ctx, x, *, permutation):
aval_out, = ctx.avals_out
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.transpose_mlir(ctx, x, permutation=permutation)
return [aval_out.dtype._rules.transpose_mlir(ctx, aval_out, x, permutation=permutation)]
return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results

transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
Expand Down Expand Up @@ -4831,7 +4828,7 @@ def empty(dtype):
empty_p.def_abstract_eval(lambda *, dtype: core.ShapedArray((), dtype))
def _empty_lower(ctx, *, dtype):
if core.is_opaque_dtype(dtype):
return dtype._rules.empty_mlir(ctx)
return dtype._rules.empty_mlir(ctx, ctx.avals_out[0])
return mlir.ir_constants(np.zeros((), np.dtype(dtype)))
mlir.register_lowering(empty_p, _empty_lower)

Expand Down
51 changes: 9 additions & 42 deletions jax/_src/lax/slicing.py
Expand Up @@ -813,23 +813,8 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices,
def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
strides = strides or [1] * len(start_indices)
aval_out, = ctx.avals_out
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.slice_mlir(
ctx, x, start_indices, limit_indices, strides)
if any(not core.is_constant_shape(s) for s in (start_indices, limit_indices, strides)):
start_indices = mlir.eval_dynamic_shape(ctx, start_indices)
limit_indices = mlir.eval_dynamic_shape(ctx, limit_indices)
strides = mlir.eval_dynamic_shape(ctx, strides)
return mhlo.RealDynamicSliceOp(mlir.aval_to_ir_type(aval_out),
x,
mlir.shape_tensor(start_indices),
mlir.shape_tensor(limit_indices),
mlir.shape_tensor(strides)).results
else:
return mhlo.SliceOp(x,
mlir.dense_int_elements(start_indices),
mlir.dense_int_elements(limit_indices),
mlir.dense_int_elements(strides)).results
return [mlir.slice_op(ctx, x, aval_out,
start_indices=start_indices, limit_indices=limit_indices, strides=strides)]

mlir.register_lowering(slice_p, _slice_lower)

Expand Down Expand Up @@ -963,24 +948,9 @@ def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes):
x_aval, *_ = ctx.avals_in
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x_aval.ndim])
aval_out, = ctx.avals_out
if core.is_opaque_dtype(aval_out.dtype) and dyn: raise NotImplementedError
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.dynamic_slice_mlir(ctx, x, start_indices,
slice_sizes)
if dyn:
slice_sizes = lax._merge_dyn_shape(slice_sizes, dyn)

if not core.is_constant_shape(slice_sizes):
slice_sizes = mlir.eval_dynamic_shape(ctx, slice_sizes)
return mhlo.RealDynamicSliceOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.shape_tensor(start_indices),
mlir.shape_tensor(slice_sizes),
mlir.shape_tensor([1] * len(slice_sizes))
).results
else:
return mhlo.DynamicSliceOp(x, start_indices,
mlir.dense_int_elements(slice_sizes)).results
aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn))
return [mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices)]

mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower)

Expand Down Expand Up @@ -1083,11 +1053,8 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims):

def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
aval_out, = ctx.avals_out
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.dynamic_update_slice_mlir(
ctx, x, update, *start_indices)
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
start_indices).results
return [mlir.dynamic_update_slice(ctx, aval_out, x, update,
start_indices=start_indices)]

mlir.register_lowering(dynamic_update_slice_p, _dynamic_update_slice_lower)

Expand Down Expand Up @@ -1411,10 +1378,10 @@ def _gather_lower(ctx, operand, indices, *,
indices_are_sorted, mode, fill_value):
aval_out, = ctx.avals_out
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.gather_mlir(
ctx, operand, indices, dimension_numbers=dimension_numbers,
return [aval_out.dtype._rules.gather_mlir(
ctx, ctx.avals_in, aval_out, operand, indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)]

if mode == GatherScatterMode.FILL_OR_DROP:
gather_fill_fn = mlir.lower_fun(_gather_fill, multiple_results=False)
Expand Down
61 changes: 28 additions & 33 deletions jax/_src/prng.py
Expand Up @@ -282,13 +282,13 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape):
class KeyTyRules:

@staticmethod
def physical_avals(aval): # TODO(frostig): rename to `grounded_avals`
def physical_avals(aval) -> Sequence[core.AbstractValue]: # TODO(frostig): rename to `grounded_avals`
# TODO(frostig): dedup with `keys_aval_to_base_arr_aval``
return [core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape),
return [core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape), # type: ignore
jnp.dtype('uint32'))]

@staticmethod
def aval_to_ir_types(aval):
def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[mlir.ir.Type]:
phys_aval, = KeyTyRules.physical_avals(aval)
return mlir.aval_to_ir_types(phys_aval)

Expand Down Expand Up @@ -393,70 +393,64 @@ def handler(bufs):
# element-type-polymorphic primitive lowering rules

@staticmethod
def empty_mlir(ctx):
aval_out, = ctx.avals_out
def empty_mlir(ctx, aval_out) -> Sequence[mlir.ir.Value]:
return mlir.ir_constants(np.zeros(aval_out.dtype.impl.key_shape,
dtype=np.dtype('uint32')))

@staticmethod
def slice_mlir(ctx, x, start_indices, limit_indices, strides):
aval_out, = ctx.avals_out
def slice_mlir(ctx, aval_out, x, start_indices, limit_indices, strides) -> mlir.ir.Value:
key_shape = aval_out.dtype.impl.key_shape
trailing_zeros = [0] * len(key_shape)
trailing_ones = [1] * len(key_shape)
start_indices = (*start_indices, *trailing_zeros)
limit_indices = (*limit_indices, *key_shape)
strides = (*strides, *trailing_ones)
return mhlo.SliceOp(x,
mlir.dense_int_elements(start_indices),
mlir.dense_int_elements(limit_indices),
mlir.dense_int_elements(strides)).results
physical_aval_out, = KeyTyRules.physical_avals(aval_out)
return mlir.slice_op(ctx, x, physical_aval_out,
start_indices=start_indices, limit_indices=limit_indices, strides=strides)

@staticmethod
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
aval_out, = ctx.avals_out
def dynamic_slice_mlir(ctx, aval_out, x, start_indices) -> mlir.ir.Value:
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
key_shape = aval_out.dtype.impl.key_shape
trailing_zeros = [mlir.ir_constant(np.array(0, dtype))] * len(key_shape)
start_indices = (*start_indices, *trailing_zeros)
slice_sizes_ = mlir.dense_int_elements((*slice_sizes, *key_shape))
return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).results
physical_aval_out, = KeyTyRules.physical_avals(aval_out)
return mlir.dynamic_slice(ctx, physical_aval_out, x,
start_indices=start_indices)

@staticmethod
def dynamic_update_slice_mlir(ctx, x, update, *start_indices):
aval_out, = ctx.avals_out
def dynamic_update_slice_mlir(ctx, aval_out, x, update, *start_indices) -> mlir.ir.Value:
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
key_shape = aval_out.dtype.impl.key_shape
zeros = [mlir.ir_constant(np.array(0, dtype=dtype))] * len(key_shape)
start_indices = (*start_indices, *zeros)
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
start_indices).results
physical_aval_out, = KeyTyRules.physical_avals(aval_out)
return mlir.dynamic_update_slice(ctx, physical_aval_out, x, update,
start_indices=start_indices)

@staticmethod
def broadcast_in_dim_mlir(ctx, x, *dyn_shape, shape, broadcast_dimensions):
if dyn_shape: raise NotImplementedError
aval_out, = ctx.avals_out
def broadcast_in_dim_mlir(ctx, aval_out, x,
broadcast_dimensions) -> mlir.ir.Value:
key_shape = aval_out.dtype.impl.key_shape
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
broadcast_dimensions = [*broadcast_dimensions, *trailing_dims]
return mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.dense_int_elements(broadcast_dimensions)).results
physical_aval_out, = KeyTyRules.physical_avals(aval_out)
return mlir.broadcast_in_dim(ctx, x, physical_aval_out, broadcast_dimensions=broadcast_dimensions)

@staticmethod
def transpose_mlir(ctx, x, *, permutation):
aval_out, = ctx.avals_out
def transpose_mlir(ctx, aval_out, x, *, permutation) -> mlir.ir.Value:
key_shape = aval_out.dtype.impl.key_shape
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
perm = [*permutation, *trailing_dims]
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).results
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).result

@staticmethod
def gather_mlir(ctx, x, indices, *,
def gather_mlir(ctx, avals_in, aval_out, x, indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
aval_x, aval_indices = ctx.avals_in
aval_y, = ctx.avals_out
indices_are_sorted, mode, fill_value) -> mlir.ir.Value:
aval_x, aval_indices = avals_in
aval_y = aval_out
key_shape = aval_x.dtype.impl.key_shape
trailing_offset_dims = [aval_y.ndim + i for i in range(len(key_shape))]
dimension_numbers = dimension_numbers._replace(
Expand All @@ -466,10 +460,11 @@ def gather_mlir(ctx, x, indices, *,
lax_internal.slicing._gather_lower, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
return mlir.delegate_lowering(
res, = mlir.delegate_lowering(
ctx, gather_lower, x, indices,
avals_in=[keys_aval_to_base_arr_aval(aval_x), aval_indices],
avals_out=[keys_aval_to_base_arr_aval(aval_y)])
return res


class KeyTy:
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -693,6 +693,7 @@ def _out_type(jax_type):
if "in_shardings" in lowered.compile_args:
args_tf = tuple(
map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"]))

res = tfxla.call_module(
args_tf,
version=xla_call_module_version,
Expand Down
29 changes: 26 additions & 3 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Expand Up @@ -29,12 +29,11 @@
from jax.experimental import jax2tf
from jax.experimental.jax2tf import shape_poly
from jax import lax
from jax import linear_util as lu
import jax.numpy as jnp
from jax import random
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
from jax._src import util
import numpy as np

from jax.experimental.jax2tf.tests import tf_test_util
Expand Down Expand Up @@ -894,6 +893,30 @@ def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2]
self.assertAllClose(g_tf[0][1], (xi * 2).astype(yf.dtype))
self.assertAllClose(g_tf[1], np.zeros_like(zb))

def test_prng(self):
# The PRNG implementation uses opaque types, test shape polymorphism
try:
prev_custom_prng = config.jax_enable_custom_prng
config.update("jax_enable_custom_prng", True)

def f_jax(x): # x: f32[b1, b2]
key = random.PRNGKey(123) # key
# Exercise key operations that have custom lowering rules
broadcast_keys = lax.broadcast_in_dim(key, x.shape, ()) # key[b1, b2]
gather_keys = lax.broadcast_in_dim(broadcast_keys[0], (1, x.shape[1]), (1,)) # : key[1, b2]
slice_keys1 = lax.slice(broadcast_keys, (0, 0), (1, x.shape[1]), (1, 1)) # key[1, b2]
slice_keys2 = lax.dynamic_slice(broadcast_keys, (0, 0), slice_sizes=(1, x.shape[1])) # key[1, b2]
upd1 = lax.dynamic_update_slice(slice_keys2, slice_keys1, start_indices=(0, 0)) # key[1, b2]
_ = lax.dynamic_update_slice(upd1, gather_keys, start_indices=(0, 0))
return x

self.CheckShapePolymorphism(f_jax,
input_signature=[tf.TensorSpec([None, None], dtype=tf.float32)],
polymorphic_shapes=["b1, b2"])
finally:
config.update("jax_enable_custom_prng", prev_custom_prng)


def test_saved_model(self):
f_jax = jnp.sin
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
Expand Down Expand Up @@ -2113,7 +2136,7 @@ class ShapePolyVmapPrimitivesTest(tf_test_util.JaxToTfTestCase):
# to parameterized below.
@primitive_harness.parameterized(
_flatten_harnesses(_POLY_SHAPE_VMAP_TEST_HARNESSES),
#one_containing="gather_from_slicing_name=[0,1]_enable_xla=True_poly_axes=[0]"
one_containing=""
)
def test_vmap_prim(self, harness: Harness):
return _test_one_harness(self, harness)
Expand Down
63 changes: 57 additions & 6 deletions jax/interpreters/mlir.py
Expand Up @@ -1253,6 +1253,10 @@ def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
broadcast_dimensions) -> ir.Value:
# Lower a possibly-dynamic broadcast_in_dim
if core.is_opaque_dtype(aval_out.dtype): # type: ignore
return aval_out.dtype._rules.broadcast_in_dim_mlir( # type: ignore
ctx, aval_out, op,
broadcast_dimensions=broadcast_dimensions)
if not core.is_constant_shape(aval_out.shape): # type: ignore
shape = eval_dynamic_shape(ctx, aval_out.shape) # type: ignore
return mhlo.DynamicBroadcastInDimOp(
Expand Down Expand Up @@ -1284,19 +1288,66 @@ def multi_broadcast_in_dim(ctx: LoweringRuleContext,
return out

def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Value:
aval_out_shape = aval_out.shape # type: ignore
if not core.is_constant_shape(aval_out_shape):
if core.is_opaque_dtype(aval_out.dtype): # type: ignore
# TODO(necula)
raise NotImplementedError("reshaping opaque types")
shape = eval_dynamic_shape(ctx, aval_out_shape)
if core.is_opaque_dtype(aval_out.dtype): # type: ignore
aval_out, = aval_out.dtype._rules.physical_avals(aval_out) # type: ignore
if not core.is_constant_shape(aval_out.shape): # type: ignore
shape = eval_dynamic_shape(ctx, aval_out.shape) # type: ignore
return mhlo.DynamicReshapeOp(
aval_to_ir_type(aval_out), op,
shape_tensor(shape),
).result
else:
return mhlo.ReshapeOp(aval_to_ir_type(aval_out), op).result

def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
start_indices, limit_indices, strides) -> ir.Value:
if core.is_opaque_dtype(aval_out.dtype):
return [aval_out.dtype._rules.slice_mlir(
ctx, aval_out, x, start_indices, limit_indices, strides)]

if any(not core.is_constant_shape(s) for s in (start_indices, limit_indices, strides)):
start_indices = eval_dynamic_shape(ctx, start_indices)
limit_indices = eval_dynamic_shape(ctx, limit_indices)
strides = eval_dynamic_shape(ctx, strides)
return mhlo.RealDynamicSliceOp(aval_to_ir_type(aval_out),
x,
shape_tensor(start_indices),
shape_tensor(limit_indices),
shape_tensor(strides)).result
else:
return mhlo.SliceOp(x,
dense_int_elements(start_indices),
dense_int_elements(limit_indices),
dense_int_elements(strides)).result

def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
start_indices) -> ir.Value:
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.dynamic_slice_mlir(ctx, aval_out, x,
start_indices)
slice_sizes = aval_out.shape
if not core.is_constant_shape(slice_sizes):
slice_sizes = eval_dynamic_shape(ctx, slice_sizes)
return mhlo.RealDynamicSliceOp(
aval_to_ir_type(aval_out), x,
shape_tensor(start_indices),
shape_tensor(slice_sizes),
shape_tensor([1] * len(slice_sizes))
).result
else:
return mhlo.DynamicSliceOp(x, start_indices,
dense_int_elements(slice_sizes)).result

def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
start_indices) -> ir.Value:
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.dynamic_update_slice_mlir(
ctx, aval_out, x, update, *start_indices)

# TODO(necula): handle dynamic shapes
return mhlo.DynamicUpdateSliceOp(aval_to_ir_type(aval_out), x, update,
start_indices).result

def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value:
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
zero = ir_constant(np.array(value, aval.dtype))
Expand Down

0 comments on commit 27f5bd0

Please sign in to comment.