diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index ffe177576f363..cb83e89176823 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -30,7 +30,7 @@ class SQLContext: tables, execute SQL over tables, cache tables, and read parquet files. """ - def __init__(self, sparkContext, sqlContext = None): + def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. @param sparkContext: The SparkContext to wrap. @@ -137,7 +137,6 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path) return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path): """Loads a text file storing one JSON object per line, returning the result as a L{SchemaRDD}. @@ -234,8 +233,8 @@ def _ssql_ctx(self): self._scala_HiveContext = self._get_hive_ctx() return self._scala_HiveContext except Py4JError as e: - raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \ - "sbt/sbt assembly" , e) + raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " + "sbt/sbt assembly", e) def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) @@ -377,7 +376,7 @@ def registerAsTable(self, name): """ self._jschema_rdd.registerAsTable(name) - def insertInto(self, tableName, overwrite = False): + def insertInto(self, tableName, overwrite=False): """Inserts the contents of this SchemaRDD into the specified table. Optionally overwriting any existing data. @@ -420,7 +419,7 @@ def _toPython(self): # in Java land in the javaToPython function. May require a custom # pickle serializer in Pyrolite return RDD(jrdd, self._sc, BatchedSerializer( - PickleSerializer())).map(lambda d: Row(d)) + PickleSerializer())).map(lambda d: Row(d)) # We override the default cache/persist/checkpoint behavior as we want to cache the underlying # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class @@ -483,6 +482,7 @@ def subtract(self, other, numPartitions=None): else: raise ValueError("Can only subtract another SchemaRDD") + def _test(): import doctest from array import array @@ -493,20 +493,25 @@ def _test(): sc = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['sc'] = sc globs['sqlCtx'] = SQLContext(sc) - globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, - {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) - jsonStrings = ['{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}'] + globs['rdd'] = sc.parallelize( + [{"field1": 1, "field2": "row1"}, + {"field1": 2, "field2": "row2"}, + {"field1": 3, "field2": "row3"}] + ) + jsonStrings = [ + '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}' + ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) globs['nestedRdd1'] = sc.parallelize([ - {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, - {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) + {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, + {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}]) globs['nestedRdd2'] = sc.parallelize([ - {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, - {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]) - (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) + {"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)}, + {"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}]) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) @@ -514,4 +519,3 @@ def _test(): if __name__ == "__main__": _test() -