Skip to content

Commit

Permalink
fix python
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jun 22, 2015
1 parent 5b5786d commit 552eba4
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 12 deletions.
9 changes: 5 additions & 4 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def __init__(self, sparkContext, sqlContext=None):
>>> df.registerTempTable("allTypes")
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
[Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
Expand Down Expand Up @@ -176,17 +177,17 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
[Row(c0=u'4')]
[Row(_c0=u'4')]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
[Row(_c0=4)]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
[Row(_c0=4)]
"""
func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ class Analyzer(
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
ResolveAliases ::
ResolveGroupingAnalytics ::
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
ResolveAliases ::
ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
Expand Down
12 changes: 10 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, ResolvedStar, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
Expand Down Expand Up @@ -629,7 +629,15 @@ class DataFrame private[sql](
@scala.annotation.varargs
def select(cols: Column*): DataFrame = {
val namedExpressions = cols.map {
case Column(expr: Expression) => UnresolvedAlias(expr)
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
// make it a NamedExpression.
case Column(u: UnresolvedAttribute) => UnresolvedAlias(u)
case Column(expr: NamedExpression) => expr
// Leave an unaliased explode with an empty list of names since the analzyer will generate the
// correct defaults after the nested expression's type has been resolved.
case Column(explode: Explode) => MultiAlias(explode, Nil)
case Column(expr: Expression) => Alias(expr, expr.prettyString)()
}
// When user continuously call `select`, speed up analysis by collapsing `Project`
import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.JavaConversions._
import scala.language.implicitConversions

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.NumericType
Expand Down Expand Up @@ -78,6 +78,10 @@ class GroupedData protected[sql](
}

val aliasedAgg = aggregates.map {
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
// make it a NamedExpression.
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan

case plan: LogicalPlan =>
case plan: LogicalPlan if plan.resolved =>
// Extract any PythonUDFs from the current operator.
val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
if (udfs.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1367,9 +1367,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {

test("SPARK-6145: special cases") {
sqlContext.read.json(sqlContext.sparkContext.makeRDD(
"""{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
"""{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1))
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1))
}

test("SPARK-6898: complete support for special chars in column names") {
Expand Down

0 comments on commit 552eba4

Please sign in to comment.