Skip to content

Commit

Permalink
Add workaround for SelectAndScatter padding bug on CPU and GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Mar 10, 2021
1 parent 3e45a83 commit 62a726d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
35 changes: 31 additions & 4 deletions jax/_src/lax/lax.py
Expand Up @@ -5462,15 +5462,34 @@ def _select_and_scatter_add_shape_rule(

def _select_and_scatter_add_translation(
c, source, operand, *, select_prim, window_dimensions, window_strides,
padding):
dtype = c.get_shape(operand).numpy_dtype()
padding, expand_padding):
shape = c.get_shape(operand)
dtype = shape.numpy_dtype()
scalar = ShapedArray((), dtype)
select = xla.primitive_subcomputation(select_prim, scalar, scalar)
scatter = xla.primitive_subcomputation(add_p, scalar, scalar)
zero = xb.constant(c, np.array(0, dtype))
return xops.SelectAndScatterWithGeneralPadding(
# TODO(b/161704903): remove this workaround when XLA:CPU bug is fixed.
expand_padding = (expand_padding and
not all(lo == 0 and hi == 0 for (lo, hi) in padding))
if expand_padding:
original_padding = padding
identity = (_get_max_identity if select_prim is ge_p
else _get_min_identity)
pads = [(lo, hi, 0) for (lo, hi) in padding]
operand = xops.Pad(operand, xb.constant(c, identity(dtype)),
xc.make_padding_config(pads))
padding = [(0, 0) for _ in padding]
output = xops.SelectAndScatterWithGeneralPadding(
operand, select, window_dimensions, window_strides, padding, source, zero,
scatter)
if expand_padding:
start_indices = [lo for (lo, hi) in original_padding]
stop_indices = [lo + d for ((lo, hi), d) in zip(original_padding,
shape.dimensions())]
output = xops.Slice(output, start_indices, stop_indices,
[1] * len(start_indices))
return output

def _select_and_scatter_add_jvp(
primals, tangents, *, select_prim, window_dimensions, window_strides,
Expand Down Expand Up @@ -5517,13 +5536,21 @@ def _select_and_scatter_add_batch_rule(

select_and_scatter_add_p = standard_primitive(
_select_and_scatter_add_shape_rule, _input_dtype, 'select_and_scatter_add',
_select_and_scatter_add_translation)
partial(_select_and_scatter_add_translation, expand_padding=False))

ad.primitive_transposes[select_and_scatter_add_p] = \
_select_and_scatter_add_transpose
ad.primitive_jvps[select_and_scatter_add_p] = _select_and_scatter_add_jvp
batching.primitive_batchers[select_and_scatter_add_p] = \
_select_and_scatter_add_batch_rule

# TODO(b/161704903): workaround for XLA/CPU crash.
xla.backend_specific_translations['cpu'][select_and_scatter_add_p] = partial(
_select_and_scatter_add_translation, expand_padding=True)
# TODO(b/182390722): workaround for XLA/GPU crash.
xla.backend_specific_translations['gpu'][select_and_scatter_add_p] = partial(
_select_and_scatter_add_translation, expand_padding=True)

def _select_and_gather_add_shape_rule(
tangents, operand, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
Expand Down
9 changes: 4 additions & 5 deletions tests/lax_autodiff_test.py
Expand Up @@ -675,16 +675,14 @@ def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
[(4, 6)],
[(2, 1), (1, 2)],
[(1, 1), (2, 1), (1, 2)],
# TODO(b/161704903): explicit paddings segfault on CPU.
["VALID", "SAME"], #, [(0, 3), (1, 2)]],
["VALID", "SAME", [(0, 3), (1, 2)]],
[(1, 1)] + ([(2, 3)] if op is lax.add else []),
[(1, 1)] + ([(1, 2)] if op is lax.add else [])),
itertools.product(
[(3, 2, 4, 6)],
[(1, 1, 2, 1), (2, 1, 2, 1)],
[(1, 2, 2, 1), (1, 1, 1, 1)],
# TODO(b/161704903): explicit paddings segfault on CPU.
["VALID", "SAME"], # [(0, 1), (1, 0), (2, 3), (0, 2)]],
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
[(1, 1, 1, 1)] + ([(2, 1, 3, 2)] if op is lax.add else []),
[(1, 1, 1, 1)] + ([(1, 2, 2, 1)] if op is lax.add else []))))
for dtype in dtypes))
Expand All @@ -702,7 +700,8 @@ def testReduceWindowGrad(
# depends on FLAGS for the device under test.
# TODO(b/31565929): enable when fixed.
if jtu.device_under_test() == "tpu" and op is not lax.add:
if len(shape) != 4 or dims != (1, 1, 2, 1):
if (len(shape) != 4 or dims != (1, 1, 2, 1)
or not isinstance(padding, str)):
raise SkipTest("Only R4 SelectAndScatter implemented on TPU")

# TODO(b/73062247): need variadic reduce-window for better precision.
Expand Down

0 comments on commit 62a726d

Please sign in to comment.