Skip to content

Commit

Permalink
Allow cumulative sums to be int32 in jax2tf
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595808939
  • Loading branch information
jax authors committed Jan 8, 2024
1 parent 69788d1 commit 32e9068
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion jax/experimental/jax2tf/impl_no_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,11 @@ def _validate_reduce_window_inputs(operand_shape, computation_name, dtype,
# tf.math.reduce_min.
raise _reduce_error(f"Min pool does not support operands of type {dtype}")
if computation_name == "add" and dtype not in [
tf.float16, tf.float32, tf.float64
tf.float16,
tf.float32,
tf.float64,
tf.int16,
tf.int32,
]:
raise _reduce_error("Add pooling does not support operands of type "
f"{dtype}")
Expand Down Expand Up @@ -653,13 +657,22 @@ def tf_pool(inputs, pooling_type):
raise NotImplementedError(
f"TODO: use tf.nn.pool with dynamic shapes¨{window_dimensions=} "
f" {window_strides=} {dilations=}")
# tf.nn.pool() currently does not suport tf.int32 and so we cast back and
# forth in order to be able to convert.
if (inputs.dtype in [tf.int16, tf.int32]) and computation_name == "add":
original_dtype = inputs.dtype
inputs = tf.cast(inputs, dtype=tf.float32)
else:
original_dtype = None
result = tf.nn.pool(
inputs,
window_shape=window_dimensions,
pooling_type=pooling_type,
padding=padding_type,
strides=window_strides,
dilations=dilations)
if original_dtype:
result = tf.cast(result, dtype=original_dtype)

if has_only_spatial_dims:
# If the input only had spatial dimensions we need to contract the batch
Expand Down

0 comments on commit 32e9068

Please sign in to comment.