forked from apache/spark
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-19020][SQL] Cardinality estimation of aggregate operator
## What changes were proposed in this pull request? Support cardinality estimation of aggregate operator ## How was this patch tested? Add test cases Author: Zhenhua Wang <wzh_zju@163.com> Author: wangzhenhua <wangzhenhua@huawei.com> Closes apache#16431 from wzhfy/aggEstimation.
- Loading branch information
Showing
4 changed files
with
198 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
...ala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation | ||
|
||
import org.apache.spark.sql.catalyst.expressions.Attribute | ||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} | ||
|
||
|
||
object AggregateEstimation { | ||
import EstimationUtils._ | ||
|
||
/** | ||
* Estimate the number of output rows based on column stats of group-by columns, and propagate | ||
* column stats for aggregate expressions. | ||
*/ | ||
def estimate(agg: Aggregate): Option[Statistics] = { | ||
val childStats = agg.child.statistics | ||
// Check if we have column stats for all group-by columns. | ||
val colStatsExist = agg.groupingExpressions.forall { e => | ||
e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) | ||
} | ||
if (rowCountsExist(agg.child) && colStatsExist) { | ||
// Multiply distinct counts of group-by columns. This is an upper bound, which assumes | ||
// the data contains all combinations of distinct values of group-by columns. | ||
var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( | ||
(res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount) | ||
|
||
// Here we set another upper bound for the number of output rows: it must not be larger than | ||
// child's number of rows. | ||
outputRows = outputRows.min(childStats.rowCount.get) | ||
|
||
val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output) | ||
Some(Statistics( | ||
sizeInBytes = outputRows * getRowSize(agg.output, outputAttrStats), | ||
rowCount = Some(outputRows), | ||
attributeStats = outputAttrStats, | ||
isBroadcastable = childStats.isBroadcastable)) | ||
} else { | ||
None | ||
} | ||
} | ||
} |
135 changes: 135 additions & 0 deletions
135
...yst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.statsEstimation | ||
|
||
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal} | ||
import org.apache.spark.sql.catalyst.expressions.aggregate.Count | ||
import org.apache.spark.sql.catalyst.plans.logical._ | ||
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ | ||
|
||
|
||
class AggEstimationSuite extends StatsEstimationTestBase { | ||
|
||
/** Columns for testing */ | ||
private val columnInfo: Map[Attribute, ColumnStat] = | ||
Map( | ||
attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, | ||
avgLen = 4, maxLen = 4), | ||
attr("key12") -> ColumnStat(distinctCount = 1, min = Some(10), max = Some(10), nullCount = 0, | ||
avgLen = 4, maxLen = 4), | ||
attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, | ||
avgLen = 4, maxLen = 4), | ||
attr("key22") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0, | ||
avgLen = 4, maxLen = 4), | ||
attr("key31") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, | ||
avgLen = 4, maxLen = 4), | ||
attr("key32") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0, | ||
avgLen = 4, maxLen = 4)) | ||
|
||
private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) | ||
private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = | ||
columnInfo.map(kv => kv._1.name -> kv) | ||
|
||
test("empty group-by column") { | ||
val colNames = Seq("key11", "key12") | ||
// Suppose table1 has 2 records: (1, 10), (2, 10) | ||
val table1 = StatsTestPlan( | ||
outputList = colNames.map(nameToAttr), | ||
stats = Statistics( | ||
sizeInBytes = 2 * (4 + 4), | ||
rowCount = Some(2), | ||
attributeStats = AttributeMap(colNames.map(nameToColInfo)))) | ||
|
||
checkAggStats( | ||
child = table1, | ||
colNames = Nil, | ||
expectedRowCount = 1) | ||
} | ||
|
||
test("there's a primary key in group-by columns") { | ||
val colNames = Seq("key11", "key12") | ||
// Suppose table1 has 2 records: (1, 10), (2, 10) | ||
val table1 = StatsTestPlan( | ||
outputList = colNames.map(nameToAttr), | ||
stats = Statistics( | ||
sizeInBytes = 2 * (4 + 4), | ||
rowCount = Some(2), | ||
attributeStats = AttributeMap(colNames.map(nameToColInfo)))) | ||
|
||
checkAggStats( | ||
child = table1, | ||
colNames = colNames, | ||
// Column key11 a primary key, so row count = ndv of key11 = child's row count | ||
expectedRowCount = table1.stats.rowCount.get) | ||
} | ||
|
||
test("the product of ndv's of group-by columns is too large") { | ||
val colNames = Seq("key21", "key22") | ||
// Suppose table2 has 4 records: (1, 10), (1, 20), (2, 30), (2, 40) | ||
val table2 = StatsTestPlan( | ||
outputList = colNames.map(nameToAttr), | ||
stats = Statistics( | ||
sizeInBytes = 4 * (4 + 4), | ||
rowCount = Some(4), | ||
attributeStats = AttributeMap(colNames.map(nameToColInfo)))) | ||
|
||
checkAggStats( | ||
child = table2, | ||
colNames = colNames, | ||
// Use child's row count as an upper bound | ||
expectedRowCount = table2.stats.rowCount.get) | ||
} | ||
|
||
test("data contains all combinations of distinct values of group-by columns.") { | ||
val colNames = Seq("key31", "key32") | ||
// Suppose table3 has 6 records: (1, 10), (1, 10), (1, 20), (2, 20), (2, 10), (2, 10) | ||
val table3 = StatsTestPlan( | ||
outputList = colNames.map(nameToAttr), | ||
stats = Statistics( | ||
sizeInBytes = 6 * (4 + 4), | ||
rowCount = Some(6), | ||
attributeStats = AttributeMap(colNames.map(nameToColInfo)))) | ||
|
||
checkAggStats( | ||
child = table3, | ||
colNames = colNames, | ||
// Row count = product of ndv | ||
expectedRowCount = nameToColInfo("key31")._2.distinctCount * nameToColInfo("key32")._2 | ||
.distinctCount) | ||
} | ||
|
||
private def checkAggStats( | ||
child: LogicalPlan, | ||
colNames: Seq[String], | ||
expectedRowCount: BigInt): Unit = { | ||
|
||
val columns = colNames.map(nameToAttr) | ||
val testAgg = Aggregate( | ||
groupingExpressions = columns, | ||
aggregateExpressions = columns :+ Alias(Count(Literal(1)), "cnt")(), | ||
child = child) | ||
|
||
val expectedAttrStats = AttributeMap(colNames.map(nameToColInfo)) | ||
val expectedStats = Statistics( | ||
sizeInBytes = expectedRowCount * getRowSize(testAgg.output, expectedAttrStats), | ||
rowCount = Some(expectedRowCount), | ||
attributeStats = expectedAttrStats) | ||
|
||
assert(testAgg.statistics == expectedStats) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters