From 5b5786d6dc2c672e450e511f5a7b6424f21d3377 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Jun 2015 02:14:27 +0800 Subject: [PATCH] fix agg --- .../sql/catalyst/analysis/Analyzer.scala | 17 ++++++++++------ .../sql/catalyst/analysis/unresolved.scala | 1 - .../catalyst/expressions/ExtractValue.scala | 4 ++++ .../plans/logical/basicOperators.scala | 20 ++++++++++++++++--- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8b5554a155e89..39aa32b954043 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -154,6 +154,11 @@ class Analyzer( case Aggregate(groups, aggs, child) if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) => Aggregate(groups, assignAliases(aggs), child) + + case g: GroupingAnalytics + if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) => + g.withNewAggs(assignAliases(g.aggregations)) + case Project(projectList, child) if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) => Project(assignAliases(projectList), child) @@ -267,24 +272,24 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) - case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(child = f.copy(children = expandedArgs), name)() :: Nil - case Alias(c @ CreateArray(args), name) if containsStar(args) => + UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil + case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(c.copy(children = expandedArgs), name)() :: Nil - case Alias(c @ CreateStruct(args), name) if containsStar(args) => + UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil + case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(c.copy(children = expandedArgs), name)() :: Nil + UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil case o => o :: Nil }, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index c51c4dc83f51a..ae3adbab05108 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.{errors, trees} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 9832207ee940c..013027b199e63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -45,14 +45,18 @@ object ExtractValue { val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull) + case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) + case (_: MapType, _) => GetMapValue(child, extraction) + case (otherType, _) => val errorMsg = otherType match { case StructType(_) | ArrayType(StructType(_), _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 963c7820914f3..f8e5916d69f9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -242,6 +242,8 @@ trait GroupingAnalytics extends UnaryNode { def aggregations: Seq[NamedExpression] override def output: Seq[Attribute] = aggregations.map(_.toAttribute) + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics } /** @@ -266,7 +268,11 @@ case class GroupingSets( groupByExprs: Seq[Expression], child: LogicalPlan, aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} /** * Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, @@ -284,7 +290,11 @@ case class Cube( groupByExprs: Seq[Expression], child: LogicalPlan, aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} /** * Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, @@ -303,7 +313,11 @@ case class Rollup( groupByExprs: Seq[Expression], child: LogicalPlan, aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output