Skip to content

Commit

Permalink
[JAX] Adds the approx_top_k_p bridge.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 441172779
  • Loading branch information
jax authors committed Apr 12, 2022
1 parent 7be37ab commit a2c2d9a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
18 changes: 17 additions & 1 deletion jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,6 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
# Not high priority?
"after_all",
"all_to_all",
"approx_top_k",
"create_token",
"custom_transpose_call",
"custom_vmap_call",
Expand Down Expand Up @@ -2380,6 +2379,23 @@ def promote_tf_dtype(tf_dtype):
tf_impl[lax.top_k_p] = _top_k


def _approx_top_k(operand: TfVal, k: int, reduction_dimension: int,
recall_target: float, is_max_k: bool,
reduction_input_size_override: int,
aggregate_to_topk: bool) -> Tuple[TfVal, TfVal]:
if is_max_k:
return tf.math.approx_max_k(operand, k, reduction_dimension, recall_target,
reduction_input_size_override,
aggregate_to_topk)
else:
return tf.math.approx_min_k(operand, k, reduction_dimension, recall_target,
reduction_input_size_override,
aggregate_to_topk)


tf_impl[lax.approx_top_k_p] = _approx_top_k


def _sort(*operands: TfVal, dimension: int, is_stable: bool,
num_keys: int) -> Tuple[TfVal, ...]:
assert 1 <= num_keys <= len(operands)
Expand Down
10 changes: 10 additions & 0 deletions jax/experimental/jax2tf/tests/jax2tf_limitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,16 @@ def acosh(cls, harness: primitive_harness.Harness):
cls.helper_get_trig_custom_limitation(np.cosh)
]

@classmethod
def approx_max_k(cls, harness: primitive_harness.Harness):
supported_dtypes = jtu.supported_dtypes()
return Jax2TfLimitation(
"eager is not supported in CPU or GPU.",
dtypes=[t for t in [jnp.bfloat16, np.float16, np.float32]
if t in supported_dtypes],
devices=("cpu", "gpu", "tpu"),
modes=("graph", "compiled"))

@classmethod
def argmax(cls, harness: primitive_harness.Harness):
return [
Expand Down

0 comments on commit a2c2d9a

Please sign in to comment.