Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API] TensorFlow Graph Transformer #39

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/sparkdl/__init__.py
Expand Up @@ -13,15 +13,17 @@
# limitations under the License.
#

from .graph.input import TFInputGraph
from .image.imageIO import imageSchema, imageType, readImages
from .transformers.keras_image import KerasImageFileTransformer
from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer
from .transformers.tf_image import TFImageTransformer
from .transformers.tf_tensor import TFTransformer
from .transformers.utils import imageInputPlaceholder


__all__ = [
'imageSchema', 'imageType', 'readImages',
'TFImageTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer',
'KerasImageFileTransformer',
'TFImageTransformer', 'TFInputGraph', 'TFTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer',
'imageInputPlaceholder']
8 changes: 4 additions & 4 deletions python/sparkdl/graph/builder.py
Expand Up @@ -47,19 +47,20 @@ def __init__(self, graph=None, using_keras=False):
self.graph = graph or tf.Graph()
self.sess = tf.Session(graph=self.graph)
if using_keras:
self.using_keras = True
self.keras_prev_sess = K.get_session()
else:
self.using_keras = False
self.keras_prev_sess = None

def __enter__(self):
self.sess.as_default()
self.sess.__enter__()
if self.keras_prev_sess is not None:
if self.using_keras:
K.set_session(self.sess)
return self

def __exit__(self, *args):
if self.keras_prev_sess is not None:
if self.using_keras:
K.set_session(self.keras_prev_sess)
self.sess.__exit__(*args)

Expand Down Expand Up @@ -268,4 +269,3 @@ def fromList(cls, functions):
gfn = issn.asGraphFunction(first_inputs, last_outputs)

return gfn

254 changes: 254 additions & 0 deletions python/sparkdl/graph/input.py
@@ -0,0 +1,254 @@
# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import, division, print_function

import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2

import sparkdl.graph.utils as tfx

__all__ = ["TFInputGraph"]

class TFInputGraph(object):
"""
An opaque serializable object containing TensorFlow graph.

[WARNING] This class should not be called by any user code.
"""
def __init__(self):
raise NotImplementedError(
"Please do NOT construct TFInputGraph directly. Instead, use one of the helper functions")

@classmethod
def _new_obj_internal(cls):
# pylint: disable=attribute-defined-outside-init
obj = object.__new__(cls)
# TODO: for (de-)serialization, the class should correspond to a ProtocolBuffer definition.
obj.graph_def = None
obj.input_tensor_name_from_signature = None
obj.output_tensor_name_from_signature = None
return obj

def translateInputMapping(self, input_mapping):
assert self.input_tensor_name_from_signature is not None
_input_mapping = {}
if isinstance(input_mapping, dict):
input_mapping = list(input_mapping.items())
assert isinstance(input_mapping, list)
for col_name, sig_key in input_mapping:
tnsr_name = self.input_tensor_name_from_signature[sig_key]
_input_mapping[col_name] = tnsr_name
return _input_mapping

def translateOutputMapping(self, output_mapping):
assert self.output_tensor_name_from_signature is not None
_output_mapping = {}
if isinstance(output_mapping, dict):
output_mapping = list(output_mapping.items())
assert isinstance(output_mapping, list)
for sig_key, col_name in output_mapping:
tnsr_name = self.output_tensor_name_from_signature[sig_key]
_output_mapping[tnsr_name] = col_name
return _output_mapping

@classmethod
def fromGraph(cls, graph, sess, feed_names, fetch_names):
"""
Construct a TFInputGraphBuilder from a in memory tf.Graph object
"""
assert isinstance(graph, tf.Graph), \
('expect tf.Graph type but got', type(graph))

def import_graph_fn(_sess):
assert _sess == sess, 'must have the same session'
return _GinBuilderInfo()

return _GinBuilder(import_graph_fn, sess, graph).build(feed_names, fetch_names)

@classmethod
def fromGraphDef(cls, graph_def, feed_names, fetch_names):
"""
Construct a TFInputGraphBuilder from a tf.GraphDef object
"""
assert isinstance(graph_def, tf.GraphDef), \
('expect tf.GraphDef type but got', type(graph_def))

def import_graph_fn(sess):
with sess.as_default():
tf.import_graph_def(graph_def, name='')
return _GinBuilderInfo()

return _GinBuilder(import_graph_fn).build(feed_names, fetch_names)

@classmethod
def fromCheckpoint(cls, checkpoint_dir, feed_names, fetch_names):
return cls._from_checkpoint_impl(checkpoint_dir,
signature_def_key=None,
feed_names=feed_names, fetch_names=fetch_names)

