From facf2fce6113d242db8754e9d7dae7f0e42efafc Mon Sep 17 00:00:00 2001 From: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Date: Tue, 20 Apr 2021 11:11:40 +0800 Subject: [PATCH] [SPARK-35080][SQL] Only allow a subset of correlated equality predicates when a subquery is aggregated This PR updated the `foundNonEqualCorrelatedPred` logic for correlated subqueries in `CheckAnalysis` to only allow correlated equality predicates that guarantee one-to-one mapping between inner and outer attributes, instead of all equality predicates. To fix correctness bugs. Before this fix Spark can give wrong results for certain correlated subqueries that pass CheckAnalysis: Example 1: ```sql create or replace view t1(c) as values ('a'), ('b') create or replace view t2(c) as values ('ab'), ('abc'), ('bc') select c, (select count(*) from t2 where t1.c = substring(t2.c, 1, 1)) from t1 ``` Correct results: [(a, 2), (b, 1)] Spark results: ``` +---+-----------------+ |c |scalarsubquery(c)| +---+-----------------+ |a |1 | |a |1 | |b |1 | +---+-----------------+ ``` Example 2: ```sql create or replace view t1(a, b) as values (0, 6), (1, 5), (2, 4), (3, 3); create or replace view t2(c) as values (6); select c, (select count(*) from t1 where a + b = c) from t2; ``` Correct results: [(6, 4)] Spark results: ``` +---+-----------------+ |c |scalarsubquery(c)| +---+-----------------+ |6 |1 | |6 |1 | |6 |1 | |6 |1 | +---+-----------------+ ``` Yes. Users will not be able to run queries that contain unsupported correlated equality predicates. Added unit tests. Closes #32179 from allisonwang-db/spark-35080-subquery-bug. Lead-authored-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan (cherry picked from commit bad4b6f025de4946112a0897892a97d5ae6822cf) Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/CheckAnalysis.scala | 77 ++++++++++++++++--- .../analysis/AnalysisErrorSuite.scala | 24 ++++++ .../sql-tests/results/udf/udf-except.sql.out | 12 ++- .../org/apache/spark/sql/SubquerySuite.scala | 11 ++- 4 files changed, 109 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 3e084f0af5121..3dfe7f46d54bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -891,14 +891,72 @@ trait CheckAnalysis extends PredicateHelper { // +- SubqueryAlias t1, `t1` // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] // +- LocalRelation [_1#73, _2#74] - def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { - if (found) { + // SPARK-35080: The same issue can happen to correlated equality predicates when + // they do not guarantee one-to-one mapping between inner and outer attributes. + // For example: + // Table: + // t1(a, b): [(0, 6), (1, 5), (2, 4)] + // t2(c): [(6)] + // + // Query: + // SELECT c, (SELECT COUNT(*) FROM t1 WHERE a + b = c) FROM t2 + // + // Original subquery plan: + // Aggregate [count(1)] + // +- Filter ((a + b) = outer(c)) + // +- LocalRelation [a, b] + // + // Plan after pulling up correlated predicates: + // Aggregate [a, b] [count(1), a, b] + // +- LocalRelation [a, b] + // + // Plan after rewrite: + // Project [c1, count(1)] + // +- Join LeftOuter ((a + b) = c) + // :- LocalRelation [c] + // +- Aggregate [a, b] [count(1), a, b] + // +- LocalRelation [a, b] + // + // The right hand side of the join transformed from the subquery will output + // count(1) | a | b + // 1 | 0 | 6 + // 1 | 1 | 5 + // 1 | 2 | 4 + // and the plan after rewrite will give the original query incorrect results. + def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = { + if (predicates.nonEmpty) { // Report a non-supported case as an exception - failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") + failAnalysis("Correlated column is not allowed in predicate " + + s"${predicates.map(_.sql).mkString}:\n$p") } } - var foundNonEqualCorrelatedPred: Boolean = false + def containsAttribute(e: Expression): Boolean = { + e.find(_.isInstanceOf[Attribute]).isDefined + } + + // Given a correlated predicate, check if it is either a non-equality predicate or + // equality predicate that does not guarantee one-on-one mapping between inner and + // outer attributes. When the correlated predicate does not contain any attribute + // (i.e. only has outer references), it is supported and should return false. E.G.: + // (a = outer(c)) -> false + // (outer(c) = outer(d)) -> false + // (a > outer(c)) -> true + // (a + b = outer(c)) -> true + // The last one is true because there can be multiple combinations of (a, b) that + // satisfy the equality condition. For example, if outer(c) = 0, then both (0, 0) + // and (-1, 1) can make the predicate evaluate to true. + def isUnsupportedPredicate(condition: Expression): Boolean = condition match { + // Only allow equality condition with one side being an attribute and another + // side being an expression without attributes from the inner query. Note + // OuterReference is a leaf node and will not be found here. + case Equality(_: Attribute, b) => containsAttribute(b) + case Equality(a, _: Attribute) => containsAttribute(a) + case e @ Equality(_, _) => containsAttribute(e) + case _ => true + } + + val unsupportedPredicates = mutable.ArrayBuffer.empty[Expression] // Simplify the predicates before validating any unsupported correlation patterns in the plan. AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp { @@ -941,22 +999,17 @@ trait CheckAnalysis extends PredicateHelper { // The other operator is Join. Filter can be anywhere in a correlated subquery. case f: Filter => val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) - - // Find any non-equality correlated predicates - foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { - case _: EqualTo | _: EqualNullSafe => false - case _ => true - } + unsupportedPredicates ++= correlated.filter(isUnsupportedPredicate) failOnInvalidOuterReference(f) // Aggregate cannot host any correlated expressions // It can be on a correlation path if the correlation contains - // only equality correlated predicates. + // only supported correlated equality predicates. // It cannot be on a correlation path if the correlation has // non-equality correlated predicates. case a: Aggregate => failOnInvalidOuterReference(a) - failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) + failOnUnsupportedCorrelatedPredicate(unsupportedPredicates.toSeq, a) // Join can host correlated expressions. case j @ Join(left, right, joinType, _, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 44128c4419951..20ba9c5f30426 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -700,4 +700,28 @@ class AnalysisErrorSuite extends AnalysisTest { UnresolvedRelation(TableIdentifier("t", Option("nonexist"))))))) assertAnalysisError(plan, "Table or view not found:" :: Nil) } + + test("SPARK-35080: Unsupported correlated equality predicates in subquery") { + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + val c = AttributeReference("c", IntegerType)() + val t1 = LocalRelation(a, b) + val t2 = LocalRelation(c) + val conditions = Seq( + (abs($"a") === $"c", "abs(`a`) = outer(`c`)"), + (abs($"a") <=> $"c", "abs(`a`) <=> outer(`c`)"), + ($"a" + 1 === $"c", "(`a` + 1) = outer(`c`)"), + ($"a" + $"b" === $"c", "(`a` + `b`) = outer(`c`)"), + ($"a" + $"c" === $"b", "(`a` + outer(`c`)) = `b`"), + (And($"a" === $"c", Cast($"a", IntegerType) === $"c"), "CAST(`a` AS INT) = outer(`c`)")) + conditions.foreach { case (cond, msg) => + val plan = Project( + ScalarSubquery( + Aggregate(Nil, count(Literal(1)).as("cnt") :: Nil, + Filter(cond, t1)) + ).as("sub") :: Nil, + t2) + assertAnalysisError(plan, s"Correlated column is not allowed in predicate ($msg)" :: Nil) + } + } } diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out index 054ee00ecc2ae..43506b49a6683 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out @@ -100,6 +100,14 @@ WHERE udf(t1.v) >= (SELECT min(udf(t2.v)) FROM t2 WHERE t2.k = t1.k) -- !query schema -struct +struct<> -- !query output -two +org.apache.spark.sql.AnalysisException +Correlated column is not allowed in predicate (CAST(udf(cast(k as string)) AS STRING) = CAST(udf(cast(outer(k#x) as string)) AS STRING)): +Aggregate [cast(udf(cast(max(cast(udf(cast(v#x as string)) as int)) as string)) as int) AS CAST(udf(cast(max(cast(udf(cast(v as string)) as int)) as string)) AS INT)#x] ++- Filter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string)) + +- SubqueryAlias t2 + +- View (`t2`, [k#x,v#x]) + +- Project [k#x, v#x] + +- SubqueryAlias t2 + +- LocalRelation [k#x, v#x] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 73b23496de515..fafe1bb39336f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -542,7 +542,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1") } assert(msg1.getMessage.contains( - "Correlated column is not allowed in a non-equality predicate:")) + "Correlated column is not allowed in predicate (l2.`a` < outer(l1.`a`))")) } test("disjunctive correlated scalar subquery") { @@ -1753,4 +1753,13 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-35080: correlated equality predicates contain only outer references") { + withTempView("t") { + Seq((0, 1), (1, 1)).toDF("c1", "c2").createOrReplaceTempView("t") + checkAnswer( + sql("select c1, c2, (select count(*) from l where c1 = c2) from t"), + Row(0, 1, 0) :: Row(1, 1, 8) :: Nil) + } + } }