Skip to content

Commit

Permalink
[SPARK-38997][SQL] DS V2 aggregate push-down supports group by expres…
Browse files Browse the repository at this point in the history
…sions

### What changes were proposed in this pull request?
Currently, Spark DS V2 aggregate push-down only supports group by column.
But the SQL show below is very useful and common.
```
SELECT
  CASE
    WHEN 'SALARY' > 8000.00
      AND 'SALARY' < 10000.00
    THEN 'SALARY'
    ELSE 0.00
  END AS key,
  SUM('SALARY')
FROM "test"."employee"
GROUP BY key
```

### Why are the changes needed?
Let DS V2 aggregate push-down supports group by expressions

### Does this PR introduce _any_ user-facing change?
'No'.
New feature.

### How was this patch tested?
New tests

Closes apache#36325 from beliefer/SPARK-38997.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
beliefer authored and chenzhx committed May 17, 2022
1 parent 4da5af0 commit 089e2ad
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.io.Serializable;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.Expression;

/**
* Aggregation in SQL statement.
Expand All @@ -30,14 +30,14 @@
@Evolving
public final class Aggregation implements Serializable {
private final AggregateFunc[] aggregateExpressions;
private final NamedReference[] groupByColumns;
private final Expression[] groupByExpressions;

public Aggregation(AggregateFunc[] aggregateExpressions, NamedReference[] groupByColumns) {
public Aggregation(AggregateFunc[] aggregateExpressions, Expression[] groupByExpressions) {
this.aggregateExpressions = aggregateExpressions;
this.groupByColumns = groupByColumns;
this.groupByExpressions = groupByExpressions;
}

public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; }

public NamedReference[] groupByColumns() { return groupByColumns; }
public Expression[] groupByExpressions() { return groupByExpressions; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ case class RowDataSourceScanExec(
"PushedFilters" -> pushedFilters) ++
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())),
"PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++
"PushedGroupByExpressions" -> seqToString(v.groupByExpressions.map(_.describe())))} ++
topNOrLimitInfo ++
pushedDownOperators.sample.map(v => "PushedSample" ->
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow}
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min}
import org.apache.spark.sql.execution.RowToColumnConverter
import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
Expand Down Expand Up @@ -82,23 +83,35 @@ object AggregatePushDownUtils {
}
}

if (aggregation.groupByColumns.nonEmpty || dataFilters.nonEmpty) {
if (dataFilters.nonEmpty) {
// Parquet/ORC footer has max/min/count for columns
// e.g. SELECT COUNT(col1) FROM t
// but footer doesn't have max/min/count for a column if max/min/count
// are combined with filter or group by
// e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8
// SELECT COUNT(col1) FROM t GROUP BY col2
// However, if the filter is on partition column, max/min/count can still be pushed down
// Todo: add support if groupby column is partition col
// (https://issues.apache.org/jira/browse/SPARK-36646)
return None
}
aggregation.groupByColumns.foreach { col =>

if (aggregation.groupByExpressions.nonEmpty &&
partitionNames.size != aggregation.groupByExpressions.length) {
// If there are group by columns, we only push down if the group by columns are the same as
// the partition columns. In theory, if group by columns are a subset of partition columns,
// we should still be able to push down. e.g. if table t has partition columns p1, p2, and p3,
// SELECT MAX(c) FROM t GROUP BY p1, p2 should still be able to push down. However, the
// partial aggregation pushed down to data source needs to be
// SELECT p1, p2, p3, MAX(c) FROM t GROUP BY p1, p2, p3, and Spark layer
// needs to have a final aggregation such as SELECT MAX(c) FROM t GROUP BY p1, p2, then the
// pushed down query schema is different from the query schema at Spark. We will keep
// aggregate push down simple and don't handle this complicate case for now.
return None
}
aggregation.groupByExpressions.map(extractColName).foreach { colName =>
// don't push down if the group by columns are not the same as the partition columns (orders
// doesn't matter because reorder can be done at data source layer)
if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) return None
finalSchema = finalSchema.add(getStructFieldForCol(col.fieldNames.head))
if (colName.isEmpty || !isPartitionCol(colName.get)) return None
finalSchema = finalSchema.add(getStructFieldForCol(colName.get))
}

aggregation.aggregateExpressions.foreach {
Expand All @@ -125,7 +138,8 @@ object AggregatePushDownUtils {
def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = {
a.aggregateExpressions.sortBy(_.hashCode())
.sameElements(b.aggregateExpressions.sortBy(_.hashCode())) &&
a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode()))
a.groupByExpressions.sortBy(_.hashCode())
.sameElements(b.groupByExpressions.sortBy(_.hashCode()))
}

/**
Expand All @@ -145,4 +159,49 @@ object AggregatePushDownUtils {
converter.convert(aggregatesAsRow, columnVectors.toArray)
new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1)
}

/**
* Return the schema for aggregates only (exclude group by columns)
*/
def getSchemaWithoutGroupingExpression(
aggSchema: StructType,
aggregation: Aggregation): StructType = {
val numOfGroupByColumns = aggregation.groupByExpressions.length
if (numOfGroupByColumns > 0) {
new StructType(aggSchema.fields.drop(numOfGroupByColumns))
} else {
aggSchema
}
}

