Skip to content

Commit

Permalink
fix agg
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jun 21, 2015
1 parent 73a90cb commit 5b5786d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ class Analyzer(
case Aggregate(groups, aggs, child)
if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) =>
Aggregate(groups, assignAliases(aggs), child)

case g: GroupingAnalytics
if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) =>
g.withNewAggs(assignAliases(g.aggregations))

case Project(projectList, child)
if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) =>
Project(assignAliases(projectList), child)
Expand Down Expand Up @@ -267,24 +272,24 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) =>
case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(child = f.copy(children = expandedArgs), name)() :: Nil
case Alias(c @ CreateArray(args), name) if containsStar(args) =>
UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil
case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(c.copy(children = expandedArgs), name)() :: Nil
case Alias(c @ CreateStruct(args), name) if containsStar(args) =>
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(c.copy(children = expandedArgs), name)() :: Nil
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case o => o :: Nil
},
child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.{errors, trees}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,18 @@ object ExtractValue {
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)

case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull)

case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
GetArrayItem(child, extraction)

case (_: MapType, _) =>
GetMapValue(child, extraction)

case (otherType, _) =>
val errorMsg = otherType match {
case StructType(_) | ArrayType(StructType(_), _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ trait GroupingAnalytics extends UnaryNode {
def aggregations: Seq[NamedExpression]

override def output: Seq[Attribute] = aggregations.map(_.toAttribute)

def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics
}

/**
Expand All @@ -266,7 +268,11 @@ case class GroupingSets(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {

def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
}

/**
* Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
Expand All @@ -284,7 +290,11 @@ case class Cube(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {

def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
}

/**
* Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
Expand All @@ -303,7 +313,11 @@ case class Rollup(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {

def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
}

case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
Expand Down

0 comments on commit 5b5786d

Please sign in to comment.