@classmethod
def fromCheckpointWithSignature(cls, checkpoint_dir, signature_def_key):
assert signature_def_key is not None
return cls._from_checkpoint_impl(checkpoint_dir,
signature_def_key,
feed_names=None, fetch_names=None)

@classmethod
def fromSavedModel(cls, saved_model_dir, tag_set, feed_names, fetch_names):
return cls._from_saved_model_impl(saved_model_dir, tag_set,
signature_def_key=None,
feed_names=feed_names, fetch_names=fetch_names)

@classmethod
def fromSavedModelWithSignature(cls, saved_model_dir, tag_set, signature_def_key):
assert signature_def_key is not None
return cls._from_saved_model_impl(saved_model_dir, tag_set,
signature_def_key=signature_def_key,
feed_names=None, fetch_names=None)

@classmethod
def _from_checkpoint_impl(cls,
checkpoint_dir,
signature_def_key=None,
feed_names=None,
fetch_names=None):
"""
Construct a TFInputGraphBuilder from a model checkpoint
"""
assert (feed_names is None) == (fetch_names is None), \
'feed_names and fetch_names, if provided must appear together'
assert (feed_names is None) != (signature_def_key is None), \
'must either provide feed_names or singnature_def_key'

def import_graph_fn(sess):
# Load checkpoint and import the graph
with sess.as_default():
ckpt_path = tf.train.latest_checkpoint(checkpoint_dir)

# NOTE(phi-dbq): we must manually load meta_graph_def to get the signature_def
# the current `import_graph_def` function seems to ignore
# any signature_def fields in a checkpoint's meta_graph_def.
meta_graph_def = meta_graph_pb2.MetaGraphDef()
with open("{}.meta".format(ckpt_path), 'rb') as fin:
meta_graph_def.ParseFromString(fin.read())

saver = tf.train.import_meta_graph(meta_graph_def, clear_devices=True)
saver.restore(sess, ckpt_path)

sig_def = None
if signature_def_key is not None:
sig_def = meta_graph_def.signature_def[signature_def_key]
assert sig_def, 'singnature_def_key {} provided, '.format(signature_def_key) + \
'but failed to find it from the meta_graph_def ' + \
'from checkpoint {}'.format(checkpoint_dir)

return _GinBuilderInfo(sig_def=sig_def)

return _GinBuilder(import_graph_fn).build(feed_names, fetch_names)

@classmethod
def _from_saved_model_impl(cls, saved_model_dir, tag_set,
signature_def_key=None,
feed_names=None,
fetch_names=None):
"""
Construct a TFInputGraphBuilder from a SavedModel
"""
assert (feed_names is None) == (fetch_names is None), \
'feed_names and fetch_names, if provided must appear together'
assert (feed_names is None) != (signature_def_key is None), \
'must either provide feed_names or singnature_def_key'

def import_graph_fn(sess):
tag_sets = tag_set.split(',')
meta_graph_def = tf.saved_model.loader.load(sess, tag_sets, saved_model_dir)

sig_def = None
if signature_def_key is not None:
sig_def = tf.contrib.saved_model.get_signature_def_by_key(
meta_graph_def, signature_def_key)

return _GinBuilderInfo(sig_def=sig_def)

return _GinBuilder(import_graph_fn).build(feed_names, fetch_names)


class _GinBuilderInfo(object):
def __init__(self, sig_def=None):
self.sig_def = sig_def
self.feed_names = None
self.feed_mapping = None
self.fetch_names = None
self.fetch_mapping = None

def extract_signatures(self):
assert self.sig_def is not None, \
"ask to find sigdef mapping, but not found any"

self.feed_mapping = {}
self.feed_names = []
for sigdef_key, tnsr_info in self.sig_def.inputs.items():
tnsr_name = tnsr_info.name
self.feed_mapping[sigdef_key] = tnsr_name
self.feed_names.append(tnsr_name)

self.fetch_mapping = {}
self.fetch_names = []
for sigdef_key, tnsr_info in self.sig_def.outputs.items():
tnsr_name = tnsr_info.name
self.fetch_mapping[sigdef_key] = tnsr_name
self.fetch_names.append(tnsr_name)

class _GinBuilder(object):
def __init__(self, import_graph_fn, sess=None, graph=None):
self.import_graph_fn = import_graph_fn
assert (sess is None) == (graph is None)
if sess is not None:
self.graph = graph
self.sess = sess
self._should_clean = False
else:
self.graph = tf.Graph()
self.sess = tf.Session(graph=self.graph)
self._should_clean = True

