Skip to content

Commit

Permalink
Temp fix [onnx#2180]
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermecoutinhoJC committed Aug 16, 2023
1 parent 0152029 commit 719e876
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
10 changes: 8 additions & 2 deletions tf2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,14 @@ def wrap_call(*args, training=False, **kwargs):
tensorflow_core.python.keras.backend.learning_phase = old_get_learning_phase

# These inputs will be removed during freezing (includes resources, etc.)
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
captured_inputs = [t_name.name for t_val, t_name in graph_captures.values()]

try:
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
captured_inputs = [t_name.name for _, t_name in graph_captures.values()]
except AttributeError:
graph_captures = concrete_func.graph.function_captures.by_val_internal
captured_inputs = [t.name for t in graph_captures.values()]

input_names = [input_tensor.name for input_tensor in concrete_func.inputs
if input_tensor.name not in captured_inputs]
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
Expand Down
8 changes: 6 additions & 2 deletions tf2onnx/tf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,12 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
tensors_to_rename = {}
if input_names is None:
inputs = [tensor.name for tensor in concrete_func.inputs if tensor.dtype != tf.dtypes.resource]
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
captured_inputs = [t_name.name for _, t_name in graph_captures.values()]
try:
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
captured_inputs = [t_name.name for _, t_name in graph_captures.values()]
except AttributeError:
graph_captures = concrete_func.graph.function_captures.by_val_internal
captured_inputs = [t.name for t in graph_captures.values()]
inputs = [inp for inp in inputs if inp not in captured_inputs]
if concrete_func.structured_input_signature is not None and not use_graph_names:
flat_structured_inp = tf.nest.flatten(concrete_func.structured_input_signature)
Expand Down

0 comments on commit 719e876

Please sign in to comment.