Skip to content

Commit

Permalink
[SPARK-28893][SQL] Support MERGE INTO in the parser and add the corre…
Browse files Browse the repository at this point in the history
…sponding logical plan

### What changes were proposed in this pull request?
This PR supports MERGE INTO in the parser and add the corresponding logical plan. The SQL syntax likes,
```
MERGE INTO [ds_catalog.][multi_part_namespaces.]target_table [AS target_alias]
USING [ds_catalog.][multi_part_namespaces.]source_table | subquery [AS source_alias]
ON <merge_condition>
[ WHEN MATCHED [ AND <condition> ] THEN <matched_action> ]
[ WHEN MATCHED [ AND <condition> ] THEN <matched_action> ]
[ WHEN NOT MATCHED [ AND <condition> ]  THEN <not_matched_action> ]
```
where
```
<matched_action>  =
  DELETE  |
  UPDATE SET *  |
  UPDATE SET column1 = value1 [, column2 = value2 ...]

<not_matched_action>  =
  INSERT *  |
  INSERT (column1 [, column2 ...]) VALUES (value1 [, value2 ...])
```

### Why are the changes needed?
This is a start work for introduce `MERGE INTO` support for the builtin datasource, and the design work for the `MERGE INTO` support in DSV2.

### Does this PR introduce any user-facing change?
No.

### How was this patch tested?
New test cases.

Closes apache#26167 from xianyinxin/SPARK-28893.

Authored-by: xy_xin <xianyin.xxy@alibaba-inc.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
xy_xin authored and cloud-fan committed Nov 9, 2019
1 parent 8152a87 commit 7cfd589
Show file tree
Hide file tree
Showing 10 changed files with 742 additions and 41 deletions.
2 changes: 2 additions & 0 deletions docs/sql-keywords.md
Expand Up @@ -169,6 +169,8 @@ Below is a list of all the keywords in Spark SQL.
<tr><td>LOGICAL</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>MACRO</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>MAP</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>MATCHED</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>MERGE</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>MICROSECOND</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>MICROSECONDS</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>MILLISECOND</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
Expand Down
Expand Up @@ -219,6 +219,12 @@ statement
| RESET #resetConfiguration
| DELETE FROM multipartIdentifier tableAlias whereClause? #deleteFromTable
| UPDATE multipartIdentifier tableAlias setClause whereClause? #updateTable
| MERGE INTO target=multipartIdentifier targetAlias=tableAlias
USING (source=multipartIdentifier |
'(' sourceQuery=query')') sourceAlias=tableAlias
ON mergeCondition=booleanExpression
matchedClause*
notMatchedClause* #mergeIntoTable
| unsupportedHiveNativeCommands .*? #failNativeCommand
;

Expand Down Expand Up @@ -479,10 +485,33 @@ selectClause
;

setClause
: SET assign (',' assign)*
: SET assignmentList
;

assign
matchedClause
: WHEN MATCHED (AND matchedCond=booleanExpression)? THEN matchedAction
;
notMatchedClause
: WHEN NOT MATCHED (AND notMatchedCond=booleanExpression)? THEN notMatchedAction
;

matchedAction
: DELETE
| UPDATE SET ASTERISK
| UPDATE SET assignmentList
;

notMatchedAction
: INSERT ASTERISK
| INSERT '(' columns=multipartIdentifierList ')'
VALUES '(' expression (',' expression)* ')'
;

assignmentList
: assignment (',' assignment)*
;

assignment
: key=multipartIdentifier EQ value=expression
;

