Skip to content

Commit

Permalink
SQL: Fix function args verification and error msgs (#34926)
Browse files Browse the repository at this point in the history
Extract data type verification for function arguments to a single place
so that NULL type can be treated as RESOLVED for all functions. Moreover
this enables to have consistent verification error messages for all functions.

Fixes: #34752
Fixes: #33469
  • Loading branch information
Marios Trivyzas committed Oct 29, 2018
1 parent 389910f commit c9ae192
Show file tree
Hide file tree
Showing 35 changed files with 260 additions and 154 deletions.
Expand Up @@ -246,4 +246,4 @@ public boolean isCompatibleWith(DataType other) {
(isString() && other.isString()) ||
(isNumeric() && other.isNumeric());
}
}
}
Expand Up @@ -35,7 +35,11 @@ public static class TypeResolution {

public static final TypeResolution TYPE_RESOLVED = new TypeResolution(false, StringUtils.EMPTY);

public TypeResolution(String message, Object... args) {
public TypeResolution(String message) {
this(true, message);
}

TypeResolution(String message, Object... args) {
this(true, format(Locale.ROOT, message, args));
}

Expand Down Expand Up @@ -132,4 +136,4 @@ public boolean resolved() {
public String toString() {
return nodeName() + "[" + propertiesToString(false) + "]";
}
}
}
Expand Up @@ -10,17 +10,27 @@
import org.elasticsearch.xpack.sql.expression.Expression.TypeResolution;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.function.Predicate;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;

public final class Expressions {

public enum ParamOrdinal {
DEFAULT,
FIRST,
SECOND,
THIRD,
FOURTH
}

private Expressions() {}

public static NamedExpression wrapAsNamed(Expression exp) {
Expand Down Expand Up @@ -127,22 +137,51 @@ public static Pipe pipe(Expression e) {
throw new SqlIllegalArgumentException("Cannot create pipe for {}", e);
}

public static TypeResolution typeMustBe(Expression e, Predicate<Expression> predicate, String message) {
return predicate.test(e) ? TypeResolution.TYPE_RESOLVED : new TypeResolution(message);
public static TypeResolution typeMustBeBoolean(Expression e, String operationName, ParamOrdinal paramOrd) {
return typeMustBe(e, dt -> dt == DataType.BOOLEAN, operationName, paramOrd, "boolean");
}

public static TypeResolution typeMustBeInteger(Expression e, String operationName, ParamOrdinal paramOrd) {
return typeMustBe(e, dt -> dt.isInteger, operationName, paramOrd, "integer");
}

public static TypeResolution typeMustBeNumeric(Expression e) {
return e.dataType().isNumeric() ? TypeResolution.TYPE_RESOLVED : new TypeResolution(incorrectTypeErrorMessage(e, "numeric"));
public static TypeResolution typeMustBeNumeric(Expression e, String operationName, ParamOrdinal paramOrd) {
return typeMustBe(e, DataType::isNumeric, operationName, paramOrd, "numeric");
}

public static TypeResolution typeMustBeNumericOrDate(Expression e) {
return e.dataType().isNumeric() || e.dataType() == DataType.DATE ?
public static TypeResolution typeMustBeString(Expression e, String operationName, ParamOrdinal paramOrd) {
return typeMustBe(e, DataType::isString, operationName, paramOrd, "string");
}

public static TypeResolution typeMustBeDate(Expression e, String operationName, ParamOrdinal paramOrd) {
return typeMustBe(e, dt -> dt == DataType.DATE, operationName, paramOrd, "date");
}

public static TypeResolution typeMustBeNumericOrDate(Expression e, String operationName, ParamOrdinal paramOrd) {
return typeMustBe(e, dt -> dt.isNumeric() || dt == DataType.DATE, operationName, paramOrd, "numeric", "date");
}

private static TypeResolution typeMustBe(Expression e,
Predicate<DataType> predicate,
String operationName,
ParamOrdinal pOrd,
String... acceptedTypes) {

return predicate.test(e.dataType()) || DataTypes.isNull(e.dataType())?
TypeResolution.TYPE_RESOLVED :
new TypeResolution(incorrectTypeErrorMessage(e, "numeric", "date"));
new TypeResolution(incorrectTypeErrorMessage(e, operationName, pOrd, acceptedTypes));

}

private static String incorrectTypeErrorMessage(Expression e, String...acceptedTypes) {
return "Argument required to be " + Strings.arrayToDelimitedString(acceptedTypes, " or ")
+ " ('" + Expressions.name(e) + "' type is '" + e.dataType().esType + "')";

private static String incorrectTypeErrorMessage(Expression e,
String operationName,
ParamOrdinal paramOrd,
String... acceptedTypes) {
return String.format(Locale.ROOT, "[%s]%s argument must be [%s], found value [%s] type [%s]",
operationName,
paramOrd == null || paramOrd == ParamOrdinal.DEFAULT ? "" : " " + paramOrd.name().toLowerCase(Locale.ROOT),
Strings.arrayToDelimitedString(acceptedTypes, " or "),
Expressions.name(e),
e.dataType().esType);
}
}
}
Expand Up @@ -26,6 +26,7 @@ public class Literal extends NamedExpression {

public static final Literal TRUE = Literal.of(Location.EMPTY, Boolean.TRUE);
public static final Literal FALSE = Literal.of(Location.EMPTY, Boolean.FALSE);
public static final Literal NULL = Literal.of(Location.EMPTY, null);

private final Object value;
private final DataType dataType;
Expand Down Expand Up @@ -163,4 +164,4 @@ public static Literal of(String name, Expression foldable) {

return new Literal(foldable.location(), name, fold, foldable.dataType());
}
}
}
Expand Up @@ -7,6 +7,7 @@

import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Expressions.ParamOrdinal;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.type.DataType;
Expand Down Expand Up @@ -44,6 +45,6 @@ public String innerName() {

@Override
protected TypeResolution resolveType() {
return Expressions.typeMustBeNumericOrDate(field());
return Expressions.typeMustBeNumericOrDate(field(), functionName(), ParamOrdinal.DEFAULT);
}
}
Expand Up @@ -7,6 +7,7 @@

import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Expressions.ParamOrdinal;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.type.DataType;
Expand Down Expand Up @@ -47,6 +48,6 @@ public String innerName() {

@Override
protected TypeResolution resolveType() {
return Expressions.typeMustBeNumericOrDate(field());
return Expressions.typeMustBeNumericOrDate(field(), functionName(), ParamOrdinal.DEFAULT);
}
}
Expand Up @@ -7,6 +7,7 @@

import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Expressions.ParamOrdinal;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;

Expand All @@ -24,7 +25,7 @@ abstract class NumericAggregate extends AggregateFunction {

@Override
protected TypeResolution resolveType() {
return Expressions.typeMustBeNumeric(field());
return Expressions.typeMustBeNumeric(field(), functionName(), ParamOrdinal.DEFAULT);
}

@Override
Expand Down
Expand Up @@ -7,6 +7,7 @@

import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Expressions.ParamOrdinal;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
Expand Down Expand Up @@ -43,7 +44,7 @@ protected TypeResolution resolveType() {
TypeResolution resolution = super.resolveType();

if (TypeResolution.TYPE_RESOLVED.equals(resolution)) {
resolution = Expressions.typeMustBeNumeric(percent());
resolution = Expressions.typeMustBeNumeric(percent(), functionName(), ParamOrdinal.DEFAULT);
}

return resolution;
Expand Down
Expand Up @@ -7,6 +7,7 @@

import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Expressions.ParamOrdinal;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
Expand Down Expand Up @@ -41,12 +42,11 @@ public Expression replaceChildren(List<Expression> newChildren) {
@Override
protected TypeResolution resolveType() {
TypeResolution resolution = super.resolveType();

if (TypeResolution.TYPE_RESOLVED.equals(resolution)) {
resolution = Expressions.typeMustBeNumeric(value);
if (resolution.unresolved()) {
return resolution;
}

return resolution;
return Expressions.typeMustBeNumeric(value, functionName(), ParamOrdinal.DEFAULT);
}

public Expression value() {
Expand Down
Expand Up @@ -65,7 +65,7 @@ public boolean nullable() {
protected TypeResolution resolveType() {
return DataTypeConversion.canConvert(from(), to()) ?
TypeResolution.TYPE_RESOLVED :
new TypeResolution("Cannot cast %s to %s", from(), to());
new TypeResolution("Cannot cast [" + from() + "] to [" + to()+ "]");
}

@Override
Expand Down Expand Up @@ -102,4 +102,4 @@ public String name() {
sb.insert(sb.length() - 1, " AS " + to().sqlName());
return sb.toString();
}
}
}
Expand Up @@ -8,10 +8,10 @@

import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Expressions.ParamOrdinal;
import org.elasticsearch.xpack.sql.expression.function.scalar.UnaryScalarFunction;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.type.DataType;
import org.joda.time.DateTime;

import java.util.Objects;
Expand Down Expand Up @@ -42,11 +42,7 @@ protected final NodeInfo<BaseDateTimeFunction> info() {

@Override
protected TypeResolution resolveType() {
if (field().dataType() == DataType.DATE) {
return TypeResolution.TYPE_RESOLVED;
}
return new TypeResolution("Function [" + functionName() + "] cannot be applied on a non-date expression (["
+ Expressions.name(field()) + "] of type [" + field().dataType().esType + "])");
return Expressions.typeMustBeDate(field(), functionName(), ParamOrdinal.DEFAULT);
}

public TimeZone timeZone() {
Expand Down Expand Up @@ -90,4 +86,4 @@ public boolean equals(Object obj) {
public int hashCode() {
return Objects.hash(field(), timeZone());
}
}
}
Expand Up @@ -7,6 +7,7 @@

import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Expressions.ParamOrdinal;
import org.elasticsearch.xpack.sql.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.math.BinaryMathProcessor.BinaryMathOperation;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
Expand All @@ -19,7 +20,7 @@ public abstract class BinaryNumericFunction extends BinaryScalarFunction {

private final BinaryMathOperation operation;

protected BinaryNumericFunction(Location location, Expression left, Expression right, BinaryMathOperation operation) {
BinaryNumericFunction(Location location, Expression left, Expression right, BinaryMathOperation operation) {
super(location, left, right);
this.operation = operation;
}
Expand All @@ -35,18 +36,12 @@ protected TypeResolution resolveType() {
return new TypeResolution("Unresolved children");
}

TypeResolution resolution = resolveInputType(left().dataType());
TypeResolution resolution = Expressions.typeMustBeNumeric(left(), functionName(), ParamOrdinal.FIRST);
if (resolution.unresolved()) {
return resolution;

if (resolution == TypeResolution.TYPE_RESOLVED) {
return resolveInputType(right().dataType());
}
return resolution;
}

protected TypeResolution resolveInputType(DataType inputType) {
return inputType.isNumeric() ?
TypeResolution.TYPE_RESOLVED :
new TypeResolution("'%s' requires a numeric type, received %s", scriptMethodName(), inputType.esType);
return Expressions.typeMustBeNumeric(right(), functionName(), ParamOrdinal.SECOND);
}

@Override
Expand Down Expand Up @@ -74,4 +69,4 @@ public boolean equals(Object obj) {
&& Objects.equals(other.right(), right())
&& Objects.equals(other.operation, operation);
}
}
}
Expand Up @@ -6,6 +6,8 @@
package org.elasticsearch.xpack.sql.expression.function.scalar.math;

import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Expressions.ParamOrdinal;
import org.elasticsearch.xpack.sql.expression.function.scalar.UnaryScalarFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.math.MathProcessor.MathOperation;
import org.elasticsearch.xpack.sql.expression.gen.processor.Processor;
Expand Down Expand Up @@ -57,8 +59,7 @@ protected TypeResolution resolveType() {
return new TypeResolution("Unresolved children");
}

return field().dataType().isNumeric() ? TypeResolution.TYPE_RESOLVED
: new TypeResolution("'%s' requires a numeric type, received %s", operation(), field().dataType().esType);
return Expressions.typeMustBeNumeric(field(), operation().toString(), ParamOrdinal.DEFAULT);
}

@Override
Expand All @@ -81,4 +82,4 @@ public boolean equals(Object obj) {
public int hashCode() {
return Objects.hash(field());
}
}
}
Expand Up @@ -10,12 +10,13 @@
import org.elasticsearch.xpack.sql.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;

import java.util.Locale;
import java.util.Objects;
import java.util.function.BiFunction;

import static org.elasticsearch.xpack.sql.expression.Expressions.ParamOrdinal;
import static org.elasticsearch.xpack.sql.expression.Expressions.typeMustBeString;
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;

/**
Expand All @@ -41,14 +42,15 @@ protected TypeResolution resolveType() {
return new TypeResolution("Unresolved children");
}

if (!left().dataType().isString()) {
return new TypeResolution("'%s' requires first parameter to be a string type, received %s", functionName(), left().dataType());
TypeResolution resolution = typeMustBeString(left(), functionName(), ParamOrdinal.FIRST);
if (resolution.unresolved()) {
return resolution;
}
return resolveSecondParameterInputType(right().dataType());

return resolveSecondParameterInputType(right());
}

protected abstract TypeResolution resolveSecondParameterInputType(DataType inputType);
protected abstract TypeResolution resolveSecondParameterInputType(Expression e);

@Override
public Object fold() {
Expand Down Expand Up @@ -83,4 +85,4 @@ public boolean equals(Object obj) {
return Objects.equals(other.left(), left())
&& Objects.equals(other.right(), right());
}
}
}

0 comments on commit c9ae192

Please sign in to comment.