Skip to content

Commit

Permalink
Add gather/scatter mode support to jax2tf.
Browse files Browse the repository at this point in the history
Use xla.lower_fun() to implement gather/scatter modes so we can share the implementation between the XLA translation and jax2tf.

Add an undocumented "fill" mode to jnp.take() that corresponds to the "fill" mode of `lax.gather`.

PiperOrigin-RevId: 407169324
  • Loading branch information
hawkinsp authored and jax authors committed Nov 2, 2021
1 parent e28bb23 commit 6a44baf
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 80 deletions.
137 changes: 75 additions & 62 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4735,45 +4735,32 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers,
return tuple(next(slice_sizes) if i in offset_dims
else next(indices_shape) for i in range(output_shape_rank))

def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *,
dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted,
mode, fill_value):
operand_aval, indices_aval = avals_in
aval_out, = avals_out
c = ctx.builder
dimensions = _gather_dimensions_proto(indices_aval.shape, dimension_numbers)
if (mode == GatherScatterMode.CLIP or
mode == GatherScatterMode.PROMISE_IN_BOUNDS):
# XLA's Gather has clamp semantics, so we can just call it directly.
return [xops.Gather(operand, indices, dimensions, slice_sizes,
indices_are_sorted=indices_are_sorted)]

# Otherwise, we need to mask out out-of-bounds indices and replace those
# slices with `fill_value`.
assert mode == GatherScatterMode.FILL_OR_DROP, mode

def _shape_as_value(shape):
"""Converts a shape that may contain Poly values into a JAX value."""
dims = [
expand_dims(convert_element_type(core.dimension_as_value(d), np.int64),
(0,))
for d in shape
]
return concatenate(dims, dimension=0)

def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
unique_indices, indices_are_sorted, fill_value,
output_shape):
"""Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking."""
dnums = dimension_numbers
intarray = partial(np.array, dtype=np.int64)
operand_dims = intarray(operand_aval.shape)
indices = xops.ConvertElementType(
indices, xla.dtype_to_primitive_type(dtypes.canonicalize_dtype(np.int64)))
num_batch_dims = len(indices_aval.shape) - 1

upper_bound = operand_dims[intarray(dnums.start_index_map)]
upper_bound -= intarray(slice_sizes)[intarray(dnums.start_index_map)]
mask = xops.And(xops.Ge(indices, xla.pyval_to_ir_constant(c, intarray(0))),
xops.Le(indices, xla.pyval_to_ir_constant(c, upper_bound),
broadcast_dimensions=[num_batch_dims]))

# Compute the conjunction of the mask elements across the dimensions in which
# we are slicing.
and_builder = xc.XlaBuilder("and_reduction")
scalar_pred = xla_client.Shape.array_shape(np.dtype(np.bool_), ())
xops.And(xb.parameter(and_builder, 0, scalar_pred),
xb.parameter(and_builder, 1, scalar_pred))
mask = xops.Reduce(c, [mask], [xla.pyval_to_ir_constant(c, True)],
and_builder.build(), [num_batch_dims])
operand_dims = _shape_as_value(operand.shape)
indices = convert_element_type(indices, np.int64)
num_batch_dims = len(indices.shape) - 1

upper_bound = (operand_dims[intarray(dnums.start_index_map)] -
intarray(slice_sizes)[intarray(dnums.start_index_map)])
mask = bitwise_and(
ge(indices, np.int64(0)),
le(indices, expand_dims(upper_bound, tuple(range(num_batch_dims)))))
mask = _reduce_and(mask, [num_batch_dims])


# Computes the output shape and the positions of the batch dimensions in the
# output
Expand All @@ -4783,13 +4770,36 @@ def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *,

# We don't consume unique_indices directly in gather(), only in its transpose
# (scatter).
return [xops.Select(
xops.BroadcastInDim(mask, aval_out.shape, batch_dims_in_output),
xops.Gather(operand, indices, dimensions, slice_sizes,
indices_are_sorted=indices_are_sorted),
xops.Broadcast(
xla.pyval_to_ir_constant(c, np.array(fill_value, operand_aval.dtype)),
aval_out.shape))]
gather_out = gather(operand, indices, dnums, slice_sizes,
indices_are_sorted=indices_are_sorted,
mode=GatherScatterMode.PROMISE_IN_BOUNDS)
return select(
broadcast_in_dim(mask, output_shape, batch_dims_in_output),
gather_out, full_like(gather_out, fill_value=fill_value))


