diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala new file mode 100644 index 0000000000000..98a7190ba984e --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.spark.sql.FileScanSuiteBase +import org.apache.spark.sql.v2.avro.AvroScan + +class AvroScanSuite extends FileScanSuiteBase { + val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( + ("AvroScan", + (s, fi, ds, rds, rps, f, o, pf, df) => AvroScan(s, fi, ds, rds, rps, o, f, pf, df), + Seq.empty)) + + run(scanBuilders) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 363dd154b5fbb..ac63725b774d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -24,8 +24,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.internal.config.IO_WARNING_LARGEFILETHRESHOLD import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet} +import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, ExpressionSet} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan, Statistics, SupportsReportStatistics} import org.apache.spark.sql.execution.PartitionedFileUtil import org.apache.spark.sql.execution.datasources._ @@ -84,11 +85,24 @@ trait FileScan extends Scan protected def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") + private lazy val (normalizedPartitionFilters, normalizedDataFilters) = { + val output = readSchema().toAttributes + val partitionFilterAttributes = AttributeSet(partitionFilters).map(a => a.name -> a).toMap + val dataFiltersAttributes = AttributeSet(dataFilters).map(a => a.name -> a).toMap + val normalizedPartitionFilters = ExpressionSet(partitionFilters.map( + QueryPlan.normalizeExpressions(_, + output.map(a => partitionFilterAttributes.getOrElse(a.name, a))))) + val normalizedDataFilters = ExpressionSet(dataFilters.map( + QueryPlan.normalizeExpressions(_, + output.map(a => dataFiltersAttributes.getOrElse(a.name, a))))) + (normalizedPartitionFilters, normalizedDataFilters) + } + override def equals(obj: Any): Boolean = obj match { case f: FileScan => - fileIndex == f.fileIndex && readSchema == f.readSchema - ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters) && - ExpressionSet(dataFilters) == ExpressionSet(f.dataFilters) + fileIndex == f.fileIndex && readSchema == f.readSchema && + normalizedPartitionFilters == f.normalizedPartitionFilters && + normalizedDataFilters == f.normalizedDataFilters case _ => false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala new file mode 100644 index 0000000000000..4e7fe8455ff93 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.mutable + +import com.google.common.collect.ImmutableMap +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{And, Expression, IsNull, LessThan} +import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitionSpec} +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan +import org.apache.spark.sql.execution.datasources.v2.json.JsonScan +import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.execution.datasources.v2.text.TextScan +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +trait FileScanSuiteBase extends SharedSparkSession { + private def newPartitioningAwareFileIndex() = { + new PartitioningAwareFileIndex(spark, Map.empty, None) { + override def partitionSpec(): PartitionSpec = { + PartitionSpec.emptySpec + } + + override protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] = { + mutable.LinkedHashMap.empty + } + + override protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = { + Map.empty + } + + override def rootPaths: Seq[Path] = { + Seq.empty + } + + override def refresh(): Unit = {} + } + } + + type ScanBuilder = ( + SparkSession, + PartitioningAwareFileIndex, + StructType, + StructType, + StructType, + Array[Filter], + CaseInsensitiveStringMap, + Seq[Expression], + Seq[Expression]) => FileScan + + def run(scanBuilders: Seq[(String, ScanBuilder, Seq[String])]): Unit = { + val dataSchema = StructType.fromDDL("data INT, partition INT, other INT") + val dataSchemaNotEqual = StructType.fromDDL("data INT, partition INT, other INT, new INT") + val readDataSchema = StructType.fromDDL("data INT") + val readDataSchemaNotEqual = StructType.fromDDL("data INT, other INT") + val readPartitionSchema = StructType.fromDDL("partition INT") + val readPartitionSchemaNotEqual = StructType.fromDDL("partition INT, other INT") + val pushedFilters = + Array[Filter](sources.And(sources.IsNull("data"), sources.LessThan("data", 0))) + val pushedFiltersNotEqual = + Array[Filter](sources.And(sources.IsNull("data"), sources.LessThan("data", 1))) + val optionsMap = ImmutableMap.of("key", "value") + val options = new CaseInsensitiveStringMap(ImmutableMap.copyOf(optionsMap)) + val optionsNotEqual = + new CaseInsensitiveStringMap(ImmutableMap.copyOf(ImmutableMap.of("key2", "value2"))) + val partitionFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0))) + val partitionFiltersNotEqual = Seq(And(IsNull('data.int), LessThan('data.int, 1))) + val dataFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0))) + val dataFiltersNotEqual = Seq(And(IsNull('data.int), LessThan('data.int, 1))) + + scanBuilders.foreach { case (name, scanBuilder, exclusions) => + test(s"SPARK-33482: Test $name equals") { + val partitioningAwareFileIndex = newPartitioningAwareFileIndex() + + val scan = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFilters) + + val scanEquals = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema.copy(), + readDataSchema.copy(), + readPartitionSchema.copy(), + pushedFilters.clone(), + new CaseInsensitiveStringMap(ImmutableMap.copyOf(optionsMap)), + Seq(partitionFilters: _*), + Seq(dataFilters: _*)) + + assert(scan === scanEquals) + } + + test(s"SPARK-33482: Test $name fileIndex not equals") { + val partitioningAwareFileIndex = newPartitioningAwareFileIndex() + + val scan = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFilters) + + val partitioningAwareFileIndexNotEqual = newPartitioningAwareFileIndex() + + val scanNotEqual = scanBuilder( + spark, + partitioningAwareFileIndexNotEqual, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFilters) + + assert(scan !== scanNotEqual) + } + + if (!exclusions.contains("dataSchema")) { + test(s"SPARK-33482: Test $name dataSchema not equals") { + val partitioningAwareFileIndex = newPartitioningAwareFileIndex() + + val scan = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFilters) + + val scanNotEqual = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchemaNotEqual, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFilters) + + assert(scan !== scanNotEqual) + } + } + + test(s"SPARK-33482: Test $name readDataSchema not equals") { + val partitioningAwareFileIndex = newPartitioningAwareFileIndex() + + val scan = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFilters) + + val scanNotEqual = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchemaNotEqual, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFilters) + + assert(scan !== scanNotEqual) + } + + test(s"SPARK-33482: Test $name readPartitionSchema not equals") { + val partitioningAwareFileIndex = newPartitioningAwareFileIndex() + + val scan = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFilters) + + val scanNotEqual = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchemaNotEqual, + pushedFilters, + options, + partitionFilters, + dataFilters) + + assert(scan !== scanNotEqual) + } + + if (!exclusions.contains("pushedFilters")) { + test(s"SPARK-33482: Test $name pushedFilters not equals") { + val partitioningAwareFileIndex = newPartitioningAwareFileIndex() + + val scan = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFilters) + + val scanNotEqual = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFiltersNotEqual, + options, + partitionFilters, + dataFilters) + + assert(scan !== scanNotEqual) + } + } + + test(s"SPARK-33482: Test $name options not equals") { + val partitioningAwareFileIndex = newPartitioningAwareFileIndex() + + val scan = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFilters) + + val scanNotEqual = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + optionsNotEqual, + partitionFilters, + dataFilters) + + assert(scan !== scanNotEqual) + } + + test(s"SPARK-33482: Test $name partitionFilters not equals") { + val partitioningAwareFileIndex = newPartitioningAwareFileIndex() + + val scan = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFilters) + + val scanNotEqual = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFiltersNotEqual, + dataFilters) + assert(scan !== scanNotEqual) + } + + test(s"SPARK-33482: Test $name dataFilters not equals") { + val partitioningAwareFileIndex = newPartitioningAwareFileIndex() + + val scan = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFilters) + + val scanNotEqual = scanBuilder( + spark, + partitioningAwareFileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters, + options, + partitionFilters, + dataFiltersNotEqual) + assert(scan !== scanNotEqual) + } + } + } +} + +class FileScanSuite extends FileScanSuiteBase { + val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( + ("ParquetScan", + (s, fi, ds, rds, rps, f, o, pf, df) => + ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df), + Seq.empty), + ("OrcScan", + (s, fi, ds, rds, rps, f, o, pf, df) => + OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f, pf, df), + Seq.empty), + ("CSVScan", + (s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o, f, pf, df), + Seq.empty), + ("JsonScan", + (s, fi, ds, rds, rps, f, o, pf, df) => JsonScan(s, fi, ds, rds, rps, o, f, pf, df), + Seq.empty), + ("TextScan", + (s, fi, _, rds, rps, _, o, pf, df) => TextScan(s, fi, rds, rps, o, pf, df), + Seq("dataSchema", "pushedFilters"))) + + run(scanBuilders) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c29eac2c7b1f3..aa673dc666510 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupporte import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -3945,6 +3946,29 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-33482: Fix FileScan canonicalization") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + withTempPath { path => + spark.range(5).toDF().write.mode("overwrite").parquet(path.toString) + withTempView("t") { + spark.read.parquet(path.toString).createOrReplaceTempView("t") + val df = sql( + """ + |SELECT * + |FROM t AS t1 + |JOIN t AS t2 ON t2.id = t1.id + |JOIN t AS t3 ON t3.id = t2.id + |""".stripMargin) + df.collect() + val reusedExchanges = collect(df.queryExecution.executedPlan) { + case r: ReusedExchangeExec => r + } + assert(reusedExchanges.size == 1) + } + } + } + } } case class Foo(bar: Option[String])