Expand Down Expand Up @@ -632,6 +661,10 @@ rowFormat
(NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited
;

multipartIdentifierList
: multipartIdentifier (',' multipartIdentifier)*
;

multipartIdentifier
: parts+=errorCapturingIdentifier ('.' parts+=errorCapturingIdentifier)*
;
Expand Down Expand Up @@ -1027,6 +1060,8 @@ ansiNonReserved
| LOGICAL
| MACRO
| MAP
| MATCHED
| MERGE
| MICROSECOND
| MICROSECONDS
| MILLISECOND
Expand Down Expand Up @@ -1278,6 +1313,8 @@ nonReserved
| LOGICAL
| MACRO
| MAP
| MATCHED
| MERGE
| MICROSECOND
| MICROSECONDS
| MILLISECOND
Expand Down Expand Up @@ -1542,6 +1579,8 @@ LOCKS: 'LOCKS';
LOGICAL: 'LOGICAL';
MACRO: 'MACRO';
MAP: 'MAP';
MATCHED: 'MATCHED';
MERGE: 'MERGE';
MICROSECOND: 'MICROSECOND';
MICROSECONDS: 'MICROSECONDS';
MILLISECOND: 'MILLISECOND';
Expand Down
Expand Up @@ -1178,11 +1178,56 @@ class Analyzer(
// table by ResolveOutputRelation. that rule will alias the attributes to the table's names.
o

case m @ MergeIntoTable(targetTable, sourceTable, _, _, _)
if !m.resolved && targetTable.resolved && sourceTable.resolved =>
val newMatchedActions = m.matchedActions.map {
case DeleteAction(deleteCondition) =>
val resolvedDeleteCondition = deleteCondition.map(resolveExpressionTopDown(_, m))
DeleteAction(resolvedDeleteCondition)
case UpdateAction(updateCondition, assignments) =>
val resolvedUpdateCondition = updateCondition.map(resolveExpressionTopDown(_, m))
UpdateAction(resolvedUpdateCondition, resolveAssignments(assignments, m))
case o => o
}
val newNotMatchedActions = m.notMatchedActions.map {
case InsertAction(insertCondition, assignments) =>
val resolvedInsertCondition = insertCondition.map(resolveExpressionTopDown(_, m))
InsertAction(resolvedInsertCondition, resolveAssignments(assignments, m))
case o => o
}
val resolvedMergeCondition = resolveExpressionTopDown(m.mergeCondition, m)
m.copy(mergeCondition = resolvedMergeCondition,
matchedActions = newMatchedActions,
notMatchedActions = newNotMatchedActions)

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}")
q.mapExpressions(resolveExpressionTopDown(_, q))
}

def resolveAssignments(
assignments: Seq[Assignment],
mergeInto: MergeIntoTable): Seq[Assignment] = {
if (assignments.isEmpty) {
val expandedColumns = mergeInto.targetTable.output
val expandedValues = mergeInto.sourceTable.output
expandedColumns.zip(expandedValues).map(kv => Assignment(kv._1, kv._2))
} else {
assignments.map { assign =>
val resolvedKey = assign.key match {
case c if !c.resolved => resolveExpressionTopDown(c, mergeInto.targetTable)
case o => o
}
val resolvedValue = assign.value match {
// The update values may contain target and/or source references.
case c if !c.resolved => resolveExpressionTopDown(c, mergeInto)
case o => o
}
Assignment(resolvedKey, resolvedValue)
}
}
}

def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.map {
case a: Alias => Alias(a.child, a.name)()
Expand Down
Expand Up @@ -341,21 +341,23 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx)
}

