Skip to content

Commit

Permalink
[SPARK-7235] [SQL] Refactor the grouping sets
Browse files Browse the repository at this point in the history
The logical plan `Expand` takes the `output` as constructor argument, which break the references chain. We need to refactor the code, as well as the column pruning.

Author: Cheng Hao <hao.cheng@intel.com>

Closes apache#5780 from chenghao-intel/expand and squashes the following commits:

76e4aa4 [Cheng Hao] revert the change for case insenstive
7c10a83 [Cheng Hao] refactor the grouping sets
  • Loading branch information
chenghao-intel authored and marmbrus committed Jun 23, 2015
1 parent 4f7fbef commit 7b1450b
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,49 +192,17 @@ class Analyzer(
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
}

/**
* Create an array of Projections for the child projection, and replace the projections'
* expressions which equal GroupBy expressions with Literal(null), if those expressions
* are not set for this grouping set (according to the bit mask).
*/
private[this] def expand(g: GroupingSets): Seq[Seq[Expression]] = {
val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]

g.bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprs = ArrayBuffer.empty[Expression]
var bit = g.groupByExprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit)
bit -= 1
}

val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
case x if x == g.gid =>
// replace the groupingId with concrete value (the bit mask)
Literal.create(bitmask, IntegerType)
})

result += substitution
}

result.toSeq
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a: Cube if a.resolved =>
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
case a: Rollup if a.resolved =>
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
case x: GroupingSets if x.resolved =>
case a: Cube =>
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
case a: Rollup =>
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
case x: GroupingSets =>
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
Aggregate(
x.groupByExprs :+ x.gid,
x.groupByExprs :+ VirtualColumn.groupingIdAttribute,
x.aggregations,
Expand(expand(x), x.child.output :+ x.gid, x.child))
Expand(x.bitmasks, x.groupByExprs, gid, x.child))
}
}

Expand Down Expand Up @@ -368,12 +336,7 @@ class Analyzer(

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 &&
resolver(nameParts(0), VirtualColumn.groupingIdName) &&
q.isInstanceOf[GroupingAnalytics] =>
// Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
q.asInstanceOf[GroupingAnalytics].gid
q transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,5 +262,5 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E

object VirtualColumn {
val groupingIdName: String = "grouping__id"
def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)()
val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName)
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ object UnionPushdown extends Rule[LogicalPlan] {
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child))
if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty =>
a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references)))

// Eliminate attributes that are not needed to calculate the specified aggregates.
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
a.copy(child = Project(a.references.toSeq, child))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet

case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
Expand Down Expand Up @@ -228,24 +229,76 @@ case class Window(
/**
* Apply the all of the GroupExpressions to every input row, hence we will get
* multiple output rows for a input row.
* @param projections The group of expressions, all of the group expressions should
* output the same schema specified by the parameter `output`
* @param output The output Schema
* @param bitmasks The bitmask set represents the grouping sets
* @param groupByExprs The grouping by expressions
* @param child Child operator
*/
case class Expand(
projections: Seq[Seq[Expression]],
output: Seq[Attribute],
bitmasks: Seq[Int],
groupByExprs: Seq[Expression],
gid: Attribute,
child: LogicalPlan) extends UnaryNode {
override def statistics: Statistics = {
val sizeInBytes = child.statistics.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}

val projections: Seq[Seq[Expression]] = expand()

/**
* Extract attribute set according to the grouping id
* @param bitmask bitmask to represent the selected of the attribute sequence
* @param exprs the attributes in sequence
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
: OpenHashSet[Expression] = {
val set = new OpenHashSet[Expression](2)

var bit = exprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
bit -= 1
}

set
}

/**
* Create an array of Projections for the child projection, and replace the projections'
* expressions which equal GroupBy expressions with Literal(null), if those expressions
* are not set for this grouping set (according to the bit mask).
*/
private[this] def expand(): Seq[Seq[Expression]] = {
val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]

bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)

val substitution = (child.output :+ gid).map(expr => expr transformDown {
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
case x if x == gid =>
// replace the groupingId with concrete value (the bit mask)
Literal.create(bitmask, IntegerType)
})

result += substitution
}

result.toSeq
}

override def output: Seq[Attribute] = {
child.output :+ gid
}
}

trait GroupingAnalytics extends UnaryNode {
self: Product =>
def gid: AttributeReference
def groupByExprs: Seq[Expression]
def aggregations: Seq[NamedExpression]

Expand All @@ -266,17 +319,12 @@ trait GroupingAnalytics extends UnaryNode {
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
* @param gid The attribute represents the virtual column GROUPING__ID, and it's also
* the bitmask indicates the selected GroupBy Expressions for each
* aggregating output row.
* The associated output will be one of the value in `bitmasks`
*/
case class GroupingSets(
bitmasks: Seq[Int],
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
aggregations: Seq[NamedExpression]) extends GroupingAnalytics {

def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
Expand All @@ -290,15 +338,11 @@ case class GroupingSets(
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
* @param gid The attribute represents the virtual column GROUPING__ID, and it's also
* the bitmask indicates the selected GroupBy Expressions for each
* aggregating output row.
*/
case class Cube(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
aggregations: Seq[NamedExpression]) extends GroupingAnalytics {

def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
Expand All @@ -313,15 +357,11 @@ case class Cube(
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
* @param gid The attribute represents the virtual column GROUPING__ID, and it's also
* the bitmask indicates the selected GroupBy Expressions for each
* aggregating output row.
*/
case class Rollup(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
aggregations: Seq[NamedExpression]) extends GroupingAnalytics {

def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Project(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case logical.Expand(projections, output, child) =>
execution.Expand(projections, output, planLater(child)) :: Nil
case e @ logical.Expand(_, _, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Window(projectList, windowExpressions, spec, child) =>
Expand Down

0 comments on commit 7b1450b

Please sign in to comment.