Skip to content

Commit

Permalink
[JAX] Update ann to use XLA based fallback ApproxTopK.
Browse files Browse the repository at this point in the history
Other small changes:
* Restricts the operand type to float.
* Add more format annotations to the docstring.

PiperOrigin-RevId: 434749705
  • Loading branch information
jax authors committed Mar 15, 2022
1 parent 6355fac commit 4fba0e7
Showing 1 changed file with 54 additions and 47 deletions.
101 changes: 54 additions & 47 deletions jax/_src/lax/ann.py
Expand Up @@ -92,30 +92,30 @@ def approx_max_k(operand: Array,
"""Returns max ``k`` values and their indices of the ``operand`` in an approximate manner.
Args:
operand : Array to search for max-k.
operand : Array to search for max-k. Must be a floating number type.
k : Specifies the number of max-k.
reduction_dimension : Integer dimension along which to search. Default: -1.
recall_target : Recall target for the approximation.
reduction_input_size_override : When set to a positive value, it overrides
the size determined by ``operands[reduction_dim]`` for evaluating the
recall. This option is useful when the given operand is only a subset of
the overall computation in SPMD or distributed pipelines, where the true
input size cannot be deferred by the operand shape.
the size determined by ``operand[reduction_dim]`` for evaluating the
recall. This option is useful when the given ``operand`` is only a subset
of the overall computation in SPMD or distributed pipelines, where the
true input size cannot be deferred by the operand shape.
aggregate_to_topk : When true, aggregates approximate results to top-k. When
false, returns the approximate results. The number of the approximate
results is implementation defined and is greater equals to the specified
``k``.
Returns:
Tuple of two arrays. The arrays are the max ``k`` values and the
corresponding indices along the reduction_dimension of the input operand.
The arrays' dimensions are the same as the input operand except for the
``reduction_dimension``: when ``aggregate_to_topk`` is true, the reduction
dimension is ``k``; otherwise, it is greater equals to k where the size is
implementation-defined.
corresponding indices along the ``reduction_dimension`` of the input
``operand``. The arrays' dimensions are the same as the input ``operand``
except for the ``reduction_dimension``: when ``aggregate_to_topk`` is true,
the reduction dimension is ``k``; otherwise, it is greater equals to ``k``
where the size is implementation-defined.
We encourage users to wrap the approx_*_k with jit. See the following example
for maximal inner production search (MIPS):
We encourage users to wrap ``approx_max_k`` with jit. See the following
example for maximal inner production search (MIPS):
>>> import functools
>>> import jax
Expand Down Expand Up @@ -151,29 +151,29 @@ def approx_min_k(operand: Array,
"""Returns min ``k`` values and their indices of the ``operand`` in an approximate manner.
Args:
operand : Array to search for min-k.
operand : Array to search for min-k. Must be a floating number type.
k : Specifies the number of min-k.
reduction_dimension: Integer dimension along which to search. Default: -1.
recall_target: Recall target for the approximation.
reduction_input_size_override : When set to a positive value, it overrides
the size determined by ``operands[reduction_dim]`` for evaluating the
recall. This option is useful when the given operand is only a subset of
the size determined by ``operand[reduction_dim]`` for evaluating the
recall. This option is useful when the given operand is only a subset of
the overall computation in SPMD or distributed pipelines, where the true
input size cannot be deferred by the operand shape.
input size cannot be deferred by the ``operand`` shape.
aggregate_to_topk: When true, aggregates approximate results to top-k. When
false, returns the approximate results. The number of the approximate
results is implementation defined and is greater equals to the specified
``k``.
Returns:
Tuple of two arrays. The arrays are the least ``k`` values and the
corresponding indices along the reduction_dimension of the input operand.
The arrays' dimensions are the same as the input operand except for the
``reduction_dimension``: when ``aggregate_to_topk`` is true, the reduction
dimension is ``k``; otherwise, it is greater equals to k where the size is
implementation-defined.
corresponding indices along the ``reduction_dimension`` of the input
``operand``. The arrays' dimensions are the same as the input ``operand``
except for the ``reduction_dimension``: when ``aggregate_to_topk`` is true,
the reduction dimension is ``k``; otherwise, it is greater equals to ``k``
where the size is implementation-defined.
We encourage users to wrap the approx_*_k with jit. See the following example
We encourage users to wrap ``approx_min_k`` with jit. See the following example
for nearest neighbor search over the squared l2 distance:
>>> import functools
Expand All @@ -189,7 +189,7 @@ def approx_min_k(operand: Array,
>>> half_db_norms = jax.numpy.linalg.norm(db, axis=1) / 2
>>> dists, neighbors = l2_ann(qy, db, half_db_norms, k=10)
We compute ``db_norms/2 - dot(qy, db^T)`` instead of
In the example above, we compute ``db_norms/2 - dot(qy, db^T)`` instead of
``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less
arithmetics and produces the same set of neighbors.
"""
Expand Down Expand Up @@ -219,6 +219,8 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
raise ValueError(
'k must be smaller than the size of reduction_dim {}, got {}'.format(
dims[reduction_dimension], k))
if not dtypes.issubdtype(operand.dtype, np.floating):
raise ValueError('operand must be a floating type')
if xc._version >= 45:
reduction_input_size = dims[reduction_dimension]
dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize(
Expand All @@ -231,7 +233,7 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
operand.update(shape=dims, dtype=np.dtype(np.int32)))


def _comparator_builder(operand, op_type, is_max_k):
def _comparator_builder(op_type, is_max_k):
c = xc.XlaBuilder(
'top_k_{}_comparator'.format('gt' if is_max_k else 'lt'))
p0 = xla.parameter(c, 0, xc.Shape.scalar_shape(op_type))
Expand All @@ -245,6 +247,9 @@ def _comparator_builder(operand, op_type, is_max_k):
return c.build(cmp_result)


def _get_init_val_literal(op_type, is_max_k):
return np.array(np.NINF if is_max_k else np.Inf, dtype=op_type)

def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k,
reduction_dimension, recall_target, is_max_k,
reduction_input_size_override,
Expand All @@ -257,20 +262,11 @@ def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k,
op_type = op_shape.element_type()
if reduction_dimension < 0:
reduction_dimension = len(op_dims) + reduction_dimension
comparator = _comparator_builder(operand, op_type, is_max_k)
if is_max_k:
if dtypes.issubdtype(op_type, np.floating):
init_literal = np.array(np.NINF, dtype=op_type)
else:
init_literal = np.iinfo(op_type).min()
else:
if dtypes.issubdtype(op_type, np.floating):
init_literal = np.array(np.Inf, dtype=op_type)
else:
init_literal = np.iinfo(op_type).max()
comparator = _comparator_builder(op_type, is_max_k)
init_val_literal = _get_init_val_literal(op_type, is_max_k)
iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims),
reduction_dimension)
init_val = xc.ops.Constant(c, init_literal)
init_val = xc.ops.Constant(c, init_val_literal)
init_arg = xc.ops.Constant(c, np.int32(-1))
out = xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k,
reduction_dimension, comparator, recall_target,
Expand All @@ -288,21 +284,32 @@ def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k,
raise ValueError('operand must be an array, but was {}'.format(op_shape))
op_dims = op_shape.dimensions()
op_type = op_shape.element_type()

if reduction_dimension < 0:
reduction_dimension = len(op_dims) + reduction_dimension
comparator = _comparator_builder(operand, op_type, is_max_k)
comparator = _comparator_builder(op_type, is_max_k)
iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims),
reduction_dimension)
val_arg = xc.ops.Sort(c, [operand, iota], comparator, reduction_dimension)
vals = xc.ops.GetTupleElement(val_arg, 0)
args = xc.ops.GetTupleElement(val_arg, 1)
sliced_vals = xc.ops.SliceInDim(vals, 0,
avals_out[0].shape[reduction_dimension], 1,
reduction_dimension)
sliced_args = xc.ops.SliceInDim(args, 0,
avals_out[0].shape[reduction_dimension], 1,
reduction_dimension)
return sliced_vals, sliced_args
if xc._version >= 60:
init_val_literal = _get_init_val_literal(op_type, is_max_k)
init_val = xc.ops.Constant(c, init_val_literal)
init_arg = xc.ops.Constant(c, np.int32(-1))
out = xc.ops.ApproxTopKFallback(c, [operand, iota], [init_val, init_arg], k,
reduction_dimension, comparator,
recall_target, aggregate_to_topk,
reduction_input_size_override)
return xla.xla_destructure(c, out)
else:
val_arg = xc.ops.Sort(c, [operand, iota], comparator, reduction_dimension)
vals = xc.ops.GetTupleElement(val_arg, 0)
args = xc.ops.GetTupleElement(val_arg, 1)
sliced_vals = xc.ops.SliceInDim(vals, 0,
avals_out[0].shape[reduction_dimension], 1,
reduction_dimension)
sliced_args = xc.ops.SliceInDim(args, 0,
avals_out[0].shape[reduction_dimension], 1,
reduction_dimension)
return sliced_vals, sliced_args


def _approx_top_k_batch_rule(batched_args, batch_dims, *, k,
Expand Down

0 comments on commit 4fba0e7

Please sign in to comment.