override def visitDeleteFromTable(
ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) {

val tableId = visitMultipartIdentifier(ctx.multipartIdentifier)
val tableAlias = if (ctx.tableAlias() != null) {
val ident = ctx.tableAlias().strictIdentifier()
// We do not allow columns aliases after table alias.
if (ctx.tableAlias().identifierList() != null) {
throw new ParseException("Columns aliases is not allowed in DELETE.",
ctx.tableAlias().identifierList())
private def getTableAliasWithoutColumnAlias(
ctx: TableAliasContext, op: String): Option[String] = {
if (ctx == null) {
None
} else {
val ident = ctx.strictIdentifier()
if (ctx.identifierList() != null) {
throw new ParseException(s"Columns aliases are not allowed in $op.", ctx.identifierList())
}
if (ident != null) Some(ident.getText) else None
} else {
None
}
}

override def visitDeleteFromTable(
ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) {
val tableId = visitMultipartIdentifier(ctx.multipartIdentifier)
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE")
val predicate = if (ctx.whereClause() != null) {
Some(expression(ctx.whereClause().booleanExpression()))
} else {
Expand All @@ -367,18 +369,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging

override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
val tableId = visitMultipartIdentifier(ctx.multipartIdentifier)
val tableAlias = if (ctx.tableAlias() != null) {
val ident = ctx.tableAlias().strictIdentifier()
// We do not allow columns aliases after table alias.
if (ctx.tableAlias().identifierList() != null) {
throw new ParseException("Columns aliases is not allowed in UPDATE.",
ctx.tableAlias().identifierList())
}
if (ident != null) Some(ident.getText) else None
} else {
None
}
val (attrs, values) = ctx.setClause().assign().asScala.map {
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE")
val (attrs, values) = ctx.setClause().assignmentList().assignment().asScala.map {
kv => visitMultipartIdentifier(kv.key) -> expression(kv.value)
}.unzip
val predicate = if (ctx.whereClause() != null) {
Expand All @@ -395,6 +387,95 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
predicate)
}

private def withAssignments(assignCtx: SqlBaseParser.AssignmentListContext): Seq[Assignment] =
withOrigin(assignCtx) {
assignCtx.assignment().asScala.map { assign =>
Assignment(UnresolvedAttribute(visitMultipartIdentifier(assign.key)),
expression(assign.value))
}
}

override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
val targetTable = UnresolvedRelation(visitMultipartIdentifier(ctx.target))
val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)

val sourceTableOrQuery = if (ctx.source != null) {
UnresolvedRelation(visitMultipartIdentifier(ctx.source))
} else if (ctx.sourceQuery != null) {
visitQuery(ctx.sourceQuery)
} else {
throw new ParseException("Empty source for merge: you should specify a source" +
" table/subquery in merge.", ctx.source)
}
val sourceTableAlias = getTableAliasWithoutColumnAlias(ctx.sourceAlias, "MERGE")
val aliasedSource =
sourceTableAlias.map(SubqueryAlias(_, sourceTableOrQuery)).getOrElse(sourceTableOrQuery)

val mergeCondition = expression(ctx.mergeCondition)

val matchedClauses = ctx.matchedClause()
if (matchedClauses.size() > 2) {
throw new ParseException("There should be at most 2 'WHEN MATCHED' clauses.",
matchedClauses.get(2))
}
val matchedActions = matchedClauses.asScala.map {
clause => {
if (clause.matchedAction().DELETE() != null) {
DeleteAction(Option(clause.matchedCond).map(expression))
} else if (clause.matchedAction().UPDATE() != null) {
val condition = Option(clause.matchedCond).map(expression)
if (clause.matchedAction().ASTERISK() != null) {
UpdateAction(condition, Seq())
} else {
UpdateAction(condition, withAssignments(clause.matchedAction().assignmentList()))
}
} else {
// It should not be here.
throw new ParseException(
s"Unrecognized matched action: ${clause.matchedAction().getText}",
clause.matchedAction())
}
}
}
val notMatchedClauses = ctx.notMatchedClause()
if (notMatchedClauses.size() > 1) {
throw new ParseException("There should be at most 1 'WHEN NOT MATCHED' clause.",
notMatchedClauses.get(1))
}
val notMatchedActions = notMatchedClauses.asScala.map {
clause => {
if (clause.notMatchedAction().INSERT() != null) {
val condition = Option(clause.notMatchedCond).map(expression)
if (clause.notMatchedAction().ASTERISK() != null) {
InsertAction(condition, Seq())
} else {
val columns = clause.notMatchedAction().columns.multipartIdentifier()
.asScala.map(attr => UnresolvedAttribute(visitMultipartIdentifier(attr)))
val values = clause.notMatchedAction().expression().asScala.map(expression)
if (columns.size != values.size) {
throw new ParseException("The number of inserted values cannot match the fields.",
clause.notMatchedAction())
}
InsertAction(condition, columns.zip(values).map(kv => Assignment(kv._1, kv._2)))
}
} else {
// It should not be here.
throw new ParseException(
s"Unrecognized not matched action: ${clause.notMatchedAction().getText}",
clause.notMatchedAction())
}
}
}

MergeIntoTable(
aliasedTarget,
aliasedSource,
mergeCondition,
matchedActions,
notMatchedActions)
}

/**
* Create a partition specification map.
*/
Expand Down
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.analysis.NamedRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.analysis.{NamedRelation, Star, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.DescribeTableSchema
import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, SupportsNamespaces, TableCatalog, TableChange}
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, ColumnChange}
Expand Down Expand Up @@ -296,6 +296,47 @@ case class UpdateTable(
override def children: Seq[LogicalPlan] = table :: Nil
}

/**
* The logical plan of the MERGE INTO command that works for v2 tables.
*/
case class MergeIntoTable(
targetTable: LogicalPlan,
sourceTable: LogicalPlan,
mergeCondition: Expression,
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction]) extends Command with SupportsSubquery {
override def children: Seq[LogicalPlan] = Seq(targetTable, sourceTable)
}

sealed abstract class MergeAction(
condition: Option[Expression]) extends Expression with Unevaluable {
override def foldable: Boolean = false
override def nullable: Boolean = false
override def dataType: DataType = throw new UnresolvedException(this, "nullable")
override def children: Seq[Expression] = condition.toSeq
}

case class DeleteAction(condition: Option[Expression]) extends MergeAction(condition)

case class UpdateAction(
condition: Option[Expression],
assignments: Seq[Assignment]) extends MergeAction(condition) {
override def children: Seq[Expression] = condition.toSeq ++ assignments
}

case class InsertAction(
condition: Option[Expression],
assignments: Seq[Assignment]) extends MergeAction(condition) {
override def children: Seq[Expression] = condition.toSeq ++ assignments
}

case class Assignment(key: Expression, value: Expression) extends Expression with Unevaluable {
override def foldable: Boolean = false
override def nullable: Boolean = false
override def dataType: DataType = throw new UnresolvedException(this, "nullable")
override def children: Seq[Expression] = key :: value :: Nil
}

/**
* The logical plan of the DROP TABLE command that works for v2 tables.
*/
Expand Down

0 comments on commit 7cfd589

Please sign in to comment.