Skip to content

Commit

Permalink
[SPARK-35213][SQL] Keep the correct ordering of nested structs in cha…
Browse files Browse the repository at this point in the history
…ined withField operations

### What changes were proposed in this pull request?

Modifies the UpdateFields optimizer to fix correctness issues with certain nested and chained withField operations. Examples for recreating the issue are in the new unit tests as well as the JIRA issue.

### Why are the changes needed?

Certain withField patterns can cause Exceptions or even incorrect results. It appears to be a result of the additional UpdateFields optimization added in apache#29812. It traverses fieldOps in reverse order to take the last one per field, but this can cause nested structs to change order which leads to mismatches between the schema and the actual data. This updates the optimization to maintain the initial ordering of nested structs to match the generated schema.

### Does this PR introduce _any_ user-facing change?

It fixes exceptions and incorrect results for valid uses in the latest Spark release.

### How was this patch tested?

Added new unit tests for these edge cases.

Closes apache#32338 from Kimahriman/bug/optimize-with-fields.

Authored-by: Adam Binford <adamq43@gmail.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
(cherry picked from commit 74afc68)
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
  • Loading branch information
Kimahriman authored and sunchao committed Apr 26, 2021
1 parent bd38aeb commit eb9f984
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,22 @@ object OptimizeUpdateFields extends Rule[LogicalPlan] {
val values = withFields.map(_.valExpr)

val newNames = mutable.ArrayBuffer.empty[String]
val newValues = mutable.ArrayBuffer.empty[Expression]
val newValues = mutable.HashMap.empty[String, Expression]
// Used to remember the casing of the last instance
val nameMap = mutable.HashMap.empty[String, String]

if (caseSensitive) {
names.zip(values).reverse.foreach { case (name, value) =>
if (!newNames.contains(name)) {
newNames += name
newValues += value
}
}
} else {
val nameSet = mutable.HashSet.empty[String]
names.zip(values).reverse.foreach { case (name, value) =>
val lowercaseName = name.toLowerCase(Locale.ROOT)
if (!nameSet.contains(lowercaseName)) {
newNames += name
newValues += value
nameSet += lowercaseName
}
names.zip(values).foreach { case (name, value) =>
val normalizedName = if (caseSensitive) name else name.toLowerCase(Locale.ROOT)
if (nameMap.contains(normalizedName)) {
newValues += normalizedName -> value
} else {
newNames += normalizedName
newValues += normalizedName -> value
}
nameMap += normalizedName -> name
}

val newWithFields = newNames.reverse.zip(newValues.reverse).map(p => WithField(p._1, p._2))
val newWithFields = newNames.map(n => WithField(nameMap(n), newValues(n)))
UpdateFields(structExpr, newWithFields.toSeq)

case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,25 @@ class OptimizeWithFieldsSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
}

test("SPARK-35213: ensure optimize WithFields maintains correct WithField ordering") {
val originalQuery = testRelation
.select(
Alias(UpdateFields('a,
WithField("a1", Literal(3)) ::
WithField("b1", Literal(4)) ::
WithField("a1", Literal(5)) ::
Nil), "out")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(
Alias(UpdateFields('a,
WithField("a1", Literal(5)) ::
WithField("b1", Literal(4)) ::
Nil), "out")())
.analyze

comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1660,6 +1660,61 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
StructType(Seq(StructField("a", IntegerType, nullable = true))))
}

test("SPARK-35213: chained withField operations should have correct schema for new columns") {
val df = spark.createDataFrame(
sparkContext.parallelize(Row(null) :: Nil),
StructType(Seq(StructField("data", NullType))))

checkAnswer(
df.withColumn("data", struct()
.withField("a", struct())
.withField("b", struct())
.withField("a.aa", lit("aa1"))
.withField("b.ba", lit("ba1"))
.withField("a.ab", lit("ab1"))),
Row(Row(Row("aa1", "ab1"), Row("ba1"))) :: Nil,
StructType(Seq(
StructField("data", StructType(Seq(
StructField("a", StructType(Seq(
StructField("aa", StringType, nullable = false),
StructField("ab", StringType, nullable = false)
)), nullable = false),
StructField("b", StructType(Seq(
StructField("ba", StringType, nullable = false)
)), nullable = false)
)), nullable = false)
))
)
}

test("SPARK-35213: optimized withField operations should maintain correct nested struct " +
"ordering") {
val df = spark.createDataFrame(
sparkContext.parallelize(Row(null) :: Nil),
StructType(Seq(StructField("data", NullType))))

checkAnswer(
df.withColumn("data", struct()
.withField("a", struct().withField("aa", lit("aa1")))
.withField("b", struct().withField("ba", lit("ba1")))
)
.withColumn("data", col("data").withField("b.bb", lit("bb1")))
.withColumn("data", col("data").withField("a.ab", lit("ab1"))),
Row(Row(Row("aa1", "ab1"), Row("ba1", "bb1"))) :: Nil,
StructType(Seq(
StructField("data", StructType(Seq(
StructField("a", StructType(Seq(
StructField("aa", StringType, nullable = false),
StructField("ab", StringType, nullable = false)
)), nullable = false),
StructField("b", StructType(Seq(
StructField("ba", StringType, nullable = false),
StructField("bb", StringType, nullable = false)
)), nullable = false)
)), nullable = false)
))
)
}

test("dropFields should throw an exception if called on a non-StructType column") {
intercept[AnalysisException] {
Expand Down

0 comments on commit eb9f984

Please sign in to comment.