def _build_impl(self, feed_names, fetch_names):
# pylint: disable=protected-access,attribute-defined-outside-init
gin = TFInputGraph._new_obj_internal()
assert (feed_names is None) == (fetch_names is None)
must_have_sig_def = fetch_names is None
# NOTE(phi-dbq): both have to be set to default
with self.sess.as_default(), self.graph.as_default():
_ginfo = self.import_graph_fn(self.sess)
if must_have_sig_def:
_ginfo.extract_signatures()
feed_names = _ginfo.feed_names
fetch_names = _ginfo.fetch_names
gin.input_tensor_name_from_signature = _ginfo.feed_mapping
gin.output_tensor_name_from_signature = _ginfo.fetch_mapping

for tnsr_name in feed_names:
assert tfx.get_op(self.graph, tnsr_name)
fetches = [tfx.get_tensor(self.graph, tnsr_name) for tnsr_name in fetch_names]
gin.graph_def = tfx.strip_and_freeze_until(fetches, self.graph, self.sess)
return gin

def build(self, feed_names=None, fetch_names=None):
try:
gin = self._build_impl(feed_names, fetch_names)
finally:
if self._should_clean:
self.sess.close()
return gin
52 changes: 35 additions & 17 deletions python/sparkdl/graph/utils.py
Expand Up @@ -95,31 +95,49 @@ def get_tensor(graph, tfobj_or_name):
'cannot locate tensor {} in current graph'.format(_tensor_name)
return tnsr

def as_tensor_name(name):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are multiple ways to get tensor/operation names in this utility module.
We should definitely consolidate the functions as_tensor_name and tensor_name.
Right now, the semantic of the two functions are:

  1. as_tensor_name: return the tensor's name of the input (without necessarily check if it exists in any graph).
  2. tensor_name: return the tensor's name only if the input exists in the given graph.

def as_tensor_name(tfobj_or_name):
"""
Derive tf.Tensor name from an op/tensor name.
We do not check if the tensor exist (as no graph parameter is passed in).
If the input is a name, we do not check if the tensor exist
(as no graph parameter is passed in).

:param name: op name or tensor name
:param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
"""
assert isinstance(name, six.string_types)
name_parts = name.split(":")
assert len(name_parts) <= 2, name_parts
if len(name_parts) < 2:
name += ":0"
return name
if isinstance(tfobj_or_name, six.string_types):
# If input is a string, assume it is a name and infer the corresponding tensor name.
# WARNING: this depends on TensorFlow's tensor naming convention
name = tfobj_or_name
name_parts = name.split(":")
assert len(name_parts) <= 2, name_parts
if len(name_parts) < 2:
name += ":0"
return name
elif hasattr(tfobj_or_name, 'graph'):
tfobj = tfobj_or_name
return get_tensor(tfobj.graph, tfobj).name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_name

else:
raise TypeError('invalid tf.Tensor name query type {}'.format(type(tfobj_or_name)))

def as_op_name(name):
def as_op_name(tfobj_or_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have 5 functions that compute some names in this file, this is at least two too many.

"""
Derive tf.Operation name from an op/tensor name
We do not check if the operation exist (as no graph parameter is passed in).
Derive tf.Operation name from an op/tensor name.
If the input is a name, we do not check if the operation exist
(as no graph parameter is passed in).

:param name: op name or tensor name
:param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
"""
assert isinstance(name, six.string_types)
name_parts = name.split(":")
assert len(name_parts) <= 2, name_parts
return name_parts[0]
if isinstance(tfobj_or_name, six.string_types):
# If input is a string, assume it is a name and infer the corresponding operation name.
# WARNING: this depends on TensorFlow's operation naming convention
name = tfobj_or_name
name_parts = name.split(":")
assert len(name_parts) <= 2, name_parts
return name_parts[0]
elif hasattr(tfobj_or_name, 'graph'):
tfobj = tfobj_or_name
return get_op(tfobj.graph, tfobj).name
else:
raise TypeError('invalid tf.Operation name query type {}'.format(type(tfobj_or_name)))

def op_name(graph, tfobj_or_name):
"""
Expand Down
8 changes: 6 additions & 2 deletions python/sparkdl/param/__init__.py
Expand Up @@ -14,7 +14,11 @@
#

from sparkdl.param.shared_params import (
keyword_only, HasInputCol, HasOutputCol, HasLabelCol, HasKerasModel,
HasKerasLoss, HasKerasOptimizer, HasOutputNodeName, SparkDLTypeConverters)
keyword_only, HasInputCol, HasOutputCol, HasLabelCol,
# TFTransformer Params
HasInputMapping, HasOutputMapping, HasTFInputGraph, HasTFHParams,
# Keras Estimator Params
HasKerasModel, HasKerasLoss, HasKerasOptimizer, HasOutputNodeName)
from sparkdl.param.converters import SparkDLTypeConverters
from sparkdl.param.image_params import (
CanLoadImage, HasInputImageNodeName, HasOutputMode, OUTPUT_MODES)