Skip to content

Commit

Permalink
[SPARK-10389] [SQL] support order by non-attribute grouping expressio…
Browse files Browse the repository at this point in the history
…n on Aggregate

For example, we can write `SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1` in PostgreSQL, and we should support this in Spark SQL.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes apache#8548 from cloud-fan/support-order-by-non-attribute.
  • Loading branch information
cloud-fan authored and markhamstra committed Oct 9, 2015
1 parent 9a625f3 commit 5a10e10
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -576,43 +576,47 @@ class Analyzer(
filter
}

case sort @ Sort(sortOrder, global,
aggregate @ Aggregate(grouping, originalAggExprs, child))
case sort @ Sort(sortOrder, global, aggregate: Aggregate)
if aggregate.resolved && !sort.resolved =>

// Try resolving the ordering as though it is in the aggregate clause.
try {
val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child)
val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions

// Expressions that have an aggregate can be pushed down.
val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate)

// Attribute references, that are missing from the order but are present in the grouping
// expressions can also be pushed down.
val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
val missingAttributes = requiredAttributes -- aggregate.outputSet
val validPushdownAttributes =
missingAttributes.filter(a => grouping.exists(a.semanticEquals))

// If resolution was successful and we see the ordering either has an aggregate in it or
// it is missing something that is projected away by the aggregate, add the ordering
// the original aggregate operator.
if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) {
val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map {
case (order, evaluated) => order.copy(child = evaluated.toAttribute)
}
val aggExprsWithOrdering: Seq[NamedExpression] =
resolvedAggregateOrdering ++ originalAggExprs

Project(aggregate.output,
Sort(evaluatedOrderings, global,
aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
} else {
sort
val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
val resolvedAliasedOrdering: Seq[Alias] =
resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]

// If we pass the analysis check, then the ordering expressions should only reference to
// aggregate expressions or grouping expressions, and it's safe to push them down to
// Aggregate.
checkAnalysis(resolvedAggregate)

val originalAggExprs = aggregate.aggregateExpressions.map(
CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])

// If the ordering expression is same with original aggregate expression, we don't need
// to push down this ordering expression and can reference the original aggregate
// expression instead.
val needsPushDown = ArrayBuffer.empty[NamedExpression]
val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
case (evaluated, order) =>
val index = originalAggExprs.indexWhere {
case Alias(child, _) => child semanticEquals evaluated.child
case other => other semanticEquals evaluated.child
}

if (index == -1) {
needsPushDown += evaluated
order.copy(child = evaluated.toAttribute)
} else {
order.copy(child = originalAggExprs(index).toAttribute)
}
}

Project(aggregate.output,
Sort(evaluatedOrderings, global,
aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
Expand All @@ -621,9 +625,7 @@ class Analyzer(
}

protected def containsAggregate(condition: Expression): Boolean = {
condition
.collect { case ae: AggregateExpression => ae }
.nonEmpty
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
}

Expand Down
19 changes: 15 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1712,10 +1712,21 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}

test("SPARK-10130 type coercion for IF should have children resolved first") {
val df = Seq((1, 1), (-1, 1)).toDF("key", "value")
df.registerTempTable("src")
checkAnswer(
sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
withTempTable("src") {
Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
checkAnswer(
sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
}
}

test("SPARK-10389: order by non-attribute grouping expression on Aggregate") {
withTempTable("src") {
Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"),
Seq(Row(1), Row(1)))
checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"),
Seq(Row(1), Row(1)))
}
}

test("SortMergeJoin returns wrong results when using UnsafeRows") {
Expand Down

0 comments on commit 5a10e10

Please sign in to comment.