Skip to content

Commit

Permalink
Fix DPO (dynamic partition overwrite) to work with multiple partition…
Browse files Browse the repository at this point in the history
… data types

Fixes DPO to work with partition columns of more than 1 data type. Adds a test to `DeltaSuite`

GitOrigin-RevId: f6db69052d6b3d7115d4671052f31a25b16fa84d
  • Loading branch information
allisonport-db authored and vkorukanti committed Dec 21, 2022
1 parent 22889ec commit b6a1c50
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 32 deletions.
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType
*/
private[delta] class CurrentTransactionInfo(
val txnId: String,
val readPredicates: Seq[Expression],
val readPredicates: Seq[DeltaTablePartitionReadPredicate],
val readFiles: Set[AddFile],
val readWholeTable: Boolean,
val readAppIds: Set[String],
Expand Down Expand Up @@ -194,16 +194,28 @@ private[delta] class ConflictChecker(
}

import org.apache.spark.sql.delta.implicits._
val predicatesMatchingAddedFiles = ExpressionSet(
currentTransactionInfo.readPredicates).iterator.flatMap { p =>
// ES-366661: use readSnapshot's partitionSchema as that is what we read in the
// beginning.
val conflictingFile = DeltaLog.filterFileList(
partitionSchema = currentTransactionInfo.partitionSchemaAtReadTime,
addedFilesToCheckForConflicts.toDF(spark), p :: Nil).as[AddFile].take(1)

conflictingFile.headOption.map(f => getPrettyPartitionMessage(f.partitionValues))
}.take(1).toArray
// we need to canonicalize the read predicates per each group of rewrites vs. nonRewrites
val canonicalPredicates = currentTransactionInfo.readPredicates
.partition(_.shouldRewriteFilter) match {
case (rewrites, nonRewrites) =>
val canonicalRewrites = ExpressionSet(rewrites.map(_.predicate))
val canonicalNonRewrites = ExpressionSet(nonRewrites.map(_.predicate))
canonicalRewrites.map(DeltaTablePartitionReadPredicate(_)) ++
canonicalNonRewrites.map(DeltaTablePartitionReadPredicate(_, shouldRewriteFilter = false))
}

val predicatesMatchingAddedFiles = canonicalPredicates.iterator
.flatMap { readPredicate =>

val conflictingFile = DeltaLog.filterFileList(
partitionSchema = currentTransactionInfo.partitionSchemaAtReadTime,
files = addedFilesToCheckForConflicts.toDF(spark),
partitionFilters = readPredicate.predicate :: Nil,
shouldRewritePartitionFilters = readPredicate.shouldRewriteFilter
).as[AddFile].take(1)

conflictingFile.headOption.map(f => getPrettyPartitionMessage(f.partitionValues))
}.take(1).toArray

if (predicatesMatchingAddedFiles.nonEmpty) {
val isWriteSerializable = isolationLevel == WriteSerializable
Expand Down
20 changes: 14 additions & 6 deletions core/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala
Expand Up @@ -833,17 +833,25 @@ object DeltaLog extends DeltaLogging {
* information
* @param partitionFilters Filters on the partition columns
* @param partitionColumnPrefixes The path to the `partitionValues` column, if it's nested
* @param shouldRewritePartitionFilters Whether to rewrite `partitionFilters` to be over the
* [[AddFile]] schema
*/
def filterFileList(
partitionSchema: StructType,
files: DataFrame,
partitionFilters: Seq[Expression],
partitionColumnPrefixes: Seq[String] = Nil): DataFrame = {
val rewrittenFilters = rewritePartitionFilters(
partitionSchema,
files.sparkSession.sessionState.conf.resolver,
partitionFilters,
partitionColumnPrefixes)
partitionColumnPrefixes: Seq[String] = Nil,
shouldRewritePartitionFilters: Boolean = true): DataFrame = {

val rewrittenFilters = if (shouldRewritePartitionFilters) {
rewritePartitionFilters(
partitionSchema,
files.sparkSession.sessionState.conf.resolver,
partitionFilters,
partitionColumnPrefixes)
} else {
partitionFilters
}
val expr = rewrittenFilters.reduceLeftOption(And).getOrElse(Literal.TrueLiteral)
val columnFilter = new Column(expr)
files.filter(columnFilter)
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/org/apache/spark/sql/delta/DeltaUDF.scala
Expand Up @@ -44,6 +44,9 @@ object DeltaUDF {
def stringFromMap(f: Map[String, String] => String): UserDefinedFunction =
createUdfFromTemplateUnsafe(stringFromMapTemplate, f, udf(f))

def booleanFromMap(f: Map[String, String] => Boolean): UserDefinedFunction =
createUdfFromTemplateUnsafe(booleanFromMapTemplate, f, udf(f))

private lazy val stringFromStringTemplate =
udf[String, String](identity).asInstanceOf[SparkUserDefinedFunction]

Expand All @@ -58,6 +61,9 @@ object DeltaUDF {
private lazy val stringFromMapTemplate =
udf((_: Map[String, String]) => "").asInstanceOf[SparkUserDefinedFunction]

private lazy val booleanFromMapTemplate =
udf((_: Map[String, String]) => true).asInstanceOf[SparkUserDefinedFunction]

/**
* Return a `UserDefinedFunction` for the given `f` from `template` if
* `INTERNAL_UDF_OPTIMIZATION_ENABLED` is enabled. Otherwise, `orElse` will be called to create a
Expand Down
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.delta.commands.cdc.CDCReader
import org.apache.spark.sql.delta.files._
import org.apache.spark.sql.delta.hooks.{CheckpointHook, GenerateSymlinkManifest, PostCommitHook}
import org.apache.spark.sql.delta.implicits.addFileEncoder
import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.schema.{SchemaMergingUtils, SchemaUtils}
import org.apache.spark.sql.delta.sources.DeltaSQLConf
Expand Down Expand Up @@ -87,6 +88,31 @@ case class CommitStats(
txnId: Option[String] = None
)

/**
* Represents a partition read predicate on a Delta table.
*
* Partition predicates can either reference the table's logical partition columns, or the
* physical [[AddFile]]'s schema. When a predicate refers to the logical partition columns it needs
* to be rewritten to be over the [[AddFile]]'s schema before filtering files. This is indicated
* with shouldRewriteFilter=true.
*
* Currently the only path for a predicate with shouldRewriteFilter=false is through DPO
* (dynamic partition overwrite) since we filter directly on [[AddFile.partitionValues]].
*
* For example, consider a table with the schema below and partition column "a"
* |-- a: integer {physicalName = "XX"}
* |-- b: integer {physicalName = "YY"}
*
* An example of a predicate that needs to be written is: (a = 0)
* Before filtering the [[AddFile]]s, this predicate needs to be rewritten to:
* (partitionValues.XX = 0)
*
* An example of a predicate that does not need to be rewritten is:
* (partitionValues = Map(XX -> 0))
*/
private[delta] case class DeltaTablePartitionReadPredicate(
predicate: Expression, shouldRewriteFilter: Boolean = true)

/**
* Used to perform a set of reads in a transaction and then commit a set of updates to the
* state of the log. All reads from the [[DeltaLog]], MUST go through this instance rather
Expand Down Expand Up @@ -190,7 +216,7 @@ trait OptimisticTransactionImpl extends TransactionalWrite
* Tracks the data that could have been seen by recording the partition
* predicates by which files have been queried by this transaction.
*/
protected val readPredicates = new ArrayBuffer[Expression]
protected val readPredicates = new ArrayBuffer[DeltaTablePartitionReadPredicate]

/** Tracks specific files that have been seen by this transaction. */
protected val readFiles = new HashSet[AddFile]
Expand Down Expand Up @@ -558,7 +584,8 @@ trait OptimisticTransactionImpl extends TransactionalWrite
val partitionFilters = filters.filter { f =>
DeltaTableUtils.isPredicatePartitionColumnsOnly(f, metadata.partitionColumns, spark)
}
readPredicates += partitionFilters.reduceLeftOption(And).getOrElse(Literal(true))
readPredicates += DeltaTablePartitionReadPredicate(
partitionFilters.reduceLeftOption(And).getOrElse(Literal(true)))
readFiles ++= scan.files
scan
}
Expand All @@ -581,14 +608,16 @@ trait OptimisticTransactionImpl extends TransactionalWrite
s" expected, found $f")
}
val scan = snapshot.filesForScan(limit, partitionFilters)
readPredicates += partitionFilters.reduceLeftOption(And).getOrElse(Literal(true))
readPredicates += DeltaTablePartitionReadPredicate(
partitionFilters.reduceLeftOption(And).getOrElse(Literal(true)))
readFiles ++= scan.files
scan
}

override def filesWithStatsForScan(partitionFilters: Seq[Expression]): DataFrame = {
val metadata = snapshot.filesWithStatsForScan(partitionFilters)
readPredicates += partitionFilters.reduceLeftOption(And).getOrElse(Literal(true))
readPredicates += DeltaTablePartitionReadPredicate(
partitionFilters.reduceLeftOption(And).getOrElse(Literal(true)))
withFilesRead(filterFiles(partitionFilters))
metadata
}
Expand All @@ -602,26 +631,36 @@ trait OptimisticTransactionImpl extends TransactionalWrite
val partitionFilters = filters.filter { f =>
DeltaTableUtils.isPredicatePartitionColumnsOnly(f, metadata.partitionColumns, spark)
}
readPredicates += partitionFilters.reduceLeftOption(And).getOrElse(Literal.TrueLiteral)
readPredicates += DeltaTablePartitionReadPredicate(
partitionFilters.reduceLeftOption(And).getOrElse(Literal.TrueLiteral))
readFiles ++= scan.files
scan.files
}

/** Returns files within the given partitions. */
/**
* Returns files within the given partitions.
*
* `partitions` is a set of the `partitionValues` stored in [[AddFile]]s. This means they refer to
* the physical column names, and values are stored as strings.
* */
def filterFiles(partitions: Set[Map[String, String]]): Seq[AddFile] = {
import org.apache.spark.sql.functions.{array, col}
val partitionValues = partitions.map { partition =>
metadata.physicalPartitionColumns.map(partition).toArray
}
val predicate = array(metadata.partitionColumns.map(col): _*)
.isInCollection(partitionValues)
.expr
filterFiles(Seq(predicate))
import org.apache.spark.sql.functions.col
val df = snapshot.allFiles.toDF()
val isFileInTouchedPartitions =
DeltaUDF.booleanFromMap(partitions.contains)(col("partitionValues"))
val filteredFiles = df
.filter(isFileInTouchedPartitions)
.withColumn("stats", DataSkippingReader.nullStringLiteral)
.as[AddFile]
.collect()
readPredicates += DeltaTablePartitionReadPredicate(isFileInTouchedPartitions.expr,
shouldRewriteFilter = false)
filteredFiles
}

/** Mark the entire table as tainted by this transaction. */
def readWholeTable(): Unit = {
readPredicates += Literal.TrueLiteral
readPredicates += DeltaTablePartitionReadPredicate(Literal.TrueLiteral)
readTheWholeTable = true
}

Expand Down
109 changes: 109 additions & 0 deletions core/src/test/scala/org/apache/spark/sql/delta/DeltaSuite.scala
Expand Up @@ -824,6 +824,32 @@ class DeltaSuite extends QueryTest
}
}

