Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add style checks and refactor suggestions #121

Merged
merged 32 commits into from Apr 25, 2018
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4062166
[ML-3487] add accepted and suggested pylint rc files
yogeshg Apr 20, 2018
f3b8f63
space, comments, messages
yogeshg Apr 20, 2018
9edd45b
refactor: rename or ignore invalid-names; remove unused import
yogeshg Apr 23, 2018
b7f8c5a
bugfix: variables named without updating in error message
yogeshg Apr 23, 2018
4681189
space, comment, alignment changes
yogeshg Apr 23, 2018
5683cd7
pylint disable protected-access
yogeshg Apr 23, 2018
6328650
ignore stlye case wise
yogeshg Apr 23, 2018
a8a5062
bugfix
yogeshg Apr 23, 2018
6a92e56
pylint snake_case and camelCase attribute names allowed
yogeshg Apr 23, 2018
e6e2ce2
ignore stlye case wise
yogeshg Apr 23, 2018
50010e0
add spaces, indent, expressions to simplify reading
yogeshg Apr 23, 2018
6e69803
pylint disable what can be kept
yogeshg Apr 23, 2018
fb35644
if len if -> if any; imports, ignore importing issues
yogeshg Apr 23, 2018
1637a09
add spaces, indent, group imports
yogeshg Apr 23, 2018
50b6dba
requires refactoring
yogeshg Apr 23, 2018
06013ec
Undefined variable name 'imageType' in __all__
yogeshg Apr 23, 2018
65c2e2e
add spaces, indent, group imports
yogeshg Apr 23, 2018
92c20d1
add and ignore todos
yogeshg Apr 23, 2018
852625b
no len as comparsion, indent, simple expressions
yogeshg Apr 23, 2018
a66e6f6
add spaces, indent
yogeshg Apr 23, 2018
07ba657
no-else-return seems weird on local machine, disabling
yogeshg Apr 23, 2018
ecb4181
group imports, ignore g, op names
yogeshg Apr 23, 2018
f92fc4a
remove trailing white spaces
yogeshg Apr 23, 2018
018e36d
ignore import error, add refactor todo, lazy logging
yogeshg Apr 23, 2018
2054849
fix imports, disable some
yogeshg Apr 23, 2018
6df2b2e
use output of _validateParam, ignore too few in ThreadSafeIterator
yogeshg Apr 23, 2018
862246b
optimize imports, ignore fixme, no-self-use
yogeshg Apr 23, 2018
5cbcd7c
optimize imports
yogeshg Apr 23, 2018
9c170e8
optimize imports, space
yogeshg Apr 23, 2018
5cd8161
fix cyclic import (sparkdl.param -> sparkdl.param.image_params)
yogeshg Apr 23, 2018
5ce5b56
optimize imports
yogeshg Apr 23, 2018
e90d3ad
fixmes in code are acceptable; throw away value of validate_params; s…
yogeshg Apr 24, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
556 changes: 556 additions & 0 deletions python/.pylint/accepted.rc

Large diffs are not rendered by default.

547 changes: 547 additions & 0 deletions python/.pylint/suggested.rc

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions python/sparkdl/__init__.py
Expand Up @@ -22,6 +22,6 @@
from .estimators.keras_image_file_estimator import KerasImageFileEstimator

__all__ = [
'imageType', 'TFImageTransformer', 'TFInputGraph', 'TFTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer', 'KerasTransformer',
'TFImageTransformer', 'TFInputGraph', 'TFTransformer', 'DeepImagePredictor',
'DeepImageFeaturizer', 'KerasImageFileTransformer', 'KerasTransformer',
'imageInputPlaceholder', 'KerasImageFileEstimator']
3 changes: 2 additions & 1 deletion python/sparkdl/estimators/keras_image_file_estimator.py
Expand Up @@ -34,6 +34,7 @@
__all__ = ['KerasImageFileEstimator']


# pylint: disable=too-few-public-methods
class _ThreadSafeIterator(object):
"""
Utility iterator class used by KerasImageFileEstimator.fitMultiple to serve models in a thread
Expand Down Expand Up @@ -264,7 +265,7 @@ def fitMultiple(self, dataset, paramMaps):
existence of a sufficiently large (and writable) file system, users are
advised to not train too many models in a single Spark job.
"""
[self._validateParams(pm) for pm in paramMaps]
assert all([self._validateParams(pm) for pm in paramMaps])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: why the assert? looks like we already raise a ValueError if a paramMap is invalid in _validateParams

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I agree, you could move the assert above to signal that it is always expected to be there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert any([...]) was changed to _ = [...] for readability, because it suggested that the error would be thrown because of the assert, but in reality the error was thrown in _validateParams


def _name_value_map(paramMap):
"""takes a dictionary {param -> value} and returns a map of {param.name -> value}"""
Expand Down
38 changes: 21 additions & 17 deletions python/sparkdl/graph/builder.py
Expand Up @@ -16,17 +16,18 @@
import logging
import os
import shutil
import six
from tempfile import mkdtemp

import keras.backend as K
from keras.models import Model as KerasModel, load_model
import six
import tensorflow as tf

import sparkdl.graph.utils as tfx

logger = logging.getLogger('sparkdl')

# pylint: disable=fixme

class IsolatedSession(object):
"""
Expand Down Expand Up @@ -83,15 +84,15 @@ def asGraphFunction(self, inputs, outputs, strip_and_freeze=True):

:param inputs: list, graph elements representing the inputs
:param outputs: list, graph elements representing the outputs
:param strip_and_freeze: bool, should we remove unused part of the graph and freee its values
:param strip_and_freeze: bool, should we remove unused part of the graph and free its values
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol I think it's actually supposed to be freeze instead of free :P

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no idea how this happened :P

"""
if strip_and_freeze:
gdef = tfx.strip_and_freeze_until(outputs, self.graph, self.sess)
else:
gdef = self.graph.as_graph_def(add_shapes=True)
return GraphFunction(graph_def=gdef,
input_names=[tfx.validated_input(elem, self.graph) for elem in inputs],
output_names=[tfx.validated_output(elem, self.graph) for elem in outputs])
input_names = [tfx.validated_input(elem, self.graph) for elem in inputs]
output_names = [tfx.validated_output(elem, self.graph) for elem in outputs]
return GraphFunction(graph_def=gdef, input_names=input_names, output_names=output_names)

def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_kargs):
"""
Expand All @@ -100,9 +101,11 @@ def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_k

.. _a link: https://www.tensorflow.org/api_docs/python/tf/import_graph_def

:param gfn: GraphFunction, an object representing a TensorFlow graph and its inputs and outputs
:param gfn: GraphFunction, an object representing a TensorFlow graph and its inputs and
outputs
:param input_map: dict, mapping from input names to existing graph elements
:param prefix: str, the scope for all the variables in the :py:class:`GraphFunction` elements
:param prefix: str, the scope for all the variables in the :py:class:`GraphFunction`
elements

.. _a link: https://www.tensorflow.org/programmers_guide/variable_scope

Expand All @@ -119,13 +122,11 @@ def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_k
input_names = gfn.input_names
output_names = gfn.output_names
scope_name = prefix
if prefix is not None:
if prefix:
scope_name = prefix.strip()
if len(scope_name) > 0:
output_names = [
scope_name + '/' + op_name for op_name in gfn.output_names]
input_names = [
scope_name + '/' + op_name for op_name in gfn.input_names]
if scope_name:
output_names = [scope_name + '/' + op_name for op_name in gfn.output_names]
input_names = [scope_name + '/' + op_name for op_name in gfn.input_names]

# When importing, provide the original output op names
tf.import_graph_def(gfn.graph_def,
Expand All @@ -142,7 +143,8 @@ class GraphFunction(object):
"""
Represent a TensorFlow graph with its GraphDef, input and output operation names.

