diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index fee7010e8e033..66e99ded24886 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -164,7 +164,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // If an aggregation needs a shuffle and support partial aggregations, a map-side partial // aggregation and a shuffle are added as children. val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator) - (mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil) + (mergeAgg, createShuffleExchange( + requiredChildDistributions.head, ensureDistributionAndOrdering(mapSideAgg)) :: Nil) case _ => // Ensure that the operator's children satisfy their output distribution requirements: val childrenWithDist = operator.children.zip(requiredChildDistributions) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 07efc72bf6296..b0aa3378e5f66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, DataFrame, Row} +import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.aggregate.SortAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} @@ -70,6 +71,25 @@ class PlannerSuite extends SharedSQLContext { s"The plan of query $query does not have partial aggregations.") } + test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") { + withTempView("testSortBasedPartialAggregation") { + val schema = StructType( + StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil) + val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString))) + spark.createDataFrame(rowRDD, schema) + .createOrReplaceTempView("testSortBasedPartialAggregation") + + // This test assumes a query below uses sort-based aggregations + val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key") + .queryExecution.executedPlan + // This line extracts both SortAggregate and Sort operators + val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n } + val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n } + assert(extractedOps.size == 4 && aggOps.size == 2, + s"The plan $planned does not have correct sort-based partial aggregate pairs.") + } + } + test("non-partial aggregation for aggregates") { withTempView("testNonPartialAggregation") { val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)