From 4fba0e787f464cc3e68358154f88b1ac13453667 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 15 Mar 2022 07:50:17 -0700 Subject: [PATCH] [JAX] Update ann to use XLA based fallback ApproxTopK. Other small changes: * Restricts the operand type to float. * Add more format annotations to the docstring. PiperOrigin-RevId: 434749705 --- jax/_src/lax/ann.py | 101 +++++++++++++++++++++++--------------------- 1 file changed, 54 insertions(+), 47 deletions(-) diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index b8126b482abc..3c699dfaaaa9 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -92,15 +92,15 @@ def approx_max_k(operand: Array, """Returns max ``k`` values and their indices of the ``operand`` in an approximate manner. Args: - operand : Array to search for max-k. + operand : Array to search for max-k. Must be a floating number type. k : Specifies the number of max-k. reduction_dimension : Integer dimension along which to search. Default: -1. recall_target : Recall target for the approximation. reduction_input_size_override : When set to a positive value, it overrides - the size determined by ``operands[reduction_dim]`` for evaluating the - recall. This option is useful when the given operand is only a subset of - the overall computation in SPMD or distributed pipelines, where the true - input size cannot be deferred by the operand shape. + the size determined by ``operand[reduction_dim]`` for evaluating the + recall. This option is useful when the given ``operand`` is only a subset + of the overall computation in SPMD or distributed pipelines, where the + true input size cannot be deferred by the operand shape. aggregate_to_topk : When true, aggregates approximate results to top-k. When false, returns the approximate results. The number of the approximate results is implementation defined and is greater equals to the specified @@ -108,14 +108,14 @@ def approx_max_k(operand: Array, Returns: Tuple of two arrays. The arrays are the max ``k`` values and the - corresponding indices along the reduction_dimension of the input operand. - The arrays' dimensions are the same as the input operand except for the - ``reduction_dimension``: when ``aggregate_to_topk`` is true, the reduction - dimension is ``k``; otherwise, it is greater equals to k where the size is - implementation-defined. + corresponding indices along the ``reduction_dimension`` of the input + ``operand``. The arrays' dimensions are the same as the input ``operand`` + except for the ``reduction_dimension``: when ``aggregate_to_topk`` is true, + the reduction dimension is ``k``; otherwise, it is greater equals to ``k`` + where the size is implementation-defined. - We encourage users to wrap the approx_*_k with jit. See the following example - for maximal inner production search (MIPS): + We encourage users to wrap ``approx_max_k`` with jit. See the following + example for maximal inner production search (MIPS): >>> import functools >>> import jax @@ -151,15 +151,15 @@ def approx_min_k(operand: Array, """Returns min ``k`` values and their indices of the ``operand`` in an approximate manner. Args: - operand : Array to search for min-k. + operand : Array to search for min-k. Must be a floating number type. k : Specifies the number of min-k. reduction_dimension: Integer dimension along which to search. Default: -1. recall_target: Recall target for the approximation. reduction_input_size_override : When set to a positive value, it overrides - the size determined by ``operands[reduction_dim]`` for evaluating the - recall. This option is useful when the given operand is only a subset of + the size determined by ``operand[reduction_dim]`` for evaluating the + recall. This option is useful when the given operand is only a subset of the overall computation in SPMD or distributed pipelines, where the true - input size cannot be deferred by the operand shape. + input size cannot be deferred by the ``operand`` shape. aggregate_to_topk: When true, aggregates approximate results to top-k. When false, returns the approximate results. The number of the approximate results is implementation defined and is greater equals to the specified @@ -167,13 +167,13 @@ def approx_min_k(operand: Array, Returns: Tuple of two arrays. The arrays are the least ``k`` values and the - corresponding indices along the reduction_dimension of the input operand. - The arrays' dimensions are the same as the input operand except for the - ``reduction_dimension``: when ``aggregate_to_topk`` is true, the reduction - dimension is ``k``; otherwise, it is greater equals to k where the size is - implementation-defined. + corresponding indices along the ``reduction_dimension`` of the input + ``operand``. The arrays' dimensions are the same as the input ``operand`` + except for the ``reduction_dimension``: when ``aggregate_to_topk`` is true, + the reduction dimension is ``k``; otherwise, it is greater equals to ``k`` + where the size is implementation-defined. - We encourage users to wrap the approx_*_k with jit. See the following example + We encourage users to wrap ``approx_min_k`` with jit. See the following example for nearest neighbor search over the squared l2 distance: >>> import functools @@ -189,7 +189,7 @@ def approx_min_k(operand: Array, >>> half_db_norms = jax.numpy.linalg.norm(db, axis=1) / 2 >>> dists, neighbors = l2_ann(qy, db, half_db_norms, k=10) - We compute ``db_norms/2 - dot(qy, db^T)`` instead of + In the example above, we compute ``db_norms/2 - dot(qy, db^T)`` instead of ``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less arithmetics and produces the same set of neighbors. """ @@ -219,6 +219,8 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, raise ValueError( 'k must be smaller than the size of reduction_dim {}, got {}'.format( dims[reduction_dimension], k)) + if not dtypes.issubdtype(operand.dtype, np.floating): + raise ValueError('operand must be a floating type') if xc._version >= 45: reduction_input_size = dims[reduction_dimension] dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize( @@ -231,7 +233,7 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, operand.update(shape=dims, dtype=np.dtype(np.int32))) -def _comparator_builder(operand, op_type, is_max_k): +def _comparator_builder(op_type, is_max_k): c = xc.XlaBuilder( 'top_k_{}_comparator'.format('gt' if is_max_k else 'lt')) p0 = xla.parameter(c, 0, xc.Shape.scalar_shape(op_type)) @@ -245,6 +247,9 @@ def _comparator_builder(operand, op_type, is_max_k): return c.build(cmp_result) +def _get_init_val_literal(op_type, is_max_k): + return np.array(np.NINF if is_max_k else np.Inf, dtype=op_type) + def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k, reduction_dimension, recall_target, is_max_k, reduction_input_size_override, @@ -257,20 +262,11 @@ def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k, op_type = op_shape.element_type() if reduction_dimension < 0: reduction_dimension = len(op_dims) + reduction_dimension - comparator = _comparator_builder(operand, op_type, is_max_k) - if is_max_k: - if dtypes.issubdtype(op_type, np.floating): - init_literal = np.array(np.NINF, dtype=op_type) - else: - init_literal = np.iinfo(op_type).min() - else: - if dtypes.issubdtype(op_type, np.floating): - init_literal = np.array(np.Inf, dtype=op_type) - else: - init_literal = np.iinfo(op_type).max() + comparator = _comparator_builder(op_type, is_max_k) + init_val_literal = _get_init_val_literal(op_type, is_max_k) iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims), reduction_dimension) - init_val = xc.ops.Constant(c, init_literal) + init_val = xc.ops.Constant(c, init_val_literal) init_arg = xc.ops.Constant(c, np.int32(-1)) out = xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k, reduction_dimension, comparator, recall_target, @@ -288,21 +284,32 @@ def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k, raise ValueError('operand must be an array, but was {}'.format(op_shape)) op_dims = op_shape.dimensions() op_type = op_shape.element_type() + if reduction_dimension < 0: reduction_dimension = len(op_dims) + reduction_dimension - comparator = _comparator_builder(operand, op_type, is_max_k) + comparator = _comparator_builder(op_type, is_max_k) iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims), reduction_dimension) - val_arg = xc.ops.Sort(c, [operand, iota], comparator, reduction_dimension) - vals = xc.ops.GetTupleElement(val_arg, 0) - args = xc.ops.GetTupleElement(val_arg, 1) - sliced_vals = xc.ops.SliceInDim(vals, 0, - avals_out[0].shape[reduction_dimension], 1, - reduction_dimension) - sliced_args = xc.ops.SliceInDim(args, 0, - avals_out[0].shape[reduction_dimension], 1, - reduction_dimension) - return sliced_vals, sliced_args + if xc._version >= 60: + init_val_literal = _get_init_val_literal(op_type, is_max_k) + init_val = xc.ops.Constant(c, init_val_literal) + init_arg = xc.ops.Constant(c, np.int32(-1)) + out = xc.ops.ApproxTopKFallback(c, [operand, iota], [init_val, init_arg], k, + reduction_dimension, comparator, + recall_target, aggregate_to_topk, + reduction_input_size_override) + return xla.xla_destructure(c, out) + else: + val_arg = xc.ops.Sort(c, [operand, iota], comparator, reduction_dimension) + vals = xc.ops.GetTupleElement(val_arg, 0) + args = xc.ops.GetTupleElement(val_arg, 1) + sliced_vals = xc.ops.SliceInDim(vals, 0, + avals_out[0].shape[reduction_dimension], 1, + reduction_dimension) + sliced_args = xc.ops.SliceInDim(args, 0, + avals_out[0].shape[reduction_dimension], 1, + reduction_dimension) + return sliced_vals, sliced_args def _approx_top_k_batch_rule(batched_args, batch_dims, *, k,