Skip to content

Commit

Permalink
[SPARK-35080][SQL] Only allow a subset of correlated equality predica…
Browse files Browse the repository at this point in the history
…tes when a subquery is aggregated

This PR updated the `foundNonEqualCorrelatedPred` logic for correlated subqueries in `CheckAnalysis` to only allow correlated equality predicates that guarantee one-to-one mapping between inner and outer attributes, instead of all equality predicates.

To fix correctness bugs. Before this fix Spark can give wrong results for certain correlated subqueries that pass CheckAnalysis:
Example 1:
```sql
create or replace view t1(c) as values ('a'), ('b')
create or replace view t2(c) as values ('ab'), ('abc'), ('bc')

select c, (select count(*) from t2 where t1.c = substring(t2.c, 1, 1)) from t1
```
Correct results: [(a, 2), (b, 1)]
Spark results:
```
+---+-----------------+
|c  |scalarsubquery(c)|
+---+-----------------+
|a  |1                |
|a  |1                |
|b  |1                |
+---+-----------------+
```
Example 2:
```sql
create or replace view t1(a, b) as values (0, 6), (1, 5), (2, 4), (3, 3);
create or replace view t2(c) as values (6);

select c, (select count(*) from t1 where a + b = c) from t2;
```
Correct results: [(6, 4)]
Spark results:
```
+---+-----------------+
|c  |scalarsubquery(c)|
+---+-----------------+
|6  |1                |
|6  |1                |
|6  |1                |
|6  |1                |
+---+-----------------+
```
Yes. Users will not be able to run queries that contain unsupported correlated equality predicates.

Added unit tests.

Closes apache#32179 from allisonwang-db/spark-35080-subquery-bug.

