Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Fixed interval type null/missing check failure #1011

Merged
merged 3 commits into from Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -24,6 +24,7 @@
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING;
import static com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionDSL.define;
import static com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionDSL.impl;
import static com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionDSL.nullMissingHandling;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntervalValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
Expand Down Expand Up @@ -54,8 +55,8 @@ public void register(BuiltinFunctionRepository repository) {

private FunctionResolver interval() {
return define(BuiltinFunctionName.INTERVAL.getName(),
impl(IntervalClause::interval, INTERVAL, INTEGER, STRING),
impl(IntervalClause::interval, INTERVAL, LONG, STRING));
impl(nullMissingHandling(IntervalClause::interval), INTERVAL, INTEGER, STRING),
impl(nullMissingHandling(IntervalClause::interval), INTERVAL, LONG, STRING));
}

private ExprValue interval(ExprValue value, ExprValue unit) {
Expand Down
Expand Up @@ -16,9 +16,13 @@
package com.amazon.opendistroforelasticsearch.sql.expression.datetime;

import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.intervalValue;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.missingValue;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.nullValue;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTERVAL;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.when;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.exception.ExpressionEvaluationException;
Expand All @@ -39,6 +43,12 @@ public class IntervalClauseTest extends ExpressionTestBase {
@Mock
Environment<Expression, ExprValue> env;

@Mock
Expression nullRef;

@Mock
Expression missingRef;

@Test
public void microsecond() {
FunctionExpression expr = dsl.interval(DSL.literal(1), DSL.literal("microsecond"));
Expand Down Expand Up @@ -114,4 +124,22 @@ public void to_string() {
FunctionExpression expr = dsl.interval(DSL.literal(1), DSL.literal("day"));
assertEquals("interval(1, \"day\")", expr.toString());
}

@Test
public void null_value() {
when(nullRef.type()).thenReturn(INTEGER);
when(nullRef.valueOf(env)).thenReturn(nullValue());
FunctionExpression expr = dsl.interval(nullRef, DSL.literal("day"));
assertEquals(INTERVAL, expr.type());
assertEquals(nullValue(), expr.valueOf(env));
}

@Test
public void missing_value() {
when(missingRef.type()).thenReturn(INTEGER);
when(missingRef.valueOf(env)).thenReturn(missingValue());
FunctionExpression expr = dsl.interval(missingRef, DSL.literal("day"));
assertEquals(INTERVAL, expr.type());
assertEquals(missingValue(), expr.valueOf(env));
}
}
Expand Up @@ -56,6 +56,14 @@ public void testNullLiteralInFunction() {
rows(null, null));
}

@Test
public void testNullLiteralInInterval() {
verifyDataRows(
query("SELECT INTERVAL NULL DAY, INTERVAL 60 * 60 * 24 * (NULL - FLOOR(NULL)) SECOND"),
rows(null, null)
);
}

private JSONObject query(String sql) {
return new JSONObject(executeQuery(sql, "jdbc"));
}
Expand Down
Expand Up @@ -21,8 +21,7 @@
import com.amazon.opendistroforelasticsearch.sql.ast.Node;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.AggregateFunction;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Alias;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Function;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Literal;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.QualifiedName;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Aggregation;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.UnresolvedPlan;
Expand Down Expand Up @@ -125,36 +124,26 @@ private List<UnresolvedExpression> replaceGroupByItemIfAliasOrOrdinal() {
*/
private Optional<UnresolvedExpression> findNonAggregatedItemInSelect() {
return querySpec.getSelectItems().stream()
.filter(this::isNonAggregatedExpression)
.filter(this::isNonLiteralFunction)
.filter(this::isNonAggregateOrLiteralExpression)
.findFirst();
}

private boolean isAggregatorNotFoundAnywhere() {
return querySpec.getAggregators().isEmpty();
}

private boolean isNonLiteralFunction(UnresolvedExpression expr) {
// The base case for recursion
if (expr instanceof Literal) {
private boolean isNonAggregateOrLiteralExpression(UnresolvedExpression expr) {
if (expr instanceof AggregateFunction) {
return false;
}
if (expr instanceof Function) {
List<? extends Node> children = expr.getChild();
return children.stream().anyMatch(child ->
isNonLiteralFunction((UnresolvedExpression) child));
}
return true;
}

private boolean isNonAggregatedExpression(UnresolvedExpression expr) {
if (expr instanceof AggregateFunction) {
return false;
if (expr instanceof QualifiedName) {
return true;
}

List<? extends Node> children = expr.getChild();
return children.stream()
.allMatch(child -> isNonAggregatedExpression((UnresolvedExpression) child));
return children.stream().anyMatch(child ->
isNonAggregateOrLiteralExpression((UnresolvedExpression) child));
}

}
Expand Up @@ -133,6 +133,27 @@ void can_build_implicit_group_by_for_aggregator_in_having_clause() {
hasGroupByItems(),
hasAggregators(
alias("AVG(age)", aggregate("AVG", qualifiedName("age"))))));

assertThat(
buildAggregation("SELECT INTERVAL 1 DAY FROM test HAVING AVG(age) > 30"),
allOf(
hasGroupByItems(),
hasAggregators(
alias("AVG(age)", aggregate("AVG", qualifiedName("age"))))));

assertThat(
buildAggregation("SELECT CAST(1 AS LONG) FROM test HAVING AVG(age) > 30"),
allOf(
hasGroupByItems(),
hasAggregators(
alias("AVG(age)", aggregate("AVG", qualifiedName("age"))))));

assertThat(
buildAggregation("SELECT CASE WHEN true THEN 1 ELSE 2 END FROM test HAVING AVG(age) > 30"),
allOf(
hasGroupByItems(),
hasAggregators(
alias("AVG(age)", aggregate("AVG", qualifiedName("age"))))));
}

@Test
Expand Down