From 25a93c2f87e19234c0a7e7014fd3b893d9ec0b0e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 22 Feb 2019 07:55:36 -0500 Subject: [PATCH] Add new functions `jax.ops.index_add` and `jax.ops.index_update` for NumPy-style indexed updates. Create a new library `jax.ops` for user-facing ops that don't exist in NumPy or SciPy. Progress on issue #101. Fixes #122. --- docs/jax.rst | 1 + jax/ops/__init__.py | 17 ++ jax/ops/scatter.py | 246 ++++++++++++++++++ tests/lax_numpy_indexing_test.py | 411 +++++++++++++++++++------------ 4 files changed, 515 insertions(+), 160 deletions(-) create mode 100644 jax/ops/__init__.py create mode 100644 jax/ops/scatter.py diff --git a/docs/jax.rst b/docs/jax.rst index 744aa0c61345..8a84eb992343 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -11,6 +11,7 @@ Subpackages jax.scipy jax.experimental jax.lax + jax.ops jax.random Module contents diff --git a/jax/ops/__init__.py b/jax/ops/__init__.py new file mode 100644 index 000000000000..42ae2f9eb311 --- /dev/null +++ b/jax/ops/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from .scatter import index, index_add, index_update \ No newline at end of file diff --git a/jax/ops/scatter.py b/jax/ops/scatter.py new file mode 100644 index 000000000000..13b34834606f --- /dev/null +++ b/jax/ops/scatter.py @@ -0,0 +1,246 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Helpers for indexed updates. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as onp + +from ..abstract_arrays import ShapedArray, ConcreteArray +from .. import core +from .. import lax +from ..numpy import lax_numpy as np + + +def _scatter_update(x, idx, y, scatter_op): + """Helper for indexed updates. + + Computes the value of x that would result from computing:: + x[idx] op= y + except in a pure functional way, with no in-place updating. + + Support NumPy-style basic indexing only, i.e., `idx` must be + `None`, an integer, a `slice` object, or ellipses, or a tuple of the above. + + TODO(phawkins): support advanced indexing. + """ + + x = np.asarray(x) + y = np.asarray(y) + x_shape = np.shape(x) + y_shape = np.shape(y) + y = lax.convert_element_type(y, lax._dtype(x)) + + if not isinstance(idx, tuple): + idx = (idx,) + + # Test for unsupported advanced indexing and report an error. + if any(onp.ndim(elt) != 0 for elt in idx): + raise NotImplementedError("Unimplemented case for indexed update. Advanced " + "indexing is not yet implemented.") + + # Remove ellipses and add trailing slice(None)s. + idx = np._canonicalize_tuple_index(x, idx) + + _int = lambda aval: not aval.shape and onp.issubdtype(aval.dtype, onp.integer) + + x_axis = 0 + y_axis = 0 # Current axis in y, before collapsing. See below. + collapsed_y_axis = 0 # Current axis in y, after collapsing. + + # Scatter dimension numbers. + update_window_dims = [] + inserted_window_dims = [] + scatter_dims_to_operand_dims = [] + + scatter_indices = np.zeros((0,), dtype=np.int32) + + # We perform three transformations to y before the scatter op, in order: + # First, y is broadcast to slice_shape. In general `y` only need broadcast to + # the right shape. + slice_shape = [] + # Next, y is reshaped to collapsed_slice_shape. This is to handle `None` + # indices, which the scatter cannot remove itself. + collapsed_slice_shape = [] + # Finally, we reverse reversed_y_dims to handle slices with negative strides. + reversed_y_dims = [] + + for i in idx: + try: + abstract_i = core.get_aval(i) + except TypeError: + abstract_i = None + if (isinstance(abstract_i, ConcreteArray) or + isinstance(abstract_i, ShapedArray)) and _int(abstract_i): + i = np.mod(i, np._constant_like(i, x.shape[x_axis])) + i = lax.convert_element_type(i, np.int32) + i = np.broadcast_to(i, tuple(scatter_indices.shape[:-1]) + (1,)) + scatter_indices = np.concatenate((scatter_indices, i), -1) + inserted_window_dims.append(x_axis) + scatter_dims_to_operand_dims.append(x_axis) + x_axis += 1 + elif i is None: + slice_shape.append(1) + y_axis += 1 + elif np._is_slice_none(i): + slice_shape.append(x_shape[x_axis]) + collapsed_slice_shape.append(x_shape[x_axis]) + update_window_dims.append(collapsed_y_axis) + collapsed_y_axis += 1 + y_axis += 1 + x_axis += 1 + elif isinstance(i, slice): + start, limit, stride, needs_rev = np._static_idx(i, x.shape[x_axis]) + if needs_rev: + reversed_y_dims.append(collapsed_y_axis) + if stride == 1: + i = lax.convert_element_type(start, np.int32) + i = np.broadcast_to(i, tuple(scatter_indices.shape[:-1]) + (1,)) + scatter_indices = np.concatenate((scatter_indices, i), -1) + slice_shape.append(limit - start) + collapsed_slice_shape.append(limit - start) + update_window_dims.append(collapsed_y_axis) + scatter_dims_to_operand_dims.append(x_axis) + else: + i = np.arange(start, limit, stride, dtype=np.int32) + size = i.shape[0] + slice_shape.append(size) + collapsed_slice_shape.append(size) + scatter_indices_shape = tuple(scatter_indices.shape[:-1]) + (size,) + i = lax.broadcast_in_dim( + i, shape=scatter_indices_shape + (1,), + broadcast_dimensions=(len(scatter_indices_shape) - 1,)) + scatter_indices = lax.broadcast_in_dim( + scatter_indices, + shape=scatter_indices_shape + (len(scatter_dims_to_operand_dims),), + broadcast_dimensions=( + tuple(range(len(scatter_indices_shape) - 1)) + + (len(scatter_indices_shape),))) + scatter_indices = np.concatenate( + (scatter_indices, i), len(scatter_indices_shape)) + scatter_dims_to_operand_dims.append(x_axis) + inserted_window_dims.append(x_axis) + + collapsed_y_axis += 1 + y_axis += 1 + x_axis += 1 + else: + raise IndexError("Unknown index type ", i) + + y = np.broadcast_to(y, tuple(slice_shape)) + y = lax.reshape(y, collapsed_slice_shape) + if reversed_y_dims: + y = lax.rev(y, reversed_y_dims) + + dnums = lax.ScatterDimensionNumbers( + update_window_dims = tuple(update_window_dims), + inserted_window_dims = tuple(inserted_window_dims), + scatter_dims_to_operand_dims = tuple(scatter_dims_to_operand_dims) + ) + return scatter_op(x, scatter_indices, y, dnums) + + +class _Indexable(object): + """Helper object for building indexes for indexed update functions. + + This is a singleton object that overrides the :code:`__getitem__` method + to return the index it is passed. + + >>> jax.ops.index[1:2, 3, None, ..., ::2] + (slice(1, 2, None), 3, None, Ellipsis, slice(None, None, 2)) + """ + __slots__ = () + + def __getitem__(self, index): + return index + +#: Index object singleton +index = _Indexable() + + +def index_add(x, idx, y): + """Pure equivalent of :code:`x[idx] += y`. + + Returns the the value of `x` that would result from the + NumPy-style :mod:`indexed assignment `:: + x[idx] += y + + Note the `index_add` operator is pure; `x` itself is + not modified, instead the new value that `x` would have taken is returned. + + Unlike the NumPy code :code:`x[idx] += y`, if multiple indices refer to the + same location the updates will be summed. (NumPy would only apply the last + update, rather than summing the updates.) The order in which conflicting + updates are applied is implementation-defined and may be nondeterministic + (e.g., due to concurrency on some hardware platforms). + + Args: + x: an array. + idx: a Numpy-style basic index, consisting of `None`, integers, `slice` + objects, ellipses, or a tuple of the above. A convenient syntactic sugar + for forming indices is via the :data:`jax.ops.index` object. + y: the array of updates. `y` must be broadcastable to the shape of the + array that would be returned by `x[idx]`. + + Returns: + An array. + + >>> x = jax.numpy.ones((5, 6)) + >>> jax.ops.index_add(x, jax.ops.index[2:4, 3:], 6.) + array([[1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 7., 7., 7.], + [1., 1., 1., 7., 7., 7.], + [1., 1., 1., 1., 1., 1.]], dtype=float32) + """ + return _scatter_update(x, idx, y, lax.scatter_add) + +def index_update(x, idx, y): + """Pure equivalent of :code:`x[idx] = y`. + + Returns the the value of `x` that would result from the + NumPy-style :mod:`indexed assignment `:: + x[idx] += y + + Note the `index_update` operator is pure; `x` itself is + not modified, instead the new value that `x` would have taken is returned. + + Unlike NumPy's :code:`x[idx] = y`, if multiple indices refer to the same + location it is undefined which update is chosen; JAX may choose the order of + updates arbitrarily and nondeterministically (e.g., due to concurrent + updates on some hardware platforms). + + Args: + x: an array. + idx: a Numpy-style basic index, consisting of `None`, integers, `slice` + objects, ellipses, or a tuple of the above. A convenient syntactic sugar + for forming indices is via the :data:`jax.ops.index` object. + y: the array of updates. `y` must be broadcastable to the shape of the + array that would be returned by `x[idx]`. + + Returns: + An array. + + >>> x = jax.numpy.ones((5, 6)) + >>> jax.ops.index_update(x, jax.ops.index[::2, 3:], 6.) + array([[1., 1., 1., 6., 6., 6.], + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 6., 6., 6.], + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 6., 6., 6.]], dtype=float32) + """ + return _scatter_update(x, idx, y, lax.scatter) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 9386f3e03996..286f216c80a6 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -17,9 +17,10 @@ from __future__ import print_function import collections +import enum from functools import partial import itertools -from unittest import skip +import unittest from absl.testing import absltest from absl.testing import parameterized @@ -29,6 +30,7 @@ from jax import api from jax import lax from jax import numpy as lnp +from jax import ops from jax import test_util as jtu from jax.config import config @@ -59,102 +61,172 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): jtu.check_vjp(f, partial(api.vjp, f), args, atol, rtol, eps) +STATIC_INDEXING_TESTS = [ + ("OneIntIndex", [ + IndexSpec(shape=(3,), indexer=1), + IndexSpec(shape=(3, 3), indexer=0), + IndexSpec(shape=(3, 4, 5), indexer=2), + IndexSpec(shape=(3,), indexer=-1), + IndexSpec(shape=(3,), indexer=-2), + ]), + ("TwoIntIndices", [ + IndexSpec(shape=(3, 3), indexer=(2, 1)), + IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), + IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), + ]), + ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), + ("OneSliceIndex", [ + IndexSpec(shape=(10,), indexer=slice(1, 3)), + IndexSpec(shape=(10,), indexer=slice(1, -1)), + IndexSpec(shape=(10,), indexer=slice(None, -1)), + IndexSpec(shape=(10,), indexer=slice(None, None, None)), + IndexSpec(shape=(10, 8), indexer=slice(1, 3)), + IndexSpec(shape=(10, 8), indexer=slice(1, None)), + IndexSpec(shape=(10, 8), indexer=slice(None, 3)), + IndexSpec(shape=(10, 8), indexer=slice(-3, None)), + ]), + ("OneSliceIndexNegativeStride", [ + IndexSpec(shape=(10,), indexer=slice(3, 1, -1)), + IndexSpec(shape=(10,), indexer=slice(1, 8, -1)), # empty result + IndexSpec(shape=(10,), indexer=slice(None, 1, -2)), + IndexSpec(shape=(10,), indexer=slice(None, None, -1)), + IndexSpec(shape=(10, 8), indexer=slice(3, 1, -1)), + IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1)), # empty result + IndexSpec(shape=(10, 8), indexer=slice(None, None, -1)), + ]), + ("OneSliceIndexNonUnitStride", [ + IndexSpec(shape=(10,), indexer=slice(0, 8, 2)), + IndexSpec(shape=(10,), indexer=slice(0, 8, 3)), + IndexSpec(shape=(10,), indexer=slice(1, 3, 2)), + IndexSpec(shape=(10,), indexer=slice(1, None, 2)), + IndexSpec(shape=(10,), indexer=slice(None, 1, -2)), + IndexSpec(shape=(10, 8), indexer=slice(1, 8, 3)), + IndexSpec(shape=(10, 8), indexer=slice(None, None, 2)), + IndexSpec(shape=(10, 8), indexer=slice(None, 1, -2)), + IndexSpec(shape=(10, 8), indexer=slice(None, None, -2)), + ]), + ("TwoSliceIndices", [ + IndexSpec(shape=(10, 8), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(10, 8), indexer=(slice(1, None), slice(None, 2))), + IndexSpec( + shape=(10, 8), indexer=(slice(None, None, -1), slice(None, 2))), + IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, None))), + IndexSpec(shape=(10, 8, 3), indexer=(slice(1, None), slice(0, 2))), + ]), + ("OneColonIndex", [ + IndexSpec(shape=(3,), indexer=slice(None)), + IndexSpec(shape=(3, 4), indexer=slice(None)), + ]), + ("MultipleColonIndices", [ + IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None))), + IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None))), + ]), + ("MixedSliceIndices", [ + IndexSpec(shape=(10, 4), indexer=(slice(None), slice(0, 2))), + IndexSpec(shape=(10, 4), indexer=(1, slice(None))), + ]), + ("EllipsisIndex", [ + IndexSpec(shape=(3,), indexer=Ellipsis), + IndexSpec(shape=(3, 4), indexer=Ellipsis), + IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis)), + IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3)), + ]), + ("NoneIndex", [ + IndexSpec(shape=(), indexer=None), + IndexSpec(shape=(), indexer=(None, None)), + IndexSpec(shape=(), indexer=(Ellipsis, None)), + IndexSpec(shape=(3,), indexer=None), + IndexSpec(shape=(3, 4), indexer=None), + IndexSpec(shape=(3, 4), indexer=(Ellipsis, None)), + IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis)), + IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis)), + ]), + ("EmptyIndex", [ + IndexSpec(shape=(), indexer=()), + IndexSpec(shape=(3,), indexer=()), + IndexSpec(shape=(3, 4), indexer=()), + ]), +] + +STATIC_INDEXING_GRAD_TESTS = [ + ("OneIntIndex", [ + IndexSpec(shape=(3,), indexer=1), + IndexSpec(shape=(3, 3), indexer=0), + IndexSpec(shape=(3, 4, 5), indexer=2), + IndexSpec(shape=(3,), indexer=-1), + IndexSpec(shape=(3,), indexer=-2), + ]), + ("TwoIntIndices", [ + IndexSpec(shape=(3, 3), indexer=(2, 1)), + IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), + IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), + ]), + ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), + ("OneSliceIndex", [ + IndexSpec(shape=(5,), indexer=slice(1, 3)), + IndexSpec(shape=(5,), indexer=slice(1, -1)), + IndexSpec(shape=(5,), indexer=slice(None, -1)), + IndexSpec(shape=(5,), indexer=slice(None, None, None)), + IndexSpec(shape=(5, 4), indexer=slice(1, 3)), + IndexSpec(shape=(5, 4), indexer=slice(1, None)), + IndexSpec(shape=(5, 4), indexer=slice(None, 3)), + IndexSpec(shape=(5, 4), indexer=slice(-3, None)), + ]), + ("TwoSliceIndices", [ + IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(5, 4), indexer=(slice(1, None), slice(None, 2))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, None), slice(0, 2))), + ]), + ("OneColonIndex", [ + IndexSpec(shape=(3,), indexer=slice(None)), + IndexSpec(shape=(3, 4), indexer=slice(None)), + ]), + ("MultipleColonIndices", [ + IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None))), + IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None))), + ]), + ("MixedSliceIndices", [ + IndexSpec(shape=(5, 4), indexer=(slice(None), slice(0, 2))), + IndexSpec(shape=(5, 4), indexer=(1, slice(None))), + ]), + ("EllipsisIndex", [ + IndexSpec(shape=(3,), indexer=Ellipsis), + IndexSpec(shape=(3, 4), indexer=Ellipsis), + IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis)), + IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3)), + ]), + ("NoneIndex", [ + IndexSpec(shape=(), indexer=None), + IndexSpec(shape=(), indexer=(None, None)), + IndexSpec(shape=(), indexer=(Ellipsis, None)), + IndexSpec(shape=(3,), indexer=None), + IndexSpec(shape=(3, 4), indexer=None), + IndexSpec(shape=(3, 4), indexer=(Ellipsis, None)), + IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis)), + IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis)), + ]), + # TODO(mattjj): these fail for uninteresting dtype reasons + # ("EmptyIndex", + # [IndexSpec(shape=(), indexer=()), + # IndexSpec(shape=(3,), indexer=()), + # IndexSpec(shape=(3, 4), indexer=()), + # ]), +] + class IndexingTest(jtu.JaxTestCase): """Tests for Numpy indexing translation rules.""" - @parameterized.named_parameters({ + @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string( shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer - } for name, index_specs in [ - ("OneIntIndex", [ - IndexSpec(shape=(3,), indexer=1), - IndexSpec(shape=(3, 3), indexer=0), - IndexSpec(shape=(3, 4, 5), indexer=2), - IndexSpec(shape=(3,), indexer=-1), - IndexSpec(shape=(3,), indexer=-2), - ]), - ("TwoIntIndices", [ - IndexSpec(shape=(3, 3), indexer=(2, 1)), - IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), - IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), - ]), - ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), - ("OneSliceIndex", [ - IndexSpec(shape=(10,), indexer=slice(1, 3)), - IndexSpec(shape=(10,), indexer=slice(1, -1)), - IndexSpec(shape=(10,), indexer=slice(None, -1)), - IndexSpec(shape=(10,), indexer=slice(None, None, None)), - IndexSpec(shape=(10, 8), indexer=slice(1, 3)), - IndexSpec(shape=(10, 8), indexer=slice(1, None)), - IndexSpec(shape=(10, 8), indexer=slice(None, 3)), - IndexSpec(shape=(10, 8), indexer=slice(-3, None)), - ]), - ("OneSliceIndexNegativeStride", [ - IndexSpec(shape=(10,), indexer=slice(3, 1, -1)), - IndexSpec(shape=(10,), indexer=slice(1, 8, -1)), # empty result - IndexSpec(shape=(10,), indexer=slice(None, 1, -2)), - IndexSpec(shape=(10,), indexer=slice(None, None, -1)), - IndexSpec(shape=(10, 8), indexer=slice(3, 1, -1)), - IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1)), # empty result - IndexSpec(shape=(10, 8), indexer=slice(None, None, -1)), - ]), - ("OneSliceIndexNonUnitStride", [ - IndexSpec(shape=(10,), indexer=slice(0, 8, 2)), - IndexSpec(shape=(10,), indexer=slice(0, 8, 3)), - IndexSpec(shape=(10,), indexer=slice(1, 3, 2)), - IndexSpec(shape=(10,), indexer=slice(1, None, 2)), - IndexSpec(shape=(10,), indexer=slice(None, 1, -2)), - IndexSpec(shape=(10, 8), indexer=slice(1, 8, 3)), - IndexSpec(shape=(10, 8), indexer=slice(None, None, 2)), - IndexSpec(shape=(10, 8), indexer=slice(None, 1, -2)), - IndexSpec(shape=(10, 8), indexer=slice(None, None, -2)), - ]), - ("TwoSliceIndices", [ - IndexSpec(shape=(10, 8), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(10, 8), indexer=(slice(1, None), slice(None, 2))), - IndexSpec( - shape=(10, 8), indexer=(slice(None, None, -1), slice(None, 2))), - IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, None))), - IndexSpec(shape=(10, 8, 3), indexer=(slice(1, None), slice(0, 2))), - ]), - ("OneColonIndex", [ - IndexSpec(shape=(3,), indexer=slice(None)), - IndexSpec(shape=(3, 4), indexer=slice(None)), - ]), - ("MultipleColonIndices", [ - IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None))), - IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None))), - ]), - ("MixedSliceIndices", [ - IndexSpec(shape=(10, 4), indexer=(slice(None), slice(0, 2))), - IndexSpec(shape=(10, 4), indexer=(1, slice(None))), - ]), - ("EllipsisIndex", [ - IndexSpec(shape=(3,), indexer=Ellipsis), - IndexSpec(shape=(3, 4), indexer=Ellipsis), - IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3)), - ]), - ("NoneIndex", [ - IndexSpec(shape=(), indexer=None), - IndexSpec(shape=(), indexer=(None, None)), - IndexSpec(shape=(), indexer=(Ellipsis, None)), - IndexSpec(shape=(3,), indexer=None), - IndexSpec(shape=(3, 4), indexer=None), - IndexSpec(shape=(3, 4), indexer=(Ellipsis, None)), - IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis)), - ]), - ("EmptyIndex", [ - IndexSpec(shape=(), indexer=()), - IndexSpec(shape=(3,), indexer=()), - IndexSpec(shape=(3, 4), indexer=()), - ]), - ] for shape, indexer in index_specs for dtype in all_dtypes - for rng in [jtu.rand_default()]) - @jtu.skip_on_devices("tpu") + } for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer in index_specs + for dtype in all_dtypes + for rng in [jtu.rand_default()])) def testStaticIndexing(self, shape, dtype, rng, indexer): args_maker = lambda: [rng(shape, dtype)] fun = lambda x: x[indexer] @@ -166,74 +238,10 @@ def testStaticIndexing(self, shape, dtype, rng, indexer): jtu.format_shape_dtype_string( shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer - } for name, index_specs in [ - ("OneIntIndex", [ - IndexSpec(shape=(3,), indexer=1), - IndexSpec(shape=(3, 3), indexer=0), - IndexSpec(shape=(3, 4, 5), indexer=2), - IndexSpec(shape=(3,), indexer=-1), - IndexSpec(shape=(3,), indexer=-2), - ]), - ("TwoIntIndices", [ - IndexSpec(shape=(3, 3), indexer=(2, 1)), - IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), - IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), - ]), - ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), - ("OneSliceIndex", [ - IndexSpec(shape=(5,), indexer=slice(1, 3)), - IndexSpec(shape=(5,), indexer=slice(1, -1)), - IndexSpec(shape=(5,), indexer=slice(None, -1)), - IndexSpec(shape=(5,), indexer=slice(None, None, None)), - IndexSpec(shape=(5, 4), indexer=slice(1, 3)), - IndexSpec(shape=(5, 4), indexer=slice(1, None)), - IndexSpec(shape=(5, 4), indexer=slice(None, 3)), - IndexSpec(shape=(5, 4), indexer=slice(-3, None)), - ]), - ("TwoSliceIndices", [ - IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(5, 4), indexer=(slice(1, None), slice(None, 2))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, None), slice(0, 2))), - ]), - ("OneColonIndex", [ - IndexSpec(shape=(3,), indexer=slice(None)), - IndexSpec(shape=(3, 4), indexer=slice(None)), - ]), - ("MultipleColonIndices", [ - IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None))), - IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None))), - ]), - ("MixedSliceIndices", [ - IndexSpec(shape=(5, 4), indexer=(slice(None), slice(0, 2))), - IndexSpec(shape=(5, 4), indexer=(1, slice(None))), - ]), - ("EllipsisIndex", [ - IndexSpec(shape=(3,), indexer=Ellipsis), - IndexSpec(shape=(3, 4), indexer=Ellipsis), - IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3)), - ]), - ("NoneIndex", [ - IndexSpec(shape=(), indexer=None), - IndexSpec(shape=(), indexer=(None, None)), - IndexSpec(shape=(), indexer=(Ellipsis, None)), - IndexSpec(shape=(3,), indexer=None), - IndexSpec(shape=(3, 4), indexer=None), - IndexSpec(shape=(3, 4), indexer=(Ellipsis, None)), - IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis)), - ]), - # TODO(mattjj): these fail for uninteresting dtype reasons - # ("EmptyIndex", - # [IndexSpec(shape=(), indexer=()), - # IndexSpec(shape=(3,), indexer=()), - # IndexSpec(shape=(3, 4), indexer=()), - # ]), - ] for shape, indexer in index_specs for dtype in float_dtypes - for rng in [jtu.rand_default()]) - @jtu.skip_on_devices("tpu") + } for name, index_specs in STATIC_INDEXING_GRAD_TESTS + for shape, indexer in index_specs + for dtype in float_dtypes + for rng in [jtu.rand_default()]) def testStaticIndexingGrads(self, shape, dtype, rng, indexer): tol = 1e-2 if onp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) @@ -322,7 +330,7 @@ def fun(x, unpacked_indexer): args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self._CompileAndCheck(fun, args_maker, check_dtypes=True) - @skip + @unittest.skip @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), @@ -644,5 +652,88 @@ def testIssue187(self): self.assertAllClose(ans, expected, check_dtypes=False) +def _broadcastable_shapes(shape): + """Returns all shapes that broadcast to `shape`.""" + def f(rshape): + yield [] + if rshape: + yield from (rshape[0:1] + list(s) for s in f(rshape[1:])) + if rshape[0] != 1: + yield from ([1] + s for s in f(rshape[1:])) + yield from (list(reversed(x)) for x in f(list(reversed(shape)))) + +def _update_shape(shape, indexer): + return onp.zeros(shape)[indexer].shape + + +class UpdateOps(enum.Enum): + UPDATE = 0 + ADD = 1 + +class IndexedUpdateTest(jtu.JaxTestCase): + + @parameterized.named_parameters(jtu.cases_from_list({ + "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( + name, jtu.format_shape_dtype_string(shape, dtype), indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), + "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer, + "update_shape": update_shape, "update_dtype": update_dtype, + "op": op + } for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer in index_specs + for op in [UpdateOps.UPDATE, UpdateOps.ADD] + for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) + for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) + for update_dtype in ([dtype] if op == UpdateOps.ADD else all_dtypes) + for rng in [jtu.rand_default()])) + def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, + rng, indexer, op): + if FLAGS.jax_test_dut == "cpu" and not shape: + # TODO(b/127315062): this case causes an XLA crash on CPU. Reenable when + # fixed. + raise unittest.SkipTest("Test case crashes on CPU") + args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] + def onp_fn(x, y): + x = x.copy() + if op == UpdateOps.UPDATE: + x[indexer] = y + else: + x[indexer] += y + return x + + jax_op = ops.index_update if op == UpdateOps.UPDATE else ops.index_add + jax_fn = lambda x, y: jax_op(x, indexer, y) + self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) + self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) + + + @parameterized.named_parameters(jtu.cases_from_list({ + "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( + name, jtu.format_shape_dtype_string(shape, dtype), indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), + "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer, + "update_shape": update_shape, "update_dtype": update_dtype, + "op": op + } for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer in index_specs + for op in [UpdateOps.UPDATE, UpdateOps.ADD] + for dtype in float_dtypes + for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) + for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes) + for rng in [jtu.rand_default()])) + def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, + rng, indexer, op): + if FLAGS.jax_test_dut == "cpu" and not shape: + # TODO(b/127315062): this case causes an XLA crash on CPU. Reenable when + # fixed. + raise unittest.SkipTest("Test case crashes on CPU") + + jax_op = ops.index_update if op == UpdateOps.UPDATE else ops.index_add + jax_fn = lambda x, y: jax_op(x, indexer, y) + x = rng(shape, dtype) + y = rng(update_shape, update_dtype) + check_grads(jax_fn, (x, y), 2, eps=1.) + + if __name__ == "__main__": absltest.main()