Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Fix CASE clause pushdown issue (#895)
Browse files Browse the repository at this point in the history
* Support case when pushdown

* Add more comparison test

* Relax type check for null

* Prepare PR

* Prepare PR

* Fix Literal.toString() NPE issue
  • Loading branch information
dai-chen committed Dec 14, 2020
1 parent 8a44305 commit 64c7bd6
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression;
import com.amazon.opendistroforelasticsearch.sql.expression.ReferenceExpression;
import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.Aggregator;
import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.CaseClause;
import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.WhenClause;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository;
import com.amazon.opendistroforelasticsearch.sql.planner.logical.LogicalAggregation;
import com.amazon.opendistroforelasticsearch.sql.planner.logical.LogicalPlan;
Expand Down Expand Up @@ -87,6 +89,33 @@ public Expression visitAggregator(Aggregator<?> node, AnalysisContext context) {
return expressionMap.getOrDefault(node, node);
}

/**
* Implement this because Case/When is not registered in function repository.
*/
@Override
public Expression visitCase(CaseClause node, AnalysisContext context) {
if (expressionMap.containsKey(node)) {
return expressionMap.get(node);
}

List<WhenClause> whenClauses = node.getWhenClauses()
.stream()
.map(expr -> (WhenClause) expr.accept(this, context))
.collect(Collectors.toList());
Expression defaultResult = null;
if (node.getDefaultResult() != null) {
defaultResult = node.getDefaultResult().accept(this, context);
}
return new CaseClause(whenClauses, defaultResult);
}

@Override
public Expression visitWhen(WhenClause node, AnalysisContext context) {
return new WhenClause(
node.getCondition().accept(this, context),
node.getResult().accept(this, context));
}


/**
* Expression Map Builder.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;

/**
* Expression node of literal type
Expand All @@ -48,6 +47,6 @@ public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {

@Override
public String toString() {
return value.toString();
return String.valueOf(value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,18 @@ public T visitNamedAggregator(NamedAggregator node, C context) {
return visitChildren(node, context);
}

/**
* Call visitFunction() by default rather than visitChildren().
* This makes CASE/WHEN able to be handled:
* 1) by visitFunction() if not overwritten: ex. FilterQueryBuilder
* 2) by visitCase/When() otherwise if any special logic: ex. ExprReferenceOptimizer
*/
public T visitCase(CaseClause node, C context) {
return visitNode(node, context);
return visitFunction(node, context);
}

public T visitWhen(WhenClause node, C context) {
return visitNode(node, context);
return visitFunction(node, context);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@

package com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases;

import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.UNKNOWN;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprNullValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionNodeVisitor;
import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression;
import com.amazon.opendistroforelasticsearch.sql.expression.env.Environment;
import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionName;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
Expand All @@ -33,11 +37,10 @@
* A CASE clause is very different from a regular function. Functions have well-defined signature,
* though CASE clause is more like a function implementation which requires type check "manually".
*/
@AllArgsConstructor
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
@ToString
public class CaseClause implements Expression {
public class CaseClause extends FunctionExpression {

/**
* List of WHEN clauses.
Expand All @@ -49,6 +52,15 @@ public class CaseClause implements Expression {
*/
private final Expression defaultResult;

/**
* Initialize case clause.
*/
public CaseClause(List<WhenClause> whenClauses, Expression defaultResult) {
super(FunctionName.of("case"), concatArgs(whenClauses, defaultResult));
this.whenClauses = whenClauses;
this.defaultResult = defaultResult;
}

@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
for (WhenClause when : whenClauses) {
Expand All @@ -61,7 +73,10 @@ public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {

@Override
public ExprType type() {
return whenClauses.get(0).type();
List<ExprType> types = allResultTypes();

// Return unknown if all WHEN/ELSE return NULL
return types.isEmpty() ? UNKNOWN : types.get(0);
}

@Override
Expand All @@ -71,7 +86,9 @@ public <T, C> T accept(ExpressionNodeVisitor<T, C> visitor, C context) {

/**
* Get types of each result in WHEN clause and ELSE clause.
* @return all result types
* Exclude UNKNOWN type from NULL literal which means NULL in THEN or ELSE clause
* is not included in result.
* @return all result types. Use list so caller can generate friendly error message.
*/
public List<ExprType> allResultTypes() {
List<ExprType> types = whenClauses.stream()
Expand All @@ -80,7 +97,20 @@ public List<ExprType> allResultTypes() {
if (defaultResult != null) {
types.add(defaultResult.type());
}

types.removeIf(type -> (type == UNKNOWN));
return types;
}

private static List<Expression> concatArgs(List<WhenClause> whenClauses,
Expression defaultResult) {
ImmutableList.Builder<Expression> args = ImmutableList.builder();
whenClauses.forEach(args::add);

if (defaultResult != null) {
args.add(defaultResult);
}
return args.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionNodeVisitor;
import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression;
import com.amazon.opendistroforelasticsearch.sql.expression.env.Environment;
import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionName;
import com.google.common.collect.ImmutableList;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;

/**
* WHEN clause that consists of a condition and a result corresponding.
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
@RequiredArgsConstructor
@ToString
public class WhenClause implements Expression {
public class WhenClause extends FunctionExpression {

/**
* Condition that must be a predicate.
Expand All @@ -45,8 +46,26 @@ public class WhenClause implements Expression {
*/
private final Expression result;

/**
* Initialize when clause.
*/
public WhenClause(Expression condition, Expression result) {
super(FunctionName.of("when"), ImmutableList.of(condition, result));
this.condition = condition;
this.result = result;
}

/**
* Evaluate when condition.
* @param valueEnv value env
* @return is condition satisfied
*/
public boolean isTrue(Environment<Expression, ExprValue> valueEnv) {
return condition.valueOf(valueEnv).booleanValue();
ExprValue result = condition.valueOf(valueEnv);
if (result.isMissing() || result.isNull()) {
return false;
}
return result.booleanValue();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
import com.amazon.opendistroforelasticsearch.sql.ast.expression.DataType;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression;
import com.amazon.opendistroforelasticsearch.sql.common.antlr.SyntaxCheckException;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
import com.amazon.opendistroforelasticsearch.sql.exception.SemanticCheckException;
import com.amazon.opendistroforelasticsearch.sql.expression.DSL;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import com.amazon.opendistroforelasticsearch.sql.expression.LiteralExpression;
import com.amazon.opendistroforelasticsearch.sql.expression.config.ExpressionConfig;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down Expand Up @@ -185,6 +187,23 @@ public void all_fields() {
AllFields.of());
}

@Test
public void case_clause() {
assertAnalyzeEqual(
DSL.cases(
DSL.literal(ExprValueUtils.nullValue()),
DSL.when(
dsl.equal(DSL.ref("integer_value", INTEGER), DSL.literal(30)),
DSL.literal("test"))),
AstDSL.caseWhen(
AstDSL.nullLiteral(),
AstDSL.when(
AstDSL.function("=",
AstDSL.qualifiedName("integer_value"),
AstDSL.intLiteral(30)),
AstDSL.stringLiteral("test"))));
}

@Test
public void skip_struct_data_type() {
SyntaxCheckException exception =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DOUBLE;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING;
import static java.util.Collections.emptyList;
import static org.junit.jupiter.api.Assertions.assertEquals;

Expand Down Expand Up @@ -72,6 +73,76 @@ void aggregation_in_expression_should_be_replaced() {
);
}

@Test
void case_clause_should_be_replaced() {
Expression caseClause = DSL.cases(
null,
DSL.when(
dsl.equal(DSL.ref("age", INTEGER), DSL.literal(30)),
DSL.literal("true")));

LogicalPlan logicalPlan =
LogicalPlanDSL.aggregation(
LogicalPlanDSL.relation("test"),
emptyList(),
ImmutableList.of(DSL.named(
"CaseClause(whenClauses=[WhenClause(condition==(age, 30), result=\"true\")],"
+ " defaultResult=null)",
caseClause)));

assertEquals(
DSL.ref(
"CaseClause(whenClauses=[WhenClause(condition==(age, 30), result=\"true\")],"
+ " defaultResult=null)", STRING),
optimize(caseClause, logicalPlan));
}

@Test
void aggregation_in_case_when_clause_should_be_replaced() {
Expression caseClause = DSL.cases(
null,
DSL.when(
dsl.equal(dsl.avg(DSL.ref("age", INTEGER)), DSL.literal(30)),
DSL.literal("true")));

LogicalPlan logicalPlan =
LogicalPlanDSL.aggregation(
LogicalPlanDSL.relation("test"),
ImmutableList.of(DSL.named("AVG(age)", dsl.avg(DSL.ref("age", INTEGER)))),
ImmutableList.of(DSL.named("name", DSL.ref("name", STRING))));

assertEquals(
DSL.cases(
null,
DSL.when(
dsl.equal(DSL.ref("AVG(age)", DOUBLE), DSL.literal(30)),
DSL.literal("true"))),
optimize(caseClause, logicalPlan));
}

@Test
void aggregation_in_case_else_clause_should_be_replaced() {
Expression caseClause = DSL.cases(
dsl.avg(DSL.ref("age", INTEGER)),
DSL.when(
dsl.equal(DSL.ref("age", INTEGER), DSL.literal(30)),
DSL.literal("true")));

LogicalPlan logicalPlan =
LogicalPlanDSL.aggregation(
LogicalPlanDSL.relation("test"),
ImmutableList.of(DSL.named("AVG(age)", dsl.avg(DSL.ref("age", INTEGER)))),
ImmutableList.of(DSL.named("name", DSL.ref("name", STRING))));

assertEquals(
DSL.cases(
DSL.ref("AVG(age)", DOUBLE),
DSL.when(
dsl.equal(DSL.ref("age", INTEGER), DSL.literal(30)),
DSL.literal("true"))),
optimize(caseClause, logicalPlan));
}

@Test
void window_expression_should_be_replaced() {
LogicalPlan logicalPlan =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,26 @@ void should_use_type_of_when_clause() {
assertEquals(ExprCoreType.INTEGER, caseClause.type());
}

@Test
void should_use_type_of_nonnull_when_or_else_clause() {
when(whenClause.type()).thenReturn(ExprCoreType.UNKNOWN);
Expression defaultResult = mock(Expression.class);
when(defaultResult.type()).thenReturn(ExprCoreType.STRING);

CaseClause caseClause = new CaseClause(ImmutableList.of(whenClause), defaultResult);
assertEquals(ExprCoreType.STRING, caseClause.type());
}

@Test
void should_use_unknown_type_of_if_all_when_and_else_return_null() {
when(whenClause.type()).thenReturn(ExprCoreType.UNKNOWN);
Expression defaultResult = mock(Expression.class);
when(defaultResult.type()).thenReturn(ExprCoreType.UNKNOWN);

CaseClause caseClause = new CaseClause(ImmutableList.of(whenClause), defaultResult);
assertEquals(ExprCoreType.UNKNOWN, caseClause.type());
}

@Test
void should_return_all_result_types_including_default() {
when(whenClause.type()).thenReturn(ExprCoreType.INTEGER);
Expand All @@ -87,4 +107,16 @@ void should_return_all_result_types_including_default() {
caseClause.allResultTypes());
}

@Test
void should_return_all_result_types_excluding_null_result() {
when(whenClause.type()).thenReturn(ExprCoreType.UNKNOWN);
Expression defaultResult = mock(Expression.class);
when(defaultResult.type()).thenReturn(ExprCoreType.UNKNOWN);

CaseClause caseClause = new CaseClause(ImmutableList.of(whenClause), defaultResult);
assertEquals(
ImmutableList.of(),
caseClause.allResultTypes());
}

}
Loading

0 comments on commit 64c7bd6

Please sign in to comment.