From 7f5e36abf5683050fb389089019c86c27aaab88a Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Tue, 24 Sep 2024 12:54:40 -0400 Subject: [PATCH] ESQL: Speed up CASE for some parameters (#112295) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This speeds up the `CASE` function when it has two or three arguments and both of the arguments are constants or fields. This works because `CASE` is lazy so it can avoid warnings in cases like ``` CASE(foo != 0, 2 / foo, 1) ``` And, in the case where the function is *very* slow, it can avoid the computations. But if the lhs and rhs of the `CASE` are constant then there isn't any work to avoid. The performance improvment is pretty substantial: ``` (operation) Before Error After Error Units case_1_lazy 97.422 ± 1.048 101.571 ± 0.737 ns/op case_1_eager 79.312 ± 1.190 4.601 ± 0.049 ns/op ``` The top line is a `CASE` that has to be lazy - it shouldn't change. The 4 nanos change here is noise. The eager version improves by about 94%. --- .../compute/operator/EvalBenchmark.java | 48 ++++++ docs/changelog/112295.yaml | 5 + .../compute/operator/EvalOperator.java | 11 ++ .../xpack/esql/evaluator/EvalMapper.java | 10 ++ .../function/scalar/conditional/Case.java | 138 +++++++++++++++--- .../scalar/conditional/CaseTests.java | 37 ++--- 6 files changed, 210 insertions(+), 39 deletions(-) create mode 100644 docs/changelog/112295.yaml diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java index d785cbeaffc60..9aab4a3e3210f 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java @@ -25,6 +25,7 @@ import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; @@ -32,6 +33,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.evaluator.EvalMapper; +import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc; import org.elasticsearch.xpack.esql.expression.function.scalar.math.Abs; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin; @@ -53,6 +55,7 @@ import java.time.Duration; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -91,6 +94,8 @@ public class EvalBenchmark { "abs", "add", "add_double", + "case_1_eager", + "case_1_lazy", "date_trunc", "equal_to_const", "long_equal_to_long", @@ -125,6 +130,18 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) { layout(doubleField) ).get(driverContext); } + case "case_1_eager", "case_1_lazy" -> { + FieldAttribute f1 = longField(); + FieldAttribute f2 = longField(); + Expression condition = new Equals(Source.EMPTY, f1, new Literal(Source.EMPTY, 1L, DataType.LONG)); + Expression lhs = f1; + Expression rhs = f2; + if (operation.endsWith("lazy")) { + lhs = new Add(Source.EMPTY, lhs, new Literal(Source.EMPTY, 1L, DataType.LONG)); + rhs = new Add(Source.EMPTY, rhs, new Literal(Source.EMPTY, 1L, DataType.LONG)); + } + yield EvalMapper.toEvaluator(new Case(Source.EMPTY, condition, List.of(lhs, rhs)), layout(f1, f2)).get(driverContext); + } case "date_trunc" -> { FieldAttribute timestamp = new FieldAttribute( Source.EMPTY, @@ -216,6 +233,28 @@ private static void checkExpected(String operation, Page actual) { } } } + case "case_1_eager" -> { + LongVector f1 = actual.getBlock(0).asVector(); + LongVector f2 = actual.getBlock(1).asVector(); + LongVector result = actual.getBlock(2).asVector(); + for (int i = 0; i < BLOCK_LENGTH; i++) { + long expected = f1.getLong(i) == 1 ? f1.getLong(i) : f2.getLong(i); + if (result.getLong(i) != expected) { + throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]"); + } + } + } + case "case_1_lazy" -> { + LongVector f1 = actual.getBlock(0).asVector(); + LongVector f2 = actual.getBlock(1).asVector(); + LongVector result = actual.getBlock(2).asVector(); + for (int i = 0; i < BLOCK_LENGTH; i++) { + long expected = 1 + (f1.getLong(i) == 1 ? f1.getLong(i) : f2.getLong(i)); + if (result.getLong(i) != expected) { + throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]"); + } + } + } case "date_trunc" -> { LongVector v = actual.getBlock(1).asVector(); long oneDay = TimeValue.timeValueHours(24).millis(); @@ -280,6 +319,15 @@ private static Page page(String operation) { } yield new Page(builder.build()); } + case "case_1_eager", "case_1_lazy" -> { + var f1 = blockFactory.newLongBlockBuilder(BLOCK_LENGTH); + var f2 = blockFactory.newLongBlockBuilder(BLOCK_LENGTH); + for (int i = 0; i < BLOCK_LENGTH; i++) { + f1.appendLong(i); + f2.appendLong(-i); + } + yield new Page(f1.build(), f2.build()); + } case "long_equal_to_long" -> { var lhs = blockFactory.newLongBlockBuilder(BLOCK_LENGTH); var rhs = blockFactory.newLongBlockBuilder(BLOCK_LENGTH); diff --git a/docs/changelog/112295.yaml b/docs/changelog/112295.yaml new file mode 100644 index 0000000000000..ecbd365d03918 --- /dev/null +++ b/docs/changelog/112295.yaml @@ -0,0 +1,5 @@ +pr: 112295 +summary: "ESQL: Speed up CASE for some parameters" +area: ES|QL +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java index 10f23ed29094f..349ce7b00ff10 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java @@ -63,6 +63,17 @@ public interface ExpressionEvaluator extends Releasable { /** A Factory for creating ExpressionEvaluators. */ interface Factory { ExpressionEvaluator get(DriverContext context); + + /** + * {@code true} if it is safe and fast to evaluate this expression eagerly + * in {@link ExpressionEvaluator}s that need to be lazy, like {@code CASE}. + * This defaults to {@code false}, but expressions + * that evaluate quickly and can not produce warnings may override this to + * {@code true} to get a significant speed-up in {@code CASE}-like operations. + */ + default boolean eagerEvalSafeInLazy() { + return false; + } } /** diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java index d36ab3e18f336..9a2e9398f52fd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java @@ -176,6 +176,11 @@ public ExpressionEvaluator get(DriverContext driverContext) { public String toString() { return "Attribute[channel=" + channel + "]"; } + + @Override + public boolean eagerEvalSafeInLazy() { + return true; + } } return new AttributeFactory(layout.get(attr.id()).channel()); } @@ -209,6 +214,11 @@ public ExpressionEvaluator get(DriverContext driverContext) { public String toString() { return "LiteralsEvaluator[lit=" + lit + "]"; } + + @Override + public boolean eagerEvalSafeInLazy() { + return true; + } } return new LiteralsEvaluatorFactory(lit); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java index 979f681a7fbd0..6acb8ea974ed0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java @@ -15,6 +15,7 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.ToMask; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; @@ -311,25 +312,16 @@ private Expression finishPartialFold(List newChildren) { @Override public ExpressionEvaluator.Factory toEvaluator(Function toEvaluator) { - ElementType resultType = PlannerUtils.toElementType(dataType()); List conditionsFactories = conditions.stream().map(c -> c.toEvaluator(toEvaluator)).toList(); ExpressionEvaluator.Factory elseValueFactory = toEvaluator.apply(elseValue); - return new ExpressionEvaluator.Factory() { - @Override - public ExpressionEvaluator get(DriverContext context) { - return new CaseEvaluator( - context.blockFactory(), - resultType, - conditionsFactories.stream().map(x -> x.apply(context)).toList(), - elseValueFactory.get(context) - ); - } + ElementType resultType = PlannerUtils.toElementType(dataType()); - @Override - public String toString() { - return "CaseEvaluator[conditions=" + conditionsFactories + ", elseVal=" + elseValueFactory + ']'; - } - }; + if (conditionsFactories.size() == 1 + && conditionsFactories.get(0).value.eagerEvalSafeInLazy() + && elseValueFactory.eagerEvalSafeInLazy()) { + return new CaseEagerEvaluatorFactory(resultType, conditionsFactories.get(0), elseValueFactory); + } + return new CaseLazyEvaluatorFactory(resultType, conditionsFactories, elseValueFactory); } record ConditionEvaluatorSupplier(Source conditionSource, ExpressionEvaluator.Factory condition, ExpressionEvaluator.Factory value) @@ -375,9 +367,42 @@ public void close() { public String toString() { return "ConditionEvaluator[condition=" + condition + ", value=" + value + ']'; } + + public void registerMultivalue() { + conditionWarnings.registerException(new IllegalArgumentException("CASE expects a single-valued boolean")); + } } - private record CaseEvaluator( + private record CaseLazyEvaluatorFactory( + ElementType resultType, + List conditionsFactories, + ExpressionEvaluator.Factory elseValueFactory + ) implements ExpressionEvaluator.Factory { + @Override + public ExpressionEvaluator get(DriverContext context) { + List conditions = new ArrayList<>(conditionsFactories.size()); + ExpressionEvaluator elseValue = null; + try { + for (ConditionEvaluatorSupplier cond : conditionsFactories) { + conditions.add(cond.apply(context)); + } + elseValue = elseValueFactory.get(context); + ExpressionEvaluator result = new CaseLazyEvaluator(context.blockFactory(), resultType, conditions, elseValue); + conditions = null; + elseValue = null; + return result; + } finally { + Releasables.close(conditions == null ? () -> {} : Releasables.wrap(conditions), elseValue); + } + } + + @Override + public String toString() { + return "CaseLazyEvaluator[conditions=" + conditionsFactories + ", elseVal=" + elseValueFactory + ']'; + } + } + + private record CaseLazyEvaluator( BlockFactory blockFactory, ElementType resultType, List conditions, @@ -409,9 +434,7 @@ public Block eval(Page page) { continue; } if (b.getValueCount(0) > 1) { - condition.conditionWarnings.registerException( - new IllegalArgumentException("CASE expects a single-valued boolean") - ); + condition.registerMultivalue(); continue; } if (false == b.getBoolean(b.getFirstValueIndex(0))) { @@ -439,7 +462,80 @@ public void close() { @Override public String toString() { - return "CaseEvaluator[conditions=" + conditions + ", elseVal=" + elseVal + ']'; + return "CaseLazyEvaluator[conditions=" + conditions + ", elseVal=" + elseVal + ']'; + } + } + + private record CaseEagerEvaluatorFactory( + ElementType resultType, + ConditionEvaluatorSupplier conditionFactory, + ExpressionEvaluator.Factory elseValueFactory + ) implements ExpressionEvaluator.Factory { + @Override + public ExpressionEvaluator get(DriverContext context) { + ConditionEvaluator conditionEvaluator = conditionFactory.apply(context); + ExpressionEvaluator elseValue = null; + try { + elseValue = elseValueFactory.get(context); + ExpressionEvaluator result = new CaseEagerEvaluator(resultType, context.blockFactory(), conditionEvaluator, elseValue); + conditionEvaluator = null; + elseValue = null; + return result; + } finally { + Releasables.close(conditionEvaluator, elseValue); + } + } + + @Override + public String toString() { + return "CaseEagerEvaluator[conditions=[" + conditionFactory + "], elseVal=" + elseValueFactory + ']'; + } + } + + private record CaseEagerEvaluator( + ElementType resultType, + BlockFactory blockFactory, + ConditionEvaluator condition, + EvalOperator.ExpressionEvaluator elseVal + ) implements EvalOperator.ExpressionEvaluator { + @Override + public Block eval(Page page) { + try (BooleanBlock lhsOrRhsBlock = (BooleanBlock) condition.condition.eval(page); ToMask lhsOrRhs = lhsOrRhsBlock.toMask()) { + if (lhsOrRhs.hadMultivaluedFields()) { + condition.registerMultivalue(); + } + if (lhsOrRhs.mask().isConstant()) { + if (lhsOrRhs.mask().getBoolean(0)) { + return condition.value.eval(page); + } else { + return elseVal.eval(page); + } + } + try ( + Block lhs = condition.value.eval(page); + Block rhs = elseVal.eval(page); + Block.Builder builder = resultType.newBlockBuilder(lhs.getTotalValueCount(), blockFactory) + ) { + for (int p = 0; p < lhs.getPositionCount(); p++) { + if (lhsOrRhs.mask().getBoolean(p)) { + builder.copyFrom(lhs, p, p + 1); + } else { + builder.copyFrom(rhs, p, p + 1); + } + } + return builder.build(); + } + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(condition, elseVal); + } + + @Override + public String toString() { + return "CaseEagerEvaluator[conditions=[" + condition + "], elseVal=" + elseVal + ']'; } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java index 616e70191ee22..9d0d9c3da30a8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java @@ -123,9 +123,7 @@ public static Iterable parameters() { ) ); } - return - - parameterSuppliersFromTypedData(suppliers); + return parameterSuppliersFromTypedData(suppliers); } private static void twoAndThreeArgs( @@ -191,7 +189,7 @@ private static void twoAndThreeArgs( type, typedData, lhs, - startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator"), + startsWith("CaseEagerEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator"), false, List.of(typedData.get(1)), addBuildEvaluatorWarnings(warnings) @@ -240,7 +238,7 @@ private static void twoAndThreeArgs( type, typedData, lhsOrRhs ? lhs : rhs, - startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator"), + startsWith("CaseEagerEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator"), false, List.of(typedData.get(lhsOrRhs ? 1 : 2)), addWarnings(warnings) @@ -262,7 +260,7 @@ private static void twoAndThreeArgs( type, typedData, lhsOrRhs ? lhs : null, - startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition="), + startsWith("CaseEagerEvaluator[conditions=[ConditionEvaluator[condition="), false, null, addWarnings(warnings) @@ -285,7 +283,7 @@ private static void twoAndThreeArgs( type, typedData, lhsOrRhs ? lhs : rhs, - startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition="), + startsWith("CaseEagerEvaluator[conditions=[ConditionEvaluator[condition="), false, null, addWarnings(warnings) @@ -296,7 +294,7 @@ private static void twoAndThreeArgs( } suppliers.add( new TestCaseSupplier( - TestCaseSupplier.nameFrom(Arrays.asList(DataType.BOOLEAN, DataType.NULL, type)), + TestCaseSupplier.nameFrom(Arrays.asList(cond, DataType.NULL, type)), List.of(DataType.BOOLEAN, DataType.NULL, type), () -> { Object rhs = randomLiteral(type).value(); @@ -309,7 +307,7 @@ private static void twoAndThreeArgs( type, typedData, lhsOrRhs ? null : rhs, - startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition="), + startsWith("CaseEagerEvaluator[conditions=[ConditionEvaluator[condition="), false, null, addWarnings(warnings) @@ -319,7 +317,7 @@ private static void twoAndThreeArgs( ); suppliers.add( new TestCaseSupplier( - TestCaseSupplier.nameFrom(Arrays.asList(DataType.BOOLEAN, type, DataType.NULL)), + TestCaseSupplier.nameFrom(Arrays.asList(cond, type, DataType.NULL)), List.of(DataType.BOOLEAN, type, DataType.NULL), () -> { Object lhs = randomLiteral(type).value(); @@ -332,7 +330,7 @@ private static void twoAndThreeArgs( type, typedData, lhsOrRhs ? lhs : null, - startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition="), + startsWith("CaseEagerEvaluator[conditions=[ConditionEvaluator[condition="), false, null, addWarnings(warnings) @@ -445,7 +443,7 @@ private static void fourAndFiveArgs( type, typedData, r1, - startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), + startsWith("CaseLazyEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), false, List.of(typedData.get(1)), addBuildEvaluatorWarnings(warnings) @@ -501,7 +499,7 @@ private static void fourAndFiveArgs( type, typedData, r2, - startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), + startsWith("CaseLazyEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), false, List.of(typedData.get(3)), addWarnings(warnings) @@ -526,7 +524,7 @@ private static void fourAndFiveArgs( type, typedData, r2, - startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), + startsWith("CaseLazyEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), false, List.of(typedData.get(3)), addWarnings(warnings) @@ -551,7 +549,7 @@ private static void fourAndFiveArgs( type, typedData, r2, - startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), + startsWith("CaseLazyEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), false, typedData.subList(2, 4), addWarnings(warnings) @@ -607,7 +605,7 @@ private static void fourAndFiveArgs( type, typedData, r3, - startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), + startsWith("CaseLazyEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), false, List.of(typedData.get(4)), addWarnings(warnings) @@ -634,7 +632,7 @@ private static void fourAndFiveArgs( type, typedData, r3, - startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), + startsWith("CaseLazyEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), false, typedData.subList(2, 5), addWarnings(warnings) @@ -648,7 +646,10 @@ private static void fourAndFiveArgs( } private static Matcher toStringMatcher(int conditions, boolean trailingNull) { - StringBuilder result = new StringBuilder("CaseEvaluator[conditions=["); + StringBuilder result = new StringBuilder(); + result.append("Case"); + result.append(conditions == 1 ? "Eager" : "Lazy"); + result.append("Evaluator[conditions=["); int channel = 0; for (int i = 0; i < conditions; i++) { if (i != 0) {