Skip to content

Commit

Permalink
[SPARK-34720][SQL] MERGE ... UPDATE/INSERT * should do by-name resolu…
Browse files Browse the repository at this point in the history
…tion

In Spark, we have an extension in the MERGE syntax: INSERT/UPDATE *. This is not from ANSI standard or any other mainstream databases, so we need to define the behaviors by our own.

The behavior today is very weird: assume the source table has `n1` columns, target table has `n2` columns. We generate the assignments by taking the first `min(n1, n2)` columns from source & target tables and pairing them by ordinal.

This PR proposes a more reasonable behavior: take all the columns from target table as keys, and find the corresponding columns from source table by name as values.

Fix the MEREG INSERT/UPDATE * to be more user-friendly and easy to do schema evolution.

Yes, but MERGE is only supported by very few data sources.

new tests

Closes apache#32192 from cloud-fan/merge.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cloud-fan authored and sunchao committed Jul 30, 2021
1 parent abf48b1 commit 3424ec7
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1646,14 +1646,18 @@ class Analyzer(override val catalogManager: CatalogManager)
case UpdateAction(updateCondition, assignments) =>
val resolvedUpdateCondition = updateCondition.map(
resolveExpressionByPlanChildren(_, m))
// The update value can access columns from both target and source tables.
UpdateAction(
resolvedUpdateCondition,
resolveAssignments(Some(assignments), m, resolveValuesWithSourceOnly = false))
// The update value can access columns from both target and source tables.
resolveAssignments(assignments, m, resolveValuesWithSourceOnly = false))
case UpdateStarAction(updateCondition) =>
val assignments = targetTable.output.map { attr =>
Assignment(attr, UnresolvedAttribute(Seq(attr.name)))
}
UpdateAction(
updateCondition.map(resolveExpressionByPlanChildren(_, m)),
resolveAssignments(assignments = None, m, resolveValuesWithSourceOnly = false))
// For UPDATE *, the value must from source table.
resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true))
case o => o
}
val newNotMatchedActions = m.notMatchedActions.map {
Expand All @@ -1664,15 +1668,18 @@ class Analyzer(override val catalogManager: CatalogManager)
resolveExpressionByPlanChildren(_, Project(Nil, m.sourceTable)))
InsertAction(
resolvedInsertCondition,
resolveAssignments(Some(assignments), m, resolveValuesWithSourceOnly = true))
resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true))
case InsertStarAction(insertCondition) =>
// The insert action is used when not matched, so its condition and value can only
// access columns from the source table.
val resolvedInsertCondition = insertCondition.map(
resolveExpressionByPlanChildren(_, Project(Nil, m.sourceTable)))
val assignments = targetTable.output.map { attr =>
Assignment(attr, UnresolvedAttribute(Seq(attr.name)))
}
InsertAction(
resolvedInsertCondition,
resolveAssignments(assignments = None, m, resolveValuesWithSourceOnly = true))
resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true))
case o => o
}
val resolvedMergeCondition = resolveExpressionByPlanChildren(m.mergeCondition, m)
Expand All @@ -1690,33 +1697,38 @@ class Analyzer(override val catalogManager: CatalogManager)
}