def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *,
dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted,
mode, fill_value):
aval_out, = avals_out
if mode == GatherScatterMode.FILL_OR_DROP:
gather_fill_fn = xla.lower_fun(_gather_fill, multiple_results=False,
new_style=True)
return gather_fill_fn(
ctx, avals_in, avals_out, operand, indices,
dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
fill_value=fill_value, output_shape=aval_out.shape)

operand_aval, indices_aval = avals_in
dimensions = _gather_dimensions_proto(indices_aval.shape, dimension_numbers)
assert (mode == GatherScatterMode.CLIP or
mode == GatherScatterMode.PROMISE_IN_BOUNDS), mode
# XLA's Gather has clamp semantics, so we can just call it directly.
return [xops.Gather(operand, indices, dimensions, slice_sizes,
indices_are_sorted=indices_are_sorted)]


def _gather_jvp_rule(g, operand, indices, *, dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted, mode,
Expand Down Expand Up @@ -5026,40 +5036,41 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr,

return operand.shape

def _clamp_scatter_indices(c, indices, operand_shape, updates_shape, dnums):

def _clamp_scatter_indices(operand, indices, updates, *, dnums):
"""Clamps `indices` to be in-range for a scatter."""
indices_shape = c.get_shape(indices)
indices_dtype = indices_shape.numpy_dtype()
intarray = partial(np.array, dtype=np.int64)
operand_dims = intarray(operand_shape)
operand_dims = intarray(operand.shape)
upper_bound = operand_dims[intarray(dnums.scatter_dims_to_operand_dims)]

slice_sizes = []
pos = 0
for i in range(len(operand_shape)):
for i in range(len(operand.shape)):
if i in dnums.inserted_window_dims:
slice_sizes.append(1)
else:
slice_sizes.append(updates_shape[dnums.update_window_dims[pos]])
slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
pos += 1

upper_bound -= intarray(slice_sizes)[intarray(dnums.scatter_dims_to_operand_dims)]
upper_bound = np.minimum(upper_bound, np.iinfo(indices_dtype).max)
return xops.Min(
xops.Max(xla.pyval_to_ir_constant(c, np.array(0, dtype=indices_dtype)),
indices),
xla.pyval_to_ir_constant(c, upper_bound.astype(indices_dtype)),
broadcast_dimensions=[len(indices_shape.dimensions()) - 1])
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
upper_bound = broadcast_in_dim(upper_bound, indices.shape,
(len(indices.shape) - 1,))
return clamp(np.int64(0), convert_element_type(indices, np.int64),
upper_bound)

def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices,
updates, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted,
unique_indices, mode):
c = ctx.builder
operand_aval, indices_aval, updates_aval = avals_in
if mode == GatherScatterMode.CLIP:
indices = _clamp_scatter_indices(c, indices, operand_aval.shape,
updates_aval.shape, dimension_numbers)
clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False,
new_style=True)
indices, = clip_fn(ctx, avals_in, [indices_aval.update(dtype=np.int64)],
operand, indices, updates, dnums=dimension_numbers)

c = ctx.builder

init_value = xla.pyval_to_ir_constant(c, np.array(0, operand_aval.dtype))
update_computation = _reduction_computation(
Expand All @@ -5073,11 +5084,13 @@ def _scatter_add_translation_rule(
ctx, avals_in, avals_out, operand, indices, updates, *, update_jaxpr,
update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode,
expand_complex128=False):
c = ctx.builder
operand_aval, indices_aval, updates_aval = avals_in
if mode == GatherScatterMode.CLIP:
indices = _clamp_scatter_indices(
c, indices, operand_aval.shape, updates_aval.shape, dimension_numbers)
clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False,
new_style=True)
indices, = clip_fn(ctx, avals_in, [indices_aval.update(dtype=np.int64)],
operand, indices, updates, dnums=dimension_numbers)

dtype = operand_aval.dtype
scatter_dims = _scatter_dimensions_proto(
indices_aval.shape, dimension_numbers)
Expand Down
15 changes: 13 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5432,14 +5432,25 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
axis_idx = _canonicalize_axis(axis, ndim(a))

