Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL: Improve painless script generated from IN #35055

Merged
merged 10 commits into from Nov 1, 2018
Expand Up @@ -19,7 +19,7 @@
import org.elasticsearch.xpack.sql.expression.function.Functions;
import org.elasticsearch.xpack.sql.expression.function.Score;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.plan.logical.Aggregate;
import org.elasticsearch.xpack.sql.plan.logical.Distinct;
import org.elasticsearch.xpack.sql.plan.logical.Filter;
Expand Down
Expand Up @@ -88,4 +88,12 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
entries.add(new Entry(Processor.class, SubstringFunctionProcessor.NAME, SubstringFunctionProcessor::new));
return entries;
}
}

public static List<Object> process(List<Processor> processors, Object input) {
List<Object> values = new ArrayList<>(processors.size());
for (Processor p : processors) {
values.add(p.process(input));
}
return values;
}
}
Expand Up @@ -27,10 +27,12 @@
import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.BinaryArithmeticProcessor.BinaryArithmeticOperation;
import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.UnaryArithmeticProcessor.UnaryArithmeticOperation;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.BinaryComparisonProcessor.BinaryComparisonOperation;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.InProcessor;
import org.elasticsearch.xpack.sql.expression.predicate.regex.RegexProcessor.RegexOperation;
import org.elasticsearch.xpack.sql.util.StringUtils;

import java.time.ZonedDateTime;
import java.util.List;
import java.util.Map;

/**
Expand Down Expand Up @@ -113,6 +115,10 @@ public static Boolean notNull(Object expression) {
return IsNotNullProcessor.apply(expression);
}

public static Boolean in(Object value, List<Object> values) {
return InProcessor.apply(value, values);
}

//
// Regex
//
Expand Down Expand Up @@ -375,4 +381,4 @@ public static String substring(String s, Number start, Number length) {
public static String ucase(String s) {
return (String) StringOperation.UCASE.apply(s);
}
}
}
Expand Up @@ -87,4 +87,4 @@ public static ScriptTemplate binaryMethod(String methodName, ScriptTemplate left
.build(),
dataType);
}
}
}
Expand Up @@ -3,20 +3,17 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.predicate;
package org.elasticsearch.xpack.sql.expression.predicate.operator.comparison;

import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.expression.NamedExpression;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.expression.gen.script.Params;
import org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptWeaver;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Comparisons;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.InPipe;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.type.DataType;
Expand All @@ -30,7 +27,6 @@
import java.util.StringJoiner;
import java.util.stream.Collectors;

import static java.lang.String.format;
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;

public class In extends NamedExpression implements ScriptWeaver {
Expand Down Expand Up @@ -84,24 +80,12 @@ public boolean foldable() {

@Override
public Boolean fold() {
if (value.dataType() == DataType.NULL) {
// Optimization for early return and Query folding to LocalExec
if (value.dataType() == DataType.NULL ||
list.size() == 1 && list.get(0).dataType() == DataType.NULL) {
return null;
}
if (list.size() == 1 && list.get(0).dataType() == DataType.NULL) {
return false;
}

Object foldedLeftValue = value.fold();
Boolean result = false;
for (Expression rightValue : list) {
Boolean compResult = Comparisons.eq(foldedLeftValue, rightValue.fold());
if (compResult == null) {
result = null;
} else if (compResult) {
return true;
}
}
return result;
return InProcessor.apply(value.fold(), Foldables.valuesOf(list, value.dataType()));
}

@Override
Expand All @@ -122,34 +106,18 @@ public Attribute toAttribute() {

@Override
public ScriptTemplate asScript() {
StringJoiner sj = new StringJoiner(" || ");
ScriptTemplate leftScript = asScript(value);
List<Params> rightParams = new ArrayList<>();
String scriptPrefix = leftScript + "==";
LinkedHashSet<Object> values = list.stream().map(Expression::fold).collect(Collectors.toCollection(LinkedHashSet::new));
for (Object valueFromList : values) {
// if checked against null => false
if (valueFromList != null) {
if (valueFromList instanceof Expression) {
ScriptTemplate rightScript = asScript((Expression) valueFromList);
sj.add(scriptPrefix + rightScript.template());
rightParams.add(rightScript.params());
} else {
if (valueFromList instanceof String) {
sj.add(scriptPrefix + '"' + valueFromList + '"');
} else {
sj.add(scriptPrefix + valueFromList.toString());
}
}
}
}

ParamsBuilder paramsBuilder = paramsBuilder().script(leftScript.params());
for (Params p : rightParams) {
paramsBuilder = paramsBuilder.script(p);
}

return new ScriptTemplate(format(Locale.ROOT, "%s", sj.toString()), paramsBuilder.build(), dataType());
// remove duplicates
List<Object> values = new ArrayList<>(new LinkedHashSet<>(Foldables.valuesOf(list, value.dataType())));
values.removeIf(Objects::isNull);

return new ScriptTemplate(
formatTemplate(String.format(Locale.ROOT, "{sql}.in(%s, {})", leftScript.template())),
paramsBuilder()
.script(leftScript.params())
.variable(values)
.build(),
dataType());
}

@Override
Expand Down
Expand Up @@ -7,6 +7,7 @@

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.sql.expression.function.scalar.Processors;
import org.elasticsearch.xpack.sql.expression.gen.processor.Processor;

import java.io.IOException;
Expand All @@ -19,7 +20,7 @@ public class InProcessor implements Processor {

private final List<Processor> processsors;

public InProcessor(List<Processor> processors) {
InProcessor(List<Processor> processors) {
this.processsors = processors;
}

Expand All @@ -40,14 +41,17 @@ public final void writeTo(StreamOutput out) throws IOException {
@Override
public Object process(Object input) {
Object leftValue = processsors.get(processsors.size() - 1).process(input);
Boolean result = false;
return apply(leftValue, Processors.process(processsors.subList(0, processsors.size() - 1), leftValue));
}

for (int i = 0; i < processsors.size() - 1; i++) {
Boolean compResult = Comparisons.eq(leftValue, processsors.get(i).process(input));
public static Boolean apply(Object input, List<Object> values) {
Boolean result = Boolean.FALSE;
for (Object v : values) {
Boolean compResult = Comparisons.eq(input, v);
if (compResult == null) {
result = null;
} else if (compResult) {
return true;
} else if (compResult == Boolean.TRUE) {
return Boolean.TRUE;
}
}
return result;
Expand Down
Expand Up @@ -39,7 +39,6 @@
import org.elasticsearch.xpack.sql.expression.predicate.BinaryOperator;
import org.elasticsearch.xpack.sql.expression.predicate.BinaryOperator.Negateable;
import org.elasticsearch.xpack.sql.expression.predicate.BinaryPredicate;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.IsNotNull;
import org.elasticsearch.xpack.sql.expression.predicate.Predicates;
import org.elasticsearch.xpack.sql.expression.predicate.Range;
Expand All @@ -50,6 +49,7 @@
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.sql.plan.logical.Aggregate;
Expand Down
Expand Up @@ -24,7 +24,7 @@
import org.elasticsearch.xpack.sql.expression.function.Function;
import org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.Cast;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.expression.predicate.IsNotNull;
import org.elasticsearch.xpack.sql.expression.predicate.Range;
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MatchQueryPredicate;
Expand Down
Expand Up @@ -30,7 +30,7 @@
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.UnaryPipe;
import org.elasticsearch.xpack.sql.expression.gen.processor.Processor;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.sql.plan.physical.EsQueryExec;
import org.elasticsearch.xpack.sql.plan.physical.FilterExec;
Expand Down
Expand Up @@ -31,7 +31,7 @@
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeHistogramFunction;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.expression.predicate.IsNotNull;
import org.elasticsearch.xpack.sql.expression.predicate.Range;
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MatchQueryPredicate;
Expand Down
Expand Up @@ -9,7 +9,7 @@
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;

import java.util.Collections;
import java.util.LinkedHashSet;
Expand All @@ -27,7 +27,7 @@ public class TermsQuery extends LeafQuery {
public TermsQuery(Location location, String term, List<Expression> values) {
super(location);
this.term = term;
values.removeIf(e -> e.dataType() == DataType.NULL);
values.removeIf(e -> DataTypes.isNull(e.dataType()));
if (values.isEmpty()) {
this.values = Collections.emptySet();
} else {
Expand Down
Expand Up @@ -105,6 +105,9 @@ public static Conversion conversionFor(DataType from, DataType to) {
if (to == DataType.NULL) {
return Conversion.NULL;
}
if (from == DataType.NULL) {
return Conversion.NULL;
}

Conversion conversion = conversion(from, to);
if (conversion == null) {
Expand Down
Expand Up @@ -24,6 +24,7 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
Boolean lte(Object, Object)
Boolean gt(Object, Object)
Boolean gte(Object, Object)
Boolean in(Object, java.util.List)

#
# Logical
Expand Down Expand Up @@ -107,4 +108,4 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
String space(Number)
String substring(String, Number, Number)
String ucase(String)
}
}
Expand Up @@ -3,15 +3,14 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.predicate;
package org.elasticsearch.xpack.sql.expression.predicate.operator.comparison;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.sql.expression.Literal;
import org.elasticsearch.xpack.sql.expression.function.scalar.Processors;
import org.elasticsearch.xpack.sql.expression.gen.processor.ConstantProcessor;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.InProcessor;

import java.util.Arrays;

Expand Down
Expand Up @@ -3,7 +3,7 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.predicate;
package org.elasticsearch.xpack.sql.expression.predicate.operator.comparison;

import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.sql.expression.Literal;
Expand Down
Expand Up @@ -35,7 +35,7 @@
import org.elasticsearch.xpack.sql.expression.function.scalar.string.Ascii;
import org.elasticsearch.xpack.sql.expression.function.scalar.string.Repeat;
import org.elasticsearch.xpack.sql.expression.predicate.BinaryOperator;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.expression.predicate.IsNotNull;
import org.elasticsearch.xpack.sql.expression.predicate.Range;
import org.elasticsearch.xpack.sql.expression.predicate.logical.And;
Expand Down Expand Up @@ -345,7 +345,7 @@ public void testConstantFoldingIn_LeftValueNotFoldable() {
public void testConstantFoldingIn_RightValueIsNull() {
In in = new In(EMPTY, getFieldAttribute(), Arrays.asList(NULL, NULL));
Literal result= (Literal) new ConstantFolding().rule(in);
assertEquals(false, result.value());
assertNull(result.value());
}

public void testConstantFoldingIn_LeftValueIsNull() {
Expand Down
Expand Up @@ -33,6 +33,7 @@
import java.util.Map;
import java.util.TimeZone;

import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.core.StringStartsWith.startsWith;

public class QueryTranslatorTests extends ESTestCase {
Expand Down Expand Up @@ -208,12 +209,11 @@ public void testTranslateInExpression_WhereClause_Painless() {
QueryTranslation translation = QueryTranslator.toQuery(condition, false);
assertNull(translation.aggFilter);
assertTrue(translation.query instanceof ScriptQuery);
ScriptQuery sq = (ScriptQuery) translation.query;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(" +
"InternalSqlScriptUtils.power(InternalSqlScriptUtils.docValue(doc,params.v0),params.v1)==10 || " +
"InternalSqlScriptUtils.power(InternalSqlScriptUtils.docValue(doc,params.v0),params.v1)==20)",
sq.script().toString());
assertEquals("[{v=int}, {v=2}]", sq.script().params().toString());
ScriptQuery sc = (ScriptQuery) translation.query;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(" +
"InternalSqlScriptUtils.power(InternalSqlScriptUtils.docValue(doc,params.v0),params.v1), params.v2))",
sc.script().toString());
assertEquals("[{v=int}, {v=2}, {v=[10.0, 20.0]}]", sc.script().params().toString());
}

public void testTranslateInExpression_HavingClause_Painless() {
Expand All @@ -225,9 +225,10 @@ public void testTranslateInExpression_HavingClause_Painless() {
QueryTranslation translation = QueryTranslator.toQuery(condition, true);
assertNull(translation.query);
AggFilter aggFilter = translation.aggFilter;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20)",
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))",
aggFilter.scriptTemplate().toString());
assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=MAX(int){a->"));
assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=[10, 20]}]"));
}

public void testTranslateInExpression_HavingClause_PainlessOneArg() {
Expand All @@ -239,9 +240,10 @@ public void testTranslateInExpression_HavingClause_PainlessOneArg() {
QueryTranslation translation = QueryTranslator.toQuery(condition, true);
assertNull(translation.query);
AggFilter aggFilter = translation.aggFilter;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10)",
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))",
aggFilter.scriptTemplate().toString());
assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=MAX(int){a->"));
assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=[10]}]"));

}

Expand All @@ -254,8 +256,9 @@ public void testTranslateInExpression_HavingClause_PainlessAndNullHandling() {
QueryTranslation translation = QueryTranslator.toQuery(condition, true);
assertNull(translation.query);
AggFilter aggFilter = translation.aggFilter;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20 || params.a0==30)",
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))",
aggFilter.scriptTemplate().toString());
assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=MAX(int){a->"));
assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=[10, 20, 30]}]"));
}
}