Skip to content

Commit

Permalink
[SPARK-1442] [SQL] Window Function Support for Spark SQL
Browse files Browse the repository at this point in the history
Adding more information about the implementation...

This PR is adding the support of window functions to Spark SQL (specifically OVER and WINDOW clause). For every expression having a OVER clause, we use a WindowExpression as the container of a WindowFunction and the corresponding WindowSpecDefinition (the definition of a window frame, i.e. partition specification, order specification, and frame specification appearing in a OVER clause).
# Implementation #
The high level work flow of the implementation is described as follows.

*	Query parsing: In the query parse process, all WindowExpressions are originally placed in the projectList of a Project operator or the aggregateExpressions of an Aggregate operator. It makes our changes to simple and keep all of parsing rules for window functions at a single place (nodesToWindowSpecification). For the WINDOWclause in a query, we use a WithWindowDefinition as the container as the mapping from the name of a window specification to a WindowSpecDefinition. This changes is similar with our common table expression support.

*	Analysis: The query analysis process has three steps for window functions.

 *	Resolve all WindowSpecReferences by replacing them with WindowSpecReferences according to the mapping table stored in the node of WithWindowDefinition.
 *	Resolve WindowFunctions in the projectList of a Project operator or the aggregateExpressions of an Aggregate operator. For this PR, we use Hive's functions for window functions because we will have a major refactoring of our internal UDAFs and it is better to switch our UDAFs after that refactoring work.
 *	Once we have resolved all WindowFunctions, we will use ResolveWindowFunction to extract WindowExpressions from projectList and aggregateExpressions and then create a Window operator for every distinct WindowSpecDefinition. With this choice, at the execution time, we can rely on the Exchange operator to do all of work on reorganizing the table and we do not need to worry about it in the physical Window operator. An example analyzed plan is shown as follows

```
sql("""
SELECT
  year, country, product, sales,
  avg(sales) over(partition by product) avg_product,
  sum(sales) over(partition by country) sum_country
FROM sales
ORDER BY year, country, product
""").explain(true)

== Analyzed Logical Plan ==
Sort [year#34 ASC,country#35 ASC,product#36 ASC], true
 Project [year#34,country#35,product#36,sales#37,avg_product#27,sum_country#28]
  Window [year#34,country#35,product#36,sales#37,avg_product#27], [HiveWindowFunction#org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum(sales#37) WindowSpecDefinition [country#35], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS sum_country#28], WindowSpecDefinition [country#35], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
   Window [year#34,country#35,product#36,sales#37], [HiveWindowFunction#org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage(sales#37) WindowSpecDefinition [product#36], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS avg_product#27], WindowSpecDefinition [product#36], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
    Project [year#34,country#35,product#36,sales#37]
     MetastoreRelation default, sales, None
```

*	Query planning: In the process of query planning, we simple generate the physical Window operator based on the logical Window operator. Then, to prepare the executedPlan, the EnsureRequirements rule will add Exchange and Sort operators if necessary. The EnsureRequirements rule will analyze the data properties and try to not add unnecessary shuffle and sort. The physical plan for the above example query is shown below.

```
== Physical Plan ==
Sort [year#34 ASC,country#35 ASC,product#36 ASC], true
 Exchange (RangePartitioning [year#34 ASC,country#35 ASC,product#36 ASC], 200), []
  Window [year#34,country#35,product#36,sales#37,avg_product#27], [HiveWindowFunction#org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum(sales#37) WindowSpecDefinition [country#35], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS sum_country#28], WindowSpecDefinition [country#35], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
   Exchange (HashPartitioning [country#35], 200), [country#35 ASC]
    Window [year#34,country#35,product#36,sales#37], [HiveWindowFunction#org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage(sales#37) WindowSpecDefinition [product#36], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS avg_product#27], WindowSpecDefinition [product#36], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
     Exchange (HashPartitioning [product#36], 200), [product#36 ASC]
      HiveTableScan [year#34,country#35,product#36,sales#37], (MetastoreRelation default, sales, None), None
```

*	Execution time: At execution time, a physical Window operator buffers all rows in a partition specified in the partition spec of a OVER clause. If necessary, it also maintains a sliding window frame. The current implementation tries to buffer the input parameters of a window function according to the window frame to avoid evaluating a row multiple times.

# Future work #

Here are three improvements that are not hard to add:
*	Taking advantage of the window frame specification to reduce the number of rows buffered in the physical Window operator. For some cases, we only need to buffer the rows appearing in the sliding window. But for other cases, we will not be able to reduce the number of rows buffered (e.g. ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING).

*	When aRAGEN frame is used, for <value> PRECEDING and <value> FOLLOWING, it will be great if the <value> part is an expression (we can start with Literal). So, when the data type of ORDER BY expression is a FractionalType, we can support FractionalType as the type <value> (<value> still needs to be evaluated as a positive value).

*	When aRAGEN frame is used, we need to support DateType and TimestampType as the data type of the expression appearing in the order specification. Then, the <value> part of <value> PRECEDING and <value> FOLLOWING can support interval types (once we support them).

This is a joint work with guowei2 and yhuai
Thanks hbutani hvanhovell for his comments
Thanks scwf for his comments and unit tests

Author: Yin Huai <yhuai@databricks.com>

Closes apache#5604 from guowei2/windowImplement and squashes the following commits:

76fe1c8 [Yin Huai] Implementation.
aa2b0ae [Yin Huai] Tests.
  • Loading branch information
yhuai authored and nemccarthy committed Jun 19, 2015
1 parent bbfcc37 commit 92492f3
Show file tree
Hide file tree
Showing 101 changed files with 34,768 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -61,6 +63,7 @@ class Analyzer(
ResolveGenerate ::
ImplicitGenerate ::
ResolveFunctions ::
ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
TrimGroupingAliases ::
Expand Down Expand Up @@ -529,6 +532,203 @@ class Analyzer(
makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
}
}

/**
* Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and
* aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]]
* operators for every distinct [[WindowSpecDefinition]].
*
* This rule handles three cases:
* - A [[Project]] having [[WindowExpression]]s in its projectList;
* - An [[Aggregate]] having [[WindowExpression]]s in its aggregateExpressions.
* - An [[Filter]]->[[Aggregate]] pattern representing GROUP BY with a HAVING
* clause and the [[Aggregate]] has [[WindowExpression]]s in its aggregateExpressions.
* Note: If there is a GROUP BY clause in the query, aggregations and corresponding
* filters (expressions in the HAVING clause) should be evaluated before any
* [[WindowExpression]]. If a query has SELECT DISTINCT, the DISTINCT part should be
* evaluated after all [[WindowExpression]]s.
*
* For every case, the transformation works as follows:
* 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions
* it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for
* all regular expressions.
* 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s.
* 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts
* it into the plan tree.
*/
object ExtractWindowExpressions extends Rule[LogicalPlan] {
def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
projectList.exists(hasWindowFunction)

def hasWindowFunction(expr: NamedExpression): Boolean = {
expr.find {
case window: WindowExpression => true
case _ => false
}.isDefined
}

/**
* From a Seq of [[NamedExpression]]s, extract window expressions and
* other regular expressions.
*/
def extract(
expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = {
// First, we simple partition the input expressions to two part, one having
// WindowExpressions and another one without WindowExpressions.
val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction)

// Then, we need to extract those regular expressions used in the WindowExpression.
// For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5),
// we need to make sure that col1 to col5 are all projected from the child of the Window
// operator.
val extractedExprBuffer = new ArrayBuffer[NamedExpression]()
def extractExpr(expr: Expression): Expression = expr match {
case ne: NamedExpression =>
// If a named expression is not in regularExpressions, add extract it and replace it
// with an AttributeReference.
val missingExpr =
AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer)
if (missingExpr.nonEmpty) {
extractedExprBuffer += ne
}
ne.toAttribute
case e: Expression if e.foldable =>
e // No need to create an attribute reference if it will be evaluated as a Literal.
case e: Expression =>
// For other expressions, we extract it and replace it with an AttributeReference (with
// an interal column name, e.g. "_w0").
val withName = Alias(e, s"_w${extractedExprBuffer.length}")()
extractedExprBuffer += withName
withName.toAttribute
}

// Now, we extract expressions from windowExpressions by using extractExpr.
val newWindowExpressions = windowExpressions.map {
_.transform {
// Extracts children expressions of a WindowFunction (input parameters of
// a WindowFunction).
case wf : WindowFunction =>
val newChildren = wf.children.map(extractExpr(_))
wf.withNewChildren(newChildren)

// Extracts expressions from the partition spec and order spec.
case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) =>
val newPartitionSpec = partitionSpec.map(extractExpr(_))
val newOrderSpec = orderSpec.map { so =>
val newChild = extractExpr(so.child)
so.copy(child = newChild)
}
wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec)

// Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...),
// we need to extract SUM(x).
case agg: AggregateExpression =>
val withName = Alias(agg, s"_w${extractedExprBuffer.length}")()
extractedExprBuffer += withName
withName.toAttribute
}.asInstanceOf[NamedExpression]
}

(newWindowExpressions, regularExpressions ++ extractedExprBuffer)
}

/**
* Adds operators for Window Expressions. Every Window operator handles a single Window Spec.
*/
def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = {
// First, we group window expressions based on their Window Spec.
val groupedWindowExpression = windowExpressions.groupBy { expr =>
val windowExpression = expr.find {
case window: WindowExpression => true
case other => false
}.map(_.asInstanceOf[WindowExpression].windowSpec)
windowExpression.getOrElse(
failAnalysis(s"$windowExpressions does not have any WindowExpression."))
}.toSeq

// For every Window Spec, we add a Window operator and set currentChild as the child of it.
var currentChild = child
var i = 0
while (i < groupedWindowExpression.size) {
val (windowSpec, windowExpressions) = groupedWindowExpression(i)
// Set currentChild to the newly created Window operator.
currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild)

// Move to next WindowExpression.
i += 1
}

// We return the top operator.
currentChild
}

// We have to use transformDown at here to make sure the rule of
// "Aggregate with Having clause" will be triggered.
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
// Lookup WindowSpecDefinitions. This rule works with unresolved children.
case WithWindowDefinition(windowDefinitions, child) =>
child.transform {
case plan => plan.transformExpressions {
case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) =>
val errorMessage =
s"Window specification $windowName is not defined in the WINDOW clause."
val windowSpecDefinition =
windowDefinitions
.get(windowName)
.getOrElse(failAnalysis(errorMessage))
WindowExpression(c, windowSpecDefinition)
}
}

// Aggregate with Having clause. This rule works with an unresolved Aggregate because
// a resolved Aggregate will not have Window Functions.
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
!a.expressions.exists(!_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
// Add a Filter operator for conditions in the Having clause.
val withFilter = Filter(condition, withAggregate)
val withWindow = addWindow(windowExpressions, withFilter)

// Finally, generate output columns according to the original projectList.
val finalProjectList = aggregateExprs.map (_.toAttribute)
Project(finalProjectList, withWindow)

case p: LogicalPlan if !p.childrenResolved => p

// Aggregate without Having clause.
case a @ Aggregate(groupingExprs, aggregateExprs, child)
if hasWindowFunction(aggregateExprs) &&
!a.expressions.exists(!_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
// Add Window operators.
val withWindow = addWindow(windowExpressions, withAggregate)

// Finally, generate output columns according to the original projectList.
val finalProjectList = aggregateExprs.map (_.toAttribute)
Project(finalProjectList, withWindow)

// We only extract Window Expressions after all expressions of the Project
// have been resolved.
case p @ Project(projectList, child)
if hasWindowFunction(projectList) && !p.expressions.exists(!_.resolved) =>
val (windowExpressions, regularExpressions) = extract(projectList)
// We add a project to get all needed expressions for window expressions from the child
// of the original Project operator.
val withProject = Project(regularExpressions, child)
// Add Window operators.
val withWindow = addWindow(windowExpressions, withProject)

// Finally, generate output columns according to the original projectList.
val finalProjectList = projectList.map (_.toAttribute)
Project(finalProjectList, withWindow)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ trait CheckAnalysis {
failAnalysis(
s"invalid expression ${b.prettyString} " +
s"between ${b.left.simpleString} and ${b.right.simpleString}")

case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty =>
// The window spec is not valid.
val reason = windowSpec.validate.get
failAnalysis(s"Window specification $windowSpec is not valid because $reason")
}

operator match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,97 @@ class JoinedRow5 extends Row {
}
}
}

/**
* JIT HACK: Replace with macros
*/
class JoinedRow6 extends Row {
private[this] var row1: Row = _
private[this] var row2: Row = _

def this(left: Row, right: Row) = {
this()
row1 = left
row2 = right
}

/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
def apply(r1: Row, r2: Row): Row = {
row1 = r1
row2 = r2
this
}

/** Updates this JoinedRow by updating its left base row. Returns itself. */
def withLeft(newLeft: Row): Row = {
row1 = newLeft
this
}

/** Updates this JoinedRow by updating its right base row. Returns itself. */
def withRight(newRight: Row): Row = {
row2 = newRight
this
}

override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq

override def length: Int = row1.length + row2.length

override def apply(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)

override def isNullAt(i: Int): Boolean =
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)

override def getInt(i: Int): Int =
if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)

override def getLong(i: Int): Long =
if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)

override def getDouble(i: Int): Double =
if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)

override def getBoolean(i: Int): Boolean =
if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)

override def getShort(i: Int): Short =
if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)

override def getByte(i: Int): Byte =
if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)

override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)

override def getString(i: Int): String =
if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)

override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)

override def copy(): Row = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
copiedValues(i) = apply(i)
i += 1
}
new GenericRow(copiedValues)
}

override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
} else if (row1 eq null) {
row2.mkString("[", ",", "]")
} else if (row2 eq null) {
row1.mkString("[", ",", "]")
} else {
mkString("[", ",", "]")
}
}
}
Loading

0 comments on commit 92492f3

Please sign in to comment.