Skip to content

Commit

Permalink
[SPARK-19020][SQL] Cardinality estimation of aggregate operator
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Support cardinality estimation of aggregate operator

## How was this patch tested?

Add test cases

Author: Zhenhua Wang <wzh_zju@163.com>
Author: wangzhenhua <wangzhenhua@huawei.com>

Closes apache#16431 from wzhfy/aggEstimation.
  • Loading branch information
wzhfy authored and cmonkey committed Feb 15, 2017
1 parent 5e5b3ad commit 25b6ca0
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 3 deletions.
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ProjectEstimation
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, ProjectEstimation}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -495,7 +495,7 @@ case class Aggregate(
child.constraints.union(getAliasedConstraints(nonAgg))
}

override lazy val statistics: Statistics = {
override lazy val statistics: Statistics = AggregateEstimation.estimate(this).getOrElse {
if (groupingExpressions.isEmpty) {
super.statistics.copy(sizeInBytes = 1)
} else {
Expand Down
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.plans.logical.statsEstimation

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics}


object AggregateEstimation {
import EstimationUtils._

/**
* Estimate the number of output rows based on column stats of group-by columns, and propagate
* column stats for aggregate expressions.
*/
def estimate(agg: Aggregate): Option[Statistics] = {
val childStats = agg.child.statistics
// Check if we have column stats for all group-by columns.
val colStatsExist = agg.groupingExpressions.forall { e =>
e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute])
}
if (rowCountsExist(agg.child) && colStatsExist) {
// Multiply distinct counts of group-by columns. This is an upper bound, which assumes
// the data contains all combinations of distinct values of group-by columns.
var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(
(res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount)

// Here we set another upper bound for the number of output rows: it must not be larger than
// child's number of rows.
outputRows = outputRows.min(childStats.rowCount.get)

val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output)
Some(Statistics(
sizeInBytes = outputRows * getRowSize(agg.output, outputAttrStats),
rowCount = Some(outputRows),
attributeStats = outputAttrStats,
isBroadcastable = childStats.isBroadcastable))
} else {
None
}
}
}
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._


class AggEstimationSuite extends StatsEstimationTestBase {

/** Columns for testing */
private val columnInfo: Map[Attribute, ColumnStat] =
Map(
attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
avgLen = 4, maxLen = 4),
attr("key12") -> ColumnStat(distinctCount = 1, min = Some(10), max = Some(10), nullCount = 0,
avgLen = 4, maxLen = 4),
attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
avgLen = 4, maxLen = 4),
attr("key22") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0,
avgLen = 4, maxLen = 4),
attr("key31") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
avgLen = 4, maxLen = 4),
attr("key32") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0,
avgLen = 4, maxLen = 4))

private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
columnInfo.map(kv => kv._1.name -> kv)

test("empty group-by column") {
val colNames = Seq("key11", "key12")
// Suppose table1 has 2 records: (1, 10), (2, 10)
val table1 = StatsTestPlan(
outputList = colNames.map(nameToAttr),
stats = Statistics(
sizeInBytes = 2 * (4 + 4),
rowCount = Some(2),
attributeStats = AttributeMap(colNames.map(nameToColInfo))))

checkAggStats(
child = table1,
colNames = Nil,
expectedRowCount = 1)
}

test("there's a primary key in group-by columns") {
val colNames = Seq("key11", "key12")
// Suppose table1 has 2 records: (1, 10), (2, 10)
val table1 = StatsTestPlan(
outputList = colNames.map(nameToAttr),
stats = Statistics(
sizeInBytes = 2 * (4 + 4),
rowCount = Some(2),
attributeStats = AttributeMap(colNames.map(nameToColInfo))))

checkAggStats(
child = table1,
colNames = colNames,
// Column key11 a primary key, so row count = ndv of key11 = child's row count
expectedRowCount = table1.stats.rowCount.get)
}

test("the product of ndv's of group-by columns is too large") {
val colNames = Seq("key21", "key22")
// Suppose table2 has 4 records: (1, 10), (1, 20), (2, 30), (2, 40)
val table2 = StatsTestPlan(
outputList = colNames.map(nameToAttr),
stats = Statistics(
sizeInBytes = 4 * (4 + 4),
rowCount = Some(4),
attributeStats = AttributeMap(colNames.map(nameToColInfo))))

checkAggStats(
child = table2,
colNames = colNames,
// Use child's row count as an upper bound
expectedRowCount = table2.stats.rowCount.get)
}

test("data contains all combinations of distinct values of group-by columns.") {
val colNames = Seq("key31", "key32")
// Suppose table3 has 6 records: (1, 10), (1, 10), (1, 20), (2, 20), (2, 10), (2, 10)
val table3 = StatsTestPlan(
outputList = colNames.map(nameToAttr),
stats = Statistics(
sizeInBytes = 6 * (4 + 4),
rowCount = Some(6),
attributeStats = AttributeMap(colNames.map(nameToColInfo))))

checkAggStats(
child = table3,
colNames = colNames,
// Row count = product of ndv
expectedRowCount = nameToColInfo("key31")._2.distinctCount * nameToColInfo("key32")._2
.distinctCount)
}

private def checkAggStats(
child: LogicalPlan,
colNames: Seq[String],
expectedRowCount: BigInt): Unit = {

val columns = colNames.map(nameToAttr)
val testAgg = Aggregate(
groupingExpressions = columns,
aggregateExpressions = columns :+ Alias(Count(Literal(1)), "cnt")(),
child = child)

val expectedAttrStats = AttributeMap(colNames.map(nameToColInfo))
val expectedStats = Statistics(
sizeInBytes = expectedRowCount * getRowSize(testAgg.output, expectedAttrStats),
rowCount = Some(expectedRowCount),
attributeStats = expectedAttrStats)

assert(testAgg.statistics == expectedStats)
}
}
Expand Up @@ -18,12 +18,15 @@
package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.types.IntegerType


class StatsEstimationTestBase extends SparkFunSuite {

def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()

/** Convert (column name, column stat) pairs to an AttributeMap based on plan output. */
def toAttributeMap(colStats: Seq[(String, ColumnStat)], plan: LogicalPlan)
: AttributeMap[ColumnStat] = {
Expand Down

0 comments on commit 25b6ca0

Please sign in to comment.