:param graph_def: GraphDef, a static ProtocolBuffer object holding informations of a TensorFlow graph
:param graph_def: GraphDef, a static ProtocolBuffer object holding information of a
TensorFlow graph
:param input_names: names to the input graph elements (must be of Placeholder type)
:param output_names: names to the output graph elements
"""
Expand Down Expand Up @@ -179,7 +181,8 @@ def fromKeras(cls, model_or_file_path):
"""
Build a GraphFunction from a Keras model

:param model_or_file_path: KerasModel or str, either a Keras model or the file path name to one
:param model_or_file_path: KerasModel or str, either a Keras model or the file path name
to one
"""
if isinstance(model_or_file_path, KerasModel):
model = model_or_file_path
Expand Down Expand Up @@ -214,7 +217,7 @@ def fromList(cls, functions):
:param functions: a list of tuples (scope name, GraphFunction object).
"""
assert len(functions) >= 1, ("must provide at least one function", functions)
if 1 == len(functions):
if len(functions) == 1:
return functions[0]
# Check against each intermediary layer input output function pairs
for (scope_in, gfn_in), (scope_out, gfn_out) in zip(functions[:-1], functions[1:]):
Expand Down Expand Up @@ -252,7 +255,8 @@ def fromList(cls, functions):

for idx, (scope, gfn) in enumerate(functions):
# Give a scope to each function to avoid name conflict
if scope is None or len(scope.strip()) == 0:
if scope is None or len(scope.strip()) == 0: # pylint: disable=len-as-condition
# TODO: refactor above and test: if not (scope and scope.strip())
scope = 'GFN-BLK-{}'.format(idx)
_msg = 'merge: stage {}, scope {}'.format(idx, scope)
logger.info(_msg)
Expand Down
2 changes: 1 addition & 1 deletion python/sparkdl/graph/input.py
Expand Up @@ -77,7 +77,7 @@ class TFInputGraph(object):
inference, i.e. the variables are converted to constants and operations like
BatchNormalization_ are converted to be independent of input batch.

.. _BatchNormalization: https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization
.. _BatchNormalization: https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization

:param input_tensor_name_from_signature: dict, signature key names mapped to tensor names.
Please see the example above.
Expand Down
4 changes: 2 additions & 2 deletions python/sparkdl/graph/pieces.py
Expand Up @@ -55,8 +55,8 @@ def buildSpImageConverter(channelOrder, img_dtype):
elif img_dtype == 'float32':
image_float = tf.decode_raw(image_buffer, tf.float32, name="decode_raw")
else:
raise ValueError(
'unsupported image data type "%s", currently only know how to handle uint8 and float32' % img_dtype)
raise ValueError('''unsupported image data type "%s", currently only know how to
handle uint8 and float32''' % img_dtype)
image_reshaped = tf.reshape(image_float, shape, name="reshaped")
image_reshaped = imageIO.fixColorChannelOrdering(channelOrder, image_reshaped)
image_input = tf.expand_dims(image_reshaped, 0, name="image_input")
Expand Down
5 changes: 4 additions & 1 deletion python/sparkdl/graph/tensorframes_udf.py
Expand Up @@ -16,13 +16,14 @@

import logging

import tensorframes as tfs
import tensorframes as tfs # pylint: disable=import-error

import sparkdl.graph.utils as tfx
from sparkdl.utils import jvmapi as JVMAPI

logger = logging.getLogger('sparkdl')

# pylint: disable=fixme

def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=False, register=True):
"""
Expand Down Expand Up @@ -85,6 +86,8 @@ def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=Fal
placeholder_names = []
placeholder_shapes = []
for node in graph.as_graph_def(add_shapes=True).node:
# pylint: disable=len-as-condition
# todo: refactor if not(node.input) and ...
if len(node.input) == 0 and str(node.op) == 'Placeholder':
tnsr_name = tfx.tensor_name(node.name, graph)
tnsr = graph.get_tensor_by_name(tnsr_name)
Expand Down
11 changes: 5 additions & 6 deletions python/sparkdl/graph/utils.py
Expand Up @@ -15,8 +15,8 @@
#

import logging
import six

import six
import tensorflow as tf

logger = logging.getLogger('sparkdl')
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_op(tfobj_or_name, graph):
if not isinstance(name, six.string_types):
raise TypeError('invalid op request for [type {}] {}'.format(type(name), name))
_op_name = op_name(name, graph=None)
op = graph.get_operation_by_name(_op_name)
op = graph.get_operation_by_name(_op_name) # pylint: disable=invalid-name
err_msg = 'cannot locate op {} in the current graph, got [type {}] {}'
assert isinstance(op, tf.Operation), err_msg.format(_op_name, type(op), op)
return op
Expand Down Expand Up @@ -190,9 +190,8 @@ def validated_input(tfobj_or_name, graph):
"""
graph = validated_graph(graph)
name = op_name(tfobj_or_name, graph)
op = graph.get_operation_by_name(name)
assert 'Placeholder' == op.type, \
('input must be Placeholder, but get', op.type)
op = graph.get_operation_by_name(name) # pylint: disable=invalid-name
assert 'Placeholder' == op.type, ('input must be Placeholder, but get', op.type)
return name


Expand Down Expand Up @@ -223,7 +222,7 @@ def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False):
sess.close()

if return_graph:
g = tf.Graph()
g = tf.Graph() # pylint: disable=invalid-name
with g.as_default():
tf.import_graph_def(gdef_frozen, name='')
return g
Expand Down
40 changes: 20 additions & 20 deletions python/sparkdl/image/imageIO.py
Expand Up @@ -25,8 +25,7 @@
from pyspark import SparkContext
from pyspark.ml.image import ImageSchema
from pyspark.sql.functions import udf
from pyspark.sql.types import (
BinaryType, IntegerType, StringType, StructField, StructType)
from pyspark.sql.types import BinaryType, StringType, StructField, StructType


# ImageType represents supported OpenCV types
Expand All @@ -39,8 +38,7 @@
# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
_OcvType = namedtuple("OcvType", ["name", "ord", "nChannels", "dtype"])


_supportedOcvTypes = (
_SUPPORTED_OCV_TYPES = (
_OcvType(name="CV_8UC1", ord=0, nChannels=1, dtype="uint8"),
_OcvType(name="CV_32FC1", ord=5, nChannels=1, dtype="float32"),
_OcvType(name="CV_8UC3", ord=16, nChannels=3, dtype="uint8"),
Expand All @@ -50,22 +48,22 @@
)

# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
_ocvTypesByName = {m.name: m for m in _supportedOcvTypes}
_ocvTypesByOrdinal = {m.ord: m for m in _supportedOcvTypes}
_OCV_TYPES_BY_NAME = {m.name: m for m in _SUPPORTED_OCV_TYPES}
_OCV_TYPES_BY_ORDINAL = {m.ord: m for m in _SUPPORTED_OCV_TYPES}


def imageTypeByOrdinal(ord):
if not ord in _ocvTypesByOrdinal:
def imageTypeByOrdinal(ordinal):
if not ordinal in _OCV_TYPES_BY_ORDINAL:
raise KeyError("unsupported image type with ordinal %d, supported OpenCV types = %s" % (
ord, str(_supportedOcvTypes)))
return _ocvTypesByOrdinal[ord]
ordinal, str(_SUPPORTED_OCV_TYPES)))
return _OCV_TYPES_BY_ORDINAL[ordinal]


def imageTypeByName(name):
if not name in _ocvTypesByName:
raise KeyError("unsupported image type with name '%s', supported supported OpenCV types = %s" % (
name, str(_supportedOcvTypes)))
return _ocvTypesByName[name]
if not name in _OCV_TYPES_BY_NAME:
raise KeyError("unsupported image type with name '%s', supported OpenCV types = %s" % (
name, str(_SUPPORTED_OCV_TYPES)))
return _OCV_TYPES_BY_NAME[name]


