Skip to content

Commit

Permalink
[SPARK-2627] miscellaneous PEP 8 fixes
Browse files Browse the repository at this point in the history
Mostly done using autopep8, plus some hand fixes.
  • Loading branch information
nchammas committed Aug 3, 2014
1 parent beaa9ac commit a31ccc4
Show file tree
Hide file tree
Showing 30 changed files with 452 additions and 231 deletions.
3 changes: 2 additions & 1 deletion python/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
# mllib that depend on top level pyspark packages, which transitively depend on python's random.
# Since Python's import logic looks for modules in the current package first, we eliminate
# mllib.random as a candidate for C{import random} by removing the first search path, the script's
# location, in order to force the loader to look in Python's top-level modules for C{random}.
# location, in order to force the loader to look in Python's top-level
# modules for C{random}.
import sys
s = sys.path.pop(0)
import random
Expand Down
19 changes: 15 additions & 4 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@
pickleSer = PickleSerializer()

# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
# the local accumulator updates back to the driver program at the end of a task.
# the local accumulator updates back to the driver program at the end of a
# task.
_accumulatorRegistry = {}


Expand All @@ -110,6 +111,7 @@ def _deserialize_accumulator(aid, zero_value, accum_param):


class Accumulator(object):

"""
A shared variable that can be accumulated, i.e., has a commutative and associative "add"
operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=}
Expand Down Expand Up @@ -139,14 +141,16 @@ def __reduce__(self):
def value(self):
"""Get the accumulator's value; only usable in driver program"""
if self._deserialized:
raise Exception("Accumulator.value cannot be accessed inside tasks")
raise Exception(
"Accumulator.value cannot be accessed inside tasks")
return self._value

@value.setter
def value(self, value):
"""Sets the accumulator's value; only usable in driver program"""
if self._deserialized:
raise Exception("Accumulator.value cannot be accessed inside tasks")
raise Exception(
"Accumulator.value cannot be accessed inside tasks")
self._value = value

def add(self, term):
Expand All @@ -166,6 +170,7 @@ def __repr__(self):


class AccumulatorParam(object):

"""
Helper object that defines how to accumulate values of a given type.
"""
Expand All @@ -186,6 +191,7 @@ def addInPlace(self, value1, value2):


class AddingAccumulatorParam(AccumulatorParam):

"""
An AccumulatorParam that uses the + operators to add values. Designed for simple types
such as integers, floats, and lists. Requires the zero value for the underlying type
Expand All @@ -210,6 +216,7 @@ def addInPlace(self, value1, value2):


class _UpdateRequestHandler(SocketServer.StreamRequestHandler):

"""
This handler will keep polling updates from the same socket until the
server is shutdown.
Expand All @@ -218,7 +225,8 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
def handle(self):
from pyspark.accumulators import _accumulatorRegistry
while not self.server.server_shutdown:
# Poll every 1 second for new data -- don't block in case of shutdown.
# Poll every 1 second for new data -- don't block in case of
# shutdown.
r, _, _ = select.select([self.rfile], [], [], 1)
if self.rfile in r:
num_updates = read_int(self.rfile)
Expand All @@ -228,7 +236,9 @@ def handle(self):
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))


class AccumulatorServer(SocketServer.TCPServer):

"""
A simple TCP server that intercepts shutdown() in order to interrupt
our continuous polling on the handler.
Expand All @@ -239,6 +249,7 @@ def shutdown(self):
self.server_shutdown = True
SocketServer.TCPServer.shutdown(self)


def _start_update_server():
"""Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler)
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _from_id(bid):


class Broadcast(object):

"""
A broadcast variable created with
L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}.
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@


class SparkConf(object):

"""
Configuration for a Spark application. Used to set various Spark
parameters as key-value pairs.
Expand Down Expand Up @@ -124,7 +125,8 @@ def setSparkHome(self, value):
def setExecutorEnv(self, key=None, value=None, pairs=None):
"""Set an environment variable to be passed to executors."""
if (key is not None and pairs is not None) or (key is None and pairs is None):
raise Exception("Either pass one key-value pair or a list of pairs")
raise Exception(
"Either pass one key-value pair or a list of pairs")
elif key is not None:
self._jconf.setExecutorEnv(key, value)
elif pairs is not None:
Expand Down
97 changes: 63 additions & 34 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@


class SparkContext(object):

"""
Main entry point for Spark functionality. A SparkContext represents the
connection to a Spark cluster, and can be used to create L{RDD}s and
Expand All @@ -59,7 +60,8 @@ class SparkContext(object):
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
# zip and egg files that need to be added to PYTHONPATH
_python_includes = None
_default_batch_size_for_serialized_input = 10

def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
Expand Down Expand Up @@ -99,13 +101,15 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
self._callsite = rdd._extract_concise_traceback()
else:
tempNamedTuple = namedtuple("Callsite", "function file linenum")
self._callsite = tempNamedTuple(function=None, file=None, linenum=None)
self._callsite = tempNamedTuple(
function=None, file=None, linenum=None)
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
# If an error occurs, clean up in order to allow future
# SparkContext creation:
self.stop()
raise

Expand Down Expand Up @@ -138,7 +142,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
if not self._conf.contains("spark.master"):
raise Exception("A master URL must be set in your configuration")
if not self._conf.contains("spark.app.name"):
raise Exception("An application name must be set in your configuration")
raise Exception(
"An application name must be set in your configuration")

# Read back our properties from the conf in case we loaded some of them from
# the classpath or an external config file
Expand Down Expand Up @@ -179,7 +184,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
self.addPyFile(path)

# Deploy code dependencies set by spark-submit; these will already have been added
# with SparkContext.addFile, so we just need to add them to the PYTHONPATH
# with SparkContext.addFile, so we just need to add them to the
# PYTHONPATH
for path in self._conf.get("spark.submit.pyFiles", "").split(","):
if path != "":
(dirname, filename) = os.path.split(path)
Expand All @@ -189,9 +195,11 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
sys.path.append(dirname)

# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(
self._jsc.sc().conf())
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
self._jvm.org.apache.spark.util.Utils.createTempDir(
local_dir).getAbsolutePath()

def _initialize_context(self, jconf):
"""
Expand All @@ -213,7 +221,7 @@ def _ensure_initialized(cls, instance=None, gateway=None):

