Skip to content

Commit

Permalink
Async Search: correct shards counting (#56272)
Browse files Browse the repository at this point in the history
Async search allows users to retrieve partial results for a running search. For partial results, the number of successful shards does not include the skipped shards, while the response returned to users should.

Also, we recently had a bug where async search would miss tracking shard failures, which would have been caught if we had assertions in place that verified that whenever we get the last response, the number of failures included in it is the same as the failures that were tracked through the listener notifications.
  • Loading branch information
javanna committed May 7, 2020
1 parent 9a4fbbe commit 896422b
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, Interna

@Override
public void onResponse(SearchResponse response) {
searchResponse.get().updateFinalResponse(response.getSuccessfulShards(), response.getInternalResponse());
searchResponse.get().updateFinalResponse(response);
executeCompletionListeners();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ synchronized void updatePartialResponse(int successfulShards, SearchResponseSect
throw new IllegalStateException("received partial response out of order: "
+ newSections.getNumReducePhases() + " < " + sections.getNumReducePhases());
}
this.successfulShards = successfulShards;
//when we get partial results skipped shards are not included in the provided number of successful shards
this.successfulShards = successfulShards + skippedShards;
this.sections = newSections;
this.isPartial = true;
this.isFinalReduce = isFinalReduce;
Expand All @@ -101,12 +102,20 @@ synchronized void updatePartialResponse(int successfulShards, SearchResponseSect
* Updates the response with the final {@link SearchResponseSections} merged from #<code>successfulShards</code>
* shards.
*/
synchronized void updateFinalResponse(int successfulShards, SearchResponseSections newSections) {
synchronized void updateFinalResponse(SearchResponse searchResponse) {
failIfFrozen();
assert searchResponse.getTotalShards() == totalShards : "received number of total shards differs from the one " +
"notified through onListShards";
assert searchResponse.getSkippedShards() == skippedShards : "received number of skipped shards differs from the one " +
"notified through onListShards";
assert searchResponse.getFailedShards() == buildShardFailures().length : "number of tracked failures differs from failed shards";
// copy the response headers from the current context
this.responseHeaders = threadContext.getResponseHeaders();
this.successfulShards = successfulShards;
this.sections = newSections;
//we take successful from the final response, which overrides whatever value we set when we received the last partial results.
//This is important for cases where e.g. aggs work fine and then fetch fails on some of the shards but not all.
//The shards where fetch has failed should not be counted as successful.
this.successfulShards = searchResponse.getSuccessfulShards();
this.sections = searchResponse.getInternalResponse();
this.isPartial = false;
this.isFinalReduce = true;
this.frozen = true;
Expand All @@ -121,6 +130,8 @@ synchronized void updateWithFailure(Exception exc) {
// copy the response headers from the current context
this.responseHeaders = threadContext.getResponseHeaders();
this.isPartial = true;
//note that when search fails, we may have gotten partial results before the failure. In that case async
// search will return an error plus the last partial results that were collected.
this.failure = ElasticsearchException.guessRootCauses(exc)[0];
this.frozen = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ static SearchResponse randomSearchResponse() {
long tookInMillis = randomNonNegativeLong();
int totalShards = randomIntBetween(1, Integer.MAX_VALUE);
int successfulShards = randomIntBetween(0, totalShards);
int skippedShards = totalShards - successfulShards;
int skippedShards = randomIntBetween(0, successfulShards);
InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
return new SearchResponse(internalSearchResponse, null, totalShards,
successfulShards, skippedShards, tookInMillis, ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,26 +133,24 @@ public void testWaitForCompletion() throws InterruptedException {
for (int i = 0; i < numSkippedShards; i++) {
skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
}

int numShardFailures = 0;
int totalShards = numShards + numSkippedShards;
task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false);
for (int i = 0; i < numShards; i++) {
task.getSearchProgressActionListener().onPartialReduce(shards.subList(i, i+1),
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numShardFailures, true);
assertCompletionListeners(task, totalShards, 1 + numSkippedShards, numSkippedShards, 0, true);
}
task.getSearchProgressActionListener().onFinalReduce(shards,
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numShardFailures, true);
assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, true);
((AsyncSearchTask.Listener)task.getProgressListener()).onResponse(
newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards));
assertCompletionListeners(task, numShards+numSkippedShards,
numSkippedShards, numShardFailures, false);
newSearchResponse(totalShards, totalShards, numSkippedShards));
assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, false);
}

public void testWithFetchFailures() throws InterruptedException {
AsyncSearchTask task = createAsyncSearchTask();
int numShards = randomIntBetween(0, 10);
int numShards = randomIntBetween(2, 10);
List<SearchShard> shards = new ArrayList<>();
for (int i = 0; i < numShards; i++) {
shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
Expand All @@ -162,38 +160,72 @@ public void testWithFetchFailures() throws InterruptedException {
for (int i = 0; i < numSkippedShards; i++) {
skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
}

int totalShards = numShards + numSkippedShards;
task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false);
for (int i = 0; i < numShards; i++) {
task.getSearchProgressActionListener().onPartialReduce(shards.subList(i, i+1),
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, 0, true);
assertCompletionListeners(task, totalShards, 1 + numSkippedShards, numSkippedShards, 0, true);
}
task.getSearchProgressActionListener().onFinalReduce(shards,
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
int numFetchFailures = randomIntBetween(0, numShards);
int numFetchFailures = randomIntBetween(1, numShards - 1);
ShardSearchFailure[] shardSearchFailures = new ShardSearchFailure[numFetchFailures];
for (int i = 0; i < numFetchFailures; i++) {
IOException failure = new IOException("boum");
task.getSearchProgressActionListener().onFetchFailure(i,
new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE),
new IOException("boum"));

failure);
shardSearchFailures[i] = new ShardSearchFailure(failure);
}
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numFetchFailures, true);
assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, numFetchFailures, true);
((AsyncSearchTask.Listener)task.getProgressListener()).onResponse(
newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards));
assertCompletionListeners(task, numShards+numSkippedShards,
numSkippedShards, numFetchFailures, false);
newSearchResponse(totalShards, totalShards - numFetchFailures, numSkippedShards, shardSearchFailures));
assertCompletionListeners(task, totalShards, totalShards - numFetchFailures, numSkippedShards, numFetchFailures, false);
}

public void testFatalFailureDuringFetch() throws InterruptedException {
AsyncSearchTask task = createAsyncSearchTask();
int numShards = randomIntBetween(0, 10);
List<SearchShard> shards = new ArrayList<>();
for (int i = 0; i < numShards; i++) {
shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
}
List<SearchShard> skippedShards = new ArrayList<>();
int numSkippedShards = randomIntBetween(0, 10);
for (int i = 0; i < numSkippedShards; i++) {
skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
}
int totalShards = numShards + numSkippedShards;
task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false);
for (int i = 0; i < numShards; i++) {
task.getSearchProgressActionListener().onPartialReduce(shards.subList(0, i+1),
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
assertCompletionListeners(task, totalShards, i + 1 + numSkippedShards, numSkippedShards, 0, true);
}
task.getSearchProgressActionListener().onFinalReduce(shards,
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
for (int i = 0; i < numShards; i++) {
task.getSearchProgressActionListener().onFetchFailure(i,
new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE),
new IOException("boum"));
}
assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, numShards, true);
((AsyncSearchTask.Listener)task.getProgressListener()).onFailure(new IOException("boum"));
assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, numShards, true);
}

private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards) {
private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards,
ShardSearchFailure... shardFailures) {
InternalSearchResponse response = new InternalSearchResponse(SearchHits.empty(),
InternalAggregations.EMPTY, null, null, false, null, 1);
return new SearchResponse(response, null, totalShards, successfulShards, skippedShards,
100, ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY);
100, shardFailures, SearchResponse.Clusters.EMPTY);
}

private void assertCompletionListeners(AsyncSearchTask task,
int expectedTotalShards,
int expectedSuccessfulShards,
int expectedSkippedShards,
int expectedShardFailures,
boolean isPartial) throws InterruptedException {
Expand All @@ -204,6 +236,7 @@ private void assertCompletionListeners(AsyncSearchTask task,
@Override
public void onResponse(AsyncSearchResponse resp) {
assertThat(resp.getSearchResponse().getTotalShards(), equalTo(expectedTotalShards));
assertThat(resp.getSearchResponse().getSuccessfulShards(), equalTo(expectedSuccessfulShards));
assertThat(resp.getSearchResponse().getSkippedShards(), equalTo(expectedSkippedShards));
assertThat(resp.getSearchResponse().getFailedShards(), equalTo(expectedShardFailures));
assertThat(resp.isPartial(), equalTo(isPartial));
Expand Down

0 comments on commit 896422b

Please sign in to comment.