From 6ea5925d55f7e85755770fb8cdac34970e587174 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Jun 2015 08:10:07 -0700 Subject: [PATCH] address comments --- python/pyspark/sql/context.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 5be6c714b3488..2ad0ffe5012b5 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -189,7 +189,13 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc._javaAccumulator, returnType.json()) - def _inferSchemaLocally(self, data): + def _inferSchemaFromList(self, data): + """ + Infer schema from list of Row or tuple. + + :param data: list of Row or tuple + :return: StructType + """ if not data: raise ValueError("can not infer schema from empty dataset") first = data[0] @@ -207,9 +213,13 @@ def _inferSchemaLocally(self, data): return schema def _inferSchema(self, rdd, samplingRatio=None): - if not isinstance(rdd, RDD): - return self._inferSchemaLocally(rdd) + """ + Infer schema from an RDD of Row or tuple. + :param rdd: an RDD of Row or tuple + :param samplingRatio: sampling ratio, or no sampling (default) + :return: StructType + """ first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " @@ -326,29 +336,36 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): schema = list(data.columns) data = [r.tolist() for r in data.to_records(index=False)] - if not isinstance(data, RDD): + if isinstance(data, RDD): + rdd = data + else: try: # data could be list, tuple, generator ... rdd = self._sc.parallelize(data) except Exception: raise TypeError("cannot create an RDD from type: %s" % type(data)) - else: - rdd = data - if not isinstance(schema, StructType): - struct = self._inferSchema(data, samplingRatio) + if schema is None or isinstance(schema, (list, tuple)): + if isinstance(data, RDD): + struct = self._inferSchema(rdd, samplingRatio) + else: + struct = self._inferSchemaFromList(data) if isinstance(schema, (list, tuple)): for i, name in enumerate(schema): struct.fields[i].name = name schema = struct converter = _create_converter(schema) rdd = rdd.map(converter) - else: + + elif isinstance(schema, StructType): # take the first few rows to verify schema rows = rdd.take(10) for row in rows: _verify_type(row, schema) + else: + raise TypeError("schema should be StructType or list or None") + # convert python objects to sql data converter = _python_to_sql_converter(schema) rdd = rdd.map(converter)