Skip to content

Commit

Permalink
[SPARK-35080][SQL] Only allow a subset of correlated equality predica…
Browse files Browse the repository at this point in the history
…tes when a subquery is aggregated

### What changes were proposed in this pull request?
This PR updated the `foundNonEqualCorrelatedPred` logic for correlated subqueries in `CheckAnalysis` to only allow correlated equality predicates that guarantee one-to-one mapping between inner and outer attributes, instead of all equality predicates.

### Why are the changes needed?
To fix correctness bugs. Before this fix Spark can give wrong results for certain correlated subqueries that pass CheckAnalysis:
Example 1:
```sql
create or replace view t1(c) as values ('a'), ('b')
create or replace view t2(c) as values ('ab'), ('abc'), ('bc')

select c, (select count(*) from t2 where t1.c = substring(t2.c, 1, 1)) from t1
```
Correct results: [(a, 2), (b, 1)]
Spark results:
```
+---+-----------------+
|c  |scalarsubquery(c)|
+---+-----------------+
|a  |1                |
|a  |1                |
|b  |1                |
+---+-----------------+
```
Example 2:
```sql
create or replace view t1(a, b) as values (0, 6), (1, 5), (2, 4), (3, 3);
create or replace view t2(c) as values (6);

select c, (select count(*) from t1 where a + b = c) from t2;
```
Correct results: [(6, 4)]
Spark results:
```
+---+-----------------+
|c  |scalarsubquery(c)|
+---+-----------------+
|6  |1                |
|6  |1                |
|6  |1                |
|6  |1                |
+---+-----------------+
```
### Does this PR introduce _any_ user-facing change?
Yes. Users will not be able to run queries that contain unsupported correlated equality predicates.

### How was this patch tested?
Added unit tests.

Closes apache#32179 from allisonwang-db/spark-35080-subquery-bug.

