diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c2818d957cc79..ae28292d941bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,6 +36,8 @@ object DefaultOptimizer extends Optimizer { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: + Batch("Transform Condition", FixedPoint(100), + TransformCondition) :: Batch("Operator Reordering", FixedPoint(100), UnionPushdown, CombineFilters, @@ -60,6 +62,80 @@ object DefaultOptimizer extends Optimizer { ConvertToLocalRelation) :: Nil } +/** + * Transform and/or Condition: + * 1. a && a => a + * 2. (a || b) && (a || c) => a || (b && c) + * 3. a || a => a + * 4. (a && b) || (a && c) => a && (b || c) + */ +object TransformCondition extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + case and @ And(left, right) => (left, right) match { + + // a && a => a + case (l, r) if l fastEquals r => l + // (a || b) && (a || c) => a || (b && c) + case _ => + // 1. Split left and right to get the disjunctive predicates, + // i.e. lhsSet = (a, b), rhsSet = (a, c) + // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) + // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) + // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff) + val lhsSet = splitDisjunctivePredicates(left).toSet + val rhsSet = splitDisjunctivePredicates(right).toSet + val common = lhsSet.intersect(rhsSet) + if (common.isEmpty) { + // No common factors, return the original predicate + and + } else { + val ldiff = lhsSet.diff(common) + val rdiff = rhsSet.diff(common) + if (ldiff.isEmpty || rdiff.isEmpty) { + // (a || b || c || ...) && (a || b) => (a || b) + common.reduce(Or) + } else { + // (a || b || c || ...) && (a || b || d || ...) => + // ((c || ...) && (d || ...)) || a || b + (common + And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) + } + } + } // end of And(left, right) + + case or @ Or(left, right) => (left, right) match { + + case (l, r) if l fastEquals r => l + // (a && b) || (a && c) => a && (b || c) + case _ => + // 1. Split left and right to get the conjunctive predicates, + // i.e. lhsSet = (a, b), rhsSet = (a, c) + // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) + // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) + // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff) + val lhsSet = splitConjunctivePredicates(left).toSet + val rhsSet = splitConjunctivePredicates(right).toSet + val common = lhsSet.intersect(rhsSet) + if (common.isEmpty) { + // No common factors, return the original predicate + or + } else { + val ldiff = lhsSet.diff(common) + val rdiff = rhsSet.diff(common) + if (ldiff.isEmpty || rdiff.isEmpty) { + // (a && b) || (a && b && c && ...) => a && b + common.reduce(And) + } else { + // (a && b && c && ...) || (a && b && d && ...) => + // ((c && ...) || (d && ...)) && a && b + (common + Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) + } + } + } // end of Or(left, right) + } + } +} + /** * Pushes operations to either side of a Union. */ @@ -347,32 +423,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { // l && false => false case (_, Literal(false, BooleanType)) => Literal(false) // a && a => a - case (l, r) if l fastEquals r => l - // (a || b) && (a || c) => a || (b && c) - case _ => - // 1. Split left and right to get the disjunctive predicates, - // i.e. lhsSet = (a, b), rhsSet = (a, c) - // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) - // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) - // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff) - val lhsSet = splitDisjunctivePredicates(left).toSet - val rhsSet = splitDisjunctivePredicates(right).toSet - val common = lhsSet.intersect(rhsSet) - if (common.isEmpty) { - // No common factors, return the original predicate - and - } else { - val ldiff = lhsSet.diff(common) - val rdiff = rhsSet.diff(common) - if (ldiff.isEmpty || rdiff.isEmpty) { - // (a || b || c || ...) && (a || b) => (a || b) - common.reduce(Or) - } else { - // (a || b || c || ...) && (a || b || d || ...) => - // ((c || ...) && (d || ...)) || a || b - (common + And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) - } - } + case _ => and } // end of And(left, right) case or @ Or(left, right) => (left, right) match { @@ -384,33 +435,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case (Literal(false, BooleanType), r) => r // l || false => l case (l, Literal(false, BooleanType)) => l - // a || a => a - case (l, r) if l fastEquals r => l - // (a && b) || (a && c) => a && (b || c) - case _ => - // 1. Split left and right to get the conjunctive predicates, - // i.e. lhsSet = (a, b), rhsSet = (a, c) - // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) - // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) - // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff) - val lhsSet = splitConjunctivePredicates(left).toSet - val rhsSet = splitConjunctivePredicates(right).toSet - val common = lhsSet.intersect(rhsSet) - if (common.isEmpty) { - // No common factors, return the original predicate - or - } else { - val ldiff = lhsSet.diff(common) - val rdiff = rhsSet.diff(common) - if (ldiff.isEmpty || rdiff.isEmpty) { - // (a && b) || (a && b && c && ...) => a && b - common.reduce(And) - } else { - // (a && b && c && ...) || (a && b && d && ...) => - // ((c && ...) || (d && ...)) && a && b - (common + Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) - } - } + case _ => or } // end of Or(left, right) case not @ Not(exp) => exp match {