Skip to content

Commit

Permalink
[tensorflow] Add support for the is_symbolic_tensor predicate
Browse files Browse the repository at this point in the history
This predicate will become available in tensorflow starting with version
2.14.

Co-authored-by: Russell Power <power@google.com>
  • Loading branch information
Roy Hvaara and rjpower committed Apr 20, 2023
1 parent 474bf50 commit 589494a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
is_offline_mode,
is_remote_url,
is_safetensors_available,
is_tf_symbolic_tensor,
logging,
requires_backends,
working_or_temp_dir,
Expand Down Expand Up @@ -511,7 +512,7 @@ def input_processing(func, config, **kwargs):
if isinstance(main_input, (tuple, list)):
for i, input in enumerate(main_input):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor:
if is_tf_symbolic_tensor(input):
# Tensor names have always the pattern `name:id` then we check only the
# `name` part
tensor_name = input.name.split(":")[0]
Expand Down Expand Up @@ -572,7 +573,7 @@ def input_processing(func, config, **kwargs):
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
# So to respect the proper output we have to add this exception
if "args" in output:
if output["args"] is not None and type(output["args"]) == tf.Tensor:
if output["args"] is not None and is_tf_symbolic_tensor(output["args"]):
tensor_name = output["args"].name.split(":")[0]
output[tensor_name] = output["args"]
else:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
is_numpy_array,
is_tensor,
is_tf_tensor,
is_tf_symbolic_tensor,
is_torch_device,
is_torch_dtype,
is_torch_tensor,
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,22 @@ def is_tf_tensor(x):
return False if not is_tf_available() else _is_tensorflow(x)


def _is_tf_symbolic_tensor(x):
import tensorflow as tf

# the `is_symbolic_tensor` predicate is only available starting with TF 2.14
if hasattr(tf, 'is_symbolic_tensor'):
return tf.is_symbolic_tensor(x)
return type(x) == tf.Tensor


def is_tf_symbolic_tensor(x):
"""
Tests if `x` is a tensorflow symbolic tensor or not (ie. not eager). Safe to call even if tensorflow is not installed.
"""
return False if not is_tf_available() else _is_tf_symbolic_tensor(x)


def _is_jax(x):
import jax.numpy as jnp # noqa: F811

Expand Down

0 comments on commit 589494a

Please sign in to comment.