diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTerms.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTerms.java index 43c8d492ed870..a84b0e369e223 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTerms.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTerms.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Releasables; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.aggregations.AggregationReduceContext; import org.elasticsearch.search.aggregations.Aggregator; @@ -16,12 +17,12 @@ import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation; +import org.elasticsearch.search.aggregations.bucket.MultiBucketAggregatorsReducer; import org.elasticsearch.search.aggregations.bucket.terms.heuristic.SignificanceHeuristic; import org.elasticsearch.search.aggregations.support.SamplingContext; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -201,50 +202,44 @@ protected AggregatorReducer getLeaderReducer(AggregationReduceContext reduceCont return new AggregatorReducer() { long globalSubsetSize = 0; long globalSupersetSize = 0; + final Map> buckets = new HashMap<>(); - final List> aggregations = new ArrayList<>(size); - - // Compute the overall result set size and the corpus size using the - // top-level Aggregations from each shard @Override public void accept(InternalAggregation aggregation) { @SuppressWarnings("unchecked") - InternalSignificantTerms terms = (InternalSignificantTerms) aggregation; + final InternalSignificantTerms terms = (InternalSignificantTerms) aggregation; + // Compute the overall result set size and the corpus size using the + // top-level Aggregations from each shard globalSubsetSize += terms.getSubsetSize(); globalSupersetSize += terms.getSupersetSize(); - aggregations.add(terms); + for (B bucket : terms.getBuckets()) { + final ReducerAndProto reducerAndProto = buckets.computeIfAbsent( + bucket.getKeyAsString(), + k -> new ReducerAndProto<>(new MultiBucketAggregatorsReducer(reduceContext, size), bucket) + ); + reducerAndProto.reducer.accept(bucket); + reducerAndProto.subsetDf[0] += bucket.subsetDf; + reducerAndProto.supersetDf[0] += bucket.supersetDf; + } } @Override public InternalAggregation get() { - final Map> buckets = new HashMap<>(); - for (InternalSignificantTerms terms : aggregations) { - for (B bucket : terms.getBuckets()) { - List existingBuckets = buckets.computeIfAbsent(bucket.getKeyAsString(), k -> new ArrayList<>(size)); - // Adjust the buckets with the global stats representing the - // total size of the pots from which the stats are drawn - existingBuckets.add( - createBucket( - bucket.getSubsetDf(), - globalSubsetSize, - bucket.getSupersetDf(), - globalSupersetSize, - bucket.aggregations, - bucket - ) - ); - } - } - final SignificanceHeuristic heuristic = getSignificanceHeuristic().rewrite(reduceContext); final int size = reduceContext.isFinalReduce() == false ? buckets.size() : Math.min(requiredSize, buckets.size()); final BucketSignificancePriorityQueue ordered = new BucketSignificancePriorityQueue<>(size); - for (Map.Entry> entry : buckets.entrySet()) { - List sameTermBuckets = entry.getValue(); - final B b = reduceBucket(sameTermBuckets, reduceContext); + for (ReducerAndProto reducerAndProto : buckets.values()) { + final B b = createBucket( + reducerAndProto.subsetDf[0], + globalSubsetSize, + reducerAndProto.supersetDf[0], + globalSupersetSize, + reducerAndProto.reducer.get(), + reducerAndProto.proto + ); b.updateScore(heuristic); if (((b.score > 0) && (b.subsetDf >= minDocCount)) || reduceContext.isFinalReduce() == false) { - B removed = ordered.insertWithOverflow(b); + final B removed = ordered.insertWithOverflow(b); if (removed == null) { reduceContext.consumeBucketsAndMaybeBreak(1); } else { @@ -254,15 +249,28 @@ public InternalAggregation get() { reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(b)); } } - B[] list = createBucketsArray(ordered.size()); + final B[] list = createBucketsArray(ordered.size()); for (int i = ordered.size() - 1; i >= 0; i--) { list[i] = ordered.pop(); } return create(globalSubsetSize, globalSupersetSize, Arrays.asList(list)); } + + @Override + public void close() { + for (ReducerAndProto reducerAndProto : buckets.values()) { + Releasables.close(reducerAndProto.reducer); + } + } }; } + private record ReducerAndProto(MultiBucketAggregatorsReducer reducer, B proto, long[] subsetDf, long[] supersetDf) { + private ReducerAndProto(MultiBucketAggregatorsReducer reducer, B proto) { + this(reducer, proto, new long[] { 0 }, new long[] { 0 }); + } + } + @Override public InternalAggregation finalizeSampling(SamplingContext samplingContext) { long supersetSize = samplingContext.scaleUp(getSupersetSize()); @@ -285,19 +293,6 @@ public InternalAggregation finalizeSampling(SamplingContext samplingContext) { ); } - private B reduceBucket(List buckets, AggregationReduceContext context) { - assert buckets.isEmpty() == false; - long subsetDf = 0; - long supersetDf = 0; - for (B bucket : buckets) { - subsetDf += bucket.subsetDf; - supersetDf += bucket.supersetDf; - } - final List aggregations = new BucketAggregationList<>(buckets); - final InternalAggregations aggs = InternalAggregations.reduce(aggregations, context); - return createBucket(subsetDf, buckets.get(0).subsetSize, supersetDf, buckets.get(0).supersetSize, aggs, buckets.get(0)); - } - abstract B createBucket( long subsetDf, long subsetSize,