Skip to content

Commit

Permalink
[ES|QL] Add function - improve test coverge (#102830)
Browse files Browse the repository at this point in the history
Relates to #100558

Adds testing, docs, etc for the Addition operator. Importantly, this PR pulls addition into the test framework type checking and null validation logic, which is not currently being applied.

This PR also includes some new test infrastructure for binary numeric functions which do not cast their arguments to doubles, an area the test framework currently doesn't cover very well.

I encountered a couple of issues while writing this. One of them is tracked in #103085, around null handling in date math. There's also a problem with how we're doing type checking for mixed type functions, which I haven't opened an issue for yet. That said, I'd rather merge this as partial work now, since it adds functionality we can reuse elsewhere and improves the test coverage for Add. We'll just need more work before we can check it off the list.

(cherry picked from commit 50e59ca)
  • Loading branch information
not-napoleon authored and bpintea committed Jan 3, 2024
1 parent 38ddf39 commit 515ac37
Show file tree
Hide file tree
Showing 4 changed files with 406 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ private void testEvaluate(boolean readFloating) {
assertFalse("expected resolved", expression.typeResolved().unresolved());
expression = new FoldNull().rule(expression);
assertThat(expression.dataType(), equalTo(testCase.expectedType));
// TODO should we convert unsigned_long into BigDecimal so it's easier to assert?
logger.info("Result type: " + expression.dataType());

Object result;
try (ExpressionEvaluator evaluator = evaluator(expression).get(driverContext())) {
try (Block block = evaluator.eval(row(testCase.getDataValues()))) {
Expand Down Expand Up @@ -722,6 +723,10 @@ protected static List<TestCaseSupplier> failureForCasesWithoutExamples(List<Test
return suppliers;
}

/**
* Validate that we know the types for all the test cases already created
* @param suppliers - list of suppliers before adding in the illegal type combinations
*/
private static void typesRequired(List<TestCaseSupplier> suppliers) {
String bad = suppliers.stream().filter(s -> s.types() == null).map(s -> s.name()).collect(Collectors.joining("\n"));
if (bad.equals("") == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
import org.hamcrest.Matcher;

import java.math.BigInteger;
import java.time.Duration;
import java.time.Instant;
import java.time.Period;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.function.DoubleFunction;
import java.util.function.Function;
Expand Down Expand Up @@ -165,12 +168,42 @@ public static List<TestCaseSupplier> forBinaryCastingToDouble(
List<String> warnings
) {
List<TestCaseSupplier> suppliers = new ArrayList<>();
casesCrossProduct(
(l, r) -> expected.apply(((Number) l).doubleValue(), ((Number) r).doubleValue()),
lhsSuppliers,
rhsSuppliers,
(lhsType, rhsType) -> name
+ "["
+ lhsName
+ "="
+ castToDoubleEvaluator("Attribute[channel=0]", lhsType)
+ ", "
+ rhsName
+ "="
+ castToDoubleEvaluator("Attribute[channel=1]", rhsType)
+ "]",
warnings,
suppliers,
DataTypes.DOUBLE
);
return suppliers;
}

private static void casesCrossProduct(
BinaryOperator<Object> expected,
List<TypedDataSupplier> lhsSuppliers,
List<TypedDataSupplier> rhsSuppliers,
BiFunction<DataType, DataType, String> evaluatorToString,
List<String> warnings,
List<TestCaseSupplier> suppliers,
DataType expectedType
) {
for (TypedDataSupplier lhsSupplier : lhsSuppliers) {
for (TypedDataSupplier rhsSupplier : rhsSuppliers) {
String caseName = lhsSupplier.name() + ", " + rhsSupplier.name();
suppliers.add(new TestCaseSupplier(caseName, List.of(lhsSupplier.type(), rhsSupplier.type()), () -> {
Number lhs = (Number) lhsSupplier.supplier().get();
Number rhs = (Number) rhsSupplier.supplier().get();
Object lhs = lhsSupplier.supplier().get();
Object rhs = rhsSupplier.supplier().get();
TypedData lhsTyped = new TypedData(
// TODO there has to be a better way to handle unsigned long
lhs instanceof BigInteger b ? NumericUtils.asLongUnsigned(b) : lhs,
Expand All @@ -182,13 +215,11 @@ public static List<TestCaseSupplier> forBinaryCastingToDouble(
rhsSupplier.type(),
"rhs"
);
String lhsEvalName = castToDoubleEvaluator("Attribute[channel=0]", lhsSupplier.type());
String rhsEvalName = castToDoubleEvaluator("Attribute[channel=1]", rhsSupplier.type());
TestCase testCase = new TestCase(
List.of(lhsTyped, rhsTyped),
name + "[" + lhsName + "=" + lhsEvalName + ", " + rhsName + "=" + rhsEvalName + "]",
DataTypes.DOUBLE,
equalTo(expected.apply(lhs.doubleValue(), rhs.doubleValue()))
evaluatorToString.apply(lhsSupplier.type(), rhsSupplier.type()),
expectedType,
equalTo(expected.apply(lhs, rhs))
);
for (String warning : warnings) {
testCase = testCase.withWarning(warning);
Expand All @@ -197,8 +228,6 @@ public static List<TestCaseSupplier> forBinaryCastingToDouble(
}));
}
}

return suppliers;
}

public static List<TypedDataSupplier> castToDoubleSuppliersFromRange(Double Min, Double Max) {
Expand All @@ -210,6 +239,146 @@ public static List<TypedDataSupplier> castToDoubleSuppliersFromRange(Double Min,
return suppliers;
}

public static List<TestCaseSupplier> forBinaryNumericNotCasting(
String name,
String lhsName,
String rhsName,
BinaryOperator<Number> expected,
DataType expectedType,
List<TypedDataSupplier> lhsSuppliers,
List<TypedDataSupplier> rhsSuppliers,
List<String> warnings,
boolean symetric
) {
return forBinaryNotCasting(
name,
lhsName,
rhsName,
(lhs, rhs) -> expected.apply((Number) lhs, (Number) rhs),
expectedType,
lhsSuppliers,
rhsSuppliers,
warnings,
symetric
);
}

public record NumericTypeTestConfig(Number min, Number max, BinaryOperator<Number> expected, String evaluatorName) {}

public record NumericTypeTestConfigs(
NumericTypeTestConfig intStuff,
NumericTypeTestConfig longStuff,
NumericTypeTestConfig doubleStuff
) {
public NumericTypeTestConfig get(DataType type) {
if (type == DataTypes.INTEGER) {
return intStuff;
}
if (type == DataTypes.LONG) {
return longStuff;
}
if (type == DataTypes.DOUBLE) {
return doubleStuff;
}
throw new IllegalArgumentException("bogus numeric type [" + type + "]");
}
}

private static DataType widen(DataType lhs, DataType rhs) {
if (lhs == rhs) {
return lhs;
}
if (lhs == DataTypes.DOUBLE || rhs == DataTypes.DOUBLE) {
return DataTypes.DOUBLE;
}
if (lhs == DataTypes.LONG || rhs == DataTypes.LONG) {
return DataTypes.LONG;
}
throw new IllegalArgumentException("Invalid numeric widening lhs: [" + lhs + "] rhs: [" + rhs + "]");
}

private static List<TypedDataSupplier> getSuppliersForNumericType(DataType type, Number min, Number max) {
if (type == DataTypes.INTEGER) {
return intCases(NumericUtils.saturatingIntValue(min), NumericUtils.saturatingIntValue(max));
}
if (type == DataTypes.LONG) {
return longCases(min.longValue(), max.longValue());
}
if (type == DataTypes.UNSIGNED_LONG) {
return ulongCases(
min instanceof BigInteger ? (BigInteger) min : BigInteger.valueOf(Math.max(min.longValue(), 0L)),
max instanceof BigInteger ? (BigInteger) max : BigInteger.valueOf(Math.max(max.longValue(), 0L))
);
}
if (type == DataTypes.DOUBLE) {
return doubleCases(min.doubleValue(), max.doubleValue());
}
throw new IllegalArgumentException("bogus numeric type [" + type + "]");
}

public static List<TestCaseSupplier> forBinaryWithWidening(
NumericTypeTestConfigs typeStuff,
String lhsName,
String rhsName,
List<String> warnings
) {
List<TestCaseSupplier> suppliers = new ArrayList<>();
List<DataType> numericTypes = List.of(DataTypes.INTEGER, DataTypes.LONG, DataTypes.DOUBLE);

for (DataType lhsType : numericTypes) {
for (DataType rhsType : numericTypes) {
DataType expected = widen(lhsType, rhsType);
NumericTypeTestConfig expectedTypeStuff = typeStuff.get(expected);
String evaluator = expectedTypeStuff.evaluatorName()
+ "["
+ lhsName
+ "="
+ getCastEvaluator("Attribute[channel=0]", lhsType, expected)
+ ", "
+ rhsName
+ "="
+ getCastEvaluator("Attribute[channel=1]", rhsType, expected)
+ "]";
casesCrossProduct(
(l, r) -> expectedTypeStuff.expected().apply((Number) l, (Number) r),
getSuppliersForNumericType(lhsType, expectedTypeStuff.min(), expectedTypeStuff.max()),
getSuppliersForNumericType(rhsType, expectedTypeStuff.min(), expectedTypeStuff.max()),
// TODO: This doesn't really need to be a function
(lt, rt) -> evaluator,
warnings,
suppliers,
expected
);
}
}

return suppliers;
}

public static List<TestCaseSupplier> forBinaryNotCasting(
String name,
String lhsName,
String rhsName,
BinaryOperator<Object> expected,
DataType expectedType,
List<TypedDataSupplier> lhsSuppliers,
List<TypedDataSupplier> rhsSuppliers,
List<String> warnings,
boolean symetric
) {
List<TestCaseSupplier> suppliers = new ArrayList<>();
casesCrossProduct(
expected,
lhsSuppliers,
rhsSuppliers,
(lhsType, rhsType) -> name + "[" + lhsName + "=Attribute[channel=0], " + rhsName + "=Attribute[channel=1]]",
warnings,
suppliers,
expectedType
);
return suppliers;
}

/**
* Generate positive test cases for a unary function operating on an {@link DataTypes#INTEGER}.
*/
Expand Down Expand Up @@ -716,7 +885,7 @@ private static List<TypedDataSupplier> booleanCases() {
);
}

private static List<TypedDataSupplier> dateCases() {
public static List<TypedDataSupplier> dateCases() {
return List.of(
new TypedDataSupplier("<1970-01-01T00:00:00Z>", () -> 0L, DataTypes.DATETIME),
new TypedDataSupplier(
Expand All @@ -733,6 +902,32 @@ private static List<TypedDataSupplier> dateCases() {
);
}

public static List<TypedDataSupplier> datePeriodCases() {
return List.of(
new TypedDataSupplier("<zero date period>", () -> Period.ZERO, EsqlDataTypes.DATE_PERIOD),
new TypedDataSupplier(
"<random date period>",
() -> Period.of(
ESTestCase.randomIntBetween(-1000, 1000),
ESTestCase.randomIntBetween(-13, 13),
ESTestCase.randomIntBetween(-32, 32)
),
EsqlDataTypes.DATE_PERIOD
)
);
}

public static List<TypedDataSupplier> timeDurationCases() {
return List.of(
new TypedDataSupplier("<zero time duration>", () -> Duration.ZERO, EsqlDataTypes.TIME_DURATION),
new TypedDataSupplier(
"<up to 7 days duration>",
() -> Duration.ofMillis(ESTestCase.randomLongBetween(-604800000L, 604800000L)), // plus/minus 7 days
EsqlDataTypes.TIME_DURATION
)
);
}

private static List<TypedDataSupplier> geoPointCases() {
return List.of(new TypedDataSupplier("<geo_point>", () -> GEO.pointAsLong(randomGeoPoint()), EsqlDataTypes.GEO_POINT));
}
Expand All @@ -743,7 +938,7 @@ private static List<TypedDataSupplier> cartesianPointCases() {
);
}

private static List<TypedDataSupplier> ipCases() {
public static List<TypedDataSupplier> ipCases() {
return List.of(
new TypedDataSupplier(
"<127.0.0.1 ip>",
Expand Down Expand Up @@ -805,6 +1000,54 @@ public static List<TypedDataSupplier> versionCases(String prefix) {
);
}

private static String getCastEvaluator(String original, DataType current, DataType target) {
if (current == target) {
return original;
}
if (target == DataTypes.LONG) {
return castToLongEvaluator(original, current);
}
if (target == DataTypes.UNSIGNED_LONG) {
return castToUnsignedLongEvaluator(original, current);
}
if (target == DataTypes.DOUBLE) {
return castToDoubleEvaluator(original, current);
}
throw new IllegalArgumentException("Invalid numeric cast to [" + target + "]");
}

private static String castToLongEvaluator(String original, DataType current) {
if (current == DataTypes.LONG) {
return original;
}
if (current == DataTypes.INTEGER) {
return "CastIntToLongEvaluator[v=" + original + "]";
}
if (current == DataTypes.DOUBLE) {
return "CastDoubleToLongEvaluator[v=" + original + "]";
}
if (current == DataTypes.UNSIGNED_LONG) {
return "CastUnsignedLongToLong[v=" + original + "]";
}
throw new UnsupportedOperationException();
}

private static String castToUnsignedLongEvaluator(String original, DataType current) {
if (current == DataTypes.UNSIGNED_LONG) {
return original;
}
if (current == DataTypes.INTEGER) {
return "CastIntToUnsignedLongEvaluator[v=" + original + "]";
}
if (current == DataTypes.LONG) {
return "CastLongToUnsignedLongEvaluator[v=" + original + "]";
}
if (current == DataTypes.DOUBLE) {
return "CastDoubleToUnsignedLongEvaluator[v=" + original + "]";
}
throw new UnsupportedOperationException();
}

private static String castToDoubleEvaluator(String original, DataType current) {
if (current == DataTypes.DOUBLE) {
return original;
Expand Down

0 comments on commit 515ac37

Please sign in to comment.