diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java index ce23010c8bd98..da7d60b15d4f2 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java @@ -84,6 +84,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.LinkedList; @@ -814,7 +815,7 @@ private Expression simplifyAndOr(BinaryPredicate bc) { } // - // common factor extraction -> (a || b) && (a || c) => a && (b || c) + // common factor extraction -> (a || b) && (a || c) => a || (b && c) // List leftSplit = splitOr(l); List rightSplit = splitOr(r); @@ -852,7 +853,7 @@ private Expression simplifyAndOr(BinaryPredicate bc) { } // - // common factor extraction -> (a && b) || (a && c) => a || (b & c) + // common factor extraction -> (a && b) || (a && c) => a && (b || c) // List leftSplit = splitAnd(l); List rightSplit = splitAnd(r); @@ -958,9 +959,11 @@ private Expression literalToTheRight(BinaryOperator be) { } /** - * Propagate Equals to eliminate conjuncted Ranges. - * When encountering a different Equals or non-containing {@link Range}, the conjunction becomes false. - * When encountering a containing {@link Range}, the range gets eliminated by the equality. + * Propagate Equals to eliminate conjuncted Ranges or BinaryComparisons. + * When encountering a different Equals, non-containing {@link Range} or {@link BinaryComparison}, the conjunction becomes false. + * When encountering a containing {@link Range}, {@link BinaryComparison} or {@link NotEquals}, these get eliminated by the equality. + * + * Since this rule can eliminate Ranges and BinaryComparisons, it should be applied before {@link CombineBinaryComparisons}. * * This rule doesn't perform any promotion of {@link BinaryComparison}s, that is handled by * {@link CombineBinaryComparisons} on purpose as the resulting Range might be foldable @@ -976,6 +979,8 @@ static class PropagateEquals extends OptimizerExpressionRule { protected Expression rule(Expression e) { if (e instanceof And) { return propagate((And) e); + } else if (e instanceof Or) { + return propagate((Or) e); } return e; } @@ -983,7 +988,11 @@ protected Expression rule(Expression e) { // combine conjunction private Expression propagate(And and) { List ranges = new ArrayList<>(); + // Only equalities, not-equalities and inequalities with a foldable .right are extracted separately; + // the others go into the general 'exps'. List equals = new ArrayList<>(); + List notEquals = new ArrayList<>(); + List inequalities = new ArrayList<>(); List exps = new ArrayList<>(); boolean changed = false; @@ -996,24 +1005,35 @@ private Expression propagate(And and) { // equals on different values evaluate to FALSE if (otherEq.right().foldable()) { for (BinaryComparison eq : equals) { - // cannot evaluate equals so skip it - if (!eq.right().foldable()) { - continue; - } if (otherEq.left().semanticEquals(eq.left())) { - if (eq.right().foldable() && otherEq.right().foldable()) { - Integer comp = BinaryComparison.compare(eq.right().fold(), otherEq.right().fold()); - if (comp != null) { - // var cannot be equal to two different values at the same time - if (comp != 0) { - return FALSE; - } + Integer comp = BinaryComparison.compare(eq.right().fold(), otherEq.right().fold()); + if (comp != null) { + // var cannot be equal to two different values at the same time + if (comp != 0) { + return FALSE; } } } } + equals.add(otherEq); + } else { + exps.add(otherEq); + } + } else if (ex instanceof GreaterThan || ex instanceof GreaterThanOrEqual || + ex instanceof LessThan || ex instanceof LessThanOrEqual) { + BinaryComparison bc = (BinaryComparison) ex; + if (bc.right().foldable()) { + inequalities.add(bc); + } else { + exps.add(ex); + } + } else if (ex instanceof NotEquals) { + NotEquals otherNotEq = (NotEquals) ex; + if (otherNotEq.right().foldable()) { + notEquals.add(otherNotEq); + } else { + exps.add(ex); } - equals.add(otherEq); } else { exps.add(ex); } @@ -1021,10 +1041,6 @@ private Expression propagate(And and) { // check for (BinaryComparison eq : equals) { - // cannot evaluate equals so skip it - if (!eq.right().foldable()) { - continue; - } Object eqValue = eq.right().fold(); for (int i = 0; i < ranges.size(); i++) { @@ -1060,9 +1076,192 @@ private Expression propagate(And and) { changed = true; } } + + // evaluate all NotEquals against the Equal + for (Iterator iter = notEquals.iterator(); iter.hasNext(); ) { + NotEquals neq = iter.next(); + if (eq.left().semanticEquals(neq.left())) { + Integer comp = BinaryComparison.compare(eqValue, neq.right().fold()); + if (comp != null) { + if (comp == 0) { + return FALSE; // clashing and conflicting: a = 1 AND a != 1 + } else { + iter.remove(); // clashing and redundant: a = 1 AND a != 2 + changed = true; + } + } + } + } + + // evaluate all inequalities against the Equal + for (Iterator iter = inequalities.iterator(); iter.hasNext(); ) { + BinaryComparison bc = iter.next(); + if (eq.left().semanticEquals(bc.left())) { + Integer compare = BinaryComparison.compare(eqValue, bc.right().fold()); + if (compare != null) { + if (bc instanceof LessThan || bc instanceof LessThanOrEqual) { // a = 2 AND a />= ? + if ((compare == 0 && bc instanceof GreaterThan) || // a = 2 AND a > 2 + compare < 0) { // a = 2 AND a >/>= 3 + return FALSE; + } + } + + iter.remove(); + changed = true; + } + } + } + } + + return changed ? Predicates.combineAnd(CollectionUtils.combine(exps, equals, notEquals, inequalities, ranges)) : and; + } + + // combine disjunction: + // a = 2 OR a > 3 -> nop; a = 2 OR a > 1 -> a > 1 + // a = 2 OR a < 3 -> a < 3; a = 2 OR a < 1 -> nop + // a = 2 OR 3 < a < 5 -> nop; a = 2 OR 1 < a < 3 -> 1 < a < 3; a = 2 OR 0 < a < 1 -> nop + // a = 2 OR a != 2 -> TRUE; a = 2 OR a = 5 -> nop; a = 2 OR a != 5 -> a != 5 + private Expression propagate(Or or) { + List exps = new ArrayList<>(); + List equals = new ArrayList<>(); // foldable right term Equals + List notEquals = new ArrayList<>(); // foldable right term NotEquals + List ranges = new ArrayList<>(); + List inequalities = new ArrayList<>(); // foldable right term (=limit) BinaryComparision + + // split expressions by type + for (Expression ex : Predicates.splitOr(or)) { + if (ex instanceof Equals) { + Equals eq = (Equals) ex; + if (eq.right().foldable()) { + equals.add(eq); + } else { + exps.add(ex); + } + } else if (ex instanceof NotEquals) { + NotEquals neq = (NotEquals) ex; + if (neq.right().foldable()) { + notEquals.add(neq); + } else { + exps.add(ex); + } + } else if (ex instanceof Range) { + ranges.add((Range) ex); + } else if (ex instanceof BinaryComparison) { + BinaryComparison bc = (BinaryComparison) ex; + if (bc.right().foldable()) { + inequalities.add(bc); + } else { + exps.add(ex); + } + } else { + exps.add(ex); + } + } + + boolean updated = false; // has the expression been modified? + + // evaluate the impact of each Equal over the different types of Expressions + for (Iterator iterEq = equals.iterator(); iterEq.hasNext(); ) { + Equals eq = iterEq.next(); + Object eqValue = eq.right().fold(); + boolean removeEquals = false; + + // Equals OR NotEquals + for (NotEquals neq : notEquals) { + if (eq.left().semanticEquals(neq.left())) { // a = 2 OR a != ? -> ... + Integer comp = BinaryComparison.compare(eqValue, neq.right().fold()); + if (comp != null) { + if (comp == 0) { // a = 2 OR a != 2 -> TRUE + return TRUE; + } else { // a = 2 OR a != 5 -> a != 5 + removeEquals = true; + break; + } + } + } + } + if (removeEquals) { + iterEq.remove(); + updated = true; + continue; + } + + // Equals OR Range + for (int i = 0; i < ranges.size(); i ++) { // might modify list, so use index loop + Range range = ranges.get(i); + if (eq.left().semanticEquals(range.value())) { + Integer lowerComp = range.lower().foldable() ? BinaryComparison.compare(eqValue, range.lower().fold()) : null; + Integer upperComp = range.upper().foldable() ? BinaryComparison.compare(eqValue, range.upper().fold()) : null; + + if (lowerComp != null && lowerComp == 0) { + if (!range.includeLower()) { // a = 2 OR 2 < a < ? -> 2 <= a < ? + ranges.set(i, new Range(range.source(), range.value(), range.lower(), true, + range.upper(), range.includeUpper())); + } // else : a = 2 OR 2 <= a < ? -> 2 <= a < ? + removeEquals = true; // update range with lower equality instead or simply superfluous + break; + } else if (upperComp != null && upperComp == 0) { + if (!range.includeUpper()) { // a = 2 OR ? < a < 2 -> ? < a <= 2 + ranges.set(i, new Range(range.source(), range.value(), range.lower(), range.includeLower(), + range.upper(), true)); + } // else : a = 2 OR ? < a <= 2 -> ? < a <= 2 + removeEquals = true; // update range with upper equality instead + break; + } else if (lowerComp != null && upperComp != null) { + if (0 < lowerComp && upperComp < 0) { // a = 2 OR 1 < a < 3 + removeEquals = true; // equality is superfluous + break; + } + } + } + } + if (removeEquals) { + iterEq.remove(); + updated = true; + continue; + } + + // Equals OR Inequality + for (int i = 0; i < inequalities.size(); i ++) { + BinaryComparison bc = inequalities.get(i); + if (eq.left().semanticEquals(bc.left())) { + Integer comp = BinaryComparison.compare(eqValue, bc.right().fold()); + if (comp != null) { + if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) { + if (comp < 0) { // a = 1 OR a > 2 -> nop + continue; + } else if (comp == 0 && bc instanceof GreaterThan) { // a = 2 OR a > 2 -> a >= 2 + inequalities.set(i, new GreaterThanOrEqual(bc.source(), bc.left(), bc.right())); + } // else (0 < comp || bc instanceof GreaterThanOrEqual) : + // a = 3 OR a > 2 -> a > 2; a = 2 OR a => 2 -> a => 2 + + removeEquals = true; // update range with equality instead or simply superfluous + break; + } else if (bc instanceof LessThan || bc instanceof LessThanOrEqual) { + if (comp > 0) { // a = 2 OR a < 1 -> nop + continue; + } + if (comp == 0 && bc instanceof LessThan) { // a = 2 OR a < 2 -> a <= 2 + inequalities.set(i, new LessThanOrEqual(bc.source(), bc.left(), bc.right())); + } // else (comp < 0 || bc instanceof LessThanOrEqual) : a = 2 OR a < 3 -> a < 3; a = 2 OR a <= 2 -> a <= 2 + removeEquals = true; // update range with equality instead or simply superfluous + break; + } + } + } + } + if (removeEquals) { + iterEq.remove(); + updated = true; + } } - return changed ? Predicates.combineAnd(CollectionUtils.combine(exps, equals, ranges)) : and; + return updated ? Predicates.combineOr(CollectionUtils.combine(exps, equals, notEquals, inequalities, ranges)) : or; } } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java index 2b6378e2ee13c..0ee0fab7f8595 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java @@ -56,6 +56,7 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.string.Concat; import org.elasticsearch.xpack.sql.expression.function.scalar.string.Repeat; import org.elasticsearch.xpack.sql.expression.predicate.BinaryOperator; +import org.elasticsearch.xpack.sql.expression.predicate.Predicates; import org.elasticsearch.xpack.sql.expression.predicate.Range; import org.elasticsearch.xpack.sql.expression.predicate.conditional.ArbitraryConditionalFunction; import org.elasticsearch.xpack.sql.expression.predicate.conditional.Case; @@ -1550,6 +1551,251 @@ public void testEliminateRangeByNullEqualsOutsideInterval() { assertEquals(FALSE, rule.rule(exp)); } + // a != 3 AND a = 3 -> FALSE + public void testPropagateEquals_VarNeq3AndVarEq3() { + FieldAttribute fa = getFieldAttribute(); + NotEquals neq = new NotEquals(EMPTY, fa, THREE); + Equals eq = new Equals(EMPTY, fa, THREE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, neq, eq)); + assertEquals(FALSE, rule.rule(exp)); + } + + // a != 4 AND a = 3 -> a = 3 + public void testPropagateEquals_VarNeq4AndVarEq3() { + FieldAttribute fa = getFieldAttribute(); + NotEquals neq = new NotEquals(EMPTY, fa, FOUR); + Equals eq = new Equals(EMPTY, fa, THREE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, neq, eq)); + assertEquals(Equals.class, exp.getClass()); + assertEquals(eq, rule.rule(exp)); + } + + // a = 2 AND a < 2 -> FALSE + public void testPropagateEquals_VarEq2AndVarLt2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + LessThan lt = new LessThan(EMPTY, fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq, lt)); + assertEquals(FALSE, exp); + } + + // a = 2 AND a <= 2 -> a = 2 + public void testPropagateEquals_VarEq2AndVarLte2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + LessThanOrEqual lt = new LessThanOrEqual(EMPTY, fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq, lt)); + assertEquals(eq, exp); + } + + // a = 2 AND a <= 1 -> FALSE + public void testPropagateEquals_VarEq2AndVarLte1() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + LessThanOrEqual lt = new LessThanOrEqual(EMPTY, fa, ONE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq, lt)); + assertEquals(FALSE, exp); + } + + // a = 2 AND a > 2 -> FALSE + public void testPropagateEquals_VarEq2AndVarGt2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + GreaterThan gt = new GreaterThan(EMPTY, fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq, gt)); + assertEquals(FALSE, exp); + } + + // a = 2 AND a >= 2 -> a = 2 + public void testPropagateEquals_VarEq2AndVarGte2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + GreaterThanOrEqual gte = new GreaterThanOrEqual(EMPTY, fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq, gte)); + assertEquals(eq, exp); + } + + // a = 2 AND a > 3 -> FALSE + public void testPropagateEquals_VarEq2AndVarLt3() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + GreaterThan gt = new GreaterThan(EMPTY, fa, THREE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq, gt)); + assertEquals(FALSE, exp); + } + + // a = 2 AND a < 3 AND a > 1 AND a != 4 -> a = 2 + public void testPropagateEquals_VarEq2AndVarLt3AndVarGt1AndVarNeq4() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + LessThan lt = new LessThan(EMPTY, fa, THREE); + GreaterThan gt = new GreaterThan(EMPTY, fa, ONE); + NotEquals neq = new NotEquals(EMPTY, fa, FOUR); + + PropagateEquals rule = new PropagateEquals(); + Expression and = Predicates.combineAnd(Arrays.asList(eq, lt, gt, neq)); + Expression exp = rule.rule(and); + assertEquals(eq, exp); + } + + // a = 2 AND 1 < a < 3 AND a > 0 AND a != 4 -> a = 2 + public void testPropagateEquals_VarEq2AndVarRangeGt1Lt3AndVarGt0AndVarNeq4() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + Range range = new Range(EMPTY, fa, ONE, false, THREE, false); + GreaterThan gt = new GreaterThan(EMPTY, fa, L(0)); + NotEquals neq = new NotEquals(EMPTY, fa, FOUR); + + PropagateEquals rule = new PropagateEquals(); + Expression and = Predicates.combineAnd(Arrays.asList(eq, range, gt, neq)); + Expression exp = rule.rule(and); + assertEquals(eq, exp); + } + + // a = 2 OR a > 1 -> a > 1 + public void testPropagateEquals_VarEq2OrVarGt1() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + GreaterThan gt = new GreaterThan(EMPTY, fa, ONE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, gt)); + assertEquals(gt, exp); + } + + // a = 2 OR a > 2 -> a >= 2 + public void testPropagateEquals_VarEq2OrVarGte2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + GreaterThan gt = new GreaterThan(EMPTY, fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, gt)); + assertEquals(GreaterThanOrEqual.class, exp.getClass()); + GreaterThanOrEqual gte = (GreaterThanOrEqual) exp; + assertEquals(TWO, gte.right()); + } + + // a = 2 OR a < 3 -> a < 3 + public void testPropagateEquals_VarEq2OrVarLt3() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + LessThan lt = new LessThan(EMPTY, fa, THREE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, lt)); + assertEquals(lt, exp); + } + + // a = 3 OR a < 3 -> a <= 3 + public void testPropagateEquals_VarEq3OrVarLt3() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, THREE); + LessThan lt = new LessThan(EMPTY, fa, THREE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, lt)); + assertEquals(LessThanOrEqual.class, exp.getClass()); + LessThanOrEqual lte = (LessThanOrEqual) exp; + assertEquals(THREE, lte.right()); + } + + // a = 2 OR 1 < a < 3 -> 1 < a < 3 + public void testPropagateEquals_VarEq2OrVarRangeGt1Lt3() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + Range range = new Range(EMPTY, fa, ONE, false, THREE, false); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, range)); + assertEquals(range, exp); + } + + // a = 2 OR 2 < a < 3 -> 2 <= a < 3 + public void testPropagateEquals_VarEq2OrVarRangeGt2Lt3() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + Range range = new Range(EMPTY, fa, TWO, false, THREE, false); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, range)); + assertEquals(Range.class, exp.getClass()); + Range r = (Range) exp; + assertEquals(TWO, r.lower()); + assertTrue(r.includeLower()); + assertEquals(THREE, r.upper()); + assertFalse(r.includeUpper()); + } + + // a = 3 OR 2 < a < 3 -> 2 < a <= 3 + public void testPropagateEquals_VarEq3OrVarRangeGt2Lt3() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, THREE); + Range range = new Range(EMPTY, fa, TWO, false, THREE, false); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, range)); + assertEquals(Range.class, exp.getClass()); + Range r = (Range) exp; + assertEquals(TWO, r.lower()); + assertFalse(r.includeLower()); + assertEquals(THREE, r.upper()); + assertTrue(r.includeUpper()); + } + + // a = 2 OR a != 2 -> TRUE + public void testPropagateEquals_VarEq2OrVarNeq2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + NotEquals neq = new NotEquals(EMPTY, fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, neq)); + assertEquals(TRUE, exp); + } + + // a = 2 OR a != 5 -> a != 5 + public void testPropagateEquals_VarEq2OrVarNeq5() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + NotEquals neq = new NotEquals(EMPTY, fa, FIVE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, neq)); + assertEquals(NotEquals.class, exp.getClass()); + NotEquals ne = (NotEquals) exp; + assertEquals(ne.right(), FIVE); + } + + // a = 2 OR 3 < a < 4 OR a > 2 OR a!= 2 -> TRUE + public void testPropagateEquals_VarEq2OrVarRangeGt3Lt4OrVarGt2OrVarNe2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = new Equals(EMPTY, fa, TWO); + Range range = new Range(EMPTY, fa, THREE, false, FOUR, false); + GreaterThan gt = new GreaterThan(EMPTY, fa, TWO); + NotEquals neq = new NotEquals(EMPTY, fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(Predicates.combineOr(Arrays.asList(eq, range, neq, gt))); + assertEquals(TRUE, exp); + } + public void testTranslateMinToFirst() { Min min1 = new Min(EMPTY, new FieldAttribute(EMPTY, "str", new EsField("str", DataType.KEYWORD, emptyMap(), true))); Min min2 = new Min(EMPTY, getFieldAttribute());