Skip to content

Commit

Permalink
[SPARK-30814][SQL] ALTER TABLE ... ADD COLUMN position should be able…
Browse files Browse the repository at this point in the history
… to reference columns being added

### What changes were proposed in this pull request?

In ALTER TABLE, a column in ADD COLUMNS can depend on the position of a column that is just being added. For example, for a table with the following schema:
```
root:
  - a: string
  - b: long
```
, the following should work:
```
ALTER TABLE t ADD COLUMNS (x int AFTER a, y int AFTER x)
```
Currently, the above statement will throw an exception saying that AFTER x cannot be resolved, because x doesn't exist yet. This PR proposes to fix this issue.

### Why are the changes needed?

To fix a bug described above.

### Does this PR introduce any user-facing change?

Yes, now
```
ALTER TABLE t ADD COLUMNS (x int AFTER a, y int AFTER x)
```
works as expected.

### How was this patch tested?

Added new tests

Closes apache#27584 from imback82/alter_table_pos_fix.

Authored-by: Terry Kim <yuminkim@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
imback82 authored and cloud-fan committed Feb 18, 2020
1 parent d8c0599 commit 5866bc7
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 25 deletions.
Expand Up @@ -3023,9 +3023,29 @@ class Analyzer(
object ResolveAlterTableChanges extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case a @ AlterTable(_, _, t: NamedRelation, changes) if t.resolved =>
// 'colsToAdd' keeps track of new columns being added. It stores a mapping from a
// normalized parent name of fields to field names that belong to the parent.
// For example, if we add columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become
// Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")).
val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]]
val schema = t.schema
val normalizedChanges = changes.flatMap {
case add: AddColumn =>
def addColumn(
parentSchema: StructType,
parentName: String,
normalizedParentName: Seq[String]): TableChange = {
val fieldsAdded = colsToAdd.getOrElse(normalizedParentName, Nil)
val pos = findColumnPosition(add.position(), parentName, parentSchema, fieldsAdded)
val field = add.fieldNames().last
colsToAdd(normalizedParentName) = fieldsAdded :+ field
TableChange.addColumn(
(normalizedParentName :+ field).toArray,
add.dataType(),
add.isNullable,
add.comment,
pos)
}
val parent = add.fieldNames().init
if (parent.nonEmpty) {
// Adding a nested field, need to normalize the parent column and position
Expand All @@ -3037,27 +3057,14 @@ class Analyzer(
val (normalizedName, sf) = target.get
sf.dataType match {
case struct: StructType =>
val pos = findColumnPosition(add.position(), parent.quoted, struct)
Some(TableChange.addColumn(
(normalizedName ++ Seq(sf.name, add.fieldNames().last)).toArray,
add.dataType(),
add.isNullable,
add.comment,
pos))

Some(addColumn(struct, parent.quoted, normalizedName :+ sf.name))
case other =>
Some(add)
}
}
} else {
// Adding to the root. Just need to normalize position
val pos = findColumnPosition(add.position(), "root", schema)
Some(TableChange.addColumn(
add.fieldNames(),
add.dataType(),
add.isNullable,
add.comment,
pos))
Some(addColumn(schema, "root", Nil))
}

case typeChange: UpdateColumnType =>
Expand Down Expand Up @@ -3156,17 +3163,18 @@ class Analyzer(

private def findColumnPosition(
position: ColumnPosition,
field: String,
struct: StructType): ColumnPosition = {
parentName: String,
struct: StructType,
fieldsAdded: Seq[String]): ColumnPosition = {
position match {
case null => null
case after: After =>
struct.fieldNames.find(n => conf.resolver(n, after.column())) match {
(struct.fieldNames ++ fieldsAdded).find(n => conf.resolver(n, after.column())) match {
case Some(colName) =>
ColumnPosition.after(colName)
case None =>
throw new AnalysisException("Couldn't find the reference column for " +
s"$after at $field")
s"$after at $parentName")
}
case other => other
}
Expand Down
Expand Up @@ -440,12 +440,16 @@ trait CheckAnalysis extends PredicateHelper {
}
field.get._2
}
def positionArgumentExists(position: ColumnPosition, struct: StructType): Unit = {
def positionArgumentExists(
position: ColumnPosition,
struct: StructType,
fieldsAdded: Seq[String]): Unit = {
position match {
case after: After =>
if (!struct.fieldNames.contains(after.column())) {
val allFields = struct.fieldNames ++ fieldsAdded
if (!allFields.contains(after.column())) {
alter.failAnalysis(s"Couldn't resolve positional argument $position amongst " +
s"${struct.fieldNames.mkString("[", ", ", "]")}")
s"${allFields.mkString("[", ", ", "]")}")
}
case _ =>
}
Expand Down Expand Up @@ -474,6 +478,11 @@ trait CheckAnalysis extends PredicateHelper {
}

val colsToDelete = mutable.Set.empty[Seq[String]]
// 'colsToAdd' keeps track of new columns being added. It stores a mapping from a parent
// name of fields to field names that belong to the parent. For example, if we add
// columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become
// Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")).
val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]]

alter.changes.foreach {
case add: AddColumn =>
Expand All @@ -483,8 +492,11 @@ trait CheckAnalysis extends PredicateHelper {
checkColumnNotExists("add", add.fieldNames(), table.schema)
}
val parent = findParentStruct("add", add.fieldNames())
positionArgumentExists(add.position(), parent)
val parentName = add.fieldNames().init
val fieldsAdded = colsToAdd.getOrElse(parentName, Nil)
positionArgumentExists(add.position(), parent, fieldsAdded)
TypeUtils.failWithIntervalType(add.dataType())
colsToAdd(parentName) = fieldsAdded :+ add.fieldNames().last
case update: UpdateColumnType =>
val field = findField("update", update.fieldNames)
val fieldName = update.fieldNames.quoted
Expand Down Expand Up @@ -523,7 +535,11 @@ trait CheckAnalysis extends PredicateHelper {
case updatePos: UpdateColumnPosition =>
findField("update", updatePos.fieldNames)
val parent = findParentStruct("update", updatePos.fieldNames())
positionArgumentExists(updatePos.position(), parent)
val parentName = updatePos.fieldNames().init
positionArgumentExists(
updatePos.position(),
parent,
colsToAdd.getOrElse(parentName, Nil))
case rename: RenameColumn =>
findField("rename", rename.fieldNames)
checkColumnNotExists(
Expand Down
Expand Up @@ -173,6 +173,42 @@ trait AlterTableTests extends SharedSparkSession {
}
}

test("SPARK-30814: add column with position referencing new columns being added") {
val t = s"${catalogAndNamespace}table_name"
withTable(t) {
sql(s"CREATE TABLE $t (a string, b int, point struct<x: double, y: double>) USING $v2Format")
sql(s"ALTER TABLE $t ADD COLUMNS (x int AFTER a, y int AFTER x, z int AFTER y)")

assert(getTableMetadata(t).schema === new StructType()
.add("a", StringType)
.add("x", IntegerType)
.add("y", IntegerType)
.add("z", IntegerType)
.add("b", IntegerType)
.add("point", new StructType()
.add("x", DoubleType)
.add("y", DoubleType)))

sql(s"ALTER TABLE $t ADD COLUMNS (point.z double AFTER x, point.zz double AFTER z)")
assert(getTableMetadata(t).schema === new StructType()
.add("a", StringType)
.add("x", IntegerType)
.add("y", IntegerType)
.add("z", IntegerType)
.add("b", IntegerType)
.add("point", new StructType()
.add("x", DoubleType)
.add("z", DoubleType)
.add("zz", DoubleType)
.add("y", DoubleType)))

// The new column being referenced should come before being referenced.
val e = intercept[AnalysisException](
sql(s"ALTER TABLE $t ADD COLUMNS (yy int AFTER xx, xx int)"))
assert(e.getMessage().contains("Couldn't find the reference column for AFTER xx at root"))
}
}

test("AlterTable: add multiple columns") {
val t = s"${catalogAndNamespace}table_name"
withTable(t) {
Expand Down
Expand Up @@ -151,6 +151,17 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
}
}

test("AlterTable: add column resolution - column position referencing new column") {
alterTableTest(
Seq(
TableChange.addColumn(
Array("x"), LongType, true, null, ColumnPosition.after("id")),
TableChange.addColumn(
Array("y"), LongType, true, null, ColumnPosition.after("X"))),
Seq("Couldn't find the reference column for AFTER X at root")
)
}

test("AlterTable: add column resolution - nested positional") {
Seq("X", "Y").foreach { ref =>
alterTableTest(
Expand All @@ -161,6 +172,17 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
}
}

test("AlterTable: add column resolution - column position referencing new nested column") {
alterTableTest(
Seq(
TableChange.addColumn(
Array("point", "z"), LongType, true, null),
TableChange.addColumn(
Array("point", "zz"), LongType, true, null, ColumnPosition.after("Z"))),
Seq("Couldn't find the reference column for AFTER Z at point")
)
}

test("AlterTable: drop column resolution") {
Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref =>
alterTableTest(
Expand Down Expand Up @@ -207,13 +229,17 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
}

private def alterTableTest(change: TableChange, error: Seq[String]): Unit = {
alterTableTest(Seq(change), error)
}

private def alterTableTest(changes: Seq[TableChange], error: Seq[String]): Unit = {
Seq(true, false).foreach { caseSensitive =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
val plan = AlterTable(
catalog,
Identifier.of(Array(), "table_name"),
TestRelation2,
Seq(change)
changes
)

if (caseSensitive) {
Expand Down

0 comments on commit 5866bc7

Please sign in to comment.