if mode is None:
# TODO(phawkins): change default mode to "fill" and delete this case.
# lax.gather() does not support negative indices, so we wrap them here
indices = where(indices < 0, indices + a.shape[axis_idx], indices)
gather_mode = lax.GatherScatterMode.CLIP
elif mode == "raise":
# TODO(phawkins): we have no way to report out of bounds errors yet.
raise NotImplementedError("The 'raise' mode to jnp.take is not supported.")
elif mode == "wrap":
indices = mod(indices, _constant_like(indices, a.shape[axis_idx]))
elif mode != "clip":
gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS
elif mode == "fill":
# Undocumented non-standard mode corresponding to the fill_or_drop mode on
# lax.gather()
gather_mode = lax.GatherScatterMode.FILL_OR_DROP
# lax.gather() does not support negative indices, so we wrap them here
indices = where(indices < 0, indices + a.shape[axis_idx], indices)
elif mode == "clip":
gather_mode = lax.GatherScatterMode.CLIP
else:
raise ValueError("Invalid mode '{}' for np.take".format(mode))

index_dims = len(shape(indices))
Expand All @@ -5463,7 +5474,7 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
start_index_map=(axis_idx,))
return lax.gather(a, indices[..., None], dimension_numbers=dnums,
slice_sizes=tuple(slice_sizes),
mode="clip")
mode=gather_mode)


def _normalize_index(index, axis_size):
Expand Down
11 changes: 7 additions & 4 deletions jax/experimental/jax2tf/impl_no_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,11 +702,14 @@ def _gather(operand, start_indices, *, dimension_numbers,
fill_value, _in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Tensorflow implementation of gather."""
del unique_indices, fill_value

if mode == lax.GatherScatterMode.FILL_OR_DROP:
raise NotImplementedError("FILL_OR_DROP gather mode is not implemented in "
"jax2tf")
gather_fill_fn = jax2tf._convert_jax_impl(lax._gather_fill,
multiple_results=False)
return gather_fill_fn(
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, fill_value=fill_value,
output_shape=_out_aval.shape, _in_avals=_in_avals, _out_aval=_out_aval)

# TODO(marcvanzee): Check if we need more tests in shape_poly for gather with
# enable_xla=False.
Expand Down
10 changes: 6 additions & 4 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2021,11 +2021,13 @@ def _gather(operand, start_indices, *, dimension_numbers, slice_sizes: core.Shap
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Tensorflow implementation of gather."""
del unique_indices, fill_value

if mode == lax.GatherScatterMode.FILL_OR_DROP:
raise NotImplementedError("FILL_OR_DROP gather mode is not implemented in "
"jax2tf")
gather_fill_fn = _convert_jax_impl(lax._gather_fill, multiple_results=False)
return gather_fill_fn(
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, fill_value=fill_value,
output_shape=_out_aval.shape, _in_avals=_in_avals, _out_aval=_out_aval)

proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
slice_sizes_tf = _eval_shape(slice_sizes)
Expand Down
19 changes: 11 additions & 8 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,14 +1128,17 @@ def _make_broadcast_in_dim_harness(name,
]:
for axis in [0, 1, 2]:
for enable_xla in [True, False]:
define(
lax.gather_p,
f"from_take_indices_name={indices_name}_axis={axis}_enable_xla={enable_xla}",
lambda a, i, axis: jnp.take(a, i, axis=axis),
[_gather_input, indices, StaticArg(axis)],
dtype=_gather_input.dtype,
enable_xla=enable_xla,
index_oob=index_oob)
for mode in ["clip", "fill"]:
define(
lax.gather_p,
f"from_take_indices_name={indices_name}_axis={axis}"
f"_enable_xla={enable_xla}_mode={mode}",
lambda a, i, axis: jnp.take(a, i, axis=axis, mode=mode),
[_gather_input, indices, StaticArg(axis)],
dtype=_gather_input.dtype,
enable_xla=enable_xla,
index_oob=index_oob,
mode=mode)

# Construct gather harnesses using array indexing and slicing.
for slices, name in [
Expand Down

0 comments on commit 6a44baf

Please sign in to comment.