diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index aeeb9d3d5a8822..a6d566dc551fbe 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -57,6 +57,7 @@ is_offline_mode, is_remote_url, is_safetensors_available, + is_tf_symbolic_tensor, logging, requires_backends, working_or_temp_dir, @@ -511,7 +512,7 @@ def input_processing(func, config, **kwargs): if isinstance(main_input, (tuple, list)): for i, input in enumerate(main_input): # EagerTensors don't allow to use the .name property so we check for a real Tensor - if type(input) == tf.Tensor: + if is_tf_symbolic_tensor(input): # Tensor names have always the pattern `name:id` then we check only the # `name` part tensor_name = input.name.split(":")[0] @@ -572,7 +573,7 @@ def input_processing(func, config, **kwargs): # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs) # So to respect the proper output we have to add this exception if "args" in output: - if output["args"] is not None and type(output["args"]) == tf.Tensor: + if output["args"] is not None and is_tf_symbolic_tensor(output["args"]): tensor_name = output["args"].name.split(":")[0] output[tensor_name] = output["args"] else: diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 003b29a2db23f9..c79a09a94091d7 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -43,6 +43,7 @@ is_numpy_array, is_tensor, is_tf_tensor, + is_tf_symbolic_tensor, is_torch_device, is_torch_dtype, is_torch_tensor, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 9c620fcc0db2c1..43571e9e896905 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -166,6 +166,22 @@ def is_tf_tensor(x): return False if not is_tf_available() else _is_tensorflow(x) +def _is_tf_symbolic_tensor(x): + import tensorflow as tf + + # the `is_symbolic_tensor` predicate is only available starting with TF 2.14 + if hasattr(tf, 'is_symbolic_tensor'): + return tf.is_symbolic_tensor(x) + return type(x) == tf.Tensor + + +def is_tf_symbolic_tensor(x): + """ + Tests if `x` is a tensorflow symbolic tensor or not (ie. not eager). Safe to call even if tensorflow is not installed. + """ + return False if not is_tf_available() else _is_tf_symbolic_tensor(x) + + def _is_jax(x): import jax.numpy as jnp # noqa: F811