Skip to content

Commit

Permalink
[JAX] Update approx_top_k doc with arxiv link.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 458258457
  • Loading branch information
dryman authored and jax authors committed Jun 30, 2022
1 parent 98ae6aa commit 61b3dc5
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions jax/_src/lax/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def approx_max_k(operand: Array,
aggregate_to_topk: bool = True) -> Tuple[Array, Array]:
"""Returns max ``k`` values and their indices of the ``operand`` in an approximate manner.
See https://arxiv.org/abs/2206.14286 for the algorithm details.
Args:
operand : Array to search for max-k. Must be a floating number type.
k : Specifies the number of max-k.
Expand Down Expand Up @@ -150,6 +152,8 @@ def approx_min_k(operand: Array,
aggregate_to_topk: bool = True) -> Tuple[Array, Array]:
"""Returns min ``k`` values and their indices of the ``operand`` in an approximate manner.
See https://arxiv.org/abs/2206.14286 for the algorithm details.
Args:
operand : Array to search for min-k. Must be a floating number type.
k : Specifies the number of min-k.
Expand Down

0 comments on commit 61b3dc5

Please sign in to comment.