/**
* Reorder partition cols if they are not in the same order as group by columns
*/
def reOrderPartitionCol(
partitionSchema: StructType,
aggregation: Aggregation,
partitionValues: InternalRow): InternalRow = {
val groupByColNames = aggregation.groupByExpressions.flatMap(extractColName)
assert(groupByColNames.length == partitionSchema.length &&
groupByColNames.length == partitionValues.numFields, "The number of group by columns " +
s"${groupByColNames.length} should be the same as partition schema length " +
s"${partitionSchema.length} and the number of fields ${partitionValues.numFields} " +
s"in partitionValues")
var reorderedPartColValues = Array.empty[Any]
if (!partitionSchema.names.sameElements(groupByColNames)) {
groupByColNames.foreach { col =>
val index = partitionSchema.names.indexOf(col)
val v = partitionValues.asInstanceOf[GenericInternalRow].values(index)
reorderedPartColValues = reorderedPartColValues :+ v
}
new GenericInternalRow(reorderedPartColValues)
} else {
partitionValues
}
}

private def extractColName(v2Expr: V2Expression): Option[String] = v2Expr match {
case f: FieldReference if f.fieldNames.length == 1 => Some(f.fieldNames.head)
case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -753,14 +753,13 @@ object DataSourceStrategy
protected[sql] def translateAggregation(
aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = {

def columnAsString(e: Expression): Option[FieldReference] = e match {
case PushableColumnWithoutNestedColumn(name) =>
Some(FieldReference(name).asInstanceOf[FieldReference])
def translateGroupBy(e: Expression): Option[V2Expression] = e match {
case PushableExpression(expr) => Some(expr)
case _ => None
}

val translatedAggregates = aggregates.flatMap(translateAggregate)
val translatedGroupBys = groupBy.flatMap(columnAsString)
val translatedGroupBys = groupBy.flatMap(translateGroupBy)

if (translatedAggregates.length != aggregates.length ||
translatedGroupBys.length != groupBy.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ object OrcUtils extends Logging {
partitionSchema: StructType,
aggregation: Aggregation,
aggSchema: StructType): InternalRow = {
require(aggregation.groupByColumns.length == 0,
require(aggregation.groupByExpressions.length == 0,
s"aggregate $aggregation with group-by column shouldn't be pushed down")
var columnsStatistics: OrcColumnStatistics = null
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
// scalastyle:on
val newOutput = scan.readSchema().toAttributes
assert(newOutput.length == groupingExpressions.length + finalAggregates.length)
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
case (_, b) => b
val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map {
case ((a: Attribute, b: Attribute), _) => b.withExprId(a.exprId)
case ((expr, attr), ordinal) =>
if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) {
groupByExprToOutputOrdinal(expr.canonicalized) = ordinal
}
attr
}
val aggOutput = newOutput.drop(groupAttrs.length)
val output = groupAttrs ++ aggOutput
Expand All @@ -188,7 +193,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
|Pushed Aggregate Functions:
| ${pushedAggregates.get.aggregateExpressions.mkString(", ")}
|Pushed Group by:
| ${pushedAggregates.get.groupByColumns.mkString(", ")}
| ${pushedAggregates.get.groupByExpressions.mkString(", ")}
|Output: ${output.mkString(", ")}
""".stripMargin)

Expand All @@ -197,14 +202,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
if (r.supportCompletePushDown(pushedAggregates.get)) {
val projectExpressions = finalResultExpressions.map { expr =>
// TODO At present, only push down group by attribute is supported.
// In future, more attribute conversion is extended here. e.g. GetStructField
expr.transform {
expr.transformDown {
case agg: AggregateExpression =>
val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
val child =
addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType)
Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId)
case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
val ordinal = groupByExprToOutputOrdinal(expr.canonicalized)
addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
}
}.asInstanceOf[Seq[NamedExpression]]
Project(projectExpressions, scanRelation)
Expand Down Expand Up @@ -247,6 +253,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
case other => other
}
agg.copy(aggregateFunction = aggFunction)
case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
val ordinal = groupByExprToOutputOrdinal(expr.canonicalized)
addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
Expand Down Expand Up @@ -70,12 +70,15 @@ case class JDBCScanBuilder(

private var pushedAggregateList: Array[String] = Array()

private var pushedGroupByCols: Option[Array[String]] = None
private var pushedGroupBys: Option[Array[String]] = None

override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
lazy val fieldNames = aggregation.groupByColumns()(0).fieldNames()
lazy val fieldNames = aggregation.groupByExpressions()(0) match {
case field: FieldReference => field.fieldNames
case _ => Array.empty[String]
}
jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) ||
(aggregation.groupByColumns().length == 1 && fieldNames.length == 1 &&
(aggregation.groupByExpressions().length == 1 && fieldNames.length == 1 &&
jdbcOptions.partitionColumn.exists(fieldNames(0).equalsIgnoreCase(_)))
}

Expand All @@ -86,28 +89,26 @@ case class JDBCScanBuilder(
val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate)
if (compiledAggs.length != aggregation.aggregateExpressions.length) return false

val groupByCols = aggregation.groupByColumns.map { col =>
if (col.fieldNames.length != 1) return false
dialect.quoteIdentifier(col.fieldNames.head)
}
val compiledGroupBys = aggregation.groupByExpressions.flatMap(dialect.compileExpression)
if (compiledGroupBys.length != aggregation.groupByExpressions.length) return false

// The column names here are already quoted and can be used to build sql string directly.
// e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") =>
// SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee"
// GROUP BY "DEPT", "NAME"
val selectList = groupByCols ++ compiledAggs
val groupByClause = if (groupByCols.isEmpty) {
val selectList = compiledGroupBys ++ compiledAggs
val groupByClause = if (compiledGroupBys.isEmpty) {
""
} else {
"GROUP BY " + groupByCols.mkString(",")
"GROUP BY " + compiledGroupBys.mkString(",")
}

val aggQuery = s"SELECT ${selectList.mkString(",")} FROM ${jdbcOptions.tableOrQuery} " +
s"WHERE 1=0 $groupByClause"
try {
finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, dialect)
pushedAggregateList = selectList
pushedGroupByCols = Some(groupByCols)
pushedGroupBys = Some(compiledGroupBys)
true
} catch {
case NonFatal(e) =>
Expand Down Expand Up @@ -173,6 +174,6 @@ case class JDBCScanBuilder(
// prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't
// be used in sql string.
JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedPredicate,
pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders)
pushedAggregateList, pushedGroupBys, tableSample, pushedLimit, sortOrders)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ case class OrcScan(

lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) {
(seqToString(pushedAggregate.get.aggregateExpressions),
seqToString(pushedAggregate.get.groupByColumns))
seqToString(pushedAggregate.get.groupByExpressions))
} else {
("[]", "[]")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ case class ParquetScan(

lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) {
(seqToString(pushedAggregate.get.aggregateExpressions),
seqToString(pushedAggregate.get.groupByColumns))
seqToString(pushedAggregate.get.groupByExpressions))
} else {
("[]", "[]")
}
Expand Down
Loading

0 comments on commit 089e2ad

Please sign in to comment.