Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,31 @@

package org.elasticsearch.xpack.esql.expression.function.scalar.conditional;

import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
import org.junit.After;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Stream;

import static org.elasticsearch.compute.data.BlockUtils.toJavaObject;
import static org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase.field;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.sameInstance;
Expand Down Expand Up @@ -166,4 +184,129 @@ public void testPartialFoldLastAfterKeepingUnknown() {
)
);
}

public void testEvalCase() {
testCase(caseExpr -> {
DriverContext driverContext = driverContext();
Page page = new Page(driverContext.blockFactory().newConstantIntBlockWith(0, 1));
try (
EvalOperator.ExpressionEvaluator eval = caseExpr.toEvaluator(AbstractFunctionTestCase::evaluator).get(driverContext);
Block block = eval.eval(page)
) {
return toJavaObject(block, 0);
} finally {
page.releaseBlocks();
}
});
}

public void testFoldCase() {
testCase(caseExpr -> {
assertTrue(caseExpr.foldable());
return caseExpr.fold();
});
}

public void testCase(Function<Case, Object> toValue) {
assertEquals(1, toValue.apply(caseExpr(true, 1)));
assertNull(toValue.apply(caseExpr(false, 1)));
assertEquals(2, toValue.apply(caseExpr(false, 1, 2)));
assertEquals(1, toValue.apply(caseExpr(true, 1, true, 2)));
assertEquals(2, toValue.apply(caseExpr(false, 1, true, 2)));
assertNull(toValue.apply(caseExpr(false, 1, false, 2)));
assertEquals(3, toValue.apply(caseExpr(false, 1, false, 2, 3)));
assertNull(toValue.apply(caseExpr(true, null, 1)));
assertEquals(1, toValue.apply(caseExpr(false, null, 1)));
assertEquals(1, toValue.apply(caseExpr(false, field("ignored", DataType.INTEGER), 1)));
assertEquals(1, toValue.apply(caseExpr(true, 1, field("ignored", DataType.INTEGER))));
}

public void testIgnoreLeadingNulls() {
assertEquals(DataType.INTEGER, resolveType(false, null, 1));
assertEquals(DataType.INTEGER, resolveType(false, null, false, null, false, 2, null));
assertEquals(DataType.NULL, resolveType(false, null, null));
assertEquals(DataType.BOOLEAN, resolveType(false, null, field("bool", DataType.BOOLEAN)));
}

public void testCaseWithInvalidCondition() {
assertEquals("expected at least two arguments in [<case>] but got 1", resolveCase(1).message());
assertEquals("first argument of [<case>] must be [boolean], found value [1] type [integer]", resolveCase(1, 2).message());
assertEquals(
"third argument of [<case>] must be [boolean], found value [3] type [integer]",
resolveCase(true, 2, 3, 4, 5).message()
);
}

public void testCaseWithIncompatibleTypes() {
assertEquals("third argument of [<case>] must be [integer], found value [hi] type [keyword]", resolveCase(true, 1, "hi").message());
assertEquals(
"fourth argument of [<case>] must be [integer], found value [hi] type [keyword]",
resolveCase(true, 1, false, "hi", 5).message()
);
assertEquals(
"argument of [<case>] must be [integer], found value [hi] type [keyword]",
resolveCase(true, 1, false, 2, true, 5, "hi").message()
);
}

public void testCaseIsLazy() {
Case caseExpr = caseExpr(true, 1, true, 2);
DriverContext driveContext = driverContext();
EvalOperator.ExpressionEvaluator evaluator = caseExpr.toEvaluator(child -> {
Object value = child.fold();
if (value != null && value.equals(2)) {
return dvrCtx -> new EvalOperator.ExpressionEvaluator() {
@Override
public Block eval(Page page) {
fail("Unexpected evaluation of 4th argument");
return null;
}

@Override
public void close() {}
};
}
return AbstractFunctionTestCase.evaluator(child);
}).get(driveContext);
Page page = new Page(driveContext.blockFactory().newConstantIntBlockWith(0, 1));
try (Block block = evaluator.eval(page)) {
assertEquals(1, toJavaObject(block, 0));
} finally {
page.releaseBlocks();
}
}

private static Case caseExpr(Object... args) {
List<Expression> exps = Stream.of(args).<Expression>map(arg -> {
if (arg instanceof Expression e) {
return e;
}
return new Literal(Source.synthetic(arg == null ? "null" : arg.toString()), arg, DataType.fromJava(arg));
}).toList();
return new Case(Source.synthetic("<case>"), exps.get(0), exps.subList(1, exps.size()));
}

private static Expression.TypeResolution resolveCase(Object... args) {
return caseExpr(args).resolveType();
}

private static DataType resolveType(Object... args) {
return caseExpr(args).dataType();
}

private final List<CircuitBreaker> breakers = Collections.synchronizedList(new ArrayList<>());

protected final DriverContext driverContext() {
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking();
CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST);
breakers.add(breaker);
return new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
}

@After
public void allMemoryReleased() {
for (CircuitBreaker breaker : breakers) {
assertThat(breaker.getUsed(), equalTo(0L));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,16 @@
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expression.TypeResolution;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;

import java.math.BigInteger;
import java.util.List;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;

import static org.elasticsearch.compute.data.BlockUtils.toJavaObject;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
Expand Down Expand Up @@ -209,113 +200,4 @@ public static Iterable<Object[]> parameters() {
protected Expression build(Source source, List<Expression> args) {
return new Case(Source.EMPTY, args.get(0), args.subList(1, args.size()));
}

public void testEvalCase() {
testCase(caseExpr -> {
DriverContext driverContext = driverContext();
Page page = new Page(driverContext.blockFactory().newConstantIntBlockWith(0, 1));
try (
EvalOperator.ExpressionEvaluator eval = caseExpr.toEvaluator(child -> evaluator(child)).get(driverContext);
Block block = eval.eval(page)
) {
return toJavaObject(block, 0);
} finally {
page.releaseBlocks();
}
});
}

public void testFoldCase() {
testCase(caseExpr -> {
assertTrue(caseExpr.foldable());
return caseExpr.fold();
});
}

public void testCase(Function<Case, Object> toValue) {
assertEquals(1, toValue.apply(caseExpr(true, 1)));
assertNull(toValue.apply(caseExpr(false, 1)));
assertEquals(2, toValue.apply(caseExpr(false, 1, 2)));
assertEquals(1, toValue.apply(caseExpr(true, 1, true, 2)));
assertEquals(2, toValue.apply(caseExpr(false, 1, true, 2)));
assertNull(toValue.apply(caseExpr(false, 1, false, 2)));
assertEquals(3, toValue.apply(caseExpr(false, 1, false, 2, 3)));
assertNull(toValue.apply(caseExpr(true, null, 1)));
assertEquals(1, toValue.apply(caseExpr(false, null, 1)));
assertEquals(1, toValue.apply(caseExpr(false, field("ignored", DataType.INTEGER), 1)));
assertEquals(1, toValue.apply(caseExpr(true, 1, field("ignored", DataType.INTEGER))));
}

public void testIgnoreLeadingNulls() {
assertEquals(DataType.INTEGER, resolveType(false, null, 1));
assertEquals(DataType.INTEGER, resolveType(false, null, false, null, false, 2, null));
assertEquals(DataType.NULL, resolveType(false, null, null));
assertEquals(DataType.BOOLEAN, resolveType(false, null, field("bool", DataType.BOOLEAN)));
}

public void testCaseWithInvalidCondition() {
assertEquals("expected at least two arguments in [<case>] but got 1", resolveCase(1).message());
assertEquals("first argument of [<case>] must be [boolean], found value [1] type [integer]", resolveCase(1, 2).message());
assertEquals(
"third argument of [<case>] must be [boolean], found value [3] type [integer]",
resolveCase(true, 2, 3, 4, 5).message()
);
}

public void testCaseWithIncompatibleTypes() {
assertEquals("third argument of [<case>] must be [integer], found value [hi] type [keyword]", resolveCase(true, 1, "hi").message());
assertEquals(
"fourth argument of [<case>] must be [integer], found value [hi] type [keyword]",
resolveCase(true, 1, false, "hi", 5).message()
);
assertEquals(
"argument of [<case>] must be [integer], found value [hi] type [keyword]",
resolveCase(true, 1, false, 2, true, 5, "hi").message()
);
}

public void testCaseIsLazy() {
Case caseExpr = caseExpr(true, 1, true, 2);
DriverContext driveContext = driverContext();
EvalOperator.ExpressionEvaluator evaluator = caseExpr.toEvaluator(child -> {
Object value = child.fold();
if (value != null && value.equals(2)) {
return dvrCtx -> new EvalOperator.ExpressionEvaluator() {
@Override
public Block eval(Page page) {
fail("Unexpected evaluation of 4th argument");
return null;
}

@Override
public void close() {}
};
}
return evaluator(child);
}).get(driveContext);
Page page = new Page(driveContext.blockFactory().newConstantIntBlockWith(0, 1));
try (Block block = evaluator.eval(page)) {
assertEquals(1, toJavaObject(block, 0));
} finally {
page.releaseBlocks();
}
}

private static Case caseExpr(Object... args) {
List<Expression> exps = Stream.of(args).<Expression>map(arg -> {
if (arg instanceof Expression e) {
return e;
}
return new Literal(Source.synthetic(arg == null ? "null" : arg.toString()), arg, DataType.fromJava(arg));
}).toList();
return new Case(Source.synthetic("<case>"), exps.get(0), exps.subList(1, exps.size()));
}

private static TypeResolution resolveCase(Object... args) {
return caseExpr(args).resolveType();
}

private static DataType resolveType(Object... args) {
return caseExpr(args).dataType();
}
}