def resolveAssignments(
assignments: Option[Seq[Assignment]],
assignments: Seq[Assignment],
mergeInto: MergeIntoTable,
resolveValuesWithSourceOnly: Boolean): Seq[Assignment] = {
if (assignments.isEmpty) {
val expandedColumns = mergeInto.targetTable.output
val expandedValues = mergeInto.sourceTable.output
expandedColumns.zip(expandedValues).map(kv => Assignment(kv._1, kv._2))
} else {
assignments.get.map { assign =>
val resolvedKey = assign.key match {
case c if !c.resolved =>
resolveExpressionByPlanChildren(c, Project(Nil, mergeInto.targetTable))
case o => o
}
val resolvedValue = assign.value match {
// The update values may contain target and/or source references.
case c if !c.resolved =>
if (resolveValuesWithSourceOnly) {
resolveExpressionByPlanChildren(c, Project(Nil, mergeInto.sourceTable))
} else {
resolveExpressionByPlanChildren(c, mergeInto)
}
case o => o
}
Assignment(resolvedKey, resolvedValue)
assignments.map { assign =>
val resolvedKey = assign.key match {
case c if !c.resolved =>
resolveMergeExprOrFail(c, Project(Nil, mergeInto.targetTable))
case o => o
}
val resolvedValue = assign.value match {
// The update values may contain target and/or source references.
case c if !c.resolved =>
if (resolveValuesWithSourceOnly) {
resolveMergeExprOrFail(c, Project(Nil, mergeInto.sourceTable))
} else {
resolveMergeExprOrFail(c, mergeInto)
}
case o => o
}
Assignment(resolvedKey, resolvedValue)
}
}

private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = {
val resolved = resolveExpressionByPlanChildren(e, p)
resolved.references.filter(!_.resolved).foreach { a =>
// Note: This will throw error only on unresolved attribute issues,
// not other resolution errors like mismatched data types.
val cols = p.inputSet.toSeq.map(_.sql).mkString(", ")
a.failAnalysis(s"cannot resolve ${a.sql} in MERGE command given columns [$cols]")
}
resolved
}

def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ class PlanResolutionSuite extends AnalysisTest {
t
}

private val table1: Table = {
val t = mock(classOf[Table])
when(t.schema()).thenReturn(new StructType().add("s", "string").add("i", "int"))
when(t.partitioning()).thenReturn(Array.empty[Transform])
t
}

private val table2: Table = {
val t = mock(classOf[Table])
when(t.schema()).thenReturn(new StructType().add("i", "int").add("x", "string"))
when(t.partitioning()).thenReturn(Array.empty[Transform])
t
}

private val tableWithAcceptAnySchemaCapability: Table = {
val t = mock(classOf[Table])
when(t.schema()).thenReturn(new StructType().add("i", "int"))
Expand Down Expand Up @@ -86,7 +100,8 @@ class PlanResolutionSuite extends AnalysisTest {
when(newCatalog.loadTable(any())).thenAnswer((invocation: InvocationOnMock) => {
invocation.getArgument[Identifier](0).name match {
case "tab" => table
case "tab1" => table
case "tab1" => table1
case "tab2" => table2
case name => throw new NoSuchTableException(name)
}
})
Expand All @@ -102,7 +117,7 @@ class PlanResolutionSuite extends AnalysisTest {
case "v1Table1" => v1Table
case "v1HiveTable" => v1HiveTable
case "v2Table" => table
case "v2Table1" => table
case "v2Table1" => table1
case "v2TableWithAcceptAnySchemaCapability" => tableWithAcceptAnySchemaCapability
case name => throw new NoSuchTableException(name)
}
Expand Down Expand Up @@ -1369,7 +1384,7 @@ class PlanResolutionSuite extends AnalysisTest {
// cte
val sql5 =
s"""
|WITH source(i, s) AS
|WITH source(s, i) AS
| (SELECT * FROM $source)
|MERGE INTO $target AS target
|USING source
Expand All @@ -1389,7 +1404,7 @@ class PlanResolutionSuite extends AnalysisTest {
updateAssigns)),
Seq(InsertAction(Some(EqualTo(il: AttributeReference, StringLiteral("insert"))),
insertAssigns))) =>
assert(source.output.map(_.name) == Seq("i", "s"))
assert(source.output.map(_.name) == Seq("s", "i"))
checkResolution(target, source, mergeCondition, Some(dl), Some(ul), Some(il),
updateAssigns, insertAssigns)

Expand All @@ -1398,8 +1413,7 @@ class PlanResolutionSuite extends AnalysisTest {
}

// no aliases
Seq(("v2Table", "v2Table1"),
("testcat.tab", "testcat.tab1")).foreach { pair =>
Seq(("v2Table", "v2Table1"), ("testcat.tab", "testcat.tab1")).foreach { pair =>

val target = pair._1
val source = pair._2
Expand Down Expand Up @@ -1491,7 +1505,7 @@ class PlanResolutionSuite extends AnalysisTest {
assert(e5.message.contains("Reference 's' is ambiguous"))
}

val sql6 =
val sql1 =
s"""
|MERGE INTO non_exist_target
|USING non_exist_source
Expand All @@ -1500,13 +1514,37 @@ class PlanResolutionSuite extends AnalysisTest {
|WHEN MATCHED THEN UPDATE SET *
|WHEN NOT MATCHED THEN INSERT *
""".stripMargin
val parsed = parseAndResolve(sql6)
val parsed = parseAndResolve(sql1)
parsed match {
case u: MergeIntoTable =>
assert(u.targetTable.isInstanceOf[UnresolvedRelation])
assert(u.sourceTable.isInstanceOf[UnresolvedRelation])
case _ => fail("Expect MergeIntoTable, but got:\n" + parsed.treeString)
}

// UPDATE * with incompatible schema between source and target tables.
val sql2 =
"""
|MERGE INTO testcat.tab
|USING testcat.tab2
|ON 1 = 1
|WHEN MATCHED THEN UPDATE SET *
|""".stripMargin
val e2 = intercept[AnalysisException](parseAndResolve(sql2))
assert(e2.message.contains(
"cannot resolve `s` in MERGE command given columns [testcat.tab2.`i`, testcat.tab2.`x`]"))

// INSERT * with incompatible schema between source and target tables.
val sql3 =
"""
|MERGE INTO testcat.tab
|USING testcat.tab2
|ON 1 = 1
|WHEN NOT MATCHED THEN INSERT *
|""".stripMargin
val e3 = intercept[AnalysisException](parseAndResolve(sql3))
assert(e3.message.contains(
"cannot resolve `s` in MERGE command given columns [testcat.tab2.`i`, testcat.tab2.`x`]"))
}

test("MERGE INTO TABLE - skip resolution on v2 tables that accept any schema") {
Expand Down

0 comments on commit 3424ec7

Please sign in to comment.