From 85e40ab7281886bb5c0e1db0d3960f8839d2bd6b Mon Sep 17 00:00:00 2001 From: Mathias Fussenegger Date: Fri, 9 Jan 2015 15:00:46 +0100 Subject: [PATCH] WIP --- .../main/java/io/crate/types/DataTypes.java | 19 ++- .../aggregation/AggregationFunction.java | 17 ++- .../aggregation/AggregationState.java | 67 --------- ...regationCollector.java => Aggregator.java} | 60 ++++---- .../impl/ArbitraryAggregation.java | 101 +++---------- .../aggregation/impl/AverageAggregation.java | 135 +++++++++++------ .../impl/CollectSetAggregation.java | 131 ++++------------ .../aggregation/impl/CountAggregation.java | 76 +++------- .../aggregation/impl/MaximumAggregation.java | 141 ++++-------------- .../aggregation/impl/MinimumAggregation.java | 131 +++------------- .../aggregation/impl/SumAggregation.java | 79 +++------- .../projectors/AggregationProjector.java | 28 ++-- .../projectors/GroupingProjector.java | 104 +++++++------ .../node/AggregationStateStreamer.java | 58 ------- .../planner/node/PlanNodeStreamerVisitor.java | 21 +-- .../executor/task/LocalMergeTaskTest.java | 17 +-- .../task/DistributedMergeTaskTest.java | 25 ++-- .../aggregation/AggregationTest.java | 8 +- ...CollectorTest.java => AggregatorTest.java} | 25 ++-- .../impl/CollectSetAggregationTest.java | 12 +- .../operation/merge/MergeOperationTest.java | 20 +-- .../projectors/GroupingProjectorTest.java | 7 +- .../node/PlanNodeStreamerVisitorTest.java | 4 +- 23 files changed, 407 insertions(+), 879 deletions(-) delete mode 100644 sql/src/main/java/io/crate/operation/aggregation/AggregationState.java rename sql/src/main/java/io/crate/operation/aggregation/{AggregationCollector.java => Aggregator.java} (67%) delete mode 100644 sql/src/main/java/io/crate/planner/node/AggregationStateStreamer.java rename sql/src/test/java/io/crate/operation/aggregation/{AggregationCollectorTest.java => AggregatorTest.java} (80%) diff --git a/core/src/main/java/io/crate/types/DataTypes.java b/core/src/main/java/io/crate/types/DataTypes.java index 108f2e90e8dc..8734b7fa9678 100644 --- a/core/src/main/java/io/crate/types/DataTypes.java +++ b/core/src/main/java/io/crate/types/DataTypes.java @@ -26,6 +26,7 @@ import com.google.common.collect.ImmutableSet; import io.crate.TimestampFormat; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.collect.MapBuilder; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.logging.ESLogger; @@ -85,7 +86,7 @@ public class DataTypes { LONG ); - public static final ImmutableMap typeRegistry = ImmutableMap.builder() + public static final Map TYPE_REGISTRY = new MapBuilder() .put(UndefinedType.ID, UNDEFINED) .put(NotSupportedType.ID, NOT_SUPPORTED) .put(ByteType.ID, BYTE) @@ -122,7 +123,7 @@ public DataType create() { public DataType create(DataType innerType) { return new SetType(innerType); } - }).build(); + }).map(); private static final Set NUMBER_CONVERSIONS = ImmutableSet.builder() .addAll(NUMERIC_PRIMITIVE_TYPES) @@ -156,11 +157,11 @@ public static boolean isCollectionType(DataType type) { public static DataType fromStream(StreamInput in) throws IOException { int i = in.readVInt(); try { - DataType type = typeRegistry.get(i).create(); + DataType type = TYPE_REGISTRY.get(i).create(); type.readFrom(in); return type; } catch (NullPointerException e) { - logger.error(String.format(Locale.ENGLISH, "%d is missing in typeRegistry", i), e); + logger.error(String.format(Locale.ENGLISH, "%d is missing in TYPE_REGISTRY", i), e); throw e; } } @@ -263,9 +264,15 @@ public static DataType ofJsonObject(Object type) { if (type instanceof List) { int idCollectionType = (Integer) ((List) type).get(0); int idInnerType = (Integer) ((List) type).get(1); - return ((CollectionTypeFactory) typeRegistry.get(idCollectionType)).create(ofJsonObject(idInnerType)); + return ((CollectionTypeFactory) TYPE_REGISTRY.get(idCollectionType)).create(ofJsonObject(idInnerType)); } assert type instanceof Integer; - return typeRegistry.get(type).create(); + return TYPE_REGISTRY.get(type).create(); + } + + public static void register(int id, DataTypeFactory dataTypeFactory) { + if (TYPE_REGISTRY.put(id, dataTypeFactory) != null) { + throw new IllegalArgumentException("Already got a dataType with id " + id); + }; } } diff --git a/sql/src/main/java/io/crate/operation/aggregation/AggregationFunction.java b/sql/src/main/java/io/crate/operation/aggregation/AggregationFunction.java index ef2a6c21512b..5163e70a8a0c 100644 --- a/sql/src/main/java/io/crate/operation/aggregation/AggregationFunction.java +++ b/sql/src/main/java/io/crate/operation/aggregation/AggregationFunction.java @@ -26,26 +26,28 @@ import io.crate.operation.Input; import io.crate.planner.symbol.Function; import io.crate.planner.symbol.Symbol; +import io.crate.types.DataType; import org.elasticsearch.common.breaker.CircuitBreakingException; -public abstract class AggregationFunction implements FunctionImplementation { +public abstract class AggregationFunction implements FunctionImplementation { /** * Apply the columnValue to the argument AggState using the logic in this AggFunction * + * @param ramAccountingContext RamAccountingContext to account for additional memory usage if the state grows in size * @param state the aggregation state for the iteration * @param args the arguments according to FunctionInfo.argumentTypes - * @return false if we do not need any further iteration for this state + * @return changed state */ - public abstract boolean iterate(T state, Input... args) throws CircuitBreakingException; - + public abstract TPartial iterate(RamAccountingContext ramAccountingContext, TPartial state, Input... args) + throws CircuitBreakingException; /** * Creates a new state for this aggregation * * @return a new state instance */ - public abstract T newState(RamAccountingContext ramAccountingContext); + public abstract TPartial newState(RamAccountingContext ramAccountingContext); @Override @@ -53,4 +55,9 @@ public Symbol normalizeSymbol(Function symbol) { return symbol; } + public abstract DataType partialType(); + + public abstract TPartial reduce(TPartial state1, TPartial state2); + + public abstract TFinal terminatePartial(TPartial state); } diff --git a/sql/src/main/java/io/crate/operation/aggregation/AggregationState.java b/sql/src/main/java/io/crate/operation/aggregation/AggregationState.java deleted file mode 100644 index d0a2484d4ece..000000000000 --- a/sql/src/main/java/io/crate/operation/aggregation/AggregationState.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to CRATE Technology GmbH ("Crate") under one or more contributor - * license agreements. See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. Crate licenses - * this file to you under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. You may - * obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * However, if you have executed another commercial license agreement - * with Crate these terms will supersede the license and you may use the - * software solely pursuant to the terms of the relevant commercial agreement. - */ - -package io.crate.operation.aggregation; - -import io.crate.breaker.RamAccountingContext; -import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.io.stream.Streamable; -import org.elasticsearch.common.logging.ESLogger; -import org.elasticsearch.common.logging.ESLoggerFactory; -import org.elasticsearch.common.unit.ByteSizeValue; - -/** - * State of a aggregation function - * - * Note on serialization: - * In order to read the correct concrete AggState class on the receiver - * the receiver has to get the ParsedStatement beforehand and then use it - * to instantiate the correct concrete AggState instances. - */ -public abstract class AggregationState implements Comparable, Streamable { - - protected final RamAccountingContext ramAccountingContext; - private final ESLogger logger = ESLoggerFactory.getLogger(getClass().getName()); - - public AggregationState(RamAccountingContext ramAccountingContext) { - this.ramAccountingContext = ramAccountingContext; - // plain object size - addEstimatedSize(8); - } - - public abstract Object value(); - public abstract void reduce(T other) throws CircuitBreakingException; - - /** - * called after the rows/state have been merged on the reducer, - * but before the rows are sent to the handler. - */ - public void terminatePartial() { - // noop; - } - - protected void addEstimatedSize(long size) throws CircuitBreakingException { - if (logger.isTraceEnabled()) { - logger.trace("[{}] Adding {} bytes to RAM accounting context", getClass(), new ByteSizeValue(size)); - } - ramAccountingContext.addBytes(size); - } -} diff --git a/sql/src/main/java/io/crate/operation/aggregation/AggregationCollector.java b/sql/src/main/java/io/crate/operation/aggregation/Aggregator.java similarity index 67% rename from sql/src/main/java/io/crate/operation/aggregation/AggregationCollector.java rename to sql/src/main/java/io/crate/operation/aggregation/Aggregator.java index 009a69361b20..a981a0311f98 100644 --- a/sql/src/main/java/io/crate/operation/aggregation/AggregationCollector.java +++ b/sql/src/main/java/io/crate/operation/aggregation/Aggregator.java @@ -23,21 +23,22 @@ import io.crate.breaker.RamAccountingContext; import io.crate.operation.Input; -import io.crate.operation.collect.RowCollector; import io.crate.planner.symbol.Aggregation; import java.util.Locale; -public class AggregationCollector implements RowCollector { +/** + * A wrapper around an AggregationFunction that is aware of the aggregation steps (iter, partial, final) + * and will call the collect functions on the aggregationFunction depending on these steps. + */ +public class Aggregator { private final Input[] inputs; private final AggregationFunction aggregationFunction; private final FromImpl fromImpl; private final ToImpl toImpl; - private AggregationState aggregationState; - - public AggregationCollector(Aggregation a, AggregationFunction aggregationFunction, Input... inputs) { + public Aggregator(Aggregation a, AggregationFunction aggregationFunction, Input... inputs) { if (a.fromStep() == Aggregation.Step.PARTIAL && inputs.length > 1) { throw new UnsupportedOperationException("Aggregation from PARTIAL is only allowed with one input."); } @@ -73,43 +74,33 @@ public AggregationCollector(Aggregation a, AggregationFunction aggregationFuncti } - public boolean startCollect(RamAccountingContext ramAccountingContext) { - aggregationState = fromImpl.startCollect(ramAccountingContext); - return true; + public Object prepareState(RamAccountingContext ramAccountingContext) { + return fromImpl.prepareState(ramAccountingContext); } - public boolean processRow() { - return fromImpl.processRow(); - } - - - public Object finishCollect() { - return toImpl.finishCollect(); + public Object processRow(RamAccountingContext ramAccountingContext, Object value) { + return fromImpl.processRow(ramAccountingContext, value); } - public AggregationState state() { - return aggregationState; - } - - public void state(AggregationState state) { - aggregationState = state; + public Object finishCollect(Object state) { + return toImpl.finishCollect(state); } abstract class FromImpl { - public AggregationState startCollect(RamAccountingContext ramAccountingContext) { + public Object prepareState(RamAccountingContext ramAccountingContext) { return aggregationFunction.newState(ramAccountingContext); } - public abstract boolean processRow(); + public abstract Object processRow(RamAccountingContext ramAccountingContext, Object value); } class FromIter extends FromImpl { @Override @SuppressWarnings("unchecked") - public boolean processRow() { - return aggregationFunction.iterate(aggregationState, inputs); + public Object processRow(RamAccountingContext ramAccountingContext, Object value) { + return aggregationFunction.iterate(ramAccountingContext, value, inputs); } } @@ -117,28 +108,29 @@ class FromPartial extends FromImpl { @Override @SuppressWarnings("unchecked") - public boolean processRow() { - aggregationState.reduce((AggregationState)inputs[0].value()); - return true; + public Object processRow(RamAccountingContext ramAccountingContext, Object value) { + return aggregationFunction.reduce(value, inputs[0].value()); } } static abstract class ToImpl { - public abstract Object finishCollect(); + public abstract Object finishCollect(Object state); } class ToPartial extends ToImpl { + @Override - public Object finishCollect() { - return aggregationState; + public Object finishCollect(Object state) { + return state; } } class ToFinal extends ToImpl { + @Override - public Object finishCollect() { - aggregationState.terminatePartial(); - return aggregationState.value(); + public Object finishCollect(Object state) { + //noinspection unchecked + return aggregationFunction.terminatePartial(state); } } } diff --git a/sql/src/main/java/io/crate/operation/aggregation/impl/ArbitraryAggregation.java b/sql/src/main/java/io/crate/operation/aggregation/impl/ArbitraryAggregation.java index 37e5af1b07c6..3ce3b9e27758 100644 --- a/sql/src/main/java/io/crate/operation/aggregation/impl/ArbitraryAggregation.java +++ b/sql/src/main/java/io/crate/operation/aggregation/impl/ArbitraryAggregation.java @@ -22,23 +22,15 @@ package io.crate.operation.aggregation.impl; import com.google.common.collect.ImmutableList; -import io.crate.Streamer; import io.crate.breaker.RamAccountingContext; -import io.crate.breaker.SizeEstimator; -import io.crate.breaker.SizeEstimatorFactory; import io.crate.metadata.FunctionIdent; import io.crate.metadata.FunctionInfo; import io.crate.operation.Input; import io.crate.operation.aggregation.AggregationFunction; -import io.crate.operation.aggregation.AggregationState; import io.crate.types.DataType; import io.crate.types.DataTypes; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import java.io.IOException; - -public abstract class ArbitraryAggregation> extends AggregationFunction> { +public class ArbitraryAggregation extends AggregationFunction { public static final String NAME = "arbitrary"; @@ -47,17 +39,8 @@ public abstract class ArbitraryAggregation> extends Aggr public static void register(AggregationImplModule mod) { for (final DataType t : DataTypes.PRIMITIVE_TYPES) { mod.register(new ArbitraryAggregation( - new FunctionInfo( - new FunctionIdent(NAME, ImmutableList.of(t)), - t, FunctionInfo.Type.AGGREGATE) - ) { - @Override - public AggregationState newState(RamAccountingContext ramAccountingContext) { - SizeEstimator sizeEstimator = SizeEstimatorFactory.create(t); - return new ArbitraryAggState(ramAccountingContext, t.streamer(), sizeEstimator); - } - } - ); + new FunctionInfo(new FunctionIdent(NAME, ImmutableList.of(t)), t, + FunctionInfo.Type.AGGREGATE))); } } @@ -71,70 +54,30 @@ public FunctionInfo info() { } @Override - public boolean iterate(ArbitraryAggState state, Input... args) { - state.add((T) args[0].value()); - return false; + public DataType partialType() { + return info.returnType(); } + @Override + public Comparable newState(RamAccountingContext ramAccountingContext) { + return null; + } - static class ArbitraryAggState> extends AggregationState> { - - Streamer streamer; - private final SizeEstimator sizeEstimator; - private Object value = null; - - public ArbitraryAggState(RamAccountingContext ramAccountingContext, - Streamer streamer, - SizeEstimator sizeEstimator) { - super(ramAccountingContext); - this.streamer = streamer; - this.sizeEstimator = sizeEstimator; - } - - @Override - public Object value() { - return value; - } - - @Override - public void reduce(ArbitraryAggState other) { - if (this.value == null){ - setValue(other.value); - } - } - - @Override - public int compareTo(ArbitraryAggState o) { - if (o == null) return 1; - if (value == null) return (o.value == null ? 0 : -1); - if (o.value == null) return 1; - - return 0; // any two object that are not null are considered equal - } - - public void add(T otherValue) { - setValue(otherValue); - } - - public void setValue(Object value) { - // setValue is only called once if value changes from null to something else, size is only estimated once here - ramAccountingContext.addBytes(sizeEstimator.estimateSize(value)); - this.value = value; - } - - @Override - public String toString() { - return ""; - } + @Override + public Object iterate(RamAccountingContext ramAccountingContext, Object state, Input... args) { + return reduce(state, args[0].value()); + } - @Override - public void readFrom(StreamInput in) throws IOException { - setValue(streamer.readValueFrom(in)); + @Override + public Object reduce(Object state1, Object state2) { + if (state1 == null) { + return state2; } + return state1; + } - @Override - public void writeTo(StreamOutput out) throws IOException { - streamer.writeValueTo(out, value); - } + @Override + public Object terminatePartial(Object state) { + return state; } } diff --git a/sql/src/main/java/io/crate/operation/aggregation/impl/AverageAggregation.java b/sql/src/main/java/io/crate/operation/aggregation/impl/AverageAggregation.java index 7d89ebb7f791..8c969653233c 100644 --- a/sql/src/main/java/io/crate/operation/aggregation/impl/AverageAggregation.java +++ b/sql/src/main/java/io/crate/operation/aggregation/impl/AverageAggregation.java @@ -22,20 +22,21 @@ package io.crate.operation.aggregation.impl; import com.google.common.collect.ImmutableList; +import io.crate.Streamer; import io.crate.breaker.RamAccountingContext; import io.crate.metadata.FunctionIdent; import io.crate.metadata.FunctionInfo; import io.crate.operation.Input; import io.crate.operation.aggregation.AggregationFunction; -import io.crate.operation.aggregation.AggregationState; import io.crate.types.DataType; +import io.crate.types.DataTypeFactory; import io.crate.types.DataTypes; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import java.io.IOException; -public class AverageAggregation extends AggregationFunction { +public class AverageAggregation extends AggregationFunction { public static final String NAME = "avg"; private final FunctionInfo info; @@ -55,64 +56,85 @@ public static void register(AggregationImplModule mod) { this.info = info; } - public static class AverageAggState extends AggregationState { + public static class AverageStateType extends DataType implements Streamer, DataTypeFactory { - private double sum = 0; - private long count = 0; + public static final int ID = 1024; + private static final AverageStateType INSTANCE = new AverageStateType(); - public AverageAggState(RamAccountingContext ramAccountingContext) { - super(ramAccountingContext); - // double sum - ramAccountingContext.addBytes(8); - // long count - ramAccountingContext.addBytes(8); + private AverageStateType() { + DataTypes.register(ID, this); } @Override - public Object value() { - if (count > 0) { - return sum / count; - } else { - return null; - } + public int id() { + return ID; } @Override - public void reduce(AverageAggState other) { - if (other != null) { - sum += other.sum; - count += other.count; - } + public String getName() { + return "average_state"; } - void add(Object otherValue) { - if (otherValue != null) { - sum += ((Number) otherValue).doubleValue(); - count++; - } + @Override + public Streamer streamer() { + return this; + } + + @Override + public AverageState value(Object value) throws IllegalArgumentException, ClassCastException { + return (AverageState) value; } @Override - public void readFrom(StreamInput in) throws IOException { - sum = in.readDouble(); - count = in.readVLong(); + public int compareValueTo(AverageState val1, AverageState val2) { + if (val1 == null) return -1; + return val1.compareTo(val2); } + @Override + public AverageState readValueFrom(StreamInput in) throws IOException { + AverageState averageState = new AverageState(); + averageState.sum = in.readDouble(); + averageState.count = in.readVLong(); + return averageState; + } @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeDouble(sum); - out.writeVLong(count); + public void writeValueTo(StreamOutput out, Object v) throws IOException { + AverageState averageState = (AverageState) v; + out.writeDouble(averageState.sum); + out.writeVLong(averageState.count); } @Override - public int compareTo(AverageAggState o) { + public DataType create() { + return INSTANCE; + } + } + + public static class AverageState implements Comparable { + + private double sum = 0; + private long count = 0; + + public Double value() { + if (count > 0) { + return sum / count; + } else { + return null; + } + } + + @Override + public int compareTo(AverageState o) { if (o == null) { return 1; } else { - Double thisValue = (Double) value(); - Double other = (Double) o.value(); - return thisValue.compareTo(other); + int compare = Double.compare(sum, o.sum); + if (compare == 0) { + return Long.compare(count, o.count); + } + return compare; } } @@ -124,14 +146,43 @@ public String toString() { @Override - public boolean iterate(AverageAggState state, Input... args) { - state.add(args[0].value()); - return true; + public AverageState iterate(RamAccountingContext ramAccountingContext, AverageState state, Input... args) { + if (state != null) { + Number value = (Number) args[0].value(); + if (value != null) { + state.count++; + state.sum += value.doubleValue(); + } + } + return state; + } + + @Override + public AverageState reduce(AverageState state1, AverageState state2) { + if (state1 == null) { + return state2; + } + if (state2 == null) { + return state1; + } + state1.count += state2.count; + state1.sum += state2.sum; + return state1; + } + + @Override + public Double terminatePartial(AverageState state) { + return state.value(); + } + + @Override + public AverageState newState(RamAccountingContext ramAccountingContext) { + return new AverageState(); } @Override - public AverageAggState newState(RamAccountingContext ramAccountingContext) { - return new AverageAggState(ramAccountingContext); + public DataType partialType() { + return AverageStateType.INSTANCE; } @Override diff --git a/sql/src/main/java/io/crate/operation/aggregation/impl/CollectSetAggregation.java b/sql/src/main/java/io/crate/operation/aggregation/impl/CollectSetAggregation.java index f75bad8d2fc6..b490f600c1f9 100644 --- a/sql/src/main/java/io/crate/operation/aggregation/impl/CollectSetAggregation.java +++ b/sql/src/main/java/io/crate/operation/aggregation/impl/CollectSetAggregation.java @@ -22,67 +22,33 @@ package io.crate.operation.aggregation.impl; import com.google.common.collect.ImmutableList; -import io.crate.Streamer; import io.crate.breaker.RamAccountingContext; -import io.crate.breaker.SizeEstimator; -import io.crate.breaker.SizeEstimatorFactory; import io.crate.metadata.FunctionIdent; import io.crate.metadata.FunctionInfo; import io.crate.operation.Input; import io.crate.operation.aggregation.AggregationFunction; -import io.crate.operation.aggregation.AggregationState; import io.crate.types.DataType; import io.crate.types.DataTypes; import io.crate.types.SetType; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import java.io.IOException; import java.util.HashSet; import java.util.Set; -public abstract class CollectSetAggregation> - extends AggregationFunction { +public class CollectSetAggregation extends AggregationFunction, Set> { public static final String NAME = "collect_set"; - private final FunctionInfo info; + private FunctionInfo info; public static void register(AggregationImplModule mod) { for (final DataType dataType : DataTypes.PRIMITIVE_TYPES) { - final Streamer setStreamer = new SetType(dataType).streamer(); - - mod.register( - new CollectSetAggregation( - new FunctionInfo(new FunctionIdent(NAME, - ImmutableList.of(dataType)), - new SetType(dataType), FunctionInfo.Type.AGGREGATE - ) - ) { - @Override - public CollectSetAggState newState(RamAccountingContext ramAccountingContext) { - return new CollectSetAggState(ramAccountingContext, SizeEstimatorFactory.create(dataType)) { - @Override - public void readFrom(StreamInput in) throws IOException { - valueSize(in.readVLong()); - addEstimatedSize(valueSize()); - setValue(setStreamer.readValueFrom(in)); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeVLong(valueSize()); - setStreamer.writeValueTo(out, value()); - } - }; - } - } - ); + mod.register(new CollectSetAggregation(new FunctionInfo(new FunctionIdent(NAME, + ImmutableList.of(dataType)), + new SetType(dataType), FunctionInfo.Type.AGGREGATE))); } } - CollectSetAggregation(FunctionInfo info) { this.info = info; } @@ -93,74 +59,33 @@ public FunctionInfo info() { } @Override - public boolean iterate(CollectSetAggState state, Input... args) throws CircuitBreakingException { - state.add(args[0].value()); - return true; - } - - public static abstract class CollectSetAggState extends AggregationState { - - private final SizeEstimator sizeEstimator; - private Set value = new HashSet<>(); - private long valueSize = 0; - - public CollectSetAggState(RamAccountingContext ramAccountingContext, SizeEstimator sizeEstimator) { - super(ramAccountingContext); - this.sizeEstimator = sizeEstimator; - } - - @Override - public Set value() { - return value; - } - - @Override - public void reduce(CollectSetAggState other) throws CircuitBreakingException { - for (Object otherValue : other.value()) { - if (value.add(otherValue)) { - long otherValueSize = sizeEstimator.estimateSize(otherValue); - addEstimatedSize(otherValueSize); - valueSize += otherValueSize; - } - } - } - - void add(Object otherValue) throws CircuitBreakingException { - // ignore null values? yes - if (otherValue != null) { - if (value.add(otherValue)) { - long otherValueSize = sizeEstimator.estimateSize(otherValue); - addEstimatedSize(otherValueSize); - valueSize += otherValueSize; - } - } - } - - public void setValue(Object value) { - this.value = (Set)value; - } - - public long valueSize() { - return this.valueSize; + public Set iterate(RamAccountingContext ramAccountingContext, Set state, Input... args) throws CircuitBreakingException { + Object value = args[0].value(); + if (value == null) { + return state; } + state.add(value); + return state; + } - public void valueSize(long valueSize) { - this.valueSize = valueSize; - } + @Override + public Set newState(RamAccountingContext ramAccountingContext) { + return new HashSet<>(); + } - @Override - public int compareTo(CollectSetAggState o) { - if (o == null) return -1; - return compareValue(o.value); - } + @Override + public DataType partialType() { + return info.returnType(); + } - public int compareValue(Set otherValue) { - return value.size() < otherValue.size() ? -1 : value.size() == otherValue.size() ? 0 : 1; - } + @Override + public Set reduce(Set state1, Set state2) { + state1.addAll(state2); + return state1; + } - @Override - public String toString() { - return " terminatePartial(Set state) { + return state; } } diff --git a/sql/src/main/java/io/crate/operation/aggregation/impl/CountAggregation.java b/sql/src/main/java/io/crate/operation/aggregation/impl/CountAggregation.java index e647d17eb4b2..d7fe3e5c31a7 100644 --- a/sql/src/main/java/io/crate/operation/aggregation/impl/CountAggregation.java +++ b/sql/src/main/java/io/crate/operation/aggregation/impl/CountAggregation.java @@ -29,20 +29,15 @@ import io.crate.metadata.FunctionInfo; import io.crate.operation.Input; import io.crate.operation.aggregation.AggregationFunction; -import io.crate.operation.aggregation.AggregationState; import io.crate.planner.symbol.Function; import io.crate.planner.symbol.Literal; import io.crate.planner.symbol.Symbol; import io.crate.types.DataType; import io.crate.types.DataTypes; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import javax.annotation.Nonnull; -import java.io.IOException; import java.util.List; -public class CountAggregation extends AggregationFunction { +public class CountAggregation extends AggregationFunction { public static final String NAME = "count"; private final FunctionInfo info; @@ -74,58 +69,18 @@ public FunctionImplementation getForTypes(List dataTypes) th this.hasArgs = hasArgs; } - public static class CountAggState extends AggregationState { - - public long value = 0; - - public CountAggState(RamAccountingContext ramAccountingContext) { - super(ramAccountingContext); - // long value - ramAccountingContext.addBytes(8); - } - - @Override - public Object value() { - return value; - } - - @Override - public void reduce(CountAggState other) { - value += other.value; - } - - @Override - public void readFrom(StreamInput in) throws IOException { - value = in.readVLong(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeVLong(value); - } - - @Override - public String toString() { - return "CountAggState {" + value + "}"; - } - - @Override - public int compareTo(@Nonnull CountAggState o) { - return Long.compare(value, o.value); - } - } - @Override - public boolean iterate(CountAggState state, Input... args) { + public Long iterate(RamAccountingContext ramAccountingContext, Long state, Input... args) { if (!hasArgs || args[0].value() != null){ - state.value++; + return state + 1; } - return true; + return state; } @Override - public CountAggState newState(RamAccountingContext ramAccountingContext) { - return new CountAggState(ramAccountingContext); + public Long newState(RamAccountingContext ramAccountingContext) { + ramAccountingContext.addBytes(8L); + return 0L; } @Override @@ -139,7 +94,7 @@ public Symbol normalizeSymbol(Function function) { if (function.arguments().size() == 1) { if (function.arguments().get(0).symbolType().isValueSymbol()) { - if (((Literal)function.arguments().get(0)).valueType() == DataTypes.UNDEFINED) { + if ((function.arguments().get(0)).valueType() == DataTypes.UNDEFINED) { return Literal.newLiteral(0L); } else{ return new Function(COUNT_STAR_FUNCTION, ImmutableList.of()); @@ -148,4 +103,19 @@ public Symbol normalizeSymbol(Function function) { } return function; } + + @Override + public DataType partialType() { + return DataTypes.LONG; + } + + @Override + public Long reduce(Long state1, Long state2) { + return state1 + state2; + } + + @Override + public Long terminatePartial(Long state) { + return state; + } } diff --git a/sql/src/main/java/io/crate/operation/aggregation/impl/MaximumAggregation.java b/sql/src/main/java/io/crate/operation/aggregation/impl/MaximumAggregation.java index fa4a4e105376..0f6c68a1ee07 100644 --- a/sql/src/main/java/io/crate/operation/aggregation/impl/MaximumAggregation.java +++ b/sql/src/main/java/io/crate/operation/aggregation/impl/MaximumAggregation.java @@ -22,25 +22,16 @@ package io.crate.operation.aggregation.impl; import com.google.common.collect.ImmutableList; -import io.crate.Streamer; -import io.crate.breaker.ConstSizeEstimator; import io.crate.breaker.RamAccountingContext; -import io.crate.breaker.SizeEstimator; -import io.crate.breaker.SizeEstimatorFactory; import io.crate.metadata.FunctionIdent; import io.crate.metadata.FunctionInfo; import io.crate.operation.Input; import io.crate.operation.aggregation.AggregationFunction; -import io.crate.operation.aggregation.AggregationState; import io.crate.types.DataType; import io.crate.types.DataTypes; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import java.io.IOException; - -public abstract class MaximumAggregation extends AggregationFunction { +public class MaximumAggregation extends AggregationFunction { public static final String NAME = "max"; @@ -52,19 +43,7 @@ public static void register(AggregationImplModule mod) { new MaximumAggregation( new FunctionInfo(new FunctionIdent(NAME, ImmutableList.of(dataType)), dataType, FunctionInfo.Type.AGGREGATE) - ) { - @Override - public MaximumAggState newState(RamAccountingContext ramAccountingContext) { - SizeEstimator sizeEstimator = SizeEstimatorFactory.create(dataType); - if (sizeEstimator instanceof ConstSizeEstimator) { - return new MaximumAggState( - dataType.streamer(), ramAccountingContext, ((ConstSizeEstimator) sizeEstimator).size()); - } else { - return new VariableSizeMaximumAggState( - dataType.streamer(), ramAccountingContext, sizeEstimator); - } - } - } + ) ); } } @@ -79,107 +58,37 @@ public FunctionInfo info() { } @Override - public boolean iterate(MaximumAggState state, Input... args) throws CircuitBreakingException { - Object value = args[0].value(); - assert value == null || value instanceof Comparable; - state.add((Comparable) value); - return true; + public DataType partialType() { + return info().returnType(); } - public static class VariableSizeMaximumAggState extends MaximumAggState { - - private final SizeEstimator sizeEstimator; - - VariableSizeMaximumAggState(Streamer streamer, - RamAccountingContext ramAccountingContext, - SizeEstimator sizeEstimator) { - super(streamer, ramAccountingContext); - this.sizeEstimator = sizeEstimator; - } - - @Override - public void setValue(Comparable newValue) throws CircuitBreakingException { - ramAccountingContext.addBytes(sizeEstimator.estimateSizeDelta(value(), newValue)); - super.setValue(newValue); - } + @Override + public Comparable newState(RamAccountingContext ramAccountingContext){ + return null; } - public static class MaximumAggState extends AggregationState { - - private final Streamer streamer; - private Comparable value = null; - - private MaximumAggState(Streamer streamer, RamAccountingContext ramAccountingContext) { - super(ramAccountingContext); - this.streamer = streamer; - } - - MaximumAggState(Streamer streamer, RamAccountingContext ramAccountingContext, long constStateSize) { - this(streamer, ramAccountingContext); - ramAccountingContext.addBytes(constStateSize); - } - - @Override - public Object value() { - return value; - } - - @Override - public void reduce(MaximumAggState other) throws CircuitBreakingException { - if (other.value == null) { - return; - } - if (value == null || compareTo(other) < 0) { - setValue(other.value); - } - } - - void add(Comparable otherValue) throws CircuitBreakingException { - if (otherValue == null) { - return; - } - if (value == null || compareValue(otherValue) < 0) { - setValue(otherValue); - } - } - - public void setValue(Comparable newValue) throws CircuitBreakingException { - this.value = newValue; - } - - @Override - public int compareTo(MaximumAggState o) { - if (o == null) return -1; - return compareValue(o.value); - } - - private int compareValue(Comparable otherValue) { - if (value == null) return (otherValue == null ? 0 : -1); - if (otherValue == null) return 1; + @Override + public Comparable iterate(RamAccountingContext ramAccountingContext, Comparable state, Input... args) throws CircuitBreakingException { + Object value = args[0].value(); + return reduce(state, (Comparable) value); + } - //noinspection unchecked - return value.compareTo(otherValue); + @Override + public Comparable reduce(Comparable state1, Comparable state2) { + if (state1 == null) { + return state2; } - - @Override - public String toString() { - return " { +public class MinimumAggregation extends AggregationFunction { public static final String NAME = "min"; @@ -50,16 +41,6 @@ public static void register(AggregationImplModule mod) { for (final DataType dataType : DataTypes.PRIMITIVE_TYPES) { mod.register(new MinimumAggregation( new FunctionInfo(new FunctionIdent(NAME, ImmutableList.of(dataType)), dataType, FunctionInfo.Type.AGGREGATE)) { - - @Override - public MinimumAggregation.MinimumAggState newState(RamAccountingContext ramAccountingContext) { - SizeEstimator sizeEstimator = SizeEstimatorFactory.create(dataType); - if (sizeEstimator instanceof ConstSizeEstimator) { - return new MinimumAggState(dataType.streamer(), ramAccountingContext, ((ConstSizeEstimator) sizeEstimator).size()); - } else { - return new VariableMinimumAggState(dataType.streamer(), ramAccountingContext, sizeEstimator); - } - } }); } } @@ -74,100 +55,36 @@ public FunctionInfo info() { } @Override - public boolean iterate(MinimumAggState state, Input... args) throws CircuitBreakingException { - Object value = args[0].value(); - assert value == null || value instanceof Comparable; - state.add((Comparable) value); - return true; + public Comparable newState(RamAccountingContext ramAccountingContext) { + return null; } - static class VariableMinimumAggState extends MinimumAggState { - - private final SizeEstimator sizeEstimator; - - public VariableMinimumAggState(Streamer streamer, - RamAccountingContext ramAccountingContext, - SizeEstimator sizeEstimator) { - super(streamer, ramAccountingContext); - this.sizeEstimator = sizeEstimator; - } - - @Override - public void setValue(Comparable newValue) throws CircuitBreakingException { - ramAccountingContext.addBytes(sizeEstimator.estimateSizeDelta(this.value, newValue)); - super.setValue(newValue); - } + @Override + public DataType partialType() { + return info().returnType(); } - public static class MinimumAggState extends AggregationState { - - private final Streamer streamer; - protected Comparable value = null; - - private MinimumAggState(Streamer streamer, RamAccountingContext ramAccountingContext) { - super(ramAccountingContext); - this.streamer = streamer; - } - - MinimumAggState(Streamer streamer, RamAccountingContext ramAccountingContext, long constStateSize) { - this(streamer, ramAccountingContext); - ramAccountingContext.addBytes(constStateSize); - } - - @Override - public Object value() { - return value; - } - - @Override - public void readFrom(StreamInput in) throws IOException { - if (!in.readBoolean()) { - setValue((Comparable) streamer.readValueFrom(in)); - } - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - Object value = value(); - out.writeBoolean(value == null); - if (value != null) { - streamer.writeValueTo(out, value); - } - } - - @Override - public void reduce(MinimumAggState other) throws CircuitBreakingException { - add(other.value); + @Override + public Comparable reduce(Comparable state1, Comparable state2) { + if (state1 == null) { + return state2; } - - void add(Comparable otherValue) throws CircuitBreakingException { - if (otherValue == null) { - return; - } - if (value == null || compareValue(otherValue) > 0) { - setValue(otherValue); - } + if (state2 == null) { + return state1; } - - public void setValue(Comparable newValue) throws CircuitBreakingException { - this.value = newValue; - } - - @Override - public int compareTo(MinimumAggState o) { - if (o == null) return -1; - return compareValue(o.value); + if (state1.compareTo(state2) > 0) { + return state2; } + return state1; + } - public int compareValue(Comparable otherValue) { - if (value == null) return (otherValue == null ? 0 : 1); - if (otherValue == null) return -1; - return value.compareTo(otherValue); - } + @Override + public Comparable terminatePartial(Comparable state) { + return state; + } - @Override - public String toString() { - return " { +public class SumAggregation extends AggregationFunction { public static final String NAME = "sum"; + private final FunctionInfo info; public static void register(AggregationImplModule mod) { @@ -52,67 +49,35 @@ public static void register(AggregationImplModule mod) { this.info = info; } - public static class SumAggState extends AggregationState { - - private Double value = null; // sum that aggregates nothing returns null, not 0.0 - - public SumAggState(RamAccountingContext ramAccountingContext) { - super(ramAccountingContext); - ramAccountingContext.addBytes(8); - } - - @Override - public Object value() { - return value; - } - - @Override - public void reduce(SumAggState other) { - add(other.value); - } - - public void add(Object value) { - if (value != null) { - this.value = (this.value == null ? 0.0 : this.value) + ((Number)value).doubleValue(); - } - } - - @Override - public int compareTo(SumAggState o) { - if (o == null) return 1; - if (value == null) return o.value == null ? 0 : -1; - if (o.value == null) return 1; - - return Double.compare(value, o.value); - } + @Override + public Double iterate(RamAccountingContext ramAccountingContext, Double state, Input... args) throws CircuitBreakingException { + return reduce(state, DataTypes.DOUBLE.value(args[0].value())); + } - @Override - @SuppressWarnings("unchecked") - public void readFrom(StreamInput in) throws IOException { - if (!in.readBoolean()) { - value = in.readDouble(); - } + @Override + public Double reduce(Double state1, Double state2) { + if (state1 == null) { + return state2; } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeBoolean(value == null); - if (value != null) { - out.writeDouble(value); - } + if (state2 == null) { + return state1; } + return state1 + state2; } + @Override + public Double terminatePartial(Double state) { + return state; + } @Override - public boolean iterate(SumAggState state, Input... args) { - state.add(args[0].value()); - return true; + public Double newState(RamAccountingContext ramAccountingContext) { + return null; } @Override - public SumAggState newState(RamAccountingContext ramAccountingContext) { - return new SumAggState(ramAccountingContext); + public DataType partialType() { + return info.returnType(); } @Override diff --git a/sql/src/main/java/io/crate/operation/projectors/AggregationProjector.java b/sql/src/main/java/io/crate/operation/projectors/AggregationProjector.java index 95e4323f5e9c..7728146f80af 100644 --- a/sql/src/main/java/io/crate/operation/projectors/AggregationProjector.java +++ b/sql/src/main/java/io/crate/operation/projectors/AggregationProjector.java @@ -24,7 +24,7 @@ import io.crate.breaker.RamAccountingContext; import io.crate.operation.AggregationContext; import io.crate.operation.ProjectorUpstream; -import io.crate.operation.aggregation.AggregationCollector; +import io.crate.operation.aggregation.Aggregator; import io.crate.operation.collect.CollectExpression; import java.util.Set; @@ -33,9 +33,11 @@ public class AggregationProjector implements Projector { - private final AggregationCollector[] aggregationCollectors; + private final Aggregator[] aggregators; private final Set> collectExpressions; private final Object[] row; + private final Object[] states; + private final RamAccountingContext ramAccountingContext; private Projector downstream; private final AtomicInteger remainingUpstreams = new AtomicInteger(0); private final AtomicReference upstreamFailure = new AtomicReference<>(null); @@ -43,19 +45,21 @@ public class AggregationProjector implements Projector { public AggregationProjector(Set> collectExpressions, AggregationContext[] aggregations, RamAccountingContext ramAccountingContext) { + this.ramAccountingContext = ramAccountingContext; row = new Object[aggregations.length]; + states = new Object[aggregations.length]; this.collectExpressions = collectExpressions; - aggregationCollectors = new AggregationCollector[aggregations.length]; - for (int i = 0; i < aggregationCollectors.length; i++) { - aggregationCollectors[i] = new AggregationCollector( + aggregators = new Aggregator[aggregations.length]; + for (int i = 0; i < aggregators.length; i++) { + aggregators[i] = new Aggregator( aggregations[i].symbol(), aggregations[i].function(), aggregations[i].inputs() ); - // startCollect creates the aggregationState. In case of the AggregationProjector + // prepareState creates the aggregationState. In case of the AggregationProjector // we only want to have 1 global state not 1 state per node/shard or even document. - aggregationCollectors[i].startCollect(ramAccountingContext); + states[i] = aggregators[i].prepareState(ramAccountingContext); } } @@ -81,8 +85,10 @@ public synchronized boolean setNextRow(Object... row) { for (CollectExpression collectExpression : collectExpressions) { collectExpression.setNextRow(row); } - for (AggregationCollector aggregationCollector : aggregationCollectors) { - aggregationCollector.processRow(); + for (int i = 0; i < aggregators.length; i++) { + Aggregator aggregator = aggregators[i]; + states[i] = aggregator.processRow(ramAccountingContext, states[i]); + } return upstreamFailure.get() == null; } @@ -97,8 +103,8 @@ public void upstreamFinished() { if (remainingUpstreams.decrementAndGet() > 0) { return; } - for (int i = 0; i < aggregationCollectors.length; i++) { - row[i] = aggregationCollectors[i].finishCollect(); + for (int i = 0; i < aggregators.length; i++) { + row[i] = aggregators[i].finishCollect(states[i]); } if (downstream != null) { downstream.setNextRow(row); diff --git a/sql/src/main/java/io/crate/operation/projectors/GroupingProjector.java b/sql/src/main/java/io/crate/operation/projectors/GroupingProjector.java index 66cd51ab9ea0..e17714d3f30a 100644 --- a/sql/src/main/java/io/crate/operation/projectors/GroupingProjector.java +++ b/sql/src/main/java/io/crate/operation/projectors/GroupingProjector.java @@ -29,8 +29,7 @@ import io.crate.operation.AggregationContext; import io.crate.operation.Input; import io.crate.operation.ProjectorUpstream; -import io.crate.operation.aggregation.AggregationCollector; -import io.crate.operation.aggregation.AggregationState; +import io.crate.operation.aggregation.Aggregator; import io.crate.operation.collect.CollectExpression; import io.crate.types.DataType; import io.crate.types.DataTypes; @@ -40,7 +39,10 @@ import org.elasticsearch.common.unit.ByteSizeValue; import javax.annotation.Nullable; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -66,9 +68,9 @@ public GroupingProjector(List keyTypes, this.collectExpressions = collectExpressions; this.ramAccountingContext = ramAccountingContext; - AggregationCollector[] aggregationCollectors = new AggregationCollector[aggregations.length]; + Aggregator[] aggregators = new Aggregator[aggregations.length]; for (int i = 0; i < aggregations.length; i++) { - aggregationCollectors[i] = new AggregationCollector( + aggregators[i] = new Aggregator( aggregations[i].symbol(), aggregations[i].function(), aggregations[i].inputs() @@ -79,10 +81,10 @@ public GroupingProjector(List keyTypes, ramAccountingContext.addBytes(8); if (keyInputs.size() == 1) { grouper = new SingleKeyGrouper(keyInputs.get(0), keyTypes.get(0), - collectExpressions, aggregationCollectors); + collectExpressions, aggregators); } else { grouper = new ManyKeyGrouper(keyInputs, keyTypes, - collectExpressions, aggregationCollectors); + collectExpressions, aggregators); } } @@ -158,9 +160,9 @@ public void upstreamFailed(Throwable throwable) { /** * transform map entry into pre-allocated object array. */ - private static void transformToRow(Map.Entry, AggregationState[]> entry, + private static void transformToRow(Map.Entry, Object[]> entry, Object[] row, - AggregationCollector[] aggregationCollectors) { + Aggregator[] aggregators) { int c = 0; for (Object o : entry.getKey()) { @@ -168,24 +170,22 @@ private static void transformToRow(Map.Entry, AggregationState[]> e c++; } - AggregationState[] aggregationStates = entry.getValue(); - for (int i = 0; i < aggregationStates.length; i++) { - aggregationCollectors[i].state(aggregationStates[i]); - row[c] = aggregationCollectors[i].finishCollect(); + Object[] states = entry.getValue(); + for (int i = 0; i < states.length; i++) { + row[c] = aggregators[i].finishCollect(states[i]); c++; } } - private static void singleTransformToRow(Map.Entry entry, + private static void singleTransformToRow(Map.Entry entry, Object[] row, - AggregationCollector[] aggregationCollectors) { + Aggregator[] aggregators) { int c = 0; row[c] = entry.getKey(); c++; - AggregationState[] aggregationStates = entry.getValue(); - for (int i = 0; i < aggregationStates.length; i++) { - aggregationCollectors[i].state(aggregationStates[i]); - row[c] = aggregationCollectors[i].finishCollect(); + Object[] states = entry.getValue(); + for (int i = 0; i < states.length; i++) { + row[c] = aggregators[i].finishCollect(states[i]); c++; } } @@ -201,8 +201,8 @@ private interface Grouper { private class SingleKeyGrouper implements Grouper { - private final Map result; - private final AggregationCollector[] aggregationCollectors; + private final Map result; + private final Aggregator[] aggregators; private final Input keyInput; private final CollectExpression[] collectExpressions; private final SizeEstimator sizeEstimator; @@ -210,11 +210,11 @@ private class SingleKeyGrouper implements Grouper { public SingleKeyGrouper(Input keyInput, DataType keyInputType, CollectExpression[] collectExpressions, - AggregationCollector[] aggregationCollectors) { + Aggregator[] aggregators) { this.collectExpressions = collectExpressions; this.result = new HashMap<>(); this.keyInput = keyInput; - this.aggregationCollectors = aggregationCollectors; + this.aggregators = aggregators; sizeEstimator = SizeEstimatorFactory.create(keyInputType); } @@ -228,21 +228,19 @@ public boolean setNextRow(Object... row) { // HashMap.get requires some objects (iterators) and at least 2 integers ramAccountingContext.addBytes(32); - AggregationState[] states = result.get(key); + Object[] states = result.get(key); if (states == null) { - states = new AggregationState[aggregationCollectors.length]; - for (int i = 0; i < aggregationCollectors.length; i++) { - aggregationCollectors[i].startCollect(ramAccountingContext); - aggregationCollectors[i].processRow(); - states[i] = aggregationCollectors[i].state(); + states = new Object[aggregators.length]; + for (int i = 0; i < aggregators.length; i++) { + Object state = aggregators[i].prepareState(ramAccountingContext); + states[i] = aggregators[i].processRow(ramAccountingContext, state); } ramAccountingContext.addBytes( RamAccountingContext.roundUp(sizeEstimator.estimateSize(key)) + 24); // 24 bytes overhead per entry result.put(key, states); } else { - for (int i = 0; i < aggregationCollectors.length; i++) { - aggregationCollectors[i].state(states[i]); - aggregationCollectors[i].processRow(); + for (int i = 0; i < aggregators.length; i++) { + states[i] = aggregators[i].processRow(ramAccountingContext, states[i]); } } @@ -261,13 +259,14 @@ public Object[][] finish() { ramAccountingContext.addBytes(RamAccountingContext.roundUp(12 + result.size() * 4)); // 2nd level ramAccountingContext.addBytes(RamAccountingContext.roundUp( - (1 + aggregationCollectors.length) * 4 + 12)); - Object[][] rows = new Object[result.size()][1 + aggregationCollectors.length]; + (1 + aggregators.length) * 4 + 12)); + Object[][] rows = new Object[result.size()][1 + aggregators.length]; boolean sendToDownStream = downstream != null; int r = 0; - for (Map.Entry entry : result.entrySet()) { + + for (Map.Entry entry : result.entrySet()) { Object[] row = rows[r]; - singleTransformToRow(entry, row, aggregationCollectors); + singleTransformToRow(entry, row, aggregators); if (sendToDownStream) { sendToDownStream = downstream.setNextRow(row); } @@ -282,8 +281,8 @@ public Object[][] finish() { private class ManyKeyGrouper implements Grouper { - private final AggregationCollector[] aggregationCollectors; - private final Map, AggregationState[]> result; + private final Aggregator[] aggregators; + private final Map, Object[]> result; private final List> keyInputs; private final CollectExpression[] collectExpressions; private final List> sizeEstimators; @@ -291,11 +290,11 @@ private class ManyKeyGrouper implements Grouper { public ManyKeyGrouper(List> keyInputs, List keyTypes, CollectExpression[] collectExpressions, - AggregationCollector[] aggregationCollectors) { + Aggregator[] aggregators) { this.collectExpressions = collectExpressions; this.result = new HashMap<>(); this.keyInputs = keyInputs; - this.aggregationCollectors = aggregationCollectors; + this.aggregators = aggregators; sizeEstimators = new ArrayList<>(keyTypes.size()); for (DataType dataType : keyTypes) { sizeEstimators.add(SizeEstimatorFactory.create(dataType)); @@ -324,20 +323,19 @@ public boolean setNextRow(Object... row) { // HashMap.get requires some objects (iterators) and at least 2 integers ramAccountingContext.addBytes(32); - AggregationState[] states = result.get(key); + Object[] states = result.get(key); if (states == null) { - states = new AggregationState[aggregationCollectors.length]; - for (int i = 0; i < aggregationCollectors.length; i++) { - aggregationCollectors[i].startCollect(ramAccountingContext); - aggregationCollectors[i].processRow(); - states[i] = aggregationCollectors[i].state(); + states = new Object[aggregators.length]; + for (int i = 0; i < aggregators.length; i++) { + Object state = aggregators[i].prepareState(ramAccountingContext); + state = aggregators[i].processRow(ramAccountingContext, state); + states[i] = state; } ramAccountingContext.addBytes(24); // 24 bytes overhead per map entry result.put(key, states); } else { - for (int i = 0; i < aggregationCollectors.length; i++) { - aggregationCollectors[i].state(states[i]); - aggregationCollectors[i].processRow(); + for (int i = 0; i < aggregators.length; i++) { + states[i] = aggregators[i].processRow(ramAccountingContext, states[i]); } } @@ -355,13 +353,13 @@ public Object[][] finish() { ramAccountingContext.addBytes(RamAccountingContext.roundUp(12 + result.size() * 4)); // 2nd level ramAccountingContext.addBytes(RamAccountingContext.roundUp(12 + - (keyInputs.size() + aggregationCollectors.length) * 4)); - Object[][] rows = new Object[result.size()][keyInputs.size() + aggregationCollectors.length]; + (keyInputs.size() + aggregators.length) * 4)); + Object[][] rows = new Object[result.size()][keyInputs.size() + aggregators.length]; boolean sendToDownStream = downstream != null; int r = 0; - for (Map.Entry, AggregationState[]> entry : result.entrySet()) { + for (Map.Entry, Object[]> entry : result.entrySet()) { Object[] row = rows[r]; - transformToRow(entry, row, aggregationCollectors); + transformToRow(entry, row, aggregators); if (sendToDownStream) { sendToDownStream = downstream.setNextRow(row); } diff --git a/sql/src/main/java/io/crate/planner/node/AggregationStateStreamer.java b/sql/src/main/java/io/crate/planner/node/AggregationStateStreamer.java deleted file mode 100644 index 818d614253c8..000000000000 --- a/sql/src/main/java/io/crate/planner/node/AggregationStateStreamer.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to CRATE Technology GmbH ("Crate") under one or more contributor - * license agreements. See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. Crate licenses - * this file to you under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. You may - * obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * However, if you have executed another commercial license agreement - * with Crate these terms will supersede the license and you may use the - * software solely pursuant to the terms of the relevant commercial agreement. - */ - -package io.crate.planner.node; - -import io.crate.Streamer; -import io.crate.breaker.RamAccountingContext; -import io.crate.operation.aggregation.AggregationFunction; -import io.crate.operation.aggregation.AggregationState; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; - -import java.io.IOException; - -/** - * Streamer used for {@link io.crate.operation.aggregation.AggregationState}s - */ -public class AggregationStateStreamer implements Streamer { - - private final AggregationFunction aggregationFunction; - private final RamAccountingContext ramAccountingContext; - - public AggregationStateStreamer(AggregationFunction aggregationFunction, - RamAccountingContext ramAccountingContext) { - this.aggregationFunction = aggregationFunction; - this.ramAccountingContext = ramAccountingContext; - } - - @Override - public AggregationState readValueFrom(StreamInput in) throws IOException { - AggregationState aggState = this.aggregationFunction.newState(ramAccountingContext); - aggState.readFrom(in); - return aggState; - } - - @Override - public void writeValueTo(StreamOutput out, Object v) throws IOException { - ((AggregationState)v).writeTo(out); - } -} diff --git a/sql/src/main/java/io/crate/planner/node/PlanNodeStreamerVisitor.java b/sql/src/main/java/io/crate/planner/node/PlanNodeStreamerVisitor.java index 5fce446cd60c..4558c31b9f57 100644 --- a/sql/src/main/java/io/crate/planner/node/PlanNodeStreamerVisitor.java +++ b/sql/src/main/java/io/crate/planner/node/PlanNodeStreamerVisitor.java @@ -83,15 +83,9 @@ public Context process(PlanNode node, RamAccountingContext ramAccountingContext) return context; } - private AggregationStateStreamer getStreamer(AggregationFunction aggregationFunction, - RamAccountingContext ramAccountingContext) { - return new AggregationStateStreamer(aggregationFunction, ramAccountingContext); - } - - private Streamer resolveStreamer(Aggregation aggregation, Aggregation.Step step, - RamAccountingContext ramAccountingContext) { + private Streamer resolveStreamer(Aggregation aggregation, Aggregation.Step step) { Streamer streamer; - AggregationFunction aggFunction = (AggregationFunction)functions.get(aggregation.functionIdent()); + AggregationFunction aggFunction = (AggregationFunction)functions.get(aggregation.functionIdent()); if (aggFunction == null) { throw new ResourceUnknownException("unknown aggregation function"); } @@ -101,7 +95,7 @@ private Streamer resolveStreamer(Aggregation aggregation, Aggregation.Step st streamer = aggFunction.info().ident().argumentTypes().get(0).streamer(); break; case PARTIAL: - streamer = getStreamer(aggFunction, ramAccountingContext); + streamer = aggFunction.partialType().streamer(); break; case FINAL: streamer = aggFunction.info().returnType().streamer(); @@ -136,8 +130,7 @@ public Void visitCollectNode(CollectNode node, Context context) { try { aggregation = aggregations.get(aggIdx); if (aggregation != null) { - context.outputStreamers.add(resolveStreamer(aggregation, - aggregation.toStep(), context.ramAccountingContext)); + context.outputStreamers.add(resolveStreamer(aggregation, aggregation.toStep())); } } catch (IndexOutOfBoundsException e) { // assume this is an unknown column @@ -212,8 +205,7 @@ private void resolveStreamer(Streamer[] streamers, streamers[columnIdx] = symbol.valueType().streamer(); } else if (symbol.symbolType() == SymbolType.AGGREGATION) { Aggregation aggregation = (Aggregation)symbol; - streamers[columnIdx] = resolveStreamer(aggregation, aggregation.toStep(), - ramAccountingContext); + streamers[columnIdx] = resolveStreamer(aggregation, aggregation.toStep()); } else if (symbol.symbolType() == SymbolType.INPUT_COLUMN) { columnIdx = ((InputColumn)symbol).index(); if (projectionIdx > 0) { @@ -248,8 +240,7 @@ private void setInputStreamers(List inputTypes, Projection projection, context.inputStreamers.add(inputType.streamer()); } else { Aggregation aggregation = aggregations.get(idx); - context.inputStreamers.add(resolveStreamer(aggregation, aggregation.fromStep(), - context.ramAccountingContext)); + context.inputStreamers.add(resolveStreamer(aggregation, aggregation.fromStep())); idx++; } } diff --git a/sql/src/test/java/io/crate/executor/task/LocalMergeTaskTest.java b/sql/src/test/java/io/crate/executor/task/LocalMergeTaskTest.java index 8ca0a5a72190..3725ef2d3364 100644 --- a/sql/src/test/java/io/crate/executor/task/LocalMergeTaskTest.java +++ b/sql/src/test/java/io/crate/executor/task/LocalMergeTaskTest.java @@ -23,7 +23,6 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import io.crate.breaker.RamAccountingContext; import io.crate.executor.QueryResult; import io.crate.executor.TaskResult; import io.crate.executor.transport.TransportActionProvider; @@ -37,7 +36,6 @@ import io.crate.planner.RowGranularity; import io.crate.planner.node.dql.MergeNode; import io.crate.planner.projection.GroupProjection; -import io.crate.planner.projection.Projection; import io.crate.planner.projection.TopNProjection; import io.crate.planner.symbol.Aggregation; import io.crate.planner.symbol.InputColumn; @@ -68,18 +66,13 @@ public class LocalMergeTaskTest { - private static final RamAccountingContext ramAccountingContext = - new RamAccountingContext("dummy", new NoopCircuitBreaker(CircuitBreaker.Name.FIELDDATA)); - private ImplementationSymbolVisitor symbolVisitor; - private AggregationFunction minAggFunction; private GroupProjection groupProjection; - private Injector injector; @Before @SuppressWarnings("unchecked") public void prepare() { - injector = new ModulesBuilder() + Injector injector = new ModulesBuilder() .add(new AggregationImplModule()) .add(new AbstractModule() { @Override @@ -93,7 +86,7 @@ protected void configure() { symbolVisitor = new ImplementationSymbolVisitor(referenceResolver, functions, RowGranularity.CLUSTER); FunctionIdent minAggIdent = new FunctionIdent(MinimumAggregation.NAME, Arrays.asList(DataTypes.DOUBLE)); - minAggFunction = (AggregationFunction) functions.get(minAggIdent); + AggregationFunction minAggFunction = (AggregationFunction) functions.get(minAggIdent); groupProjection = new GroupProjection(); groupProjection.keys(Arrays.asList(new InputColumn(0, DataTypes.INTEGER))); @@ -106,11 +99,9 @@ private ListenableFuture getUpstreamResult(int numRows) { Object[][] rows = new Object[numRows][]; for (int i=0; iasList(new InputColumn(0), new InputColumn(1))); MergeNode mergeNode = new MergeNode("merge", 2); - mergeNode.projections(Arrays.asList( + mergeNode.projections(Arrays.asList( groupProjection, topNProjection )); diff --git a/sql/src/test/java/io/crate/executor/transport/task/DistributedMergeTaskTest.java b/sql/src/test/java/io/crate/executor/transport/task/DistributedMergeTaskTest.java index d62b8e2257b8..98e853b74234 100644 --- a/sql/src/test/java/io/crate/executor/transport/task/DistributedMergeTaskTest.java +++ b/sql/src/test/java/io/crate/executor/transport/task/DistributedMergeTaskTest.java @@ -2,7 +2,6 @@ import com.google.common.collect.ImmutableList; import io.crate.Streamer; -import io.crate.breaker.RamAccountingContext; import io.crate.executor.transport.distributed.DistributedResultRequest; import io.crate.executor.transport.distributed.DistributedResultResponse; import io.crate.executor.transport.merge.TransportMergeNodeAction; @@ -11,7 +10,6 @@ import io.crate.metadata.Functions; import io.crate.operation.aggregation.AggregationFunction; import io.crate.operation.aggregation.impl.CountAggregation; -import io.crate.planner.node.AggregationStateStreamer; import io.crate.planner.node.dql.MergeNode; import io.crate.planner.projection.GroupProjection; import io.crate.planner.projection.TopNProjection; @@ -25,8 +23,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.ClusterService; import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.common.breaker.CircuitBreaker; -import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.junit.Test; import java.util.*; @@ -36,9 +32,6 @@ @CrateIntegrationTest.ClusterScope(scope = CrateIntegrationTest.Scope.GLOBAL) public class DistributedMergeTaskTest extends SQLTransportIntegrationTest { - private static final RamAccountingContext RAM_ACCOUNTING_CONTEXT = - new RamAccountingContext("dummy", new NoopCircuitBreaker(CircuitBreaker.Name.FIELDDATA)); - @Test public void testDistributedMergeTask() throws Exception { ClusterService clusterService = cluster().getInstance(ClusterService.class); @@ -77,7 +70,7 @@ public void testDistributedMergeTask() throws Exception { mergeNode.outputTypes(Arrays.asList(DataTypes.STRING, DataTypes.LONG)); Streamer[] mapperOutputStreamer = new Streamer[] { - new AggregationStateStreamer(countAggregation, RAM_ACCOUNTING_CONTEXT), + DataTypes.LONG.streamer(), DataTypes.STRING.streamer() }; @@ -92,13 +85,13 @@ public void testDistributedMergeTask() throws Exception { DistributedResultRequest request1 = new DistributedResultRequest(mergeNode.contextId(), mapperOutputStreamer); request1.rows(new Object[][] { - new Object[] { new CountAggregation.CountAggState(RAM_ACCOUNTING_CONTEXT) {{ value = 1; }}, new BytesRef("bar") }, + new Object[] { 1L , new BytesRef("bar") }, }); DistributedResultRequest request2 = new DistributedResultRequest(mergeNode.contextId(), mapperOutputStreamer); request2.rows(new Object[][] { - new Object[] { new CountAggregation.CountAggState(RAM_ACCOUNTING_CONTEXT) {{ value = 1; }}, new BytesRef("bar") }, - new Object[] { new CountAggregation.CountAggState(RAM_ACCOUNTING_CONTEXT) {{ value = 3; }}, new BytesRef("bar") }, - new Object[] { new CountAggregation.CountAggState(RAM_ACCOUNTING_CONTEXT) {{ value = 3; }}, new BytesRef("foobar") }, + new Object[] { 1L, new BytesRef("bar") }, + new Object[] { 3L, new BytesRef("bar") }, + new Object[] { 3L, new BytesRef("foobar") }, }); transportMergeNodeAction.mergeRows(firstNode, request1, noopListener); @@ -106,13 +99,13 @@ public void testDistributedMergeTask() throws Exception { DistributedResultRequest request3 = new DistributedResultRequest(mergeNode.contextId(), mapperOutputStreamer); request3.rows(new Object[][] { - new Object[] { new CountAggregation.CountAggState(RAM_ACCOUNTING_CONTEXT) {{ value = 10; }}, new BytesRef("foo") }, - new Object[] { new CountAggregation.CountAggState(RAM_ACCOUNTING_CONTEXT) {{ value = 20; }}, new BytesRef("foo") }, + new Object[] { 10, new BytesRef("foo") }, + new Object[] { 20, new BytesRef("foo") }, }); DistributedResultRequest request4 = new DistributedResultRequest(mergeNode.contextId(), mapperOutputStreamer); request4.rows(new Object[][] { - new Object[] { new CountAggregation.CountAggState(RAM_ACCOUNTING_CONTEXT) {{ value = 10; }}, new BytesRef("foo") }, - new Object[] { new CountAggregation.CountAggState(RAM_ACCOUNTING_CONTEXT) {{ value = 14; }}, new BytesRef("test") }, + new Object[] { 10, new BytesRef("foo") }, + new Object[] { 14, new BytesRef("test") }, }); String secondNode = iterator.next(); diff --git a/sql/src/test/java/io/crate/operation/aggregation/AggregationTest.java b/sql/src/test/java/io/crate/operation/aggregation/AggregationTest.java index 47e3c9f36d42..f630a9c70709 100644 --- a/sql/src/test/java/io/crate/operation/aggregation/AggregationTest.java +++ b/sql/src/test/java/io/crate/operation/aggregation/AggregationTest.java @@ -65,7 +65,6 @@ public void setUp() throws Exception { functions = injector.getInstance(Functions.class); } - public Object[][] executeAggregation(String name, DataType dataType, Object[][] data) throws Exception { FunctionIdent fi; @@ -78,16 +77,17 @@ public Object[][] executeAggregation(String name, DataType dataType, Object[][] inputs = new InputCollectExpression[0]; } AggregationFunction impl = (AggregationFunction) functions.get(fi); - AggregationState state = impl.newState(ramAccountingContext); + Object state = impl.newState(ramAccountingContext); for (Object[] row : data) { for (InputCollectExpression i : inputs) { i.setNextRow(row); } - impl.iterate(state, inputs); + state = impl.iterate(ramAccountingContext, state, inputs); } - return new Object[][]{{state.value()}}; + state = impl.terminatePartial(state); + return new Object[][]{{state}}; } } diff --git a/sql/src/test/java/io/crate/operation/aggregation/AggregationCollectorTest.java b/sql/src/test/java/io/crate/operation/aggregation/AggregatorTest.java similarity index 80% rename from sql/src/test/java/io/crate/operation/aggregation/AggregationCollectorTest.java rename to sql/src/test/java/io/crate/operation/aggregation/AggregatorTest.java index 8fb7869a1f70..10b05437b7ac 100644 --- a/sql/src/test/java/io/crate/operation/aggregation/AggregationCollectorTest.java +++ b/sql/src/test/java/io/crate/operation/aggregation/AggregatorTest.java @@ -44,19 +44,18 @@ import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertThat; -public class AggregationCollectorTest { +public class AggregatorTest { protected static final RamAccountingContext RAM_ACCOUNTING_CONTEXT = new RamAccountingContext("dummy", new NoopCircuitBreaker(CircuitBreaker.Name.FIELDDATA)); - private FunctionIdent countAggIdent; private AggregationFunction countImpl; @Before public void setUpFunctions() { Injector injector = new ModulesBuilder().add(new AggregationImplModule()).createInjector(); Functions functions = injector.getInstance(Functions.class); - countAggIdent = new FunctionIdent(CountAggregation.NAME, Arrays.asList(DataTypes.STRING)); + FunctionIdent countAggIdent = new FunctionIdent(CountAggregation.NAME, Arrays.asList(DataTypes.STRING)); countImpl = (AggregationFunction) functions.get(countAggIdent); } @@ -69,7 +68,7 @@ public void testAggregationFromPartial() { Aggregation.Step.FINAL ); Input dummyInput = new Input() { - CountAggregation.CountAggState state = new CountAggregation.CountAggState(RAM_ACCOUNTING_CONTEXT) {{ value = 10L; }}; + Long state = 10L; @Override @@ -78,11 +77,11 @@ public Object value() { } }; - AggregationCollector collector = new AggregationCollector(aggregation, countImpl, dummyInput); - collector.startCollect(RAM_ACCOUNTING_CONTEXT); - collector.processRow(); - collector.processRow(); - Object result = collector.finishCollect(); + Aggregator aggregator = new Aggregator(aggregation, countImpl, dummyInput); + Object state = aggregator.prepareState(RAM_ACCOUNTING_CONTEXT); + state = aggregator.processRow(RAM_ACCOUNTING_CONTEXT, state); + state = aggregator.processRow(RAM_ACCOUNTING_CONTEXT, state); + Object result = aggregator.finishCollect(state); assertThat((Long)result, is(20L)); } @@ -104,13 +103,13 @@ public Object value() { } }; - AggregationCollector collector = new AggregationCollector(aggregation, countImpl, dummyInput); - collector.startCollect(RAM_ACCOUNTING_CONTEXT); + Aggregator collector = new Aggregator(aggregation, countImpl, dummyInput); + Object state = collector.prepareState(RAM_ACCOUNTING_CONTEXT); for (int i = 0; i < 5; i++) { - collector.processRow(); + state = collector.processRow(RAM_ACCOUNTING_CONTEXT, state); } - long result = (Long)collector.finishCollect(); + long result = (Long)collector.finishCollect(state); assertThat(result, is(5L)); } } diff --git a/sql/src/test/java/io/crate/operation/aggregation/impl/CollectSetAggregationTest.java b/sql/src/test/java/io/crate/operation/aggregation/impl/CollectSetAggregationTest.java index 40ee1932703d..4e4ec41e453c 100644 --- a/sql/src/test/java/io/crate/operation/aggregation/impl/CollectSetAggregationTest.java +++ b/sql/src/test/java/io/crate/operation/aggregation/impl/CollectSetAggregationTest.java @@ -24,7 +24,6 @@ import com.google.common.collect.ImmutableList; import io.crate.metadata.FunctionIdent; import io.crate.operation.aggregation.AggregationFunction; -import io.crate.operation.aggregation.AggregationState; import io.crate.operation.aggregation.AggregationTest; import io.crate.types.DataType; import io.crate.types.DataTypes; @@ -64,14 +63,14 @@ public void testDouble() throws Exception { public void testLongSerialization() throws Exception { FunctionIdent fi = new FunctionIdent("collect_set", ImmutableList.of(DataTypes.LONG)); AggregationFunction impl = (AggregationFunction) functions.get(fi); - AggregationState state = impl.newState(ramAccountingContext); + + Object state = impl.newState(ramAccountingContext); BytesStreamOutput streamOutput = new BytesStreamOutput(); - state.writeTo(streamOutput); + impl.partialType().streamer().writeValueTo(streamOutput, state); - AggregationState newState = impl.newState(ramAccountingContext); - newState.readFrom(new BytesStreamInput(streamOutput.bytes())); - assertEquals(state.value(), newState.value()); + Object newState = impl.partialType().streamer().readValueFrom(new BytesStreamInput(streamOutput.bytes())); + assertEquals(state, newState); } @Test @@ -138,5 +137,4 @@ public void testNullValue() throws Exception { assertEquals(2, ((Set)result[0][0]).size()); assertFalse(((Set)result[0][0]).contains(null)); } - } diff --git a/sql/src/test/java/io/crate/operation/merge/MergeOperationTest.java b/sql/src/test/java/io/crate/operation/merge/MergeOperationTest.java index 6578b1199d26..681aaa56244b 100644 --- a/sql/src/test/java/io/crate/operation/merge/MergeOperationTest.java +++ b/sql/src/test/java/io/crate/operation/merge/MergeOperationTest.java @@ -25,7 +25,6 @@ import io.crate.executor.transport.TransportActionProvider; import io.crate.metadata.*; import io.crate.operation.ImplementationSymbolVisitor; -import io.crate.operation.aggregation.AggregationFunction; import io.crate.operation.aggregation.impl.AggregationImplModule; import io.crate.operation.aggregation.impl.MinimumAggregation; import io.crate.operation.projectors.TopN; @@ -65,14 +64,12 @@ public class MergeOperationTest { new RamAccountingContext("dummy", new NoopCircuitBreaker(CircuitBreaker.Name.FIELDDATA)); private GroupProjection groupProjection; - private AggregationFunction minAggFunction; private ImplementationSymbolVisitor symbolVisitor; - private Injector injector; @Before @SuppressWarnings("unchecked") public void prepare() { - injector = new ModulesBuilder() + Injector injector = new ModulesBuilder() .add(new AggregationImplModule()) .add(new AbstractModule() { @Override @@ -87,7 +84,6 @@ protected void configure() { FunctionIdent minAggIdent = new FunctionIdent(MinimumAggregation.NAME, Arrays.asList(DataTypes.DOUBLE)); FunctionInfo minAggInfo = new FunctionInfo(minAggIdent, DataTypes.DOUBLE); - minAggFunction = (AggregationFunction) functions.get(minAggIdent); groupProjection = new GroupProjection(); groupProjection.keys(Arrays.asList(new InputColumn(0, DataTypes.INTEGER))); @@ -120,9 +116,9 @@ public void testMergeSingleResult() throws Exception { Object[][] rows = new Object[20][]; for (int i=0; i { diff --git a/sql/src/test/java/io/crate/planner/node/PlanNodeStreamerVisitorTest.java b/sql/src/test/java/io/crate/planner/node/PlanNodeStreamerVisitorTest.java index 116955f6d6ff..3f85546b7df2 100644 --- a/sql/src/test/java/io/crate/planner/node/PlanNodeStreamerVisitorTest.java +++ b/sql/src/test/java/io/crate/planner/node/PlanNodeStreamerVisitorTest.java @@ -119,7 +119,7 @@ public void testGetOutputStreamersFromCollectNodeWithAggregations() throws Excep assertThat(streamers.length, is(4)); assertThat(streamers[0], instanceOf(DataTypes.BOOLEAN.streamer().getClass())); assertThat(streamers[1], instanceOf(DataTypes.INTEGER.streamer().getClass())); - assertThat(streamers[2], instanceOf(AggregationStateStreamer.class)); + assertThat(streamers[2], instanceOf(DataTypes.INTEGER.streamer().getClass())); assertThat(streamers[3], instanceOf(DataTypes.DOUBLE.streamer().getClass())); } @@ -173,7 +173,7 @@ public void testGetInputStreamersForMergeNodeWithAggregations() throws Exception PlanNodeStreamerVisitor.Context ctx = visitor.process(mergeNode, ramAccountingContext); Streamer[] streamers = ctx.inputStreamers(); assertThat(streamers.length, is(2)); - assertThat(streamers[0], instanceOf(AggregationStateStreamer.class)); + assertThat(streamers[0], instanceOf(DataTypes.INTEGER.streamer().getClass())); assertThat(streamers[1], instanceOf(DataTypes.TIMESTAMP.streamer().getClass())); }