diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index fff32c8e1dd6f..d87ddcc0ecab6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -49,9 +49,9 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(partAttrs, relation)) => + case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(_, attrs, filters, rel)) => // We only apply this optimization when only partitioned attributes are scanned. - if (a.references.subsetOf(partAttrs)) { + if (a.references.subsetOf(attrs)) { val aggFunctions = aggExprs.flatMap(_.collect { case agg: AggregateExpression => agg }) @@ -72,7 +72,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic "This could result in wrong results when the partition metadata exists but the " + "inclusive data files are empty." ) - a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, relation))) + a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, rel, filters))) } else { a } @@ -103,14 +103,27 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic */ private def replaceTableScanWithPartitionMetadata( child: LogicalPlan, - relation: LogicalPlan): LogicalPlan = { + relation: LogicalPlan, + partFilters: Seq[Expression]): LogicalPlan = { + // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the + // relation's schema. PartitionedRelation ensures that the filters only reference partition cols + val relFilters = partFilters.map { e => + e transform { + case a: AttributeReference => + a.withName(relation.output.find(_.semanticEquals(a)).get.name) + } + } + child transform { case plan if plan eq relation => relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(Nil, Nil) - LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) + val partitionData = fsRelation.location.listFiles(relFilters, Nil) + // partition data may be a stream, which can cause serialization to hit stack level too + // deep exceptions because it is a recursive structure in memory. converting to array + // avoids the problem. + LocalRelation(partAttrs, partitionData.map(_.values).toArray, isStreaming) case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) @@ -118,12 +131,21 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic CaseInsensitiveMap(relation.tableMeta.storage.properties) val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(SQLConf.get.sessionLocalTimeZone) - val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => + val partitions = if (partFilters.nonEmpty) { + catalog.listPartitionsByFilter(relation.tableMeta.identifier, relFilters) + } else { + catalog.listPartitions(relation.tableMeta.identifier) + } + + val partitionData = partitions.map { p => InternalRow.fromSeq(partAttrs.map { attr => Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } - LocalRelation(partAttrs, partitionData) + // partition data may be a stream, which can cause serialization to hit stack level too + // deep exceptions because it is a recursive structure in memory. converting to array + // avoids the problem. + LocalRelation(partAttrs, partitionData.toArray) case _ => throw new IllegalStateException(s"unrecognized table scan node: $relation, " + @@ -134,35 +156,47 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic /** * A pattern that finds the partitioned table relation node inside the given plan, and returns a - * pair of the partition attributes and the table relation node. + * pair of the partition attributes, partition filters, and the table relation node. * * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with * deterministic expressions, and returns result after reaching the partitioned table relation * node. */ - object PartitionedRelation { - - def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = plan match { - case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) - if fsRelation.partitionSchema.nonEmpty => - val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - Some((AttributeSet(partAttrs), l)) - - case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => - val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) - Some((AttributeSet(partAttrs), relation)) - - case p @ Project(projectList, child) if projectList.forall(_.deterministic) => - unapply(child).flatMap { case (partAttrs, relation) => - if (p.references.subsetOf(partAttrs)) Some((p.outputSet, relation)) else None - } + object PartitionedRelation extends PredicateHelper { + + def unapply( + plan: LogicalPlan): Option[(AttributeSet, AttributeSet, Seq[Expression], LogicalPlan)] = { + plan match { + case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) + if fsRelation.partitionSchema.nonEmpty => + val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)) + Some((partAttrs, partAttrs, Nil, l)) + + case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => + val partAttrs = AttributeSet( + getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)) + Some((partAttrs, partAttrs, Nil, relation)) + + case p @ Project(projectList, child) if projectList.forall(_.deterministic) => + unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => + if (p.references.subsetOf(attrs)) { + Some((partAttrs, p.outputSet, filters, relation)) + } else { + None + } + } - case f @ Filter(condition, child) if condition.deterministic => - unapply(child).flatMap { case (partAttrs, relation) => - if (f.references.subsetOf(partAttrs)) Some((partAttrs, relation)) else None - } + case f @ Filter(condition, child) if condition.deterministic => + unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => + if (f.references.subsetOf(partAttrs)) { + Some((partAttrs, attrs, splitConjunctivePredicates(condition) ++ filters, relation)) + } else { + None + } + } - case _ => None + case _ => None + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala new file mode 100644 index 0000000000000..95f192f0e40e2 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleton + with BeforeAndAfter with SQLTestUtils { + + import spark.implicits._ + + before { + sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)") + (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)")) + } + + test("SPARK-23877: validate metadata-only query pushes filters to metastore") { + withTable("metadata_only") { + val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount + + // verify the number of matching partitions + assert(sql("SELECT DISTINCT part FROM metadata_only WHERE part < 5").collect().length === 5) + + // verify that the partition predicate was pushed down to the metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount === 5) + } + } + + test("SPARK-23877: filter on projected expression") { + withTable("metadata_only") { + val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount + + // verify the matching partitions + val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr, + Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]), + spark.table("metadata_only").logicalPlan.asInstanceOf[SubqueryAlias].child))) + .queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType)))) + + checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x")) + + // verify that the partition predicate was not pushed down to the metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11) + } + } +}