Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Merge branch 'develop' into support-aggregate-window-functions-2
Browse files Browse the repository at this point in the history
  • Loading branch information
dai-chen committed Jan 7, 2021
2 parents 1ec79c1 + 280482c commit 012ea5e
Show file tree
Hide file tree
Showing 30 changed files with 458 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,12 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext
Optional<BuiltinFunctionName> builtinFunctionName = BuiltinFunctionName.of(node.getFuncName());
if (builtinFunctionName.isPresent()) {
Expression arg = node.getField().accept(this, context);
return (Aggregator)
repository.compile(
Aggregator aggregator = (Aggregator) repository.compile(
builtinFunctionName.get().getName(), Collections.singletonList(arg));
if (node.getCondition() != null) {
aggregator.condition(analyze(node.getCondition(), context));
}
return aggregator;
} else {
throw new SemanticCheckException("Unsupported aggregation function " + node.getFuncName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ public static UnresolvedExpression aggregate(
return new AggregateFunction(func, field, Arrays.asList(args));
}

public static UnresolvedExpression filteredAggregate(
String func, UnresolvedExpression field, UnresolvedExpression condition) {
return new AggregateFunction(func, field, condition);
}

public static Function function(String funcName, UnresolvedExpression... funcArgs) {
return new Function(funcName, Arrays.asList(funcArgs));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class AggregateFunction extends UnresolvedExpression {
private final String funcName;
private final UnresolvedExpression field;
private final List<UnresolvedExpression> argList;
private UnresolvedExpression condition;

/**
* Constructor.
Expand All @@ -46,6 +47,20 @@ public AggregateFunction(String funcName, UnresolvedExpression field) {
this.argList = Collections.emptyList();
}

/**
* Constructor.
* @param funcName function name.
* @param field {@link UnresolvedExpression}.
* @param condition condition in aggregation filter.
*/
public AggregateFunction(String funcName, UnresolvedExpression field,
UnresolvedExpression condition) {
this.funcName = funcName;
this.field = field;
this.argList = Collections.emptyList();
this.condition = condition;
}

@Override
public List<UnresolvedExpression> getChild() {
return Collections.singletonList(field);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import com.amazon.opendistroforelasticsearch.sql.analysis.ExpressionAnalyzer;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType;
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType;
import com.amazon.opendistroforelasticsearch.sql.exception.ExpressionEvaluationException;
Expand All @@ -30,6 +31,8 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;

/**
* Aggregator which will iterate on the {@link BindingTuple}s to aggregate the result.
Expand All @@ -46,20 +49,40 @@ public abstract class Aggregator<S extends AggregationState>
@Getter
private final List<Expression> arguments;
protected final ExprCoreType returnType;
@Setter
@Getter
@Accessors(fluent = true)
protected Expression condition;

/**
* Create an {@link AggregationState} which will be used for aggregation.
*/
public abstract S create();

/**
* Iterate on the {@link BindingTuple}.
* Iterate on {@link ExprValue}.
* @param value {@link ExprValue}
* @param state {@link AggregationState}
* @return {@link AggregationState}
*/
protected abstract S iterate(ExprValue value, S state);

/**
* Let the aggregator iterate on the {@link BindingTuple}
* To filter out ExprValues that are missing, null or cannot satisfy {@link #condition}
* Before the specific aggregator iterating ExprValue in the tuple.
*
* @param tuple {@link BindingTuple}
* @param state {@link AggregationState}
* @return {@link AggregationState}
*/
public abstract S iterate(BindingTuple tuple, S state);
public S iterate(BindingTuple tuple, S state) {
ExprValue value = getArguments().get(0).valueOf(tuple);
if (value.isNull() || value.isMissing() || !conditionValue(tuple)) {
return state;
}
return iterate(value, state);
}

@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
Expand All @@ -77,4 +100,14 @@ public <T, C> T accept(ExpressionNodeVisitor<T, C> visitor, C context) {
return visitor.visitAggregator(this, context);
}

/**
* Util method to get value of condition in aggregation filter.
*/
public boolean conditionValue(BindingTuple tuple) {
if (condition == null) {
return true;
}
return ExprValueUtils.getBooleanValue(condition.valueOf(tuple));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,9 @@ public AvgState create() {
}

@Override
public AvgState iterate(BindingTuple tuple, AvgState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
state.count++;
state.total += ExprValueUtils.getDoubleValue(value);
}
protected AvgState iterate(ExprValue value, AvgState state) {
state.count++;
state.total += ExprValueUtils.getDoubleValue(value);
return state;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,8 @@ public CountAggregator.CountState create() {
}

@Override
public CountState iterate(BindingTuple tuple, CountState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
state.count++;
}
protected CountState iterate(ExprValue value, CountState state) {
state.count++;
return state;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,8 @@ public MaxState create() {
}

@Override
public MaxState iterate(BindingTuple tuple, MaxState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
state.max(value);
}
protected MaxState iterate(ExprValue value, MaxState state) {
state.max(value);
return state;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,8 @@ public MinState create() {
}

@Override
public MinState iterate(BindingTuple tuple, MinState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
state.min(value);
}
protected MinState iterate(ExprValue value, MinState state) {
state.min(value);
return state;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package com.amazon.opendistroforelasticsearch.sql.expression.aggregation;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionNodeVisitor;
import com.amazon.opendistroforelasticsearch.sql.storage.bindingtuple.BindingTuple;
import com.google.common.base.Strings;
Expand Down Expand Up @@ -63,8 +64,8 @@ public AggregationState create() {
}

@Override
public AggregationState iterate(BindingTuple tuple, AggregationState state) {
return delegated.iterate(tuple, state);
protected AggregationState iterate(ExprValue value, AggregationState state) {
return delegated.iterate(value, state);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,9 @@ public SumState create() {
}

@Override
public SumState iterate(BindingTuple tuple, SumState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
state.isEmptyCollection = false;
state.add(value);
}
protected SumState iterate(ExprValue value, SumState state) {
state.isEmptyCollection = false;
state.add(value);
return state;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package com.amazon.opendistroforelasticsearch.sql.analysis;

import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.field;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.function;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.intLiteral;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.qualifiedName;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.integerValue;
Expand Down Expand Up @@ -273,6 +275,16 @@ public void undefined_aggregation_function() {
assertEquals("Unsupported aggregation function ESTDC_ERROR", exception.getMessage());
}

@Test
public void aggregation_filter() {
assertAnalyzeEqual(
dsl.avg(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))),
AstDSL.filteredAggregate("avg", qualifiedName("integer_value"),
function(">", qualifiedName("integer_value"), intLiteral(1)))
);
}

protected Expression analyze(UnresolvedExpression unresolvedExpression) {
return expressionAnalyzer.analyze(unresolvedExpression, analysisContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ public void avg_arithmetic_expression() {
assertEquals(25.0, result.value());
}

@Test
public void filtered_avg() {
ExprValue result = aggregation(dsl.avg(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), tuples);
assertEquals(3.0, result.value());
}

@Test
public void avg_with_missing() {
ExprValue result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ public void count_array_field_expression() {
assertEquals(1, result.value());
}

@Test
public void filtered_count() {
ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), tuples);
assertEquals(3, result.value());
}

@Test
public void count_with_missing() {
ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ public void test_max_arithmetic_expression() {
assertEquals(4, result.value());
}

@Test
public void filtered_max() {
ExprValue result = aggregation(dsl.max(DSL.ref("integer_value", INTEGER))
.condition(dsl.less(DSL.ref("integer_value", INTEGER), DSL.literal(4))), tuples);
assertEquals(3, result.value());
}

@Test
public void test_max_null() {
ExprValue result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ public void test_min_arithmetic_expression() {
assertEquals(1, result.value());
}

@Test
public void filtered_min() {
ExprValue result = aggregation(dsl.min(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), tuples);
assertEquals(2, result.value());
}

@Test
public void test_min_null() {
ExprValue result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ public void sum_string_field_expression() {
assertEquals("unexpected type [STRING] in sum aggregation", exception.getMessage());
}

@Test
public void filtered_sum() {
ExprValue result = aggregation(dsl.sum(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), tuples);
assertEquals(9, result.value());
}

@Test
public void sum_with_missing() {
ExprValue result =
Expand Down
4 changes: 2 additions & 2 deletions docs/developing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Here are the official instructions on how to set ``JAVA_HOME`` for different pla
Elasticsearch & Kibana
----------------------

For convenience, we recommend installing Elasticsearch and Kibana on your local machine. You can download the open source ZIP for each and extract them to a folder.
For convenience, we recommend installing `Elasticsearch <https://www.elastic.co/downloads/past-releases#elasticsearch-oss>`_ and `Kibana <https://www.elastic.co/downloads/past-releases#kibana-oss>`_ on your local machine. You can download the open source ZIP for each and extract them to a folder.

If you just want to have a quick look, you can also get an Elasticsearch running with plugin installed by ``./gradlew :plugin:run``.

Expand Down Expand Up @@ -217,7 +217,7 @@ Most of the time you just need to run ./gradlew build which will make sure you p
- Run all checks according to Checkstyle configuration.
* - ./gradlew test
- Run all unit tests.
* - ./gradlew :integ-test:integTestRunner
* - ./gradlew :integ-test:integTest
- Run all integration test (this takes time).
* - ./gradlew build
- Build plugin by run all tasks above (this takes time).
Expand Down
39 changes: 39 additions & 0 deletions docs/user/dql/aggregations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,42 @@ Additionally, a ``HAVING`` clause can work without ``GROUP BY`` clause. This is
| Total of age > 100 |
+------------------------+


FILTER Clause
=============

Description
-----------

A ``FILTER`` clause can set specific condition for the current aggregation bucket, following the syntax ``aggregation_function(expr) FILTER(WHERE condition_expr)``. If a filter is specified, then only the input rows for which the condition in the filter clause evaluates to true are fed to the aggregate function; other rows are discarded. The aggregation with filter clause can be use in ``SELECT`` clause only.

FILTER with GROUP BY
--------------------

The group by aggregation with ``FILTER`` clause can set different conditions for each aggregation bucket. Here is an example to use ``FILTER`` in group by aggregation::

od> SELECT avg(age) FILTER(WHERE balance > 10000) AS filtered, gender FROM accounts GROUP BY gender
fetched rows / total rows = 2/2
+------------+----------+
| filtered | gender |
|------------+----------|
| 28.0 | F |
| 32.0 | M |
+------------+----------+

FILTER without GROUP BY
-----------------------

The ``FILTER`` clause can be used in aggregation functions without GROUP BY as well. For example::

od> SELECT
... count(*) AS unfiltered,
... count(*) FILTER(WHERE age > 34) AS filtered
... FROM accounts
fetched rows / total rows = 1/1
+--------------+------------+
| unfiltered | filtered |
|--------------+------------|
| 4 | 1 |
+--------------+------------+

0 comments on commit 012ea5e

Please sign in to comment.