if instance:
if (SparkContext._active_spark_context and
SparkContext._active_spark_context != instance):
SparkContext._active_spark_context != instance):
currentMaster = SparkContext._active_spark_context.master
currentAppName = SparkContext._active_spark_context.appName
callsite = SparkContext._active_spark_context._callsite
Expand Down Expand Up @@ -284,7 +292,8 @@ def parallelize(self, c, numSlices=None):
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
# Make sure we distribute data evenly if it's smaller than self.batchSize
# Make sure we distribute data evenly if it's smaller than
# self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
batchSize = min(len(c) // numSlices, self._batchSize)
Expand Down Expand Up @@ -403,10 +412,12 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None,
Java object. (default sc._default_batch_size_for_serialized_input)
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
batchSize = max(
1, batchSize or self._default_batch_size_for_serialized_input)
ser = BatchedSerializer(PickleSerializer()) if (
batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass,
keyConverter, valueConverter, minSplits, batchSize)
keyConverter, valueConverter, minSplits, batchSize)
return RDD(jrdd, self, ser)

def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
Expand Down Expand Up @@ -434,10 +445,13 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv
Java object. (default sc._default_batch_size_for_serialized_input)
"""
jconf = self._dictToJavaMap(conf)
batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter, jconf, batchSize)
batchSize = max(
1, batchSize or self._default_batch_size_for_serialized_input)
ser = BatchedSerializer(PickleSerializer()) if (
batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.newAPIHadoopFile(
self._jsc, path, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter, jconf, batchSize)
return RDD(jrdd, self, ser)

def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
Expand All @@ -462,10 +476,13 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N
Java object. (default sc._default_batch_size_for_serialized_input)
"""
jconf = self._dictToJavaMap(conf)
batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter, jconf, batchSize)
batchSize = max(
1, batchSize or self._default_batch_size_for_serialized_input)
ser = BatchedSerializer(PickleSerializer()) if (
batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(
self._jsc, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter, jconf, batchSize)
return RDD(jrdd, self, ser)

def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
Expand Down Expand Up @@ -493,10 +510,13 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=
Java object. (default sc._default_batch_size_for_serialized_input)
"""
jconf = self._dictToJavaMap(conf)
batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter, jconf, batchSize)
batchSize = max(
1, batchSize or self._default_batch_size_for_serialized_input)
ser = BatchedSerializer(PickleSerializer()) if (
batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.hadoopFile(
self._jsc, path, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter, jconf, batchSize)
return RDD(jrdd, self, ser)

def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
Expand All @@ -521,10 +541,12 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
Java object. (default sc._default_batch_size_for_serialized_input)
"""
jconf = self._dictToJavaMap(conf)
batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
batchSize = max(
1, batchSize or self._default_batch_size_for_serialized_input)
ser = BatchedSerializer(PickleSerializer()) if (
batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass,
keyConverter, valueConverter, jconf, batchSize)
keyConverter, valueConverter, jconf, batchSize)
return RDD(jrdd, self, ser)

def _checkpointFile(self, name, input_deserializer):
Expand Down Expand Up @@ -587,7 +609,8 @@ def accumulator(self, value, accum_param=None):
elif isinstance(value, complex):
accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
else:
raise Exception("No default accumulator param for type %s" % type(value))
raise Exception(
"No default accumulator param for type %s" % type(value))
SparkContext._next_accum_id += 1
return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)

Expand Down Expand Up @@ -632,12 +655,14 @@ def addPyFile(self, path):
HTTP, HTTPS or FTP URI.
"""
self.addFile(path)
(dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix
# dirname may be directory or HDFS/S3 prefix
(dirname, filename) = os.path.split(path)

if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
self._python_includes.append(filename)
# for tests in local mode
sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename))
sys.path.append(
os.path.join(SparkFiles.getRootDirectory(), filename))

def setCheckpointDir(self, dirName):
"""
Expand All @@ -651,7 +676,8 @@ def _getJavaStorageLevel(self, storageLevel):
Returns a Java StorageLevel based on a pyspark.StorageLevel.
"""
if not isinstance(storageLevel, StorageLevel):
raise Exception("storageLevel must be of type pyspark.StorageLevel")
raise Exception(
"storageLevel must be of type pyspark.StorageLevel")

newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel
return newStorageLevel(storageLevel.useDisk,
Expand Down Expand Up @@ -754,13 +780,15 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
"""
if partitions is None:
partitions = range(rdd._jrdd.partitions().size())
javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client)
javaPartitions = ListConverter().convert(
partitions, self._gateway._gateway_client)

# Implementation note: This is implemented as a mapPartitions followed
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
it = self._jvm.PythonRDD.runJob(
self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))


Expand All @@ -772,7 +800,8 @@ def _test():
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
globs['tempdir'] = tempfile.mkdtemp()
atexit.register(lambda: shutil.rmtree(globs['tempdir']))
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)
Expand Down
Loading

0 comments on commit a31ccc4

Please sign in to comment.