Lead-authored-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
allisonwang-db and cloud-fan committed Apr 20, 2021
1 parent e55ff83 commit bad4b6f
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -893,14 +893,72 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
// +- SubqueryAlias t1, `t1`
// +- Project [_1#73 AS c1#76, _2#74 AS c2#77]
// +- LocalRelation [_1#73, _2#74]
def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = {
if (found) {
// SPARK-35080: The same issue can happen to correlated equality predicates when
// they do not guarantee one-to-one mapping between inner and outer attributes.
// For example:
// Table:
// t1(a, b): [(0, 6), (1, 5), (2, 4)]
// t2(c): [(6)]
//
// Query:
// SELECT c, (SELECT COUNT(*) FROM t1 WHERE a + b = c) FROM t2
//
// Original subquery plan:
// Aggregate [count(1)]
// +- Filter ((a + b) = outer(c))
// +- LocalRelation [a, b]
//
// Plan after pulling up correlated predicates:
// Aggregate [a, b] [count(1), a, b]
// +- LocalRelation [a, b]
//
// Plan after rewrite:
// Project [c1, count(1)]
// +- Join LeftOuter ((a + b) = c)
// :- LocalRelation [c]
// +- Aggregate [a, b] [count(1), a, b]
// +- LocalRelation [a, b]
//
// The right hand side of the join transformed from the subquery will output
// count(1) | a | b
// 1 | 0 | 6
// 1 | 1 | 5
// 1 | 2 | 4
// and the plan after rewrite will give the original query incorrect results.
def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = {
if (predicates.nonEmpty) {
// Report a non-supported case as an exception
failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p")
failAnalysis("Correlated column is not allowed in predicate " +
s"${predicates.map(_.sql).mkString}:\n$p")
}
}

var foundNonEqualCorrelatedPred: Boolean = false
def containsAttribute(e: Expression): Boolean = {
e.find(_.isInstanceOf[Attribute]).isDefined
}

// Given a correlated predicate, check if it is either a non-equality predicate or
// equality predicate that does not guarantee one-on-one mapping between inner and
// outer attributes. When the correlated predicate does not contain any attribute
// (i.e. only has outer references), it is supported and should return false. E.G.:
// (a = outer(c)) -> false
// (outer(c) = outer(d)) -> false
// (a > outer(c)) -> true
// (a + b = outer(c)) -> true
// The last one is true because there can be multiple combinations of (a, b) that
// satisfy the equality condition. For example, if outer(c) = 0, then both (0, 0)
// and (-1, 1) can make the predicate evaluate to true.
def isUnsupportedPredicate(condition: Expression): Boolean = condition match {
// Only allow equality condition with one side being an attribute and another
// side being an expression without attributes from the inner query. Note
// OuterReference is a leaf node and will not be found here.
case Equality(_: Attribute, b) => containsAttribute(b)
case Equality(a, _: Attribute) => containsAttribute(a)
case e @ Equality(_, _) => containsAttribute(e)
case _ => true
}

val unsupportedPredicates = mutable.ArrayBuffer.empty[Expression]

// Simplify the predicates before validating any unsupported correlation patterns in the plan.
AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp {
Expand Down Expand Up @@ -943,22 +1001,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
// The other operator is Join. Filter can be anywhere in a correlated subquery.
case f: Filter =>
val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter)

// Find any non-equality correlated predicates
foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists {
case _: EqualTo | _: EqualNullSafe => false
case _ => true
}
unsupportedPredicates ++= correlated.filter(isUnsupportedPredicate)
failOnInvalidOuterReference(f)

// Aggregate cannot host any correlated expressions
// It can be on a correlation path if the correlation contains
// only equality correlated predicates.
// only supported correlated equality predicates.
// It cannot be on a correlation path if the correlation has
// non-equality correlated predicates.
case a: Aggregate =>
failOnInvalidOuterReference(a)
failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
failOnUnsupportedCorrelatedPredicate(unsupportedPredicates.toSeq, a)

// Join can host correlated expressions.
case j @ Join(left, right, joinType, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,4 +767,28 @@ class AnalysisErrorSuite extends AnalysisTest {
"using ordinal position or wrap it in first() (or first_value) if you don't care " +
"which value you get." :: Nil)
}

test("SPARK-35080: Unsupported correlated equality predicates in subquery") {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val c = AttributeReference("c", IntegerType)()
val t1 = LocalRelation(a, b)
val t2 = LocalRelation(c)
val conditions = Seq(
(abs($"a") === $"c", "abs(a) = outer(c)"),
(abs($"a") <=> $"c", "abs(a) <=> outer(c)"),
($"a" + 1 === $"c", "(a + 1) = outer(c)"),
($"a" + $"b" === $"c", "(a + b) = outer(c)"),
($"a" + $"c" === $"b", "(a + outer(c)) = b"),
(And($"a" === $"c", Cast($"a", IntegerType) === $"c"), "CAST(a AS INT) = outer(c)"))
conditions.foreach { case (cond, msg) =>
val plan = Project(
ScalarSubquery(
Aggregate(Nil, count(Literal(1)).as("cnt") :: Nil,
Filter(cond, t1))
).as("sub") :: Nil,
t2)
assertAnalysisError(plan, s"Correlated column is not allowed in predicate ($msg)" :: Nil)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ WHERE udf(t1.v) >= (SELECT min(udf(t2.v))
FROM t2
WHERE t2.k = t1.k)
-- !query schema
struct<k:string>
struct<>
-- !query output
two
org.apache.spark.sql.AnalysisException
Correlated column is not allowed in predicate (CAST(udf(cast(k as string)) AS STRING) = CAST(udf(cast(outer(k#x) as string)) AS STRING)):
Aggregate [cast(udf(cast(max(cast(udf(cast(v#x as string)) as int)) as string)) as int) AS udf(max(udf(v)))#x]
+- Filter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))
+- SubqueryAlias t2
+- View (`t2`, [k#x,v#x])
+- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]
+- Project [k#x, v#x]
+- SubqueryAlias t2
+- LocalRelation [k#x, v#x]
11 changes: 10 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1")
}
assert(msg1.getMessage.contains(
"Correlated column is not allowed in a non-equality predicate:"))
"Correlated column is not allowed in predicate (l2.a < outer(l1.a))"))
}

test("disjunctive correlated scalar subquery") {
Expand Down Expand Up @@ -1827,4 +1827,13 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
Row(0, 1, 1) :: Row(1, 2, null) :: Nil)
}
}

test("SPARK-35080: correlated equality predicates contain only outer references") {
withTempView("t") {
Seq((0, 1), (1, 1)).toDF("c1", "c2").createOrReplaceTempView("t")
checkAnswer(
sql("select c1, c2, (select count(*) from l where c1 = c2) from t"),
Row(0, 1, 0) :: Row(1, 1, 8) :: Nil)
}
}
}

0 comments on commit bad4b6f

Please sign in to comment.