test("batch write: append, dynamic partition overwrite string and integer partition column") {
withSQLConf(DeltaSQLConf.DYNAMIC_PARTITION_OVERWRITE_ENABLED.key -> "true") {
withTempDir { tempDir =>
def data: DataFrame = spark.read.format("delta").load(tempDir.toString)

Seq((1, "x"), (2, "y"), (3, "z")).toDF("value", "part2")
.withColumn("part1", $"value" % 2)
.write
.format("delta")
.partitionBy("part1", "part2")
.mode("append")
.save(tempDir.getCanonicalPath)

Seq((5, "x"), (7, "y")).toDF("value", "part2")
.withColumn("part1", $"value" % 2)
.write
.format("delta")
.partitionBy("part1", "part2")
.mode("overwrite")
.option(DeltaOptions.PARTITION_OVERWRITE_MODE_OPTION, "dynamic")
.save(tempDir.getCanonicalPath)
checkDatasetUnorderly(data.select("value").as[Int], 2, 3, 5, 7)
}
}
}

test("batch write: append, dynamic partition overwrite overwrites nothing") {
withSQLConf(DeltaSQLConf.DYNAMIC_PARTITION_OVERWRITE_ENABLED.key -> "true") {
withTempDir { tempDir =>
Expand Down Expand Up @@ -1068,6 +1094,32 @@ class DeltaSuite extends QueryTest
}
}

test("batch write: append, dynamic partition with 'partitionValues' column") {
withSQLConf(DeltaSQLConf.DYNAMIC_PARTITION_OVERWRITE_ENABLED.key -> "true") {
withTempDir { tempDir =>
def data: DataFrame = spark.read.format("delta").load(tempDir.toString)

Seq(1, 2, 3).toDF
.withColumn("partitionValues", $"value" % 2)
.write
.format("delta")
.partitionBy("partitionValues")
.mode("append")
.save(tempDir.getCanonicalPath)

Seq(1, 5).toDF
.withColumn("partitionValues", $"value" % 2)
.write
.format("delta")
.partitionBy("partitionValues")
.mode("overwrite")
.option(DeltaOptions.PARTITION_OVERWRITE_MODE_OPTION, "dynamic")
.save(tempDir.getCanonicalPath)
checkDatasetUnorderly(data.select("value").as[Int], 1, 2, 5)
}
}
}

test("batch write: ignore") {
withTempDir { tempDir =>
def data: DataFrame = spark.read.format("delta").load(tempDir.toString)
Expand Down Expand Up @@ -2336,6 +2388,8 @@ class DeltaSuite extends QueryTest
class DeltaNameColumnMappingSuite extends DeltaSuite
with DeltaColumnMappingEnableNameMode {

import testImplicits._

override protected def runOnlyTests = Seq(
"handle partition filters and data filters",
"query with predicates should skip partitions",
Expand All @@ -2345,4 +2399,59 @@ class DeltaNameColumnMappingSuite extends DeltaSuite
"isBlindAppend with save and saveAsTable"
)


test(
"dynamic partition overwrite with conflicting logical vs. physical named partition columns") {
// It isn't sufficient to just test with column mapping enabled because the physical names are
// generated automatically and thus are unique w.r.t. the logical names.
// Instead we need to have: ColA.logicalName = ColB.physicalName,
// which means we need to start with columnMappingMode=None, and then upgrade to
// columnMappingMode=name and rename our columns

withSQLConf(DeltaSQLConf.DYNAMIC_PARTITION_OVERWRITE_ENABLED.key -> "true",
DeltaConfigs.COLUMN_MAPPING_MODE.defaultTablePropertyKey-> NoMapping.name) {
withTempDir { tempDir =>
def data: DataFrame = spark.read.format("delta").load(tempDir.toString)

Seq(("a", "x", 1), ("b", "y", 2), ("c", "x", 3)).toDF("part1", "part2", "value")
.write
.format("delta")
.partitionBy("part1", "part2")
.mode("append")
.save(tempDir.getCanonicalPath)

spark.sql(
s"""
|ALTER TABLE delta.`${tempDir.getCanonicalPath}` SET TBLPROPERTIES (
| 'delta.minReaderVersion' = '2',
| 'delta.minWriterVersion' = '5',
| 'delta.columnMapping.mode' = 'name'
|)
|""".stripMargin)

spark.sql(
s"""
|ALTER TABLE delta.`${tempDir.getCanonicalPath}` RENAME COLUMN part1 TO temp
|""".stripMargin)
spark.sql(
s"""
|ALTER TABLE delta.`${tempDir.getCanonicalPath}` RENAME COLUMN part2 TO part1
|""".stripMargin)
spark.sql(
s"""
|ALTER TABLE delta.`${tempDir.getCanonicalPath}` RENAME COLUMN temp TO part2
|""".stripMargin)

Seq(("a", "x", 4), ("d", "x", 5)).toDF("part2", "part1", "value")
.write
.format("delta")
.partitionBy("part2", "part1")
.mode("overwrite")
.option(DeltaOptions.PARTITION_OVERWRITE_MODE_OPTION, "dynamic")
.save(tempDir.getCanonicalPath)
checkDatasetUnorderly(data.select("part2", "part1", "value").as[(String, String, Int)],
("a", "x", 4), ("b", "y", 2), ("c", "x", 3), ("d", "x", 5))
}
}
}
}
Expand Up @@ -96,4 +96,9 @@ class DeltaUDFSuite extends QueryTest with SharedSparkSession {
func = DeltaUDF.stringFromMap(x => x.toString),
input = Map("foo" -> "bar"),
expected = "Map(foo -> bar)")
testUDF(
name = "booleanFromMap",
func = DeltaUDF.booleanFromMap(x => x.isEmpty),
input = Map("foo" -> "bar"),
expected = false)
}

0 comments on commit b6a1c50

Please sign in to comment.