def imageArrayToStruct(imgArray, origin=""):
Expand Down Expand Up @@ -151,13 +149,13 @@ def fixColorChannelOrdering(currentOrder, imgAry):
elif currentOrder == 'BGR':
return imgAry
elif currentOrder == 'L':
if len(img.shape) != 1:
if len(imgAry.shape) != 1:
raise ValueError(
"channel order suggests only one color channel but got shape " + str(img.shape))
"channel order suggests only one color channel but got shape " + str(imgAry.shape))
return imgAry
else:
raise ValueError(
"Unexpected channel order, expected one of L,RGB,BGR but got " + currentChannelOrder)
"Unexpected channel order, expected one of L,RGB,BGR but got " + currentOrder)


def _reverseChannels(ary):
Expand All @@ -176,6 +174,7 @@ def createResizeImageUDF(size):
if len(size) != 2:
raise ValueError(
"New image size should have format [height, width] but got {}".format(size))
# pylint: disable=invalid-name
sz = (size[1], size[0])

def _resizeImageAsRow(imgAsRow):
Expand Down Expand Up @@ -228,11 +227,12 @@ def _decode(raw_bytes):

def readImagesWithCustomFn(path, decode_f, numPartition=None):
"""
Read a directory of images (or a single image) into a DataFrame using a custom library to decode the images.
Read a directory of images (or a single image) into a DataFrame using a custom library to
decode the images.

:param path: str, file path.
:param decode_f: function to decode the raw bytes into an array compatible with one of the supported OpenCv modes.
see @imageIO.PIL_decode for an example.
:param decode_f: function to decode the raw bytes into an array compatible with one of the
supported OpenCv modes. see @imageIO.PIL_decode for an example.
:param numPartition: [optional] int, number or partitions to use for reading files.
:return: DataFrame with schema == ImageSchema.imageSchema.
"""
Expand Down
18 changes: 7 additions & 11 deletions python/sparkdl/param/converters.py
Expand Up @@ -13,8 +13,6 @@
# limitations under the License.
#

# pylint: disable=invalid-name,import-error

""" SparkDLTypeConverters

Type conversion utilities for defining MLlib `Params` used in Spark Deep Learning Pipelines.
Expand All @@ -25,12 +23,9 @@
"""

import six

import tensorflow as tf

from pyspark.ml.param import TypeConverters

from sparkdl.graph.input import *
from sparkdl.graph.input import TFInputGraph
import sparkdl.utils.keras_model as kmutil

__all__ = ['SparkDLTypeConverters']
Expand Down Expand Up @@ -129,7 +124,8 @@ def buildSupportedItemConverter(supportedList):
Create a "converter" that try to check if a value is part of the supported list of values.

:param supportedList: list, containing supported objects.
:return: a converter that try to check if a value is part of the `supportedList` and return it.
:return: a converter that try to check if a value is part of the `supportedList` and
return it.
Raise an error otherwise.
"""

Expand Down Expand Up @@ -171,8 +167,8 @@ def toKerasOptimizer(value):
@staticmethod
def toChannelOrder(value):
if not value in ('L', 'RGB', 'BGR'):
raise ValueError(
"Unsupported channel order. Expected one of ('L', 'RGB', 'BGR') but got '%s'") % value
raise ValueError("""Unsupported channel order. Expected one of ('L', 'RGB',
'BGR') but got '%s'""" % value)
return value


Expand All @@ -189,8 +185,8 @@ def _check_is_tensor_name(_maybe_tnsr_name):
# may optionally be followed by control inputs that have the format
# "^node".
# Reference:
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/node_def.proto
# https://stackoverflow.com/questions/36150834/how-does-tensorflow-name-tensors
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/node_def.proto
# https://stackoverflow.com/questions/36150834/how-does-tensorflow-name-tensors
try:
_, src_idx = _maybe_tnsr_name.split(":")
_ = int(src_idx)
Expand Down