Skip to content

Commit

Permalink
Move ESQL's LOCATE test cases to cases (#107271)
Browse files Browse the repository at this point in the history
This moves the test cases declared in the tests for ESQL's LOCATE
function to test cases which will cause #106782 to properly generate all
of the available signatures. It also buys us all of testing for
incorrect parameter combinations.
  • Loading branch information
nik9000 committed Apr 9, 2024
1 parent 75228df commit 8852566
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
import static org.elasticsearch.xpack.ql.expression.TypeResolutions.ParamOrdinal.FIRST;
import static org.elasticsearch.xpack.ql.expression.TypeResolutions.ParamOrdinal.SECOND;
import static org.elasticsearch.xpack.ql.expression.TypeResolutions.ParamOrdinal.THIRD;
import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isInteger;
import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isString;
import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isType;

/**
* Locate function, given a string 'a' and a substring 'b', it returns the index of the first occurrence of the substring 'b' in 'a'.
Expand Down Expand Up @@ -80,7 +80,7 @@ protected TypeResolution resolveType() {
return resolution;
}

return start == null ? TypeResolution.TYPE_RESOLVED : isInteger(start, sourceText(), THIRD);
return start == null ? TypeResolution.TYPE_RESOLVED : isType(start, dt -> dt == DataTypes.INTEGER, sourceText(), THIRD, "integer");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public TestCaseSupplier(List<DataType> types, Supplier<TestCase> supplier) {
this(nameFromTypes(types), types, supplier);
}

static String nameFromTypes(List<DataType> types) {
public static String nameFromTypes(List<DataType> types) {
return types.stream().map(t -> "<" + t.typeName() + ">").collect(Collectors.joining(", "));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,21 @@
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Literal;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.function.Function;
import java.util.function.Supplier;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.elasticsearch.compute.data.BlockUtils.toJavaObject;
import static org.hamcrest.Matchers.equalTo;

/**
Expand All @@ -37,192 +36,175 @@ public LocateTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCas
this.testCase = testCaseSupplier.get();
}

private static final DataType[] STRING_TYPES = new DataType[] { DataTypes.KEYWORD, DataTypes.TEXT };

@ParametersFactory
public static Iterable<Object[]> parameters() {
List<TestCaseSupplier> suppliers = new ArrayList<>();
suppliers.add(
supplier(
"keywords",
DataTypes.KEYWORD,
DataTypes.KEYWORD,
() -> randomRealisticUnicodeOfCodepointLength(10),
() -> randomRealisticUnicodeOfCodepointLength(2),
() -> 0
)
);
suppliers.add(
supplier(
"mixed keyword, text",
DataTypes.KEYWORD,
DataTypes.TEXT,
() -> randomRealisticUnicodeOfCodepointLength(10),
() -> randomRealisticUnicodeOfCodepointLength(2),
() -> 0
)
);
suppliers.add(
supplier(
"texts",
DataTypes.TEXT,
DataTypes.TEXT,
() -> randomRealisticUnicodeOfCodepointLength(10),
() -> randomRealisticUnicodeOfCodepointLength(2),
() -> 0
)
);
suppliers.add(
supplier(
"mixed text, keyword",
DataTypes.TEXT,
DataTypes.KEYWORD,
() -> randomRealisticUnicodeOfCodepointLength(10),
() -> randomRealisticUnicodeOfCodepointLength(2),
() -> 0
)
);
return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(anyNullIsNull(true, suppliers)));
}

public void testToString() {
assertThat(
evaluator(
new Locate(
Source.EMPTY,
field("str", DataTypes.KEYWORD),
field("substr", DataTypes.KEYWORD),
field("start", DataTypes.INTEGER)
)
).get(driverContext()).toString(),
equalTo("LocateEvaluator[str=Attribute[channel=0], substr=Attribute[channel=1], start=Attribute[channel=2]]")
);
}

@Override
protected Expression build(Source source, List<Expression> args) {
return new Locate(source, args.get(0), args.get(1), args.size() < 3 ? null : args.get(2));
}

public void testPrefixString() {
assertThat(process("a tiger", "a t", 0), equalTo(1));
assertThat(process("a tiger", "a", 0), equalTo(1));
assertThat(process("界世", "界", 0), equalTo(1));
}

public void testSuffixString() {
assertThat(process("a tiger", "er", 0), equalTo(6));
assertThat(process("a tiger", "r", 0), equalTo(7));
assertThat(process("世界", "界", 0), equalTo(2));
}

public void testMidString() {
assertThat(process("a tiger", "ti", 0), equalTo(3));
assertThat(process("a tiger", "ige", 0), equalTo(4));
assertThat(process("世界世", "界", 0), equalTo(2));
}

public void testOutOfRange() {
assertThat(process("a tiger", "tigers", 0), equalTo(0));
assertThat(process("a tiger", "ipa", 0), equalTo(0));
assertThat(process("世界世", "\uD83C\uDF0D", 0), equalTo(0));
}

public void testExactString() {
assertThat(process("a tiger", "a tiger", 0), equalTo(1));
assertThat(process("tigers", "tigers", 0), equalTo(1));
assertThat(process("界世", "界世", 0), equalTo(1));
}
for (DataType strType : STRING_TYPES) {
for (DataType substrType : STRING_TYPES) {
suppliers.add(
supplier(
"",
strType,
substrType,
() -> randomRealisticUnicodeOfCodepointLength(10),
str -> randomRealisticUnicodeOfCodepointLength(2),
null,
(str, substr, start) -> 1 + str.indexOf(substr)
)
);
suppliers.add(
supplier(
"exact match ",
strType,
substrType,
() -> randomRealisticUnicodeOfCodepointLength(10),
str -> str,
null,
(str, substr, start) -> 1
)
);
suppliers.add(
supplier(
"",
strType,
substrType,
() -> randomRealisticUnicodeOfCodepointLength(10),
str -> randomRealisticUnicodeOfCodepointLength(2),
() -> between(0, 3),
(str, substr, start) -> 1 + str.indexOf(substr, start)
)
);
}
}

public void testSupplementaryCharacter() {
suppliers = errorsForCasesWithoutExamples(anyNullIsNull(true, suppliers));

// Here follows some non-randomized examples that we want to cover on every run
suppliers.add(supplier("a tiger", "a t", null, 1));
suppliers.add(supplier("a tiger", "a", null, 1));
suppliers.add(supplier("界世", "界", null, 1));
suppliers.add(supplier("a tiger", "er", null, 6));
suppliers.add(supplier("a tiger", "r", null, 7));
suppliers.add(supplier("界世", "世", null, 2));
suppliers.add(supplier("a tiger", "ti", null, 3));
suppliers.add(supplier("a tiger", "ige", null, 4));
suppliers.add(supplier("世界世", "界", null, 2));
suppliers.add(supplier("a tiger", "tigers", null, 0));
suppliers.add(supplier("a tiger", "ipa", null, 0));
suppliers.add(supplier("世界世", "\uD83C\uDF0D", null, 0));

// Extra assertions about 4-byte characters
// some assertions about the supplementary (4-byte) character we'll use for testing
assert "𠜎".length() == 2;
assert "𠜎".codePointCount(0, 2) == 1;
assert "𠜎".getBytes(UTF_8).length == 4;

assertThat(process("a ti𠜎er", "𠜎er", 0), equalTo(5));
assertThat(process("a ti𠜎er", "i𠜎e", 0), equalTo(4));
assertThat(process("a ti𠜎er", "ti𠜎", 0), equalTo(3));
assertThat(process("a ti𠜎er", "er", 0), equalTo(6));
assertThat(process("a ti𠜎er", "r", 0), equalTo(7));

assertThat(process("𠜎a ti𠜎er", "𠜎er", 0), equalTo(6));
assertThat(process("𠜎a ti𠜎er", "i𠜎e", 0), equalTo(5));
assertThat(process("𠜎a ti𠜎er", "ti𠜎", 0), equalTo(4));
assertThat(process("𠜎a ti𠜎er", "er", 0), equalTo(7));
assertThat(process("𠜎a ti𠜎er", "r", 0), equalTo(8));

// exact
assertThat(process("a ti𠜎er", "a ti𠜎er", 0), equalTo(1));
assertThat(process("𠜎𠜎𠜎abc", "𠜎𠜎𠜎abc", 0), equalTo(1));
assertThat(process(" 𠜎𠜎𠜎abc", " 𠜎𠜎𠜎abc", 0), equalTo(1));
assertThat(process("𠜎𠜎𠜎 abc ", "𠜎𠜎𠜎 abc ", 0), equalTo(1));

assert "𠜎".getBytes(StandardCharsets.UTF_8).length == 4;
suppliers.add(supplier("a ti𠜎er", "𠜎er", null, 5));
suppliers.add(supplier("a ti𠜎er", "i𠜎e", null, 4));
suppliers.add(supplier("a ti𠜎er", "ti𠜎", null, 3));
suppliers.add(supplier("a ti𠜎er", "er", null, 6));
suppliers.add(supplier("a ti𠜎er", "r", null, 7));
suppliers.add(supplier("a ti𠜎er", "a ti𠜎er", null, 1));
// prefix
assertThat(process("𠜎abc", "𠜎", 0), equalTo(1));
assertThat(process("𠜎 abc", "𠜎 ", 0), equalTo(1));
assertThat(process("𠜎𠜎𠜎abc", "𠜎𠜎𠜎", 0), equalTo(1));
assertThat(process("𠜎𠜎𠜎 abc", "𠜎𠜎𠜎 ", 0), equalTo(1));
assertThat(process(" 𠜎𠜎𠜎 abc", " 𠜎𠜎𠜎 ", 0), equalTo(1));
assertThat(process("𠜎 𠜎 𠜎 abc", "𠜎 𠜎 𠜎 ", 0), equalTo(1));

suppliers.add(supplier("𠜎abc", "𠜎", null, 1));
suppliers.add(supplier("𠜎 abc", "𠜎 ", null, 1));
suppliers.add(supplier("𠜎𠜎𠜎abc", "𠜎𠜎𠜎", null, 1));
suppliers.add(supplier("𠜎𠜎𠜎 abc", "𠜎𠜎𠜎 ", null, 1));
suppliers.add(supplier(" 𠜎𠜎𠜎 abc", " 𠜎𠜎𠜎 ", null, 1));
suppliers.add(supplier("𠜎 𠜎 𠜎 abc", "𠜎 𠜎 𠜎 ", null, 1));
// suffix
assertThat(process("abc𠜎", "𠜎", 0), equalTo(4));
assertThat(process("abc 𠜎", " 𠜎", 0), equalTo(4));
assertThat(process("abc𠜎𠜎𠜎", "𠜎𠜎𠜎", 0), equalTo(4));
assertThat(process("abc 𠜎𠜎𠜎", " 𠜎𠜎𠜎", 0), equalTo(4));
assertThat(process("abc𠜎𠜎𠜎 ", "𠜎𠜎𠜎 ", 0), equalTo(4));

suppliers.add(supplier("abc𠜎", "𠜎", null, 4));
suppliers.add(supplier("abc 𠜎", " 𠜎", null, 4));
suppliers.add(supplier("abc𠜎𠜎𠜎", "𠜎𠜎𠜎", null, 4));
suppliers.add(supplier("abc 𠜎𠜎𠜎", " 𠜎𠜎𠜎", null, 4));
suppliers.add(supplier("abc𠜎𠜎𠜎 ", "𠜎𠜎𠜎 ", null, 4));
// out of range
assertThat(process("𠜎a ti𠜎er", "𠜎a ti𠜎ers", 0), equalTo(0));
assertThat(process("a ti𠜎er", "aa ti𠜎er", 0), equalTo(0));
assertThat(process("abc𠜎𠜎", "𠜎𠜎𠜎", 0), equalTo(0));
suppliers.add(supplier("𠜎a ti𠜎er", "𠜎a ti𠜎ers", null, 0));
suppliers.add(supplier("a ti𠜎er", "aa ti𠜎er", null, 0));
suppliers.add(supplier("abc𠜎𠜎", "𠜎𠜎𠜎", null, 0));

assert "🐱".length() == 2 && "🐶".length() == 2;
assert "🐱".codePointCount(0, 2) == 1 && "🐶".codePointCount(0, 2) == 1;
assert "🐱".getBytes(UTF_8).length == 4 && "🐶".getBytes(UTF_8).length == 4;
assertThat(process("🐱Meow!🐶Woof!", "🐱Meow!🐶Woof!", 0), equalTo(1));
assertThat(process("🐱Meow!🐶Woof!", "Meow!🐶Woof!", 0), equalTo(2));
assertThat(process("🐱Meow!🐶Woof!", "eow!🐶Woof!", 0), equalTo(3));
assert "🐱".getBytes(StandardCharsets.UTF_8).length == 4 && "🐶".getBytes(StandardCharsets.UTF_8).length == 4;
suppliers.add(supplier("🐱Meow!🐶Woof!", "🐱Meow!🐶Woof!", null, 1));
suppliers.add(supplier("🐱Meow!🐶Woof!", "Meow!🐶Woof!", 0, 2));
suppliers.add(supplier("🐱Meow!🐶Woof!", "eow!🐶Woof!", 0, 3));

return parameterSuppliersFromTypedData(suppliers);
}

@Override
protected Expression build(Source source, List<Expression> args) {
return new Locate(source, args.get(0), args.get(1), args.size() < 3 ? null : args.get(2));
}

private Integer process(String str, String substr, Integer start) {
try (
EvalOperator.ExpressionEvaluator eval = evaluator(
new Locate(
Source.EMPTY,
field("str", DataTypes.KEYWORD),
field("substr", DataTypes.KEYWORD),
new Literal(Source.EMPTY, start, DataTypes.INTEGER)
)
).get(driverContext());
Block block = eval.eval(row(List.of(new BytesRef(str), new BytesRef(substr))))
) {
return block.isNull(0) ? Integer.valueOf(0) : ((Integer) toJavaObject(block, 0));
private static TestCaseSupplier supplier(String str, String substr, @Nullable Integer start, @Nullable Integer expectedValue) {
String name = String.format(Locale.ROOT, "\"%s\" in \"%s\"", substr, str);
if (start != null) {
name += " starting at " + start;
}

return new TestCaseSupplier(
name,
types(DataTypes.KEYWORD, DataTypes.KEYWORD, start != null),
() -> testCase(DataTypes.KEYWORD, DataTypes.KEYWORD, str, substr, start, expectedValue)
);
}

interface ExpectedValue {
int expectedValue(String str, String substr, Integer start);
}

private static TestCaseSupplier supplier(
String name,
DataType firstType,
DataType secondType,
DataType strType,
DataType substrType,
Supplier<String> strValueSupplier,
Supplier<String> substrValueSupplier,
Supplier<Integer> startSupplier
Function<String, String> substrValueSupplier,
@Nullable Supplier<Integer> startSupplier,
ExpectedValue expectedValue
) {
return new TestCaseSupplier(name, List.of(firstType, secondType), () -> {
List<TestCaseSupplier.TypedData> values = new ArrayList<>();
String expectedToString = "LocateEvaluator[str=Attribute[channel=0], substr=Attribute[channel=1], start=Attribute[channel=2]]";

String value = strValueSupplier.get();
values.add(new TestCaseSupplier.TypedData(new BytesRef(value), firstType, "0"));
List<DataType> types = types(strType, substrType, startSupplier != null);
return new TestCaseSupplier(name + TestCaseSupplier.nameFromTypes(types), types, () -> {
String str = strValueSupplier.get();
String substr = substrValueSupplier.apply(str);
Integer start = startSupplier == null ? null : startSupplier.get();
return testCase(strType, substrType, str, substr, start, expectedValue.expectedValue(str, substr, start));
});
}

String substrValue = substrValueSupplier.get();
values.add(new TestCaseSupplier.TypedData(new BytesRef(substrValue), secondType, "1"));
private static String expectedToString(boolean hasStart) {
if (hasStart) {
return "LocateEvaluator[str=Attribute[channel=0], substr=Attribute[channel=1], start=Attribute[channel=2]]";
}
return "LocateNoStartEvaluator[str=Attribute[channel=0], substr=Attribute[channel=1]]";
}

Integer startValue = startSupplier.get();
values.add(new TestCaseSupplier.TypedData(startValue, DataTypes.INTEGER, "2"));
private static List<DataType> types(DataType firstType, DataType secondType, boolean hasStart) {
List<DataType> types = new ArrayList<>();
types.add(firstType);
types.add(secondType);
if (hasStart) {
types.add(DataTypes.INTEGER);
}
return types;
}

int expectedValue = 1 + value.indexOf(substrValue);
return new TestCaseSupplier.TestCase(values, expectedToString, DataTypes.INTEGER, equalTo(expectedValue));
});
private static TestCaseSupplier.TestCase testCase(
DataType strType,
DataType substrType,
String str,
String substr,
Integer start,
Integer expectedValue
) {
List<TestCaseSupplier.TypedData> values = new ArrayList<>();
values.add(new TestCaseSupplier.TypedData(str == null ? null : new BytesRef(str), strType, "str"));
values.add(new TestCaseSupplier.TypedData(substr == null ? null : new BytesRef(substr), substrType, "substr"));
if (start != null) {
values.add(new TestCaseSupplier.TypedData(start, DataTypes.INTEGER, "start"));
}
return new TestCaseSupplier.TestCase(values, expectedToString(start != null), DataTypes.INTEGER, equalTo(expectedValue));
}
}

0 comments on commit 8852566

Please sign in to comment.