Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 3, 2015
1 parent 6ceaeff commit 6ea5925
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,13 @@ def registerFunction(self, name, f, returnType=StringType()):
self._sc._javaAccumulator,
returnType.json())

def _inferSchemaLocally(self, data):
def _inferSchemaFromList(self, data):
"""
Infer schema from list of Row or tuple.
:param data: list of Row or tuple
:return: StructType
"""
if not data:
raise ValueError("can not infer schema from empty dataset")
first = data[0]
Expand All @@ -207,9 +213,13 @@ def _inferSchemaLocally(self, data):
return schema

def _inferSchema(self, rdd, samplingRatio=None):
if not isinstance(rdd, RDD):
return self._inferSchemaLocally(rdd)
"""
Infer schema from an RDD of Row or tuple.
:param rdd: an RDD of Row or tuple
:param samplingRatio: sampling ratio, or no sampling (default)
:return: StructType
"""
first = rdd.first()
if not first:
raise ValueError("The first row in RDD is empty, "
Expand Down Expand Up @@ -326,29 +336,36 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
schema = list(data.columns)
data = [r.tolist() for r in data.to_records(index=False)]

if not isinstance(data, RDD):
if isinstance(data, RDD):
rdd = data
else:
try:
# data could be list, tuple, generator ...
rdd = self._sc.parallelize(data)
except Exception:
raise TypeError("cannot create an RDD from type: %s" % type(data))
else:
rdd = data

if not isinstance(schema, StructType):
struct = self._inferSchema(data, samplingRatio)
if schema is None or isinstance(schema, (list, tuple)):
if isinstance(data, RDD):
struct = self._inferSchema(rdd, samplingRatio)
else:
struct = self._inferSchemaFromList(data)
if isinstance(schema, (list, tuple)):
for i, name in enumerate(schema):
struct.fields[i].name = name
schema = struct
converter = _create_converter(schema)
rdd = rdd.map(converter)
else:

elif isinstance(schema, StructType):
# take the first few rows to verify schema
rows = rdd.take(10)
for row in rows:
_verify_type(row, schema)

else:
raise TypeError("schema should be StructType or list or None")

# convert python objects to sql data
converter = _python_to_sql_converter(schema)
rdd = rdd.map(converter)
Expand Down

0 comments on commit 6ea5925

Please sign in to comment.