New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add style checks and refactor suggestions #121
Changes from 31 commits
4062166
f3b8f63
9edd45b
b7f8c5a
4681189
5683cd7
6328650
a8a5062
6a92e56
e6e2ce2
50010e0
6e69803
fb35644
1637a09
50b6dba
06013ec
65c2e2e
92c20d1
852625b
a66e6f6
07ba657
ecb4181
f92fc4a
018e36d
2054849
6df2b2e
862246b
5cbcd7c
9c170e8
5cd8161
5ce5b56
e90d3ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,17 +16,18 @@ | |
import logging | ||
import os | ||
import shutil | ||
import six | ||
from tempfile import mkdtemp | ||
|
||
import keras.backend as K | ||
from keras.models import Model as KerasModel, load_model | ||
import six | ||
import tensorflow as tf | ||
|
||
import sparkdl.graph.utils as tfx | ||
|
||
logger = logging.getLogger('sparkdl') | ||
|
||
# pylint: disable=fixme | ||
|
||
class IsolatedSession(object): | ||
""" | ||
|
@@ -83,15 +84,15 @@ def asGraphFunction(self, inputs, outputs, strip_and_freeze=True): | |
|
||
:param inputs: list, graph elements representing the inputs | ||
:param outputs: list, graph elements representing the outputs | ||
:param strip_and_freeze: bool, should we remove unused part of the graph and freee its values | ||
:param strip_and_freeze: bool, should we remove unused part of the graph and free its values | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lol I think it's actually supposed to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no idea how this happened :P |
||
""" | ||
if strip_and_freeze: | ||
gdef = tfx.strip_and_freeze_until(outputs, self.graph, self.sess) | ||
else: | ||
gdef = self.graph.as_graph_def(add_shapes=True) | ||
return GraphFunction(graph_def=gdef, | ||
input_names=[tfx.validated_input(elem, self.graph) for elem in inputs], | ||
output_names=[tfx.validated_output(elem, self.graph) for elem in outputs]) | ||
input_names = [tfx.validated_input(elem, self.graph) for elem in inputs] | ||
output_names = [tfx.validated_output(elem, self.graph) for elem in outputs] | ||
return GraphFunction(graph_def=gdef, input_names=input_names, output_names=output_names) | ||
|
||
def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_kargs): | ||
""" | ||
|
@@ -100,9 +101,11 @@ def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_k | |
|
||
.. _a link: https://www.tensorflow.org/api_docs/python/tf/import_graph_def | ||
|
||
:param gfn: GraphFunction, an object representing a TensorFlow graph and its inputs and outputs | ||
:param gfn: GraphFunction, an object representing a TensorFlow graph and its inputs and | ||
outputs | ||
:param input_map: dict, mapping from input names to existing graph elements | ||
:param prefix: str, the scope for all the variables in the :py:class:`GraphFunction` elements | ||
:param prefix: str, the scope for all the variables in the :py:class:`GraphFunction` | ||
elements | ||
|
||
.. _a link: https://www.tensorflow.org/programmers_guide/variable_scope | ||
|
||
|
@@ -119,13 +122,11 @@ def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_k | |
input_names = gfn.input_names | ||
output_names = gfn.output_names | ||
scope_name = prefix | ||
if prefix is not None: | ||
if prefix: | ||
scope_name = prefix.strip() | ||
if len(scope_name) > 0: | ||
output_names = [ | ||
scope_name + '/' + op_name for op_name in gfn.output_names] | ||
input_names = [ | ||
scope_name + '/' + op_name for op_name in gfn.input_names] | ||
if scope_name: | ||
output_names = [scope_name + '/' + op_name for op_name in gfn.output_names] | ||
input_names = [scope_name + '/' + op_name for op_name in gfn.input_names] | ||
|
||
# When importing, provide the original output op names | ||
tf.import_graph_def(gfn.graph_def, | ||
|
@@ -142,7 +143,8 @@ class GraphFunction(object): | |
""" | ||
Represent a TensorFlow graph with its GraphDef, input and output operation names. | ||
|
||
:param graph_def: GraphDef, a static ProtocolBuffer object holding informations of a TensorFlow graph | ||
:param graph_def: GraphDef, a static ProtocolBuffer object holding information of a | ||
TensorFlow graph | ||
:param input_names: names to the input graph elements (must be of Placeholder type) | ||
:param output_names: names to the output graph elements | ||
""" | ||
|
@@ -179,7 +181,8 @@ def fromKeras(cls, model_or_file_path): | |
""" | ||
Build a GraphFunction from a Keras model | ||
|
||
:param model_or_file_path: KerasModel or str, either a Keras model or the file path name to one | ||
:param model_or_file_path: KerasModel or str, either a Keras model or the file path name | ||
to one | ||
""" | ||
if isinstance(model_or_file_path, KerasModel): | ||
model = model_or_file_path | ||
|
@@ -214,7 +217,7 @@ def fromList(cls, functions): | |
:param functions: a list of tuples (scope name, GraphFunction object). | ||
""" | ||
assert len(functions) >= 1, ("must provide at least one function", functions) | ||
if 1 == len(functions): | ||
if len(functions) == 1: | ||
return functions[0] | ||
# Check against each intermediary layer input output function pairs | ||
for (scope_in, gfn_in), (scope_out, gfn_out) in zip(functions[:-1], functions[1:]): | ||
|
@@ -252,7 +255,8 @@ def fromList(cls, functions): | |
|
||
for idx, (scope, gfn) in enumerate(functions): | ||
# Give a scope to each function to avoid name conflict | ||
if scope is None or len(scope.strip()) == 0: | ||
if scope is None or len(scope.strip()) == 0: # pylint: disable=len-as-condition | ||
# TODO: refactor above and test: if not (scope and scope.strip()) | ||
scope = 'GFN-BLK-{}'.format(idx) | ||
_msg = 'merge: stage {}, scope {}'.format(idx, scope) | ||
logger.info(_msg) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: why the assert? looks like we already raise a ValueError if a paramMap is invalid in
_validateParams
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I agree, you could move the assert above to signal that it is always expected to be there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert any([...])
was changed to_ = [...]
for readability, because it suggested that the error would be thrown because of the assert, but in reality the error was thrown in_validateParams