Skip to content

Commit

Permalink
[jax2tf] Improves support for lax.gather (enable_xla=False).
Browse files Browse the repository at this point in the history
We were initially only handling `mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS` (from #10653), but there are two problems with this:

* In eager or graph mode, TF throws an error for out-of-bounds indices. This was undocumented, so I've added a limitation and documented this in the g4doc.

* `PROMISE_IN_BOUNDS` is semantically the same as `FILL_OR_DROP` (for the forward pass). In fact. JAX's `.set()` uses `FILL_OR_DROP`. I've now changed it so that both modes are supported (only `CLIP` is not supported).

I have also added more tests for OOB behavior for enable_xla = False, and a few more tests.

PiperOrigin-RevId: 470215381
  • Loading branch information
marcvanzee authored and jax authors committed Aug 26, 2022
1 parent 3c91787 commit 6dd425c
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 168 deletions.
17 changes: 15 additions & 2 deletions jax/experimental/jax2tf/g3doc/no_xla_limitations.md
Expand Up @@ -187,6 +187,19 @@ This op is called by `lax.scatter`, `lax.scatter_min`, `lax.scatter_max`,
We support all these ops for unique indices. For non-unique indices we
support (min,max,mul,add) for single depth scatters.

We implement support for this op through
[tf.tensor_scatter_nd_update](https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update).

There are a few more limitations:
* the GatherScatterMode must be PROMISE_IN_BOUNDS.
* dtypes `np.bool` and `jnp.complex*` are not supported.

* Dtypes `np.bool` and `jnp.complex*` are not supported.
* We disallow scatter mode `lax.GatherScatterMode.CLIP` because it may lead to
incorrect behavior for out-of-bounds indices (see next point).
* The behavior for out-of-bounds scatter indices is as follows:
- When running in eager or graph mode, it throws an error. This is because
`tf.scatter` throws an error as well. If this is problematic for your use
case, please let us know and we can add more support for this.
- When running in compile mode, the out-of-bounds indices are dropped, which
is the behavior of both `lax.GatherScatterMode.FILL_OR_DROP` and
`lax.GatherScatterMode.PROMISE_IN_BOUNDS`. This is why `CLIP` is not
allowed.
152 changes: 74 additions & 78 deletions jax/experimental/jax2tf/impl_no_xla.py
Expand Up @@ -51,6 +51,8 @@ def _error(primitive_name: str, suffix_msg: str = "") -> Exception:

_conv_error = lambda msg: _error("conv_general_dilated", msg)
_reduce_error = lambda msg: _error("reduce_window", msg)
_scatter_error = lambda msg: _error("scatter_(update/add/multiply/min/max)", msg
)

def _unimplemented(name):

Expand Down Expand Up @@ -1016,46 +1018,32 @@ def _dynamic_update_slice(operand, update, *start_indices,
tf_impl_no_xla[lax.dynamic_update_slice_p] = _dynamic_update_slice


def shift_axes_forward(operand, axes: tuple, inverse: bool=False,
forward: bool=True):
def shift_axes_forward(operand,
axes: Tuple[int, ...],
inverse: bool = False,
forward: bool = True):
"""Shifts the tuple of axes to the front of an array"""
other_axes = tuple([i for i in range(len(operand.shape)) if i not in axes])
fwd_order = axes + other_axes if forward else other_axes + axes
order = fwd_order if not inverse else _invert_permutation(fwd_order)
return tf.transpose(operand, order)

def convert_scatter_jax_to_tf(update_op, unsorted_segment_op=None):
def error(msg):
suffix = ("See source code for the precise conditions under which "
"scatter_(update/add/multiply/min/max) ops can be converted without XLA.")
return _error("scatter_(update/add/multiply/min/max)", f"{msg} - {suffix}")

def _sparse_scatter(
operand,
scatter_indices,
updates,
update_jaxpr,
update_consts,
dimension_numbers,
indices_are_sorted: bool,
unique_indices: bool,
mode,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""
Implementation of scatter specialised to indexing from the
front axes. This covers unique indices and non-unique indices
of single depth.
Note on unique indices: `tf.tensor_scatter_nd_update` interprets
indices thusly: every axis except the final one encodes a batch
dimension, the final axis encoding the actual indices to scatter in to.
It enforces, at least one, batch dimension so we add an empty
dimension to indices and updates if lacking.
Note on non-unique indices: There is no tf op for non single depth
indexing. But if indexing is single depth, this can be viewed as a
segment op.

def _sparse_scatter(operand, scatter_indices, updates, unique_indices, mode,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implementation of scatter specialised to indexing from the front axes.
This covers unique indices and non-unique indices of single depth.
Note on unique indices: `tf.tensor_scatter_nd_update` interprets indices
thusly: every axis except the final one encodes a batch dimension, the final
axis encoding the actual indices to scatter in to. It enforces, at least
one, batch dimension so we add an empty dimension to indices and updates if
lacking.
Note on non-unique indices: There is no tf op for non-single depth indexing,
but if indexing is single depth, this can be viewed as a segment op.
"""
# Infer unique indices from lack of batch dimension
unique_indices = unique_indices or (len(scatter_indices.shape) == 1)
Expand All @@ -1068,72 +1056,80 @@ def _sparse_scatter(
updated_suboperand = updated_suboperand[None]
y = tf.tensor_scatter_nd_update(operand, scatter_indices, updated_suboperand)
else:
if (scatter_indices.shape[-1] == 1) and (unsorted_segment_op != None):
if (scatter_indices.shape[-1] == 1) and unsorted_segment_op:
# If only indexing into the first dimension, it's a segment op
operand_update = unsorted_segment_op(updates, tf.squeeze(scatter_indices, -1), operand.shape[0])
operand_update = unsorted_segment_op(updates,
tf.squeeze(scatter_indices, -1),
operand.shape[0])
y = update_op(operand, operand_update)
else:
raise error("Scatter supports unique indices. Scatter also supports non-unique indices with indexing into only one dimension for (add, mul, min, max)")
raise _scatter_error(
"Scatter only supports non-unique "
"indices with indexing into only one dimension for (add, mul, min, "
"max)")
return y

def sparse_scatter(
operand,
scatter_indices,
updates,
update_jaxpr,
update_consts,
dimension_numbers,
indices_are_sorted: bool,
unique_indices: bool,
mode,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
def sparse_scatter(operand, scatter_indices, updates, update_jaxpr,
update_consts, dimension_numbers, indices_are_sorted: bool,
unique_indices: bool, mode,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""
Wrapper around the scatter function.
The underlying tf ops `tf.tensor_scatter_nd_update` and
`tf.math.unsorted_segment_*` index from the front dimensions.
`tf.math.unsorted_segment_*` indexs to a depth 1 from the front.
`tf.tensor_scatter_nd_update` indexs from the front dimensions onwards
, with no ability to skip a dimension. This function
shifts the axes to be indexed to the front then calls a front-specific
implementation, then inverse-shifts the output.
`tf.tensor_scatter_nd_update` indexs from the front dimensions onwards,
with no ability to skip a dimension. This function shifts the axes to be
indexed to the front then calls a front-specific implementation, then
inverse-shifts the output.
scatter_dims_to_operand_dims: dimensions which the scatter indexes in to.
We shift these to the front to match tf syntax. All other dims are batch
update_window_dims: dimensions which are not batch dimensions. We shift
these to the back as the remaining dimensions are batch dimensions.
"""
ud = dimension_numbers.update_window_dims
wd = dimension_numbers.inserted_window_dims
sd = dimension_numbers.scatter_dims_to_operand_dims
dtype = operand.dtype # assume updates has same dtype as operand
del update_jaxpr, update_consts, indices_are_sorted # Unused arguments

update_window_dims = dimension_numbers.update_window_dims
inserted_window_dims = dimension_numbers.inserted_window_dims
scatter_to_operand_dims = dimension_numbers.scatter_dims_to_operand_dims

dtype = operand.dtype # assume updates has same dtype as operand
if dtype in [tf.bool, tf.complex64]:
raise error(f"Scatter does not support operands of type {dtype}")
if not (wd == sd):
raise error("Complex scatters are not supported")
if not (mode == lax.GatherScatterMode.PROMISE_IN_BOUNDS):
raise error("Only scatter mode `PROMISE_IN_BOUNDS` is supported")
raise _scatter_error(f"Scatter does not support operands of type {dtype}")

if inserted_window_dims != scatter_to_operand_dims:
raise _scatter_error("Complex scatters are not supported")

if (mode != lax.GatherScatterMode.FILL_OR_DROP and
mode != lax.GatherScatterMode.PROMISE_IN_BOUNDS):
# The OOB behavior for tf.scatter is as follows:
# - When running in eager or graph mode, it throws an error.
# TODO(marcvanzee): Fix this case by removing the OOB indices.
# - When running in compile mode, the OOB indices are dropped, which is
# the same behavior as FILL_OR_DROP and PROMISE_IN_BOUNDS.
# To ensure correctness, we disallow CLIP mode for now.
raise _scatter_error("Only scatter modes `FILL_OR_DROP` and "
"`PROMISE_IN_BOUNDS` are supported.")

# Shift axes to the front to match tf syntax, inverse afterwards
fwd = partial(shift_axes_forward, axes=sd)
fwd = partial(shift_axes_forward, axes=scatter_to_operand_dims)
inv = partial(fwd, inverse=True)
# shift update value axes to the back, so batch are at the front
updates_shifted = shift_axes_forward(updates, axes=ud, forward=False)
return inv(_sparse_scatter(
fwd(operand),
scatter_indices,
updates_shifted,
update_jaxpr,
update_consts,
dimension_numbers,
indices_are_sorted,
unique_indices,
mode,
_in_avals,
_out_aval,
))

# Shift update value axes to the back, so batch are at the front
updates_shifted = shift_axes_forward(
updates, axes=update_window_dims, forward=False)

return inv(
_sparse_scatter(
fwd(operand), scatter_indices, updates_shifted, unique_indices,
mode, _in_avals, _out_aval))
return sparse_scatter

tf_impl_no_xla[lax.scatter_p] = convert_scatter_jax_to_tf(lambda x,y: y) # just replace with the update

tf_impl_no_xla[lax.scatter_p] = convert_scatter_jax_to_tf(
lambda x, y: y) # just replace with the update
tf_impl_no_xla[lax.scatter_add_p] = convert_scatter_jax_to_tf(tf.add, tf.math.unsorted_segment_sum)
tf_impl_no_xla[lax.scatter_mul_p] = convert_scatter_jax_to_tf(tf.multiply, tf.math.unsorted_segment_prod)
tf_impl_no_xla[lax.scatter_min_p] = convert_scatter_jax_to_tf(tf.minimum, tf.math.unsorted_segment_min)
Expand Down
34 changes: 25 additions & 9 deletions jax/experimental/jax2tf/tests/jax2tf_limitations.py
Expand Up @@ -134,11 +134,10 @@ def limitations_for_harness(
"random_categorical", "random_uniform", "random_randint",
"reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
"reduce_window_mul", "reduce_window_min", "reduce_window_max", "real",
"reshape", "rev", "rsqrt", "scatter_max", "scatter_min", "select_n",
"select_and_scatter_add", "shift_left", "shift_right_logical",
"shift_right_arithmetic", "sign", "sin", "sinh", "slice", "sqrt",
"squeeze", "stop_gradient", "sub", "tie_in", "transpose", "xor",
"zeros_like"
"reshape", "rev", "rsqrt", "select_n", "select_and_scatter_add",
"shift_left", "shift_right_logical", "shift_right_arithmetic", "sign",
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub",
"tie_in", "transpose", "xor", "zeros_like"
}

@classmethod
Expand Down Expand Up @@ -970,17 +969,34 @@ def round(cls, harness: primitive_harness.Harness):
modes=("eager", "graph"))
]

@classmethod
def scatter(cls, harness):
return [
Jax2TfLimitation(
"out-of-bounds scatters are not supported in graph and eager mode",
dtypes=jtu.dtypes.all_inexact,
devices=("cpu", "gpu", "tpu"),
modes=("eager", "graph"),
expect_tf_error=True,
skip_comparison=True,
enabled=("modes_out_of_bounds" in harness.name and not harness.params["enable_xla"])),
]

@classmethod
def scatter_add(cls, harness):
return []
return cls.scatter(harness)

@classmethod
def scatter_mul(cls, harness):
return []
return cls.scatter(harness)

@classmethod
def scatter(cls, harness):
return []
def scatter_max(cls, harness):
return cls.scatter(harness)

@classmethod
def scatter_min(cls, harness):
return cls.scatter(harness)

@classmethod
def select_and_gather_add(cls, harness):
Expand Down

0 comments on commit 6dd425c

Please sign in to comment.