Skip to content

Commit

Permalink
Merge pull request #429 from hawkinsp/docs
Browse files Browse the repository at this point in the history
Add docstrings for lax.gather/scatter.
  • Loading branch information
hawkinsp committed Feb 22, 2019
2 parents 6a843aa + 969391e commit e2681ab
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jax.numpy package
flip
fliplr
flipud
float_power
floor
floor_divide
fmod
Expand Down
93 changes: 86 additions & 7 deletions jax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,12 +503,52 @@ def dynamic_update_slice(operand, update, start_indices):
return dynamic_update_slice_p.bind(operand, update, start_indices,
update_shape=update.shape)

def gather(operand, start_indices, dimension_numbers=None, slice_sizes=None):
def gather(operand, start_indices, dimension_numbers, slice_sizes):
"""Gather operator.
Wraps `XLA's Gather operator
<https://www.tensorflow.org/xla/operation_semantics#gather>`_.
The semantics of gather are complicated, and its API might change in the
future. For most use cases, you should prefer `Numpy-style indexing
<https://docs.scipy.org/doc/numpy-1.16.0/reference/arrays.indexing.html>`_
(e.g., `x[:, (1,4,7), ...]`), rather than using `gather` directly.
Args:
operand: an array from which slices should be taken
start_indices: the indices at which slices should be taken
dimension_numbers: a `lax.GatherDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices` and the output relate.
slice_sizes: the size of each slice. Must be a sequence of non-negative
integers with size equal to `ndim(operand)`.
Returns:
An array containing the gather output.
"""
return gather_p.bind(
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=tuple(slice_sizes), operand_shape=operand.shape)

def scatter_add(operand, scatter_indices, updates, dimension_numbers=None):
def scatter_add(operand, scatter_indices, updates, dimension_numbers):
"""Scatter operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_.
The semantics of scatter are complicated and its API is subject to change.
Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
relate.
Returns:
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = _reduction_jaxpr(add, _const(operand, 0))
return scatter_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
Expand Down Expand Up @@ -2364,10 +2404,30 @@ def _dynamic_update_slice_translation_rule(c, operand, update, start_indices,



GatherDimensionNumbers = collections.namedtuple(
class GatherDimensionNumbers(collections.namedtuple(
"GatherDimensionNumbers",
["offset_dims", "collapsed_slice_dims", "start_index_map",
"index_vector_dim"])
"index_vector_dim"])):
"""
Describes the dimension number arguments to an `XLA's Gather operator
<https://www.tensorflow.org/xla/operation_semantics#gather>`_. See the XLA
documentation for more details of what the dimension numbers mean.
Args:
offset_dims: the set of dimensions in the `gather` output that offset into
an array sliced from `operand`. Must be a tuple of integers in ascending
order, each representing a dimension number of the output.
collapsed_slice_dims: the set of dimensions `i` in `operand` that have
`slice_sizes[i] == 1` and that should not have a corresponding dimension
in the output of the gather. Must be a tuple of integers in ascending
order.
start_index_map: for each dimension in `start_indices`, gives the
corresponding dimension in `operand` that is to be sliced. Must be a
tuple of integers with size equal to `ndim(start_indices)`.
index_vector_dim: describes which dimension of `start_indices` "contains"
the start indices. If equal to `len(start_indices)` the indices are
taken to be scalars.
"""

def _gather_dimensions_proto(dimension_numbers):
assert type(dimension_numbers) is GatherDimensionNumbers
Expand Down Expand Up @@ -2506,10 +2566,29 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
batching.primitive_batchers[gather_p] = _gather_batching_rule


ScatterDimensionNumbers = collections.namedtuple(
class ScatterDimensionNumbers(collections.namedtuple(
"ScatterDimensionNumbers",
["update_window_dims", "inserted_window_dims",
"scatter_dims_to_operand_dims", "index_vector_dim"])
"scatter_dims_to_operand_dims", "index_vector_dim"])):
"""
Describes the dimension number arguments to an `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_. See the XLA
documentation for more details of what the dimension numbers mean.
Args:
update_window_dims: the set of dimensions in the `updates` that are window
dimensions. Must be a tuple of integers in ascending
order, each representing a dimension number.
inserted_window_dims: the set of size 1 window dimensions that must be inserted
into the shape of `updates`. Must be a tuple of integers in ascending
order, each representing a dimension number of the output. These are the
mirror image of `collapsed_slice_dims` in the case of `gather`.
scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives
the corresponding dimension in `operand`.
index_vector_dim: describes which dimension of `scatter_indices` "contains"
the start indices. If equal to `len(scatter_indices)` the indices are
taken to be scalars.
"""

def _scatter_dimensions_proto(dimension_numbers):
assert type(dimension_numbers) is ScatterDimensionNumbers
Expand All @@ -2523,7 +2602,7 @@ def _scatter_dimensions_proto(dimension_numbers):

def _scatter_dtype_rule(operand, scatter_indices, updates, **kwargs):
if not onp.issubdtype(scatter_indices.dtype, onp.integer):
raise ValueError("start_indices must have an integer type")
raise ValueError("scatter_indices must have an integer type")
_check_same_dtypes("scatter", False, operand.dtype, updates.dtype)
return xla_bridge.canonicalize_dtype(operand.dtype)

Expand Down

0 comments on commit e2681ab

Please sign in to comment.