Skip to content

Commit

Permalink
[SPARK-17289][SQL] Fix a bug to satisfy sort requirements in partial …
Browse files Browse the repository at this point in the history
…aggregations

## What changes were proposed in this pull request?
Partial aggregations are generated in `EnsureRequirements`, but the planner fails to
check if partial aggregation satisfies sort requirements.
For the following query:
```
val df2 = (0 to 1000).map(x => (x % 2, x.toString)).toDF("a", "b").createOrReplaceTempView("t2")
spark.sql("select max(b) from t2 group by a").explain(true)
```
Now, the SortAggregator won't insert Sort operator before partial aggregation, this will break sort-based partial aggregation.
```
== Physical Plan ==
SortAggregate(key=[a#5], functions=[max(b#6)], output=[max(b)alteryx#17])
+- *Sort [a#5 ASC], false, 0
   +- Exchange hashpartitioning(a#5, 200)
      +- SortAggregate(key=[a#5], functions=[partial_max(b#6)], output=[a#5, max#19])
         +- LocalTableScan [a#5, b#6]
```
Actually, a correct plan is:
```
== Physical Plan ==
SortAggregate(key=[a#5], functions=[max(b#6)], output=[max(b)alteryx#17])
+- *Sort [a#5 ASC], false, 0
   +- Exchange hashpartitioning(a#5, 200)
      +- SortAggregate(key=[a#5], functions=[partial_max(b#6)], output=[a#5, max#19])
         +- *Sort [a#5 ASC], false, 0
            +- LocalTableScan [a#5, b#6]
```

## How was this patch tested?
Added tests in `PlannerSuite`.

Author: Takeshi YAMAMURO <linguin.m.s@gmail.com>

Closes apache#14865 from maropu/SPARK-17289.
  • Loading branch information
maropu authored and liancheng committed Aug 30, 2016
1 parent 8fb445d commit 94922d7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
// If an aggregation needs a shuffle and support partial aggregations, a map-side partial
// aggregation and a shuffle are added as children.
val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator)
(mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil)
(mergeAgg, createShuffleExchange(
requiredChildDistributions.head, ensureDistributionAndOrdering(mapSideAgg)) :: Nil)
case _ =>
// Ensure that the operator's children satisfy their output distribution requirements:
val childrenWithDist = operator.children.zip(requiredChildDistributions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, DataFrame, Row}
import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.aggregate.SortAggregateExec
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
Expand Down Expand Up @@ -70,6 +71,25 @@ class PlannerSuite extends SharedSQLContext {
s"The plan of query $query does not have partial aggregations.")
}

test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") {
withTempView("testSortBasedPartialAggregation") {
val schema = StructType(
StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil)
val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString)))
spark.createDataFrame(rowRDD, schema)
.createOrReplaceTempView("testSortBasedPartialAggregation")

// This test assumes a query below uses sort-based aggregations
val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key")
.queryExecution.executedPlan
// This line extracts both SortAggregate and Sort operators
val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n }
val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n }
assert(extractedOps.size == 4 && aggOps.size == 2,
s"The plan $planned does not have correct sort-based partial aggregate pairs.")
}
}

test("non-partial aggregation for aggregates") {
withTempView("testNonPartialAggregation") {
val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
Expand Down

0 comments on commit 94922d7

Please sign in to comment.