Skip to content

Commit

Permalink
Implements more slicing support for lax.gather in jax2tf when enable_…
Browse files Browse the repository at this point in the history
…xla=False, and adds tests.

PiperOrigin-RevId: 380753589
  • Loading branch information
marcvanzee authored and jax authors committed Jun 22, 2021
1 parent 6382cee commit e502760
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 37 deletions.
112 changes: 76 additions & 36 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2214,20 +2214,67 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers):
return proto


def _gather_without_xla(operand: TfVal,
start_indices: TfVal, *,
dimension_numbers, slice_sizes,
_in_avals: Sequence[core.ShapedArray]):
# Attempt to use tf.gather for lax.gather_p.
def _clip(max_indices: Sequence[TfVal], start_indices: TfVal, slice_sizes):
"""Simulates XLA clipping behavior with TF ops.
Various TF ops have different clipping behavior than XLA:
* If `start_indices` is OOB, then TF fails but XLA clips the indices to
[0, max_len].
* If `start_indices + slice_size` is OOB, then TF fails, but XLA adjust
`start_indices` so that a full slice is returned.
This function clips the start indices correctly.
"""
max_start = tf.subtract(max_indices, slice_sizes)
# If `start_indices` and `slice_sizes` are Python tuples of integers,
# `tf.subtract` returns a Tensor of dtype tf.int32, which may conflict with
# the dtype of `start_indices` if we run in x64 mode and throw an error when
# calling `tf.clip_by_vaue`. Therefore we cast to the right dtype here
# explicitly.
max_start = tf.cast(max_start, dtype=start_indices.dtype)
return tf.clip_by_value(start_indices, 0, max_start)


def _gather_using_tf_slice(operand: TfVal, start_indices: TfVal, *,
dimension_numbers, slice_sizes,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implements 'scalar indexing into arrays' cases of lax.gather using tf.slice.
E.g., op[2], op[:, :5, :], jnp.take(op, 0, axis=0).
"""
op_shape = _in_avals[0].shape
indices = tf.expand_dims(dimension_numbers.start_index_map, 1)
# lax.gather uses an "index map" which maps `start_indices` to the right axes
# in `operand`. Since tf.strided_slice uses a single array for specifying the
# start indices, we use a scatter to map the start indices to the right axes.
begin = tf.scatter_nd(indices, start_indices, [len(op_shape)])
begin = _clip(_eval_shape(op_shape), begin, slice_sizes)
end = slice_sizes + begin

# Convert from tuple of dimensions to shrink mask. e.g. (0, 2) --> 5.
shrink_mask = sum(2 ** x for x in dimension_numbers.collapsed_slice_dims)
res = tf.strided_slice(operand, begin, end, shrink_axis_mask=shrink_mask)
# Shape inference doesn't work for tf.strided_slice.
res.set_shape(_aval_to_tf_shape(_out_aval))
return res

# Handle only the case when batch_dims=0.

def _gather_using_tf_gather(operand: TfVal, start_indices: TfVal, *,
dimension_numbers, slice_sizes,
_in_avals: Sequence[core.ShapedArray]):
"""Implements 'multi-dimensional indexing into arrays' cases of lax.gather using tf.gather.
E.g., jnp.take(op, [[0], [1]], axis=0).
"""
# Handle only the case when tf.gather argument batch_dims=0.
# Find axis to match the tf.gather semantics
# Let I = len(start_indices_shape)
# let O = len(op_shape)
# slice_sizes == op_shape[:axis] + (1,) + op_shape[axis+1:]
# collapsed_slice_dims == (axis,)
# start_index_map == (axis,)
# offset_dims == (0, 1, ..., axis - 1, axis + I, ..., O + I - 1)
# We added a trailing dimension of size 1
op_shape = _in_avals[0].shape
start_indices_shape = _in_avals[1].shape
assert len(op_shape) == len(slice_sizes)
Expand Down Expand Up @@ -2259,14 +2306,9 @@ def _gather_without_xla(operand: TfVal,
"gather",
f"unexpected slice_sizes {slice_sizes} != {expected_slice_sizes}")

start_indices_reshaped = tf.reshape(start_indices,
_eval_shape(start_indices_shape[0:-1]))
start_indices_clipped = tf.clip_by_value(
start_indices_reshaped,
tf.constant(0, dtype=start_indices_reshaped.dtype),
tf.subtract(tf.cast(_eval_shape(op_shape)[axis], start_indices_reshaped.dtype),
tf.constant(1, dtype=start_indices_reshaped.dtype)))
return tf.gather(operand, start_indices_clipped, axis=axis, batch_dims=0)
squeezed_indices = tf.squeeze(start_indices, -1)
start_indices = _clip(_eval_shape(op_shape)[axis], squeezed_indices, 1)
return tf.gather(operand, start_indices, axis=axis, batch_dims=0)


@partial(bool_to_int8, argnums=[0])
Expand All @@ -2284,9 +2326,25 @@ def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
out.set_shape(_aval_to_tf_shape(_out_aval))
return out

return _gather_without_xla(operand, start_indices,
dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, _in_avals=_in_avals)
# TODO(marcvanzee): Check if we need more tests in shape_poly for gather with
# enable_xla=False.

if len(_in_avals[1].shape) == 1:
# Use tf.slice if `start_indices` is a 1D array.
try:
return _gather_using_tf_slice(operand, start_indices,
dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes,
_in_avals=_in_avals,
_out_aval=_out_aval)
except NotImplementedError:
# If `_gather_using_tf_slice` fails, don't give up yet.
pass

return _gather_using_tf_gather(operand, start_indices,
dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes,
_in_avals=_in_avals)


tf_impl_with_avals[lax.gather_p] = _gather
Expand Down Expand Up @@ -2321,26 +2379,8 @@ def _dynamic_slice(operand, *start_indices, slice_sizes,
res.set_shape(_aval_to_tf_shape(_out_aval))
return res

# If XLA is disabled, we use `tf.slice` as a fallback, which has different out
# of bounds (OOB) behavior than `lax.dynamic_slice_p`:
# * If `slice size > max_len`, then both `tf.slice` and `lax.dynamic_slice_p`
# fail, so we ignore this case.
# * If `start_indices` is OOB, then `tf.slice` fails but `lax.dynamic_slice_p`
# clips the indices to [0, max_len].
# * If `start_indices + slice_size` is OOB, then `tf.slice` fails, but
# `lax.dynamic_slice_p` adjust `start_indices` so that a full slice is
# returned.
# The code below manually clips the start indices so that the behavior is
# the same as `lax.dynamic_slice_p`.
operand_shape = _eval_shape(_in_avals[0].shape)
max_start = tf.subtract(operand_shape, slice_sizes)
# If `operand_shape` and `slice_sizes` are Python tuples of integers,
# `tf.subtract` returns a Tensor of dtype tf.int32, which may conflict with
# the dtype of `start_indices` if we run in x64 mode and throw an error when
# calling `tf.clip_by_vaue`. Therefore we cast to the right dtype here
# explicitly.
max_start = tf.cast(max_start, dtype=start_indices.dtype)
start_indices = tf.clip_by_value(start_indices, 0, max_start)
start_indices = _clip(operand_shape, start_indices, slice_sizes)
return tf.slice(operand, start_indices, size=slice_sizes)


Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/jax2tf_limitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def dot_general(cls, harness: primitive_harness.Harness):
jnp.bfloat16, np.float16, np.float32, np.complex64
],
devices="gpu",
modes="compiled",
modes=("eager", "graph", "compiled"),
enabled=(harness.params["preferred_element_type"] is not None),
skip_comparison=True)
]
Expand Down
18 changes: 18 additions & 0 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,24 @@ def _make_broadcast_harness(name, *, dtype=np.float32, shape=(2,), sizes=()):
enable_xla=enable_xla,
index_oob=index_oob)

# Construct gather harnesses using array indexing and slicing.
for slices, name in [
((0,), "[0]"),
((0, 1), "[0,1]"),
((slice(0, 10), 2, slice(0, 10)), "[:,:2,:]"),
((slice(2, 5), 5), "[2:5,5]"),
((-1, -5, -200), "[-1,-5,-200]"),
((slice(5, -2), 300), "[5:-2,300]"),
]:
for enable_xla in [False, True]:
define(
lax.gather_p,
f"from_slicing_name={name}_enable_xla={enable_xla}",
lambda arr, *s: jnp.array(arr).__getitem__(*s),
[_gather_input, StaticArg(slices)],
dtype=_gather_input.dtype,
enable_xla=enable_xla)

# Directly from lax.gather in lax_test.py.
for shape, idxs, dnums, slice_sizes, needs_xla in [
((5,), np.array([[0], [2]]),
Expand Down

0 comments on commit e502760

Please sign in to comment.