Skip to content

Commit 61dd5d1

Browse files
johanl-dballisonport-db
authored andcommitted
Correct case handling in MERGE with schema evolution
This fixes an issue where inconsistently using case-insensitive column names with schema evolution and generated columns can trigger an assertion during analysis. If `new_column` is a column present in the source and not the target of a MERGE operation and the target contains a generated column, the following INSERT clause will fail as `NEW_column` and `new_column` are wrongly considered different column names when computing the final schema after evolution: ``` WHEN NOT MATCHED THEN INSERT (NEW_column) VALUES (source.new_column); ``` Added tests for schema evolution, generated column to cover case-(in)sensitive column names. Closes #2272 GitOrigin-RevId: 5f4e3f1294ca2538484de7238c294236cfc8a8b5
1 parent 8b768b6 commit 61dd5d1

File tree

4 files changed

+72
-5
lines changed

4 files changed

+72
-5
lines changed

Diff for: spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/deltaMerge.scala

+15-4
Original file line numberDiff line numberDiff line change
@@ -602,15 +602,26 @@ object DeltaMergeInto {
602602
// clause, then merge this schema with the target to give the final schema.
603603
def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType =
604604
StructType(sourceSchema.flatMap { field =>
605-
val fieldPath = basePath :+ field.name.toLowerCase(Locale.ROOT)
606-
val childAssignedInMergeClause = assignments.exists(_.startsWith(fieldPath))
605+
val fieldPath = basePath :+ field.name
606+
607+
// Helper method to check if a given field path is a prefix of another path. Delegates
608+
// equality to conf.resolver to correctly handle case sensitivity.
609+
def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean =
610+
prefix.length <= path.length && prefix.zip(path).forall {
611+
case (prefixNamePart, pathNamePart) => conf.resolver(prefixNamePart, pathNamePart)
612+
}
613+
614+
// Helper method to check if a given field path is equal to another path.
615+
def isEqual(path1: Seq[String], path2: Seq[String]): Boolean =
616+
path1.length == path2.length && isPrefix(path1, path2)
617+
607618

608619
field.dataType match {
609620
// Specifically assigned to in one clause: always keep, including all nested attributes
610-
case _ if assignments.contains(fieldPath) => Some(field)
621+
case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field)
611622
// If this is a struct and one of the children is being assigned to in a merge clause,
612623
// keep it and continue filtering children.
613-
case struct: StructType if childAssignedInMergeClause =>
624+
case struct: StructType if assignments.exists(isPrefix(fieldPath, _)) =>
614625
Some(field.copy(dataType = filterSchema(struct, fieldPath)))
615626
// The field isn't assigned to directly or indirectly (i.e. its children) in any non-*
616627
// clause. Check if it should be kept with any * action.

Diff for: spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,9 @@ case class PreprocessTableMerge(override val conf: SQLConf)
410410
if (implicitColumns.isEmpty) {
411411
return (allActions, Set[String]())
412412
}
413-
assert(finalSchema.size == allActions.size)
413+
assert(finalSchema.size == allActions.size,
414+
"Invalid number of columns in INSERT clause with generated columns. Expected schema: " +
415+
s"$finalSchema, INSERT actions: $allActions")
414416

415417
val track = mutable.Set[String]()
416418

Diff for: spark/src/test/scala/org/apache/spark/sql/delta/GeneratedColumnSuite.scala

+36
Original file line numberDiff line numberDiff line change
@@ -1724,6 +1724,42 @@ trait GeneratedColumnSuiteBase extends GeneratedColumnTest {
17241724
}
17251725
}
17261726

1727+
test("MERGE INSERT with schema evolution on different name case") {
1728+
withTableName("source") { src =>
1729+
withTableName("target") { tgt =>
1730+
createTable(
1731+
tableName = src,
1732+
path = None,
1733+
schemaString = "c1 INT, c2 INT",
1734+
generatedColumns = Map.empty,
1735+
partitionColumns = Seq.empty
1736+
)
1737+
sql(s"INSERT INTO ${src} values (2, 4);")
1738+
createTable(
1739+
tableName = tgt,
1740+
path = None,
1741+
schemaString = "c1 INT, c3 INT",
1742+
generatedColumns = Map("c3" -> "c1 + 1"),
1743+
partitionColumns = Seq.empty
1744+
)
1745+
sql(s"INSERT INTO ${tgt} values (1, 2);")
1746+
1747+
withSQLConf(("spark.databricks.delta.schema.autoMerge.enabled", "true")) {
1748+
sql(s"""
1749+
|MERGE INTO ${tgt}
1750+
|USING ${src}
1751+
|on ${tgt}.c1 = ${src}.c1
1752+
|WHEN NOT MATCHED THEN INSERT (c1, C2) VALUES (${src}.c1, ${src}.c2)
1753+
|""".stripMargin)
1754+
}
1755+
checkAnswer(
1756+
sql(s"SELECT * FROM ${tgt}"),
1757+
Seq(Row(1, 2, null), Row(2, 3, 4))
1758+
)
1759+
}
1760+
}
1761+
}
1762+
17271763
test("generated columns with cdf") {
17281764
val tableName1 = "gcEnabledCDCOn"
17291765
val tableName2 = "gcEnabledCDCOff"

Diff for: spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSchemaEvolutionSuite.scala

+18
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,24 @@ trait MergeIntoSchemaEvolutionBaseTests {
451451
expectedWithoutEvolution = ((0, 0) +: (2, 2) +: (3, 30) +: (1, 1) +: Nil).toDF("key", "value")
452452
)
453453

454+
testEvolution(s"case-insensitive insert")(
455+
targetData = Seq((0, 0), (1, 10), (3, 30)).toDF("key", "value"),
456+
sourceData = Seq((1, 1), (2, 2)).toDF("key", "VALUE"),
457+
clauses = insert("(key, value, VALUE) VALUES (s.key, s.value, s.VALUE)") :: Nil,
458+
expected = ((0, 0) +: (1, 10) +: (3, 30) +: (2, 2) +: Nil).toDF("key", "value"),
459+
expectedWithoutEvolution = ((0, 0) +: (1, 10) +: (3, 30) +: (2, 2) +: Nil).toDF("key", "value"),
460+
confs = Seq(SQLConf.CASE_SENSITIVE.key -> "false")
461+
)
462+
463+
testEvolution(s"case-sensitive insert")(
464+
targetData = Seq((0, 0), (1, 10), (3, 30)).toDF("key", "value"),
465+
sourceData = Seq((1, 1), (2, 2)).toDF("key", "VALUE"),
466+
clauses = insert("(key, value, VALUE) VALUES (s.key, s.value, s.VALUE)") :: Nil,
467+
expectErrorContains = "Cannot resolve s.value in INSERT clause",
468+
expectErrorWithoutEvolutionContains = "Cannot resolve s.value in INSERT clause",
469+
confs = Seq(SQLConf.CASE_SENSITIVE.key -> "true")
470+
)
471+
454472
testEvolution("evolve partitioned table")(
455473
targetData = Seq((0, 0), (1, 10), (3, 30)).toDF("key", "value"),
456474
sourceData = Seq((1, 1, "extra1"), (2, 2, "extra2")).toDF("key", "value", "extra"),

0 commit comments

Comments
 (0)