Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only load specified signatures when loading saved models in tf1 #1409

Merged
merged 3 commits into from
Mar 19, 2021
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions tf2onnx/tf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,29 +231,32 @@ def from_checkpoint(model_path, input_names, output_names):
return frozen_graph, input_names, output_names


def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signatures):
def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signature_names):
"""Load tensorflow graph from saved_model."""

wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
wrn_empty_tag = "'--tag' value is empty string. Using tag =[[]]"
wrn_empty_tag = "'--tag' value is empty string. Using tags = []"
wrn_empty_sig = "'--signature_def' not provided. Using all signatures."

if tag is None:
tag = [tf.saved_model.tag_constants.SERVING]
logger.warning(wrn_no_tag)

if not signature_names:
logger.warning(wrn_empty_sig)

if tag == '':
tag = [[]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TF1 complains that list is an unhashable type. We actually want to specify an empty set, which is set([]) not set([[]])

tag = []
logger.warning(wrn_empty_tag)

if not isinstance(tag, list):
tag = [tag]

imported = tf.saved_model.loader.load(sess, tag, model_path)
signatures = []
for k in imported.signature_def.keys():
if k.startswith("_"):
# consider signatures starting with '_' private
continue
signatures.append(k)
if k in signature_names or (not signature_names and not k.startswith("_")):
signatures.append(k)
try:
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
# pylint: disable=unnecessary-lambda
Expand Down