Lead-authored-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit bad4b6f)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
2 people authored and sunchao committed Apr 26, 2021
1 parent c9ac7eb commit facf2fc
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -891,14 +891,72 @@ trait CheckAnalysis extends PredicateHelper {
// +- SubqueryAlias t1, `t1`
// +- Project [_1#73 AS c1#76, _2#74 AS c2#77]
// +- LocalRelation [_1#73, _2#74]
def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = {
if (found) {
// SPARK-35080: The same issue can happen to correlated equality predicates when
// they do not guarantee one-to-one mapping between inner and outer attributes.
// For example:
// Table:
// t1(a, b): [(0, 6), (1, 5), (2, 4)]
// t2(c): [(6)]
//
// Query:
// SELECT c, (SELECT COUNT(*) FROM t1 WHERE a + b = c) FROM t2
//
// Original subquery plan:
// Aggregate [count(1)]
// +- Filter ((a + b) = outer(c))
// +- LocalRelation [a, b]
//
// Plan after pulling up correlated predicates:
// Aggregate [a, b] [count(1), a, b]
// +- LocalRelation [a, b]
//
// Plan after rewrite:
// Project [c1, count(1)]
// +- Join LeftOuter ((a + b) = c)
// :- LocalRelation [c]
// +- Aggregate [a, b] [count(1), a, b]
// +- LocalRelation [a, b]
//
// The right hand side of the join transformed from the subquery will output
// count(1) | a | b
// 1 | 0 | 6
// 1 | 1 | 5
// 1 | 2 | 4
// and the plan after rewrite will give the original query incorrect results.
def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = {
if (predicates.nonEmpty) {
// Report a non-supported case as an exception
failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p")
failAnalysis("Correlated column is not allowed in predicate " +
s"${predicates.map(_.sql).mkString}:\n$p")
}
}

var foundNonEqualCorrelatedPred: Boolean = false
def containsAttribute(e: Expression): Boolean = {
e.find(_.isInstanceOf[Attribute]).isDefined
}

// Given a correlated predicate, check if it is either a non-equality predicate or
// equality predicate that does not guarantee one-on-one mapping between inner and
// outer attributes. When the correlated predicate does not contain any attribute
// (i.e. only has outer references), it is supported and should return false. E.G.:
// (a = outer(c)) -> false
// (outer(c) = outer(d)) -> false
// (a > outer(c)) -> true
// (a + b = outer(c)) -> true
// The last one is true because there can be multiple combinations of (a, b) that
// satisfy the equality condition. For example, if outer(c) = 0, then both (0, 0)
// and (-1, 1) can make the predicate evaluate to true.
def isUnsupportedPredicate(condition: Expression): Boolean = condition match {
// Only allow equality condition with one side being an attribute and another
// side being an expression without attributes from the inner query. Note
// OuterReference is a leaf node and will not be found here.
case Equality(_: Attribute, b) => containsAttribute(b)
case Equality(a, _: Attribute) => containsAttribute(a)
case e @ Equality(_, _) => containsAttribute(e)
case _ => true
}

val unsupportedPredicates = mutable.ArrayBuffer.empty[Expression]

// Simplify the predicates before validating any unsupported correlation patterns in the plan.
AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp {
Expand Down Expand Up @@ -941,22 +999,17 @@ trait CheckAnalysis extends PredicateHelper {
// The other operator is Join. Filter can be anywhere in a correlated subquery.
case f: Filter =>
val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter)

// Find any non-equality correlated predicates
foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists {
case _: EqualTo | _: EqualNullSafe => false
case _ => true
}
unsupportedPredicates ++= correlated.filter(isUnsupportedPredicate)
failOnInvalidOuterReference(f)

// Aggregate cannot host any correlated expressions
// It can be on a correlation path if the correlation contains
// only equality correlated predicates.
// only supported correlated equality predicates.
// It cannot be on a correlation path if the correlation has
// non-equality correlated predicates.
case a: Aggregate =>
failOnInvalidOuterReference(a)
failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
failOnUnsupportedCorrelatedPredicate(unsupportedPredicates.toSeq, a)

// Join can host correlated expressions.
case j @ Join(left, right, joinType, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,4 +700,28 @@ class AnalysisErrorSuite extends AnalysisTest {
UnresolvedRelation(TableIdentifier("t", Option("nonexist")))))))
assertAnalysisError(plan, "Table or view not found:" :: Nil)
}

test("SPARK-35080: Unsupported correlated equality predicates in subquery") {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val c = AttributeReference("c", IntegerType)()
val t1 = LocalRelation(a, b)
val t2 = LocalRelation(c)
val conditions = Seq(
(abs($"a") === $"c", "abs(`a`) = outer(`c`)"),
(abs($"a") <=> $"c", "abs(`a`) <=> outer(`c`)"),
($"a" + 1 === $"c", "(`a` + 1) = outer(`c`)"),
($"a" + $"b" === $"c", "(`a` + `b`) = outer(`c`)"),
($"a" + $"c" === $"b", "(`a` + outer(`c`)) = `b`"),
(And($"a" === $"c", Cast($"a", IntegerType) === $"c"), "CAST(`a` AS INT) = outer(`c`)"))
conditions.foreach { case (cond, msg) =>
val plan = Project(
ScalarSubquery(
Aggregate(Nil, count(Literal(1)).as("cnt") :: Nil,
Filter(cond, t1))
).as("sub") :: Nil,
t2)
assertAnalysisError(plan, s"Correlated column is not allowed in predicate ($msg)" :: Nil)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ WHERE udf(t1.v) >= (SELECT min(udf(t2.v))
FROM t2
WHERE t2.k = t1.k)
-- !query schema
struct<k:string>
struct<>
-- !query output
two
org.apache.spark.sql.AnalysisException
Correlated column is not allowed in predicate (CAST(udf(cast(k as string)) AS STRING) = CAST(udf(cast(outer(k#x) as string)) AS STRING)):
Aggregate [cast(udf(cast(max(cast(udf(cast(v#x as string)) as int)) as string)) as int) AS CAST(udf(cast(max(cast(udf(cast(v as string)) as int)) as string)) AS INT)#x]
+- Filter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))
+- SubqueryAlias t2
+- View (`t2`, [k#x,v#x])
+- Project [k#x, v#x]
+- SubqueryAlias t2
+- LocalRelation [k#x, v#x]
11 changes: 10 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1")
}
assert(msg1.getMessage.contains(
"Correlated column is not allowed in a non-equality predicate:"))
"Correlated column is not allowed in predicate (l2.`a` < outer(l1.`a`))"))
}

test("disjunctive correlated scalar subquery") {
Expand Down Expand Up @@ -1753,4 +1753,13 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}
}

test("SPARK-35080: correlated equality predicates contain only outer references") {
withTempView("t") {
Seq((0, 1), (1, 1)).toDF("c1", "c2").createOrReplaceTempView("t")
checkAnswer(
sql("select c1, c2, (select count(*) from l where c1 = c2) from t"),
Row(0, 1, 0) :: Row(1, 1, 8) :: Nil)
}
}
}

0 comments on commit facf2fc

Please sign in to comment.