Skip to content

Commit

Permalink
SQL: "Polish" painless script generated from IN
Browse files Browse the repository at this point in the history
Replace standard `||` and `==` painless operators with
`or` and `eq` null safe alternatives from `InternalSqlScriptUtils`.

Follow up to #34750
  • Loading branch information
matriv committed Oct 30, 2018
1 parent 7ef65de commit 782b6fe
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 48 deletions.
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.xpack.sql.expression.function.scalar.string.ReplaceFunctionProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.string.StringProcessor.StringOperation;
import org.elasticsearch.xpack.sql.expression.function.scalar.string.SubstringFunctionProcessor;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.IsNotNullProcessor;
import org.elasticsearch.xpack.sql.expression.predicate.logical.BinaryLogicProcessor.BinaryLogicOperation;
import org.elasticsearch.xpack.sql.expression.predicate.logical.NotProcessor;
Expand All @@ -31,6 +32,7 @@
import org.elasticsearch.xpack.sql.util.StringUtils;

import java.time.ZonedDateTime;
import java.util.List;
import java.util.Map;

/**
Expand Down Expand Up @@ -113,6 +115,10 @@ public static Boolean notNull(Object expression) {
return IsNotNullProcessor.apply(expression);
}

public static Boolean in(Object value, List<Object> values) {
return In.doFold(value, values);
}

//
// Regex
//
Expand Down Expand Up @@ -375,4 +381,4 @@ public static String substring(String s, Number start, Number length) {
public static String ucase(String s) {
return (String) StringOperation.UCASE.apply(s);
}
}
}
Expand Up @@ -87,4 +87,4 @@ public static ScriptTemplate binaryMethod(String methodName, ScriptTemplate left
.build(),
dataType);
}
}
}
Expand Up @@ -8,11 +8,10 @@
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.expression.NamedExpression;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.expression.gen.script.Params;
import org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptWeaver;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Comparisons;
Expand All @@ -30,7 +29,6 @@
import java.util.StringJoiner;
import java.util.stream.Collectors;

import static java.lang.String.format;
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;

public class In extends NamedExpression implements ScriptWeaver {
Expand Down Expand Up @@ -84,17 +82,21 @@ public boolean foldable() {

@Override
public Boolean fold() {
// Optimization for early return and Query folding to LocalExec
if (value.dataType() == DataType.NULL) {
return null;
}
if (list.size() == 1 && list.get(0).dataType() == DataType.NULL) {
return false;
return null;
}

Object foldedLeftValue = value.fold();
return doFold(value.fold(), Foldables.valuesOf(list, value.dataType()));
}

public static Boolean doFold(Object value, List<Object> values) {
Boolean result = false;
for (Expression rightValue : list) {
Boolean compResult = Comparisons.eq(foldedLeftValue, rightValue.fold());
for (Object v : values) {
Boolean compResult = Comparisons.eq(value, v);
if (compResult == null) {
result = null;
} else if (compResult) {
Expand Down Expand Up @@ -122,34 +124,18 @@ public Attribute toAttribute() {

@Override
public ScriptTemplate asScript() {
StringJoiner sj = new StringJoiner(" || ");
ScriptTemplate leftScript = asScript(value);
List<Params> rightParams = new ArrayList<>();
String scriptPrefix = leftScript + "==";
LinkedHashSet<Object> values = list.stream().map(Expression::fold).collect(Collectors.toCollection(LinkedHashSet::new));
for (Object valueFromList : values) {
// if checked against null => false
if (valueFromList != null) {
if (valueFromList instanceof Expression) {
ScriptTemplate rightScript = asScript((Expression) valueFromList);
sj.add(scriptPrefix + rightScript.template());
rightParams.add(rightScript.params());
} else {
if (valueFromList instanceof String) {
sj.add(scriptPrefix + '"' + valueFromList + '"');
} else {
sj.add(scriptPrefix + valueFromList.toString());
}
}
}
}

ParamsBuilder paramsBuilder = paramsBuilder().script(leftScript.params());
for (Params p : rightParams) {
paramsBuilder = paramsBuilder.script(p);
}

return new ScriptTemplate(format(Locale.ROOT, "%s", sj.toString()), paramsBuilder.build(), dataType());
// remove duplicates
// TODO: Don't exclude nulls, painless script should handle them
List<Object> values = new ArrayList<>(
list.stream().map(Expression::fold).filter(Objects::nonNull).collect(Collectors.toCollection(LinkedHashSet::new)));

return new ScriptTemplate(String.format(Locale.ROOT, formatTemplate("{sql}.in(%s, {})"), leftScript.template(), "%s"),
paramsBuilder()
.script(leftScript.params())
.variable(values)
.build(),
dataType());
}

@Override
Expand Down
Expand Up @@ -682,8 +682,7 @@ protected LogicalPlan rule(Filter filter) {
if (TRUE.equals(filter.condition())) {
return filter.child();
}
// TODO: add comparison with null as well
if (FALSE.equals(filter.condition())) {
if (FALSE.equals(filter.condition()) || ((Literal) filter.condition()).value() == null) {
return new LocalRelation(filter.location(), new EmptyExecutable(filter.output()));
}
}
Expand Down
Expand Up @@ -9,7 +9,7 @@
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;

import java.util.Collections;
import java.util.LinkedHashSet;
Expand All @@ -27,7 +27,7 @@ public class TermsQuery extends LeafQuery {
public TermsQuery(Location location, String term, List<Expression> values) {
super(location);
this.term = term;
values.removeIf(e -> e.dataType() == DataType.NULL);
values.removeIf(e -> DataTypes.isNull(e.dataType()));
if (values.isEmpty()) {
this.values = Collections.emptySet();
} else {
Expand Down
Expand Up @@ -105,6 +105,9 @@ public static Conversion conversionFor(DataType from, DataType to) {
if (to == DataType.NULL) {
return Conversion.NULL;
}
if (from == DataType.NULL) {
return Conversion.NULL;
}

Conversion conversion = conversion(from, to);
if (conversion == null) {
Expand Down
Expand Up @@ -24,6 +24,7 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
Boolean lte(Object, Object)
Boolean gt(Object, Object)
Boolean gte(Object, Object)
Boolean in(Object, java.util.List)

#
# Logical
Expand Down Expand Up @@ -107,4 +108,4 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
String space(Number)
String substring(String, Number, Number)
String ucase(String)
}
}
Expand Up @@ -340,7 +340,7 @@ public void testConstantFoldingIn_LeftValueNotFoldable() {
public void testConstantFoldingIn_RightValueIsNull() {
In in = new In(EMPTY, getFieldAttribute(), Arrays.asList(NULL, NULL));
Literal result= (Literal) new ConstantFolding().rule(in);
assertEquals(false, result.value());
assertNull(result.value());
}

public void testConstantFoldingIn_LeftValueIsNull() {
Expand Down
Expand Up @@ -33,6 +33,7 @@
import java.util.Map;
import java.util.TimeZone;

import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.core.StringStartsWith.startsWith;

public class QueryTranslatorTests extends AbstractBuilderTestCase {
Expand Down Expand Up @@ -161,7 +162,7 @@ public void testLikeConstructsNotSupported() {
}

public void testTranslateInExpression_WhereClause() throws IOException {
LogicalPlan p = plan("SELECT * FROM test WHERE keyword IN ('foo', 'bar', 'lala', 'foo', concat('la', 'la'))");
LogicalPlan p = plan("SELECT * FROM test WHERE keyword IN ('foo', 'bar', 'l''ala', 'foo', concat('l''a', 'la'))");
assertTrue(p instanceof Project);
assertTrue(p.children().get(0) instanceof Filter);
Expression condition = ((Filter) p.children().get(0)).condition();
Expand All @@ -170,7 +171,7 @@ public void testTranslateInExpression_WhereClause() throws IOException {
Query query = translation.query;
assertTrue(query instanceof TermsQuery);
TermsQuery tq = (TermsQuery) query;
assertEquals("keyword:(bar foo lala)", tq.asBuilder().toQuery(createShardContext()).toString());
assertEquals("keyword:(bar foo l'ala)", tq.asBuilder().toQuery(createShardContext()).toString());
}

public void testTranslateInExpression_WhereClauseAndNullHandling() throws IOException {
Expand Down Expand Up @@ -206,20 +207,39 @@ public void testTranslateInExpression_HavingClause_Painless() {
QueryTranslation translation = QueryTranslator.toQuery(condition, false);
assertTrue(translation.query instanceof ScriptQuery);
ScriptQuery sq = (ScriptQuery) translation.query;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20)", sq.script().toString());
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))",
sq.script().toString());
assertThat(sq.script().params().toString(), startsWith("[{a=MAX(int){a->"));
assertThat(sq.script().params().toString(), endsWith(", {v=[10, 20]}]"));
}

public void testTranslateInExpression_HavingClauseAndNullHandling_Painless() {
LogicalPlan p = plan("SELECT keyword, max(int) FROM test GROUP BY keyword HAVING max(int) in (10, null, 20, null, 30 - 10)");
public void testTranslateInExpression_HavingClause_PainlessOneArg() {
LogicalPlan p = plan("SELECT keyword, max(int) FROM test GROUP BY keyword HAVING max(int) in (10, 30 - 20)");
assertTrue(p instanceof Project);
assertTrue(p.children().get(0) instanceof Filter);
Expression condition = ((Filter) p.children().get(0)).condition();
assertFalse(condition.foldable());
QueryTranslation translation = QueryTranslator.toQuery(condition, false);
assertTrue(translation.query instanceof ScriptQuery);
ScriptQuery sq = (ScriptQuery) translation.query;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20)", sq.script().toString());
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))", sq.script().toString());
assertThat(sq.script().params().toString(), startsWith("[{a=MAX(int){a->"));
assertThat(sq.script().params().toString(), endsWith(", {v=[10]}]"));

}

public void testTranslateInExpression_HavingClause_PainlessAndNullHandling() {
LogicalPlan p = plan("SELECT keyword, max(int) FROM test GROUP BY keyword HAVING max(int) in (10, null, 20, 30, null, 30 - 10)");
assertTrue(p instanceof Project);
assertTrue(p.children().get(0) instanceof Filter);
Expression condition = ((Filter) p.children().get(0)).condition();
assertFalse(condition.foldable());
QueryTranslation translation = QueryTranslator.toQuery(condition, false);
assertTrue(translation.query instanceof ScriptQuery);
ScriptQuery sq = (ScriptQuery) translation.query;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))",
sq.script().toString());
assertThat(sq.script().params().toString(), startsWith("[{a=MAX(int){a->"));
assertThat(sq.script().params().toString(), endsWith(", {v=[10, 20, 30]}]"));
}
}
Expand Up @@ -224,6 +224,12 @@ public void testConversionToNull() {
assertNull(conversion.convert(10.0));
}

public void testConversionFromNull() {
Conversion conversion = DataTypeConversion.conversionFor(DataType.NULL, DataType.INTEGER);
assertNull(conversion.convert(null));
assertNull(conversion.convert(10));
}

public void testConversionToIdentity() {
Conversion conversion = DataTypeConversion.conversionFor(DataType.INTEGER, DataType.INTEGER);
assertNull(conversion.convert(null));
Expand Down

0 comments on commit 782b6fe

Please sign in to comment.