diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6311784422a91..0a3f5a7b5cade 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -192,49 +192,17 @@ class Analyzer( Seq.tabulate(1 << c.groupByExprs.length)(i => i) } - /** - * Create an array of Projections for the child projection, and replace the projections' - * expressions which equal GroupBy expressions with Literal(null), if those expressions - * are not set for this grouping set (according to the bit mask). - */ - private[this] def expand(g: GroupingSets): Seq[Seq[Expression]] = { - val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] - - g.bitmasks.foreach { bitmask => - // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupExprs = ArrayBuffer.empty[Expression] - var bit = g.groupByExprs.length - 1 - while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit) - bit -= 1 - } - - val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown { - case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined => - // if the input attribute in the Invalid Grouping Expression set of for this group - // replace it with constant null - Literal.create(null, expr.dataType) - case x if x == g.gid => - // replace the groupingId with concrete value (the bit mask) - Literal.create(bitmask, IntegerType) - }) - - result += substitution - } - - result.toSeq - } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a: Cube if a.resolved => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid) - case a: Rollup if a.resolved => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid) - case x: GroupingSets if x.resolved => + case a: Cube => + GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) + case a: Rollup => + GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) + case x: GroupingSets => + val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() Aggregate( - x.groupByExprs :+ x.gid, + x.groupByExprs :+ VirtualColumn.groupingIdAttribute, x.aggregations, - Expand(expand(x), x.child.output :+ x.gid, x.child)) + Expand(x.bitmasks, x.groupByExprs, gid, x.child)) } } @@ -368,12 +336,7 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressionsUp { - case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 && - resolver(nameParts(0), VirtualColumn.groupingIdName) && - q.isInstanceOf[GroupingAnalytics] => - // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics - q.asInstanceOf[GroupingAnalytics].gid + q transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 58dbeaf89cad5..9cacdceb13837 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -262,5 +262,5 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E object VirtualColumn { val groupingIdName: String = "grouping__id" - def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)() + val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9132a786f77a7..98b4476076854 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -121,6 +121,10 @@ object UnionPushdown extends Rule[LogicalPlan] { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child)) + if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references))) + // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b009a200b920f..a505fa5df34e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -207,7 +207,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { if (resolver(attribute.name, nameParts.head)) { - Option((attribute.withName(nameParts.head), nameParts.tail.toList)) + Option((attribute, nameParts.tail.toList)) } else { None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 7814e51628db6..fae339808c233 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.OpenHashSet case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) @@ -228,24 +229,76 @@ case class Window( /** * Apply the all of the GroupExpressions to every input row, hence we will get * multiple output rows for a input row. - * @param projections The group of expressions, all of the group expressions should - * output the same schema specified by the parameter `output` - * @param output The output Schema + * @param bitmasks The bitmask set represents the grouping sets + * @param groupByExprs The grouping by expressions * @param child Child operator */ case class Expand( - projections: Seq[Seq[Expression]], - output: Seq[Attribute], + bitmasks: Seq[Int], + groupByExprs: Seq[Expression], + gid: Attribute, child: LogicalPlan) extends UnaryNode { override def statistics: Statistics = { val sizeInBytes = child.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } + + val projections: Seq[Seq[Expression]] = expand() + + /** + * Extract attribute set according to the grouping id + * @param bitmask bitmask to represent the selected of the attribute sequence + * @param exprs the attributes in sequence + * @return the attributes of non selected specified via bitmask (with the bit set to 1) + */ + private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) + : OpenHashSet[Expression] = { + val set = new OpenHashSet[Expression](2) + + var bit = exprs.length - 1 + while (bit >= 0) { + if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) + bit -= 1 + } + + set + } + + /** + * Create an array of Projections for the child projection, and replace the projections' + * expressions which equal GroupBy expressions with Literal(null), if those expressions + * are not set for this grouping set (according to the bit mask). + */ + private[this] def expand(): Seq[Seq[Expression]] = { + val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] + + bitmasks.foreach { bitmask => + // get the non selected grouping attributes according to the bit mask + val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) + + val substitution = (child.output :+ gid).map(expr => expr transformDown { + case x: Expression if nonSelectedGroupExprSet.contains(x) => + // if the input attribute in the Invalid Grouping Expression set of for this group + // replace it with constant null + Literal.create(null, expr.dataType) + case x if x == gid => + // replace the groupingId with concrete value (the bit mask) + Literal.create(bitmask, IntegerType) + }) + + result += substitution + } + + result.toSeq + } + + override def output: Seq[Attribute] = { + child.output :+ gid + } } trait GroupingAnalytics extends UnaryNode { self: Product => - def gid: AttributeReference def groupByExprs: Seq[Expression] def aggregations: Seq[NamedExpression] @@ -266,17 +319,12 @@ trait GroupingAnalytics extends UnaryNode { * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions - * @param gid The attribute represents the virtual column GROUPING__ID, and it's also - * the bitmask indicates the selected GroupBy Expressions for each - * aggregating output row. - * The associated output will be one of the value in `bitmasks` */ case class GroupingSets( bitmasks: Seq[Int], groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + aggregations: Seq[NamedExpression]) extends GroupingAnalytics { def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = this.copy(aggregations = aggs) @@ -290,15 +338,11 @@ case class GroupingSets( * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions - * @param gid The attribute represents the virtual column GROUPING__ID, and it's also - * the bitmask indicates the selected GroupBy Expressions for each - * aggregating output row. */ case class Cube( groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + aggregations: Seq[NamedExpression]) extends GroupingAnalytics { def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = this.copy(aggregations = aggs) @@ -313,15 +357,11 @@ case class Cube( * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions - * @param gid The attribute represents the virtual column GROUPING__ID, and it's also - * the bitmask indicates the selected GroupBy Expressions for each - * aggregating output row. */ case class Rollup( groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + aggregations: Seq[NamedExpression]) extends GroupingAnalytics { def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = this.copy(aggregations = aggs) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5c420eb9d761f..1ff1cc224de8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -308,8 +308,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case logical.Expand(projections, output, child) => - execution.Expand(projections, output, planLater(child)) :: Nil + case e @ logical.Expand(_, _, _, child) => + execution.Expand(e.projections, e.output, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Window(projectList, windowExpressions, spec, child) =>