Skip to content

Commit

Permalink
Support for non global aggregations with concurrent segment search. T…
Browse files Browse the repository at this point in the history
…his PR does not include the support for

profile option with aggregations to work with concurrent model

Signed-off-by: Sorabh Hamirwasia <sohami.apache@gmail.com>
  • Loading branch information
sohami committed May 11, 2023
1 parent 4ec6abd commit ee87f1e
Show file tree
Hide file tree
Showing 19 changed files with 694 additions and 197 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,9 @@ public void testMap() {
assertThat(scriptedMetricAggregation.aggregation(), notNullValue());
assertThat(scriptedMetricAggregation.aggregation(), instanceOf(ArrayList.class));
List<?> aggregationList = (List<?>) scriptedMetricAggregation.aggregation();
assertThat(aggregationList.size(), equalTo(getNumShards("idx").numPrimaries));
// with script based aggregation, if it does not support reduce then aggregationList size
// will be numShards * slicesCount
assertThat(aggregationList.size(), greaterThanOrEqualTo(getNumShards("idx").numPrimaries));
int numShardsRun = 0;
for (Object object : aggregationList) {
assertThat(object, notNullValue());
Expand Down Expand Up @@ -483,7 +485,9 @@ public void testMapWithParams() {
assertThat(scriptedMetricAggregation.aggregation(), notNullValue());
assertThat(scriptedMetricAggregation.aggregation(), instanceOf(ArrayList.class));
List<?> aggregationList = (List<?>) scriptedMetricAggregation.aggregation();
assertThat(aggregationList.size(), equalTo(getNumShards("idx").numPrimaries));
// with script based aggregation, if it does not support reduce then aggregationList size
// will be numShards * slicesCount
assertThat(aggregationList.size(), greaterThanOrEqualTo(getNumShards("idx").numPrimaries));
int numShardsRun = 0;
for (Object object : aggregationList) {
assertThat(object, notNullValue());
Expand Down Expand Up @@ -535,7 +539,9 @@ public void testInitMutatesParams() {
assertThat(scriptedMetricAggregation.aggregation(), notNullValue());
assertThat(scriptedMetricAggregation.aggregation(), instanceOf(ArrayList.class));
List<?> aggregationList = (List<?>) scriptedMetricAggregation.aggregation();
assertThat(aggregationList.size(), equalTo(getNumShards("idx").numPrimaries));
// with script based aggregation, if it does not support reduce then aggregationList size
// will be numShards * slicesCount
assertThat(aggregationList.size(), greaterThanOrEqualTo(getNumShards("idx").numPrimaries));
long totalCount = 0;
for (Object object : aggregationList) {
assertThat(object, notNullValue());
Expand Down Expand Up @@ -588,7 +594,9 @@ public void testMapCombineWithParams() {
assertThat(scriptedMetricAggregation.aggregation(), notNullValue());
assertThat(scriptedMetricAggregation.aggregation(), instanceOf(ArrayList.class));
List<?> aggregationList = (List<?>) scriptedMetricAggregation.aggregation();
assertThat(aggregationList.size(), equalTo(getNumShards("idx").numPrimaries));
// with script based aggregation, if it does not support reduce then aggregationList size
// will be numShards * slicesCount
assertThat(aggregationList.size(), greaterThanOrEqualTo(getNumShards("idx").numPrimaries));
long totalCount = 0;
for (Object object : aggregationList) {
assertThat(object, notNullValue());
Expand Down Expand Up @@ -651,7 +659,9 @@ public void testInitMapCombineWithParams() {
assertThat(scriptedMetricAggregation.aggregation(), notNullValue());
assertThat(scriptedMetricAggregation.aggregation(), instanceOf(ArrayList.class));
List<?> aggregationList = (List<?>) scriptedMetricAggregation.aggregation();
assertThat(aggregationList.size(), equalTo(getNumShards("idx").numPrimaries));
// with script based aggregation, if it does not support reduce then aggregationList size
// will be numShards * slicesCount
assertThat(aggregationList.size(), greaterThanOrEqualTo(getNumShards("idx").numPrimaries));
long totalCount = 0;
for (Object object : aggregationList) {
assertThat(object, notNullValue());
Expand Down
13 changes: 12 additions & 1 deletion server/src/main/java/org/opensearch/search/SearchModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
import org.opensearch.plugins.SearchPlugin.SuggesterSpec;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.BaseAggregationBuilder;
import org.opensearch.search.aggregations.ConcurrentAggregationProcessor;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.PipelineAggregationBuilder;
import org.opensearch.search.aggregations.bucket.adjacency.AdjacencyMatrixAggregationBuilder;
Expand Down Expand Up @@ -1290,7 +1291,17 @@ public FetchPhase getFetchPhase() {
}

public QueryPhase getQueryPhase() {
return (queryPhaseSearcher == null) ? new QueryPhase() : new QueryPhase(queryPhaseSearcher);
QueryPhase queryPhase;
if (queryPhaseSearcher == null) {
// use the defaults
queryPhase = new QueryPhase();
} else if (queryPhaseSearcher instanceof ConcurrentQueryPhaseSearcher) {
// use ConcurrentAggregationProcessor only with ConcurrentQueryPhaseSearcher
queryPhase = new QueryPhase(queryPhaseSearcher, new ConcurrentAggregationProcessor());
} else {
queryPhase = new QueryPhase(queryPhaseSearcher);
}
return queryPhase;
}

public @Nullable ExecutorService getIndexSearcherExecutor(ThreadPool pool) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.aggregations;

import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.ReduceableSearchResult;

import java.io.IOException;
import java.util.Collection;
import java.util.List;

import static org.opensearch.search.aggregations.DefaultAggregationProcessor.createCollector;

/**
* {@link CollectorManager} to take care of aggregation operators both in case of concurrent and non-concurrent
* segment search
*/
public class AggregationCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {
private final SearchContext context;

public AggregationCollectorManager(SearchContext context) {
this.context = context;
}

@Override
public Collector newCollector() throws IOException {
List<Aggregator> nonGlobalAggregators = context.aggregations().factories().createTopLevelNonGlobalAggregators(context);
assert !nonGlobalAggregators.isEmpty() : "Expected atleast one non global aggregator to be present";
context.aggregations().addNonGlobalAggregators(nonGlobalAggregators);
return createCollector(context, nonGlobalAggregators);
}

@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException {
return (result) -> {};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.aggregations;

import org.opensearch.search.internal.SearchContext;

/**
* Interface to define different stages of aggregation processing before and after document collection
*/
public interface AggregationProcessor {

/**
* Callback invoked before collection of documents are done
* @param context {@link SearchContext} for the request
*/
void preProcess(SearchContext context);

/**
* Callback invoked after collection of documents are done
* @param context {@link SearchContext} for the request
*/
void postProcess(SearchContext context);

/**
* Callback to process the {@link org.opensearch.search.aggregations.bucket.global.GlobalAggregator} present
* in the search request
* @param context {@link SearchContext} for the request
*/
void processGlobalAggregators(SearchContext context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.Rewriteable;
import org.opensearch.search.aggregations.bucket.global.GlobalAggregationBuilder;
import org.opensearch.search.aggregations.bucket.global.GlobalAggregatorFactory;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
Expand All @@ -59,6 +60,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -70,6 +72,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -237,6 +240,13 @@ private static AggregatorFactories.Builder parseAggregators(XContentParser parse

public static final AggregatorFactories EMPTY = new AggregatorFactories(new AggregatorFactory[0]);

private static final Predicate<AggregatorFactory> GLOBAL_AGGREGATOR_FACTORY_PREDICATE = new Predicate<>() {
@Override
public boolean test(AggregatorFactory o) {
return o instanceof GlobalAggregatorFactory;
}
};

private AggregatorFactory[] factories;

public static Builder builder() {
Expand Down Expand Up @@ -268,24 +278,44 @@ public Aggregator[] createSubAggregators(SearchContext searchContext, Aggregator
return aggregators;
}

public Aggregator[] createTopLevelAggregators(SearchContext searchContext) throws IOException {
public List<Aggregator> createTopLevelAggregators(SearchContext searchContext) throws IOException {
return createTopLevelAggregators(searchContext, (aggregatorFactory) -> true);
}

public List<Aggregator> createTopLevelGlobalAggregators(SearchContext searchContext) throws IOException {
return createTopLevelAggregators(searchContext, GLOBAL_AGGREGATOR_FACTORY_PREDICATE);
}

public List<Aggregator> createTopLevelNonGlobalAggregators(SearchContext searchContext) throws IOException {
return createTopLevelAggregators(searchContext, GLOBAL_AGGREGATOR_FACTORY_PREDICATE.negate());
}

private List<Aggregator> createTopLevelAggregators(SearchContext searchContext, Predicate<AggregatorFactory> factoryFilter)
throws IOException {
// These aggregators are going to be used with a single bucket ordinal, no need to wrap the PER_BUCKET ones
Aggregator[] aggregators = new Aggregator[factories.length];
List<Aggregator> aggregators = new ArrayList<>();
for (int i = 0; i < factories.length; i++) {
/*
* Top level aggs only collect from owningBucketOrd 0 which is
* *exactly* what CardinalityUpperBound.ONE *means*.
*/
Aggregator factory = factories[i].create(searchContext, null, CardinalityUpperBound.ONE);
Profilers profilers = factory.context().getProfilers();
if (profilers != null) {
factory = new ProfilingAggregator(factory, profilers.getAggregationProfiler());
Aggregator factory;
if (factoryFilter.test(factories[i])) {
factory = factories[i].create(searchContext, null, CardinalityUpperBound.ONE);
Profilers profilers = factory.context().getProfilers();
if (profilers != null) {
factory = new ProfilingAggregator(factory, profilers.getAggregationProfiler());
}
aggregators.add(factory);
}
aggregators[i] = factory;
}
return aggregators;
}

public boolean hasNonGlobalAggregator() {
return Arrays.stream(factories).anyMatch(GLOBAL_AGGREGATOR_FACTORY_PREDICATE.negate());
}

/**
* @return the number of sub-aggregator factories
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.aggregations;

import org.opensearch.search.internal.SearchContext;

import java.util.Collections;
import java.util.List;

/**
* {@link AggregationProcessor} implementation to be used with {@link org.opensearch.search.query.ConcurrentQueryPhaseSearcher}. It takes
* care of performing shard level reduce on Aggregation results collected as part of concurrent execution among slices. This is done to
* avoid the increase in aggregation result sets returned by each shard to coordinator where final reduce happens for results received from
* all the shards
*/
public class ConcurrentAggregationProcessor extends DefaultAggregationProcessor {

@Override
public void populateResult(SearchContext context, List<InternalAggregation> aggregations) {
InternalAggregations internalAggregations = InternalAggregations.from(aggregations);
// Reduce the aggregations across slices before sending to the coordinator. We will perform shard level reduce iff multiple slices
// were created to execute this request and it used concurrent segment search path
// TODO: Add the check for flag that the request was executed using concurrent search
if (context.searcher().getSlices().length > 1) {
// using reduce is fine here instead of topLevelReduce as pipeline aggregation is evaluated on the coordinator after all
// documents are collected across shards for an aggregation
internalAggregations = InternalAggregations.reduce(
Collections.singletonList(internalAggregations),
context.aggregationReduceContext()
);
}
context.queryResult().aggregations(internalAggregations);
}
}

0 comments on commit ee87f1e

Please sign in to comment.