Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in reading hash table initializers for tables with names othe… #1223

Merged
merged 1 commit into from Dec 9, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 19 additions & 12 deletions tf2onnx/tf_loader.py
Expand Up @@ -45,6 +45,12 @@ def not_implemented_tf_placeholder(*args, **kwargs):
except ImportError:
function_def_to_graph = _not_implemented_tf_placeholder('function_def_to_graph')

try:
# pylint: disable=protected-access
from tensorflow.python.saved_model.load import _RestoredResource as TfRestoredResourceType
except ImportError:
TfRestoredResourceType = None

if is_tf2():
convert_variables_to_constants = tf.compat.v1.graph_util.convert_variables_to_constants
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
Expand Down Expand Up @@ -364,18 +370,19 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d

table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
placeholder_to_table_info = {}
if hasattr(imported, '_table') and hasattr(imported._table, '_create_resource'): # pylint: disable=protected-access
# Add tables from saved_model table initializers
# pylint: disable=protected-access
initializer = imported._table._create_resource.concrete_functions[0].function_def
new_names, new_k_dtypes, new_v_dtypes = get_hash_table_info(initializer.node_def)
table_names.extend(new_names)
key_dtypes.extend(new_k_dtypes)
value_dtypes.extend(new_v_dtypes)
table_handle = id(imported._table.resource_handle)
if table_handle in removed_resource_to_placeholder and len(new_names) == 1:
table_info = (new_names[0], new_k_dtypes[0], new_v_dtypes[0])
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info
for r in imported.__dict__.values():
if isinstance(r, TfRestoredResourceType) and hasattr(r, '_create_resource') and hasattr(r, 'resource_handle'):
# Add tables from saved_model table initializers
# pylint: disable=protected-access
initializer = r._create_resource.concrete_functions[0].function_def
new_names, new_k_dtypes, new_v_dtypes = get_hash_table_info(initializer.node_def)
table_names.extend(new_names)
key_dtypes.extend(new_k_dtypes)
value_dtypes.extend(new_v_dtypes)
table_handle = id(r.resource_handle)
if table_handle in removed_resource_to_placeholder and len(new_names) == 1:
table_info = (new_names[0], new_k_dtypes[0], new_v_dtypes[0])
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info

initialized_tables = {}
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
Expand Down