Skip to content

Commit

Permalink
[SPARK-37839][SQL] DS V2 supports partial aggregate push-down AVG
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
`max`,`min`,`count`,`sum`,`avg` are the most commonly used aggregation functions.
Currently, DS V2 supports complete aggregate push-down of `avg`. But, supports partial aggregate push-down of `avg` is very useful.

The aggregate push-down algorithm is:

1. Spark translates group expressions of `Aggregate` to DS V2 `Aggregation`.
2. Spark calls `supportCompletePushDown` to check if it can completely push down aggregate.
3. If `supportCompletePushDown` returns true, we preserves the aggregate expressions as final aggregate expressions. Otherwise, we split `AVG` into 2 functions: `SUM` and `COUNT`.
4. Spark translates final aggregate expressions and group expressions of `Aggregate` to DS V2 `Aggregation` again, and pushes the `Aggregation` to JDBC source.
5. Spark constructs the final aggregate.

### Why are the changes needed?
DS V2 supports partial aggregate push-down `AVG`

### Does this PR introduce _any_ user-facing change?
'Yes'. DS V2 could partial aggregate push-down `AVG`

### How was this patch tested?
New tests.

Closes apache#35130 from beliefer/SPARK-37839.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
beliefer authored and chenzhx committed Mar 30, 2022
1 parent 0e50f11 commit 08989f0
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 72 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.expressions.aggregate;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.NamedReference;

/**
* An aggregate function that returns the mean of all the values in a group.
*
* @since 3.3.0
*/
@Evolving
public final class Avg implements AggregateFunc {
private final NamedReference column;
private final boolean isDistinct;

public Avg(NamedReference column, boolean isDistinct) {
this.column = column;
this.isDistinct = isDistinct;
}

public NamedReference column() { return column; }
public boolean isDistinct() { return isDistinct; }

@Override
public String toString() {
if (isDistinct) {
return "AVG(DISTINCT " + column.describe() + ")";
} else {
return "AVG(" + column.describe() + ")";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
* <p>
* The currently supported SQL aggregate functions:
* <ol>
* <li><pre>AVG(input1)</pre> Since 3.3.0</li>
* <li><pre>VAR_POP(input1)</pre> Since 3.3.0</li>
* <li><pre>VAR_SAMP(input1)</pre> Since 3.3.0</li>
* <li><pre>STDDEV_POP(input1)</pre> Since 3.3.0</li>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ case class Average(
case _ => DoubleType
}

private lazy val sumDataType = child.dataType match {
lazy val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case _: YearMonthIntervalType => YearMonthIntervalType()
case _: DayTimeIntervalType => DayTimeIntervalType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
Expand Down Expand Up @@ -717,7 +717,7 @@ object DataSourceStrategy
case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
Some(new Sum(FieldReference(name), aggregates.isDistinct))
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("AVG", aggregates.isDistinct, Array(FieldReference(name))))
Some(new Avg(FieldReference(name), aggregates.isDistinct))
case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("VAR_POP", aggregates.isDistinct, Array(FieldReference(name))))
case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) =>
Expand Down Expand Up @@ -746,6 +746,31 @@ object DataSourceStrategy
}
}

/**
* Translate aggregate expressions and group by expressions.
*
* @return translated aggregation.
*/
protected[sql] def translateAggregation(
aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = {

def columnAsString(e: Expression): Option[FieldReference] = e match {
case PushableColumnWithoutNestedColumn(name) =>
Some(FieldReference.column(name).asInstanceOf[FieldReference])
case _ => None
}

val translatedAggregates = aggregates.flatMap(translateAggregate)
val translatedGroupBys = groupBy.flatMap(columnAsString)

if (translatedAggregates.length != aggregates.length ||
translatedGroupBys.length != groupBy.length) {
return None
}

Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray))
}

protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = {
def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match {
case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -103,38 +101,6 @@ object PushDownUtils extends PredicateHelper {
}
}

/**
* Pushes down aggregates to the data source reader
*
* @return pushed aggregation.
*/
def pushAggregates(
scanBuilder: ScanBuilder,
aggregates: Seq[AggregateExpression],
groupBy: Seq[Expression]): Option[Aggregation] = {

def columnAsString(e: Expression): Option[FieldReference] = e match {
case PushableColumnWithoutNestedColumn(name) =>
Some(FieldReference(name).asInstanceOf[FieldReference])
case _ => None
}

scanBuilder match {
case r: SupportsPushDownAggregates if aggregates.nonEmpty =>
val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
val translatedGroupBys = groupBy.flatMap(columnAsString)

if (translatedAggregates.length != aggregates.length ||
translatedGroupBys.length != groupBy.length) {
return None
}

val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
Some(agg).filter(r.pushAggregation)
case _ => None
}
}

/**
* Pushes down TableSample to the data source Scan
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc}
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, GeneralAggregateFunc}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{DataType, LongType, StructType}
import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType}
import org.apache.spark.sql.util.SchemaUtils._

object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
Expand Down Expand Up @@ -86,27 +86,68 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
sHolder.builder match {
case _: SupportsPushDownAggregates =>
case r: SupportsPushDownAggregates =>
val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
var ordinal = 0
val aggregates = resultExpressions.flatMap { expr =>
expr.collect {
// Do not push down duplicated aggregate expressions. For example,
// `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
// `max(a)` to the data source.
case agg: AggregateExpression
if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
aggExprToOutputOrdinal(agg.canonicalized) = ordinal
ordinal += 1
agg
}
}
val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal)
val normalizedAggregates = DataSourceStrategy.normalizeExprs(
aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs(
groupingExpressions, sHolder.relation.output)
val pushedAggregates = PushDownUtils.pushAggregates(
sHolder.builder, normalizedAggregates, normalizedGroupingExpressions)
val translatedAggregates = DataSourceStrategy.translateAggregation(
normalizedAggregates, normalizedGroupingExpressions)
val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = {
if (translatedAggregates.isEmpty ||
r.supportCompletePushDown(translatedAggregates.get) ||
translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) {
(resultExpressions, aggregates, translatedAggregates)
} else {
// scalastyle:off
// The data source doesn't support the complete push-down of this aggregation.
// Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be
// pushed, completely or partially.
// e.g. TABLE t (c1 INT, c2 INT, c3 INT)
// SELECT avg(c1) FROM t GROUP BY c2;
// The original logical plan is
// Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19]
// +- ScanOperation[...]
//
// After convert avg(c1#9) to sum(c1#9)/count(c1#9)
// we have the following
// Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19]
// +- ScanOperation[...]
// scalastyle:on
val newResultExpressions = resultExpressions.map { expr =>
expr.transform {
case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) =>
val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct)
// Closely follow `Average.evaluateExpression`
avg.dataType match {
case _: YearMonthIntervalType =>
If(EqualTo(count, Literal(0L)),
Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count))
case _: DayTimeIntervalType =>
If(EqualTo(count, Literal(0L)),
Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count))
case _ =>
// TODO deal with the overflow issue
Divide(addCastIfNeeded(sum, avg.dataType),
addCastIfNeeded(count, avg.dataType), false)
}
}
}.asInstanceOf[Seq[NamedExpression]]
// Because aggregate expressions changed, translate them again.
aggExprToOutputOrdinal.clear()
val newAggregates =
collectAggregates(newResultExpressions, aggExprToOutputOrdinal)
val newNormalizedAggregates = DataSourceStrategy.normalizeExprs(
newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
(newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation(
newNormalizedAggregates, normalizedGroupingExpressions))
}
}

val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation)
if (pushedAggregates.isEmpty) {
aggNode // return original plan node
} else if (!supportPartialAggPushDown(pushedAggregates.get) &&
Expand All @@ -129,7 +170,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
// scalastyle:on
val newOutput = scan.readSchema().toAttributes
assert(newOutput.length == groupingExpressions.length + aggregates.length)
assert(newOutput.length == groupingExpressions.length + finalAggregates.length)
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
case (_, b) => b
Expand Down Expand Up @@ -164,7 +205,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
Project(projectExpressions, scanRelation)
} else {
val plan = Aggregate(
output.take(groupingExpressions.length), resultExpressions, scanRelation)
output.take(groupingExpressions.length), finalResultExpressions, scanRelation)

// scalastyle:off
// Change the optimized logical plan to reflect the pushed down aggregate
Expand Down Expand Up @@ -210,16 +251,33 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
}

private def collectAggregates(resultExpressions: Seq[NamedExpression],
aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = {
var ordinal = 0
resultExpressions.flatMap { expr =>
expr.collect {
// Do not push down duplicated aggregate expressions. For example,
// `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
// `max(a)` to the data source.
case agg: AggregateExpression
if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
aggExprToOutputOrdinal(agg.canonicalized) = ordinal
ordinal += 1
agg
}
}
}

private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
// We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down.
agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc])
}

private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) =
if (aggAttribute.dataType == aggDataType) {
aggAttribute
private def addCastIfNeeded(expression: Expression, expectedDataType: DataType) =
if (expression.dataType == expectedDataType) {
expression
} else {
Cast(aggAttribute, aggDataType)
Cast(expression, expectedDataType)
}

def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.connector.catalog.TableChange
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
Expand Down Expand Up @@ -217,10 +217,11 @@ abstract class JdbcDialect extends Serializable with Logging{
Some(s"SUM($distinct$column)")
case _: CountStar =>
Some("COUNT(*)")
case f: GeneralAggregateFunc if f.name() == "AVG" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"AVG($distinct${f.inputs().head})")
case avg: Avg =>
if (avg.column.fieldNames.length != 1) return None
val distinct = if (avg.isDistinct) "DISTINCT " else ""
val column = quoteIdentifier(avg.column.fieldNames.head)
Some(s"AVG($distinct$column)")
case _ => None
}
}
Expand Down
Loading

0 comments on commit 08989f0

Please sign in to comment.