From 08989f01e1082c20e80e6c67afe9b7ccd0465e7c Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 20 Jan 2022 12:13:00 +0800 Subject: [PATCH] [SPARK-37839][SQL] DS V2 supports partial aggregate push-down `AVG` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? `max`,`min`,`count`,`sum`,`avg` are the most commonly used aggregation functions. Currently, DS V2 supports complete aggregate push-down of `avg`. But, supports partial aggregate push-down of `avg` is very useful. The aggregate push-down algorithm is: 1. Spark translates group expressions of `Aggregate` to DS V2 `Aggregation`. 2. Spark calls `supportCompletePushDown` to check if it can completely push down aggregate. 3. If `supportCompletePushDown` returns true, we preserves the aggregate expressions as final aggregate expressions. Otherwise, we split `AVG` into 2 functions: `SUM` and `COUNT`. 4. Spark translates final aggregate expressions and group expressions of `Aggregate` to DS V2 `Aggregation` again, and pushes the `Aggregation` to JDBC source. 5. Spark constructs the final aggregate. ### Why are the changes needed? DS V2 supports partial aggregate push-down `AVG` ### Does this PR introduce _any_ user-facing change? 'Yes'. DS V2 could partial aggregate push-down `AVG` ### How was this patch tested? New tests. Closes #35130 from beliefer/SPARK-37839. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../connector/expressions/aggregate/Avg.java | 49 ++++++++ .../aggregate/GeneralAggregateFunc.java | 1 - .../expressions/aggregate/Average.scala | 2 +- .../datasources/DataSourceStrategy.scala | 29 ++++- .../datasources/v2/PushDownUtils.scala | 40 +------ .../v2/V2ScanRelationPushDown.scala | 108 ++++++++++++++---- .../apache/spark/sql/jdbc/JdbcDialects.scala | 11 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 82 ++++++++++++- 8 files changed, 250 insertions(+), 72 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java new file mode 100644 index 0000000000000..5e10ec9ee1644 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the mean of all the values in a group. + * + * @since 3.3.0 + */ +@Evolving +public final class Avg implements AggregateFunc { + private final NamedReference column; + private final boolean isDistinct; + + public Avg(NamedReference column, boolean isDistinct) { + this.column = column; + this.isDistinct = isDistinct; + } + + public NamedReference column() { return column; } + public boolean isDistinct() { return isDistinct; } + + @Override + public String toString() { + if (isDistinct) { + return "AVG(DISTINCT " + column.describe() + ")"; + } else { + return "AVG(" + column.describe() + ")"; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index 32615e201643b..0ff26c8875b7a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -31,7 +31,6 @@ *

* The currently supported SQL aggregate functions: *

    - *
  1. AVG(input1)
    Since 3.3.0
  2. *
  3. VAR_POP(input1)
    Since 3.3.0
  4. *
  5. VAR_SAMP(input1)
    Since 3.3.0
  6. *
  7. STDDEV_POP(input1)
    Since 3.3.0
  8. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 9714a096a69a2..05f7edaeb5d48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -69,7 +69,7 @@ case class Average( case _ => DoubleType } - private lazy val sumDataType = child.dataType match { + lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) case _: YearMonthIntervalType => YearMonthIntervalType() case _: DayTimeIntervalType => DayTimeIntervalType() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 990a00ca918fb..1934ef9f03228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -717,7 +717,7 @@ object DataSourceStrategy case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => Some(new Sum(FieldReference(name), aggregates.isDistinct)) case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc("AVG", aggregates.isDistinct, Array(FieldReference(name)))) + Some(new Avg(FieldReference(name), aggregates.isDistinct)) case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => Some(new GeneralAggregateFunc("VAR_POP", aggregates.isDistinct, Array(FieldReference(name)))) case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) => @@ -746,6 +746,31 @@ object DataSourceStrategy } } + /** + * Translate aggregate expressions and group by expressions. + * + * @return translated aggregation. + */ + protected[sql] def translateAggregation( + aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = { + + def columnAsString(e: Expression): Option[FieldReference] = e match { + case PushableColumnWithoutNestedColumn(name) => + Some(FieldReference.column(name).asInstanceOf[FieldReference]) + case _ => None + } + + val translatedAggregates = aggregates.flatMap(translateAggregate) + val translatedGroupBys = groupBy.flatMap(columnAsString) + + if (translatedAggregates.length != aggregates.length || + translatedGroupBys.length != groupBy.length) { + return None + } + + Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)) + } + protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = { def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match { case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 2b26eee45221d..b54917e49ed3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder} -import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType @@ -103,38 +101,6 @@ object PushDownUtils extends PredicateHelper { } } - /** - * Pushes down aggregates to the data source reader - * - * @return pushed aggregation. - */ - def pushAggregates( - scanBuilder: ScanBuilder, - aggregates: Seq[AggregateExpression], - groupBy: Seq[Expression]): Option[Aggregation] = { - - def columnAsString(e: Expression): Option[FieldReference] = e match { - case PushableColumnWithoutNestedColumn(name) => - Some(FieldReference(name).asInstanceOf[FieldReference]) - case _ => None - } - - scanBuilder match { - case r: SupportsPushDownAggregates if aggregates.nonEmpty => - val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate) - val translatedGroupBys = groupBy.flatMap(columnAsString) - - if (translatedAggregates.length != aggregates.length || - translatedGroupBys.length != groupBy.length) { - return None - } - - val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray) - Some(agg).filter(r.pushAggregation) - case _ => None - } - } - /** * Pushes down TableSample to the data source Scan */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 67002e50e4680..05857c545cdf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,18 +19,18 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.SortOrder -import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, GeneralAggregateFunc} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType} import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { @@ -86,27 +86,68 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) => sHolder.builder match { - case _: SupportsPushDownAggregates => + case r: SupportsPushDownAggregates => val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - var ordinal = 0 - val aggregates = resultExpressions.flatMap { expr => - expr.collect { - // Do not push down duplicated aggregate expressions. For example, - // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one - // `max(a)` to the data source. - case agg: AggregateExpression - if !aggExprToOutputOrdinal.contains(agg.canonicalized) => - aggExprToOutputOrdinal(agg.canonicalized) = ordinal - ordinal += 1 - agg - } - } + val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal) val normalizedAggregates = DataSourceStrategy.normalizeExprs( aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( groupingExpressions, sHolder.relation.output) - val pushedAggregates = PushDownUtils.pushAggregates( - sHolder.builder, normalizedAggregates, normalizedGroupingExpressions) + val translatedAggregates = DataSourceStrategy.translateAggregation( + normalizedAggregates, normalizedGroupingExpressions) + val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { + if (translatedAggregates.isEmpty || + r.supportCompletePushDown(translatedAggregates.get) || + translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) { + (resultExpressions, aggregates, translatedAggregates) + } else { + // scalastyle:off + // The data source doesn't support the complete push-down of this aggregation. + // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be + // pushed, completely or partially. + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT avg(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // + // After convert avg(c1#9) to sum(c1#9)/count(c1#9) + // we have the following + // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // scalastyle:on + val newResultExpressions = resultExpressions.map { expr => + expr.transform { + case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => + val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) + val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) + // Closely follow `Average.evaluateExpression` + avg.dataType match { + case _: YearMonthIntervalType => + If(EqualTo(count, Literal(0L)), + Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count)) + case _: DayTimeIntervalType => + If(EqualTo(count, Literal(0L)), + Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count)) + case _ => + // TODO deal with the overflow issue + Divide(addCastIfNeeded(sum, avg.dataType), + addCastIfNeeded(count, avg.dataType), false) + } + } + }.asInstanceOf[Seq[NamedExpression]] + // Because aggregate expressions changed, translate them again. + aggExprToOutputOrdinal.clear() + val newAggregates = + collectAggregates(newResultExpressions, aggExprToOutputOrdinal) + val newNormalizedAggregates = DataSourceStrategy.normalizeExprs( + newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] + (newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation( + newNormalizedAggregates, normalizedGroupingExpressions)) + } + } + + val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation) if (pushedAggregates.isEmpty) { aggNode // return original plan node } else if (!supportPartialAggPushDown(pushedAggregates.get) && @@ -129,7 +170,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] // scalastyle:on val newOutput = scan.readSchema().toAttributes - assert(newOutput.length == groupingExpressions.length + aggregates.length) + assert(newOutput.length == groupingExpressions.length + finalAggregates.length) val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) case (_, b) => b @@ -164,7 +205,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { Project(projectExpressions, scanRelation) } else { val plan = Aggregate( - output.take(groupingExpressions.length), resultExpressions, scanRelation) + output.take(groupingExpressions.length), finalResultExpressions, scanRelation) // scalastyle:off // Change the optimized logical plan to reflect the pushed down aggregate @@ -210,16 +251,33 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def collectAggregates(resultExpressions: Seq[NamedExpression], + aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = { + var ordinal = 0 + resultExpressions.flatMap { expr => + expr.collect { + // Do not push down duplicated aggregate expressions. For example, + // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one + // `max(a)` to the data source. + case agg: AggregateExpression + if !aggExprToOutputOrdinal.contains(agg.canonicalized) => + aggExprToOutputOrdinal(agg.canonicalized) = ordinal + ordinal += 1 + agg + } + } + } + private def supportPartialAggPushDown(agg: Aggregation): Boolean = { // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc]) } - private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) = - if (aggAttribute.dataType == aggDataType) { - aggAttribute + private def addCastIfNeeded(expression: Expression, expectedDataType: DataType) = + if (expression.dataType == expectedDataType) { + expression } else { - Cast(aggAttribute, aggDataType) + Cast(expression, expectedDataType) } def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index e516960bb6746..7456b390c616e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -217,10 +217,11 @@ abstract class JdbcDialect extends Serializable with Logging{ Some(s"SUM($distinct$column)") case _: CountStar => Some("COUNT(*)") - case f: GeneralAggregateFunc if f.name() == "AVG" => - assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"AVG($distinct${f.inputs().head})") + case avg: Avg => + if (avg.column.fieldNames.length != 1) return None + val distinct = if (avg.isDistinct) "DISTINCT " else "" + val column = quoteIdentifier(avg.column.fieldNames.head) + Some(s"AVG($distinct$column)") case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 72dde8fa13222..637e01c260c99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.{lit, sum, udf} +import org.apache.spark.sql.functions.{avg, count, lit, sum, udf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -831,4 +831,84 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row("alex", 1), Row("amy", 1), Row("cathy", 1), Row("david", 1), Row("jen", 1))) } + + test("scan with aggregate push-down: complete push-down SUM, AVG, COUNT") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "1") + .table("h2.test.employee") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "1") + .table("h2.test.employee") + .groupBy($"name") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df2, Seq( + Row("alex", 12000.00, 12000.000000, 1), + Row("amy", 10000.00, 10000.000000, 1), + Row("cathy", 9000.00, 9000.000000, 1), + Row("david", 10000.00, 10000.000000, 1), + Row("jen", 12000.00, 12000.000000, 1))) + } + + test("scan with aggregate push-down: partial push-down SUM, AVG, COUNT") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy($"name") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df2, Seq( + Row("alex", 12000.00, 12000.000000, 1), + Row("amy", 10000.00, 10000.000000, 1), + Row("cathy", 9000.00, 9000.000000, 1), + Row("david", 10000.00, 10000.000000, 1), + Row("jen", 12000.00, 12000.000000, 1))) + } }