-
Notifications
You must be signed in to change notification settings - Fork 494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[API] TensorFlow Graph Transformer #39
Closed
Closed
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
42c6e6e
flat param API impl
phi-dbq ecbefb9
support input graph scenarios
phi-dbq ab89bd2
(WIP) new interface implementation
phi-dbq 8c7d72e
docs and cleanup
phi-dbq eb543c6
using tensorflow API instead of our utilities
phi-dbq 4743bb9
automatic type conversion
phi-dbq 622c788
cleanup
phi-dbq 07f1cec
PR comments
phi-dbq 692b0eb
(WIP) address comments
phi-dbq 66d44e9
(WIP) respond to PR comments
phi-dbq 9b3fe86
test refactor
phi-dbq 8c32501
Merge remote-tracking branch 'upstream/master' into tf-1d-transformer
phi-dbq dbd9aaa
(wip) consolidating params
phi-dbq 4572205
rebase upstream
phi-dbq 1cc7591
import params fix
phi-dbq 2fc6787
(wip) TFInputGraph impl
phi-dbq 889df0a
(wip) moving to new API
phi-dbq 86cd6d9
(wip) enable saved_model tests
phi-dbq ac09182
(wip) enable checkpoint test
phi-dbq 6b22eed
(wip) enable multiple tensor tests
phi-dbq a3517d6
enable all tests
phi-dbq 457a4c2
params and converters
phi-dbq 323939a
tests
phi-dbq 6e46073
Merge branch 'tf-transformer-part1' into api-tf-transformer
phi-dbq b232b3c
optimize graph for inference
phi-dbq File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
# Copyright 2017 Databricks, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import tensorflow as tf | ||
from tensorflow.core.protobuf import meta_graph_pb2 | ||
|
||
import sparkdl.graph.utils as tfx | ||
|
||
__all__ = ["TFInputGraph"] | ||
|
||
class TFInputGraph(object): | ||
""" | ||
An opaque serializable object containing TensorFlow graph. | ||
|
||
[WARNING] This class should not be called by any user code. | ||
""" | ||
def __init__(self): | ||
raise NotImplementedError( | ||
"Please do NOT construct TFInputGraph directly. Instead, use one of the helper functions") | ||
|
||
@classmethod | ||
def _new_obj_internal(cls): | ||
# pylint: disable=attribute-defined-outside-init | ||
obj = object.__new__(cls) | ||
# TODO: for (de-)serialization, the class should correspond to a ProtocolBuffer definition. | ||
obj.graph_def = None | ||
obj.input_tensor_name_from_signature = None | ||
obj.output_tensor_name_from_signature = None | ||
return obj | ||
|
||
def translateInputMapping(self, input_mapping): | ||
assert self.input_tensor_name_from_signature is not None | ||
_input_mapping = {} | ||
if isinstance(input_mapping, dict): | ||
input_mapping = list(input_mapping.items()) | ||
assert isinstance(input_mapping, list) | ||
for col_name, sig_key in input_mapping: | ||
tnsr_name = self.input_tensor_name_from_signature[sig_key] | ||
_input_mapping[col_name] = tnsr_name | ||
return _input_mapping | ||
|
||
def translateOutputMapping(self, output_mapping): | ||
assert self.output_tensor_name_from_signature is not None | ||
_output_mapping = {} | ||
if isinstance(output_mapping, dict): | ||
output_mapping = list(output_mapping.items()) | ||
assert isinstance(output_mapping, list) | ||
for sig_key, col_name in output_mapping: | ||
tnsr_name = self.output_tensor_name_from_signature[sig_key] | ||
_output_mapping[tnsr_name] = col_name | ||
return _output_mapping | ||
|
||
@classmethod | ||
def fromGraph(cls, graph, sess, feed_names, fetch_names): | ||
""" | ||
Construct a TFInputGraphBuilder from a in memory tf.Graph object | ||
""" | ||
assert isinstance(graph, tf.Graph), \ | ||
('expect tf.Graph type but got', type(graph)) | ||
|
||
def import_graph_fn(_sess): | ||
assert _sess == sess, 'must have the same session' | ||
return _GinBuilderInfo() | ||
|
||
return _GinBuilder(import_graph_fn, sess, graph).build(feed_names, fetch_names) | ||
|
||
@classmethod | ||
def fromGraphDef(cls, graph_def, feed_names, fetch_names): | ||
""" | ||
Construct a TFInputGraphBuilder from a tf.GraphDef object | ||
""" | ||
assert isinstance(graph_def, tf.GraphDef), \ | ||
('expect tf.GraphDef type but got', type(graph_def)) | ||
|
||
def import_graph_fn(sess): | ||
with sess.as_default(): | ||
tf.import_graph_def(graph_def, name='') | ||
return _GinBuilderInfo() | ||
|
||
return _GinBuilder(import_graph_fn).build(feed_names, fetch_names) | ||
|
||
@classmethod | ||
def fromCheckpoint(cls, checkpoint_dir, feed_names, fetch_names): | ||
return cls._from_checkpoint_impl(checkpoint_dir, | ||
signature_def_key=None, | ||
feed_names=feed_names, fetch_names=fetch_names) | ||
|
||
@classmethod | ||
def fromCheckpointWithSignature(cls, checkpoint_dir, signature_def_key): | ||
assert signature_def_key is not None | ||
return cls._from_checkpoint_impl(checkpoint_dir, | ||
signature_def_key, | ||
feed_names=None, fetch_names=None) | ||
|
||
@classmethod | ||
def fromSavedModel(cls, saved_model_dir, tag_set, feed_names, fetch_names): | ||
return cls._from_saved_model_impl(saved_model_dir, tag_set, | ||
signature_def_key=None, | ||
feed_names=feed_names, fetch_names=fetch_names) | ||
|
||
@classmethod | ||
def fromSavedModelWithSignature(cls, saved_model_dir, tag_set, signature_def_key): | ||
assert signature_def_key is not None | ||
return cls._from_saved_model_impl(saved_model_dir, tag_set, | ||
signature_def_key=signature_def_key, | ||
feed_names=None, fetch_names=None) | ||
|
||
@classmethod | ||
def _from_checkpoint_impl(cls, | ||
checkpoint_dir, | ||
signature_def_key=None, | ||
feed_names=None, | ||
fetch_names=None): | ||
""" | ||
Construct a TFInputGraphBuilder from a model checkpoint | ||
""" | ||
assert (feed_names is None) == (fetch_names is None), \ | ||
'feed_names and fetch_names, if provided must appear together' | ||
assert (feed_names is None) != (signature_def_key is None), \ | ||
'must either provide feed_names or singnature_def_key' | ||
|
||
def import_graph_fn(sess): | ||
# Load checkpoint and import the graph | ||
with sess.as_default(): | ||
ckpt_path = tf.train.latest_checkpoint(checkpoint_dir) | ||
|
||
# NOTE(phi-dbq): we must manually load meta_graph_def to get the signature_def | ||
# the current `import_graph_def` function seems to ignore | ||
# any signature_def fields in a checkpoint's meta_graph_def. | ||
meta_graph_def = meta_graph_pb2.MetaGraphDef() | ||
with open("{}.meta".format(ckpt_path), 'rb') as fin: | ||
meta_graph_def.ParseFromString(fin.read()) | ||
|
||
saver = tf.train.import_meta_graph(meta_graph_def, clear_devices=True) | ||
saver.restore(sess, ckpt_path) | ||
|
||
sig_def = None | ||
if signature_def_key is not None: | ||
sig_def = meta_graph_def.signature_def[signature_def_key] | ||
assert sig_def, 'singnature_def_key {} provided, '.format(signature_def_key) + \ | ||
'but failed to find it from the meta_graph_def ' + \ | ||
'from checkpoint {}'.format(checkpoint_dir) | ||
|
||
return _GinBuilderInfo(sig_def=sig_def) | ||
|
||
return _GinBuilder(import_graph_fn).build(feed_names, fetch_names) | ||
|
||
@classmethod | ||
def _from_saved_model_impl(cls, saved_model_dir, tag_set, | ||
signature_def_key=None, | ||
feed_names=None, | ||
fetch_names=None): | ||
""" | ||
Construct a TFInputGraphBuilder from a SavedModel | ||
""" | ||
assert (feed_names is None) == (fetch_names is None), \ | ||
'feed_names and fetch_names, if provided must appear together' | ||
assert (feed_names is None) != (signature_def_key is None), \ | ||
'must either provide feed_names or singnature_def_key' | ||
|
||
def import_graph_fn(sess): | ||
tag_sets = tag_set.split(',') | ||
meta_graph_def = tf.saved_model.loader.load(sess, tag_sets, saved_model_dir) | ||
|
||
sig_def = None | ||
if signature_def_key is not None: | ||
sig_def = tf.contrib.saved_model.get_signature_def_by_key( | ||
meta_graph_def, signature_def_key) | ||
|
||
return _GinBuilderInfo(sig_def=sig_def) | ||
|
||
return _GinBuilder(import_graph_fn).build(feed_names, fetch_names) | ||
|
||
|
||
class _GinBuilderInfo(object): | ||
def __init__(self, sig_def=None): | ||
self.sig_def = sig_def | ||
self.feed_names = None | ||
self.feed_mapping = None | ||
self.fetch_names = None | ||
self.fetch_mapping = None | ||
|
||
def extract_signatures(self): | ||
assert self.sig_def is not None, \ | ||
"ask to find sigdef mapping, but not found any" | ||
|
||
self.feed_mapping = {} | ||
self.feed_names = [] | ||
for sigdef_key, tnsr_info in self.sig_def.inputs.items(): | ||
tnsr_name = tnsr_info.name | ||
self.feed_mapping[sigdef_key] = tnsr_name | ||
self.feed_names.append(tnsr_name) | ||
|
||
self.fetch_mapping = {} | ||
self.fetch_names = [] | ||
for sigdef_key, tnsr_info in self.sig_def.outputs.items(): | ||
tnsr_name = tnsr_info.name | ||
self.fetch_mapping[sigdef_key] = tnsr_name | ||
self.fetch_names.append(tnsr_name) | ||
|
||
class _GinBuilder(object): | ||
def __init__(self, import_graph_fn, sess=None, graph=None): | ||
self.import_graph_fn = import_graph_fn | ||
assert (sess is None) == (graph is None) | ||
if sess is not None: | ||
self.graph = graph | ||
self.sess = sess | ||
self._should_clean = False | ||
else: | ||
self.graph = tf.Graph() | ||
self.sess = tf.Session(graph=self.graph) | ||
self._should_clean = True | ||
|
||
def _build_impl(self, feed_names, fetch_names): | ||
# pylint: disable=protected-access,attribute-defined-outside-init | ||
gin = TFInputGraph._new_obj_internal() | ||
assert (feed_names is None) == (fetch_names is None) | ||
must_have_sig_def = fetch_names is None | ||
# NOTE(phi-dbq): both have to be set to default | ||
with self.sess.as_default(), self.graph.as_default(): | ||
_ginfo = self.import_graph_fn(self.sess) | ||
if must_have_sig_def: | ||
_ginfo.extract_signatures() | ||
feed_names = _ginfo.feed_names | ||
fetch_names = _ginfo.fetch_names | ||
gin.input_tensor_name_from_signature = _ginfo.feed_mapping | ||
gin.output_tensor_name_from_signature = _ginfo.fetch_mapping | ||
|
||
for tnsr_name in feed_names: | ||
assert tfx.get_op(self.graph, tnsr_name) | ||
fetches = [tfx.get_tensor(self.graph, tnsr_name) for tnsr_name in fetch_names] | ||
gin.graph_def = tfx.strip_and_freeze_until(fetches, self.graph, self.sess) | ||
return gin | ||
|
||
def build(self, feed_names=None, fetch_names=None): | ||
try: | ||
gin = self._build_impl(feed_names, fetch_names) | ||
finally: | ||
if self._should_clean: | ||
self.sess.close() | ||
return gin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,31 +95,49 @@ def get_tensor(graph, tfobj_or_name): | |
'cannot locate tensor {} in current graph'.format(_tensor_name) | ||
return tnsr | ||
|
||
def as_tensor_name(name): | ||
def as_tensor_name(tfobj_or_name): | ||
""" | ||
Derive tf.Tensor name from an op/tensor name. | ||
We do not check if the tensor exist (as no graph parameter is passed in). | ||
If the input is a name, we do not check if the tensor exist | ||
(as no graph parameter is passed in). | ||
|
||
:param name: op name or tensor name | ||
:param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either | ||
""" | ||
assert isinstance(name, six.string_types) | ||
name_parts = name.split(":") | ||
assert len(name_parts) <= 2, name_parts | ||
if len(name_parts) < 2: | ||
name += ":0" | ||
return name | ||
if isinstance(tfobj_or_name, six.string_types): | ||
# If input is a string, assume it is a name and infer the corresponding tensor name. | ||
# WARNING: this depends on TensorFlow's tensor naming convention | ||
name = tfobj_or_name | ||
name_parts = name.split(":") | ||
assert len(name_parts) <= 2, name_parts | ||
if len(name_parts) < 2: | ||
name += ":0" | ||
return name | ||
elif hasattr(tfobj_or_name, 'graph'): | ||
tfobj = tfobj_or_name | ||
return get_tensor(tfobj.graph, tfobj).name | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
else: | ||
raise TypeError('invalid tf.Tensor name query type {}'.format(type(tfobj_or_name))) | ||
|
||
def as_op_name(name): | ||
def as_op_name(tfobj_or_name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have 5 functions that compute some names in this file, this is at least two too many. |
||
""" | ||
Derive tf.Operation name from an op/tensor name | ||
We do not check if the operation exist (as no graph parameter is passed in). | ||
Derive tf.Operation name from an op/tensor name. | ||
If the input is a name, we do not check if the operation exist | ||
(as no graph parameter is passed in). | ||
|
||
:param name: op name or tensor name | ||
:param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either | ||
""" | ||
assert isinstance(name, six.string_types) | ||
name_parts = name.split(":") | ||
assert len(name_parts) <= 2, name_parts | ||
return name_parts[0] | ||
if isinstance(tfobj_or_name, six.string_types): | ||
# If input is a string, assume it is a name and infer the corresponding operation name. | ||
# WARNING: this depends on TensorFlow's operation naming convention | ||
name = tfobj_or_name | ||
name_parts = name.split(":") | ||
assert len(name_parts) <= 2, name_parts | ||
return name_parts[0] | ||
elif hasattr(tfobj_or_name, 'graph'): | ||
tfobj = tfobj_or_name | ||
return get_op(tfobj.graph, tfobj).name | ||
else: | ||
raise TypeError('invalid tf.Operation name query type {}'.format(type(tfobj_or_name))) | ||
|
||
def op_name(graph, tfobj_or_name): | ||
""" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are multiple ways to get tensor/operation names in this utility module.
We should definitely consolidate the functions
as_tensor_name
andtensor_name
.Right now, the semantic of the two functions are:
as_tensor_name
: return the tensor's name of the input (without necessarily check if it exists in any graph).tensor_name
: return the tensor's name only if the input exists in the given graph.