Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Executes incremental reduce in the search thread pool #58461

Merged
merged 19 commits into from
Jul 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;

import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;

@ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 2)
Expand Down Expand Up @@ -85,17 +87,17 @@ public void onFailure(Exception e) {
if (response instanceof SearchResponse) {
SearchResponse searchResponse = (SearchResponse) response;
for (ShardSearchFailure failure : searchResponse.getShardFailures()) {
assertTrue("got unexpected reason..." + failure.reason(),
failure.reason().toLowerCase(Locale.ENGLISH).contains("rejected"));
assertThat(failure.reason().toLowerCase(Locale.ENGLISH),
anyOf(containsString("cancelled"), containsString("rejected")));
}
} else {
Exception t = (Exception) response;
Throwable unwrap = ExceptionsHelper.unwrapCause(t);
if (unwrap instanceof SearchPhaseExecutionException) {
SearchPhaseExecutionException e = (SearchPhaseExecutionException) unwrap;
for (ShardSearchFailure failure : e.shardFailures()) {
assertTrue("got unexpected reason..." + failure.reason(),
failure.reason().toLowerCase(Locale.ENGLISH).contains("rejected"));
assertThat(failure.reason().toLowerCase(Locale.ENGLISH),
anyOf(containsString("cancelled"), containsString("rejected")));
}
} else if ((unwrap instanceof EsRejectedExecutionException) == false) {
throw new AssertionError("unexpected failure", (Throwable) response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ public void testSearchProgressWithShardSort() throws Exception {
testCase((NodeClient) client(), request, sortShards, false);
}

private static void testCase(NodeClient client, SearchRequest request,
List<SearchShard> expectedShards, boolean hasFetchPhase) throws InterruptedException {
private void testCase(NodeClient client, SearchRequest request,
List<SearchShard> expectedShards, boolean hasFetchPhase) throws InterruptedException {
AtomicInteger numQueryResults = new AtomicInteger();
AtomicInteger numQueryFailures = new AtomicInteger();
AtomicInteger numFetchResults = new AtomicInteger();
Expand Down Expand Up @@ -204,7 +204,6 @@ public SearchTask createTask(long id, String type, String action, TaskId parentT
}
}, listener);
latch.await();

assertThat(shardsListener.get(), equalTo(expectedShards));
assertThat(numQueryResults.get(), equalTo(searchResponse.get().getSuccessfulShards()));
assertThat(numQueryFailures.get(), equalTo(searchResponse.get().getFailedShards()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,15 @@ public final void onShardFailure(final int shardIndex, @Nullable SearchShardTarg
protected void onShardResult(Result result, SearchShardIterator shardIt) {
assert result.getShardIndex() != -1 : "shard index is not set";
assert result.getSearchShardTarget() != null : "search shard target must not be null";
successfulOps.incrementAndGet();
results.consumeResult(result);
hasShardResponse.set(true);
if (logger.isTraceEnabled()) {
logger.trace("got first-phase result from {}", result != null ? result.getSearchShardTarget() : null);
}
results.consumeResult(result, () -> onShardResultConsumed(result, shardIt));
}

private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
successfulOps.incrementAndGet();
// clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level
// so its ok concurrency wise to miss potentially the shard failures being created because of another failure
// in the #addShardFailure, because by definition, it will happen on *another* shardIndex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ Stream<Result> getSuccessfulResults() {
return results.asList().stream();
}

void consumeResult(Result result) {
@Override
void consumeResult(Result result, Runnable next) {
assert results.get(result.getShardIndex()) == null : "shardIndex: " + result.getShardIndex() + " is already set";
results.set(result.getShardIndex(), result);
next.run();
}

boolean hasResult(int shardIndex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,12 @@ private static final class CanMatchSearchPhaseResults extends SearchPhaseResults
}

@Override
void consumeResult(CanMatchResponse result) {
consumeResult(result.getShardIndex(), result.canMatch(), result.estimatedMinAndMax());
void consumeResult(CanMatchResponse result, Runnable next) {
try {
consumeResult(result.getShardIndex(), result.canMatch(), result.estimatedMinAndMax());
} finally {
next.run();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,18 @@
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;

import java.util.function.Consumer;

/**
* This is a simple base class to simplify fan out to shards and collect their results. Each results passed to
* {@link #onResult(SearchPhaseResult)} will be set to the provided result array
* where the given index is used to set the result on the array.
*/
final class CountedCollector<R extends SearchPhaseResult> {
private final Consumer<R> resultConsumer;
private final ArraySearchPhaseResults<R> resultConsumer;
private final CountDown counter;
private final Runnable onFinish;
private final SearchPhaseContext context;

CountedCollector(Consumer<R> resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
CountedCollector(ArraySearchPhaseResults<R> resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
this.resultConsumer = resultConsumer;
this.counter = new CountDown(expectedOps);
this.onFinish = onFinish;
Expand All @@ -58,11 +56,7 @@ void countDown() {
* Sets the result to the given array index and then runs {@link #countDown()}
*/
void onResult(R result) {
try {
resultConsumer.accept(result);
} finally {
countDown();
}
resultConsumer.consumeResult(result, this::countDown);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import java.io.IOException;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;

/**
Expand All @@ -51,10 +52,11 @@ final class DfsQueryPhase extends SearchPhase {
DfsQueryPhase(AtomicArray<DfsSearchResult> dfsSearchResults,
SearchPhaseController searchPhaseController,
Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
SearchPhaseContext context) {
SearchPhaseContext context, Consumer<Exception> onPartialMergeFailure) {
super("dfs_query");
this.progressListener = context.getTask().getProgressListener();
this.queryResult = searchPhaseController.newSearchPhaseResults(progressListener, context.getRequest(), context.getNumShards());
this.queryResult = searchPhaseController.newSearchPhaseResults(context, progressListener,
context.getRequest(), context.getNumShards(), onPartialMergeFailure);
this.searchPhaseController = searchPhaseController;
this.dfsSearchResults = dfsSearchResults;
this.nextPhaseFactory = nextPhaseFactory;
Expand All @@ -68,7 +70,7 @@ public void run() throws IOException {
// to free up memory early
final List<DfsSearchResult> resultList = dfsSearchResults.asList();
final AggregatedDfs dfs = searchPhaseController.aggregateDfs(resultList);
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(queryResult::consumeResult,
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(queryResult,
resultList.size(),
() -> context.executeNextPhase(this, nextPhaseFactory.apply(queryResult)), context);
for (final DfsSearchResult dfsResult : resultList) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.transport.Transport;

import java.io.IOException;
import java.util.List;
import java.util.function.BiFunction;

Expand All @@ -45,7 +44,7 @@
* Then it reaches out to all relevant shards to fetch the topN hits.
*/
final class FetchSearchPhase extends SearchPhase {
private final AtomicArray<FetchSearchResult> fetchResults;
private final ArraySearchPhaseResults<FetchSearchResult> fetchResults;
private final SearchPhaseController searchPhaseController;
private final AtomicArray<SearchPhaseResult> queryResults;
private final BiFunction<InternalSearchResponse, String, SearchPhase> nextPhaseFactory;
Expand Down Expand Up @@ -73,7 +72,7 @@ final class FetchSearchPhase extends SearchPhase {
throw new IllegalStateException("number of shards must match the length of the query results but doesn't:"
+ context.getNumShards() + "!=" + resultConsumer.getNumShards());
}
this.fetchResults = new AtomicArray<>(resultConsumer.getNumShards());
this.fetchResults = new ArraySearchPhaseResults<>(resultConsumer.getNumShards());
this.searchPhaseController = searchPhaseController;
this.queryResults = resultConsumer.getAtomicArray();
this.nextPhaseFactory = nextPhaseFactory;
Expand Down Expand Up @@ -102,7 +101,7 @@ public void onFailure(Exception e) {
});
}

private void innerRun() throws IOException {
private void innerRun() throws Exception {
final int numShards = context.getNumShards();
final boolean isScrollSearch = context.getRequest().scroll() != null;
final List<SearchPhaseResult> phaseResults = queryResults.asList();
Expand All @@ -117,7 +116,7 @@ private void innerRun() throws IOException {
final boolean queryAndFetchOptimization = queryResults.length() == 1;
final Runnable finishPhase = ()
-> moveToNextPhase(searchPhaseController, scrollId, reducedQueryPhase, queryAndFetchOptimization ?
queryResults : fetchResults);
queryResults : fetchResults.getAtomicArray());
if (queryAndFetchOptimization) {
assert phaseResults.isEmpty() || phaseResults.get(0).fetchResult() != null : "phaseResults empty [" + phaseResults.isEmpty()
+ "], single result: " + phaseResults.get(0).fetchResult();
Expand All @@ -137,7 +136,7 @@ private void innerRun() throws IOException {
final ScoreDoc[] lastEmittedDocPerShard = isScrollSearch ?
searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, numShards)
: null;
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(r -> fetchResults.set(r.getShardIndex(), r),
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(fetchResults,
docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not
finishPhase, context);
for (int i = 0; i < docIdsToLoad.length; i++) {
Expand Down
Loading