Skip to content

Commit

Permalink
Fix possible NPE on search phase failure (#57952)
Browse files Browse the repository at this point in the history
When a search phase fails, we release the context of all successful shards.
Successful shards that rewrite the request to match none will not create any context
since #. This change ensures that we don't try to release a `null` context on these
successful shards.

Closes #57945
  • Loading branch information
jimczi committed Jun 11, 2020
1 parent c36df27 commit 4c6bfe3
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -559,13 +559,15 @@ public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause)
*/
private void raisePhaseFailure(SearchPhaseExecutionException exception) {
results.getSuccessfulResults().forEach((entry) -> {
try {
SearchShardTarget searchShardTarget = entry.getSearchShardTarget();
Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId());
sendReleaseSearchContext(entry.getContextId(), connection, searchShardTarget.getOriginalIndices());
} catch (Exception inner) {
inner.addSuppressed(exception);
logger.trace("failed to release context", inner);
if (entry.getContextId() != null) {
try {
SearchShardTarget searchShardTarget = entry.getSearchShardTarget();
Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId());
sendReleaseSearchContext(entry.getContextId(), connection, searchShardTarget.getOriginalIndices());
} catch (Exception inner) {
inner.addSuppressed(exception);
logger.trace("failed to release context", inner);
}
}
});
listener.onFailure(exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;

/**
Expand Down Expand Up @@ -199,7 +200,7 @@ static class ScrollFreeContextRequest extends TransportRequest {
private SearchContextId contextId;

ScrollFreeContextRequest(SearchContextId contextId) {
this.contextId = contextId;
this.contextId = Objects.requireNonNull(contextId);
}

ScrollFreeContextRequest(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.search;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.fetch.FetchSearchResult;
Expand Down Expand Up @@ -52,7 +53,9 @@ protected SearchPhaseResult(StreamInput in) throws IOException {

/**
* Returns the search context ID that is used to reference the search context on the executing node
* or <code>null</code> if no context was created.
*/
@Nullable
public SearchContextId getContextId() {
return contextId;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@

import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentSet;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;

public class SearchAsyncActionTests extends ESTestCase {
Expand Down Expand Up @@ -376,6 +377,113 @@ protected void executeNext(Runnable runnable, Thread originalThread) {
executor.shutdown();
}

public void testFanOutAndFail() throws InterruptedException {
SearchRequest request = new SearchRequest();
request.allowPartialSearchResults(true);
request.setMaxConcurrentShardRequests(randomIntBetween(1, 100));
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<Exception> failure = new AtomicReference<>();
ActionListener<SearchResponse> responseListener = ActionListener.wrap(
searchResponse -> { throw new AssertionError("unexpected response"); },
exc -> {
failure.set(exc);
latch.countDown();
});
DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT);
DiscoveryNode replicaNode = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT);

Map<DiscoveryNode, Set<SearchContextId>> nodeToContextMap = newConcurrentMap();
AtomicInteger contextIdGenerator = new AtomicInteger(0);
int numShards = randomIntBetween(2, 10);
GroupShardsIterator<SearchShardIterator> shardsIter = getShardsIter("idx",
new OriginalIndices(new String[]{"idx"}, SearchRequest.DEFAULT_INDICES_OPTIONS),
numShards, randomBoolean(), primaryNode, replicaNode);
AtomicInteger numFreedContext = new AtomicInteger();
SearchTransportService transportService = new SearchTransportService(null, null) {
@Override
public void sendFreeContext(Transport.Connection connection, SearchContextId contextId, OriginalIndices originalIndices) {
assertNotNull(contextId);
numFreedContext.incrementAndGet();
assertTrue(nodeToContextMap.containsKey(connection.getNode()));
assertTrue(nodeToContextMap.get(connection.getNode()).remove(contextId));
}
};
Map<String, Transport.Connection> lookup = new HashMap<>();
lookup.put(primaryNode.getId(), new MockConnection(primaryNode));
lookup.put(replicaNode.getId(), new MockConnection(replicaNode));
Map<String, AliasFilter> aliasFilters = Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY));
ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors()));
AbstractSearchAsyncAction<TestSearchPhaseResult> asyncAction =
new AbstractSearchAsyncAction<TestSearchPhaseResult>(
"test",
logger,
transportService,
(cluster, node) -> {
assert cluster == null : "cluster was not null: " + cluster;
return lookup.get(node); },
aliasFilters,
Collections.emptyMap(),
Collections.emptyMap(),
executor,
request,
responseListener,
shardsIter,
new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0),
ClusterState.EMPTY_STATE,
null,
new ArraySearchPhaseResults<>(shardsIter.size()),
request.getMaxConcurrentShardRequests(),
SearchResponse.Clusters.EMPTY) {
TestSearchResponse response = new TestSearchResponse();

@Override
protected void executePhaseOnShard(SearchShardIterator shardIt,
ShardRouting shard,
SearchActionListener<TestSearchPhaseResult> listener) {
assertTrue("shard: " + shard.shardId() + " has been queried twice", response.queried.add(shard.shardId()));
Transport.Connection connection = getConnection(null, shard.currentNodeId());
final TestSearchPhaseResult testSearchPhaseResult;
if (shard.shardId().id() == 0) {
testSearchPhaseResult = new TestSearchPhaseResult(null, connection.getNode());
} else {
testSearchPhaseResult = new TestSearchPhaseResult(new SearchContextId(UUIDs.randomBase64UUID(),
contextIdGenerator.incrementAndGet()), connection.getNode());
Set<SearchContextId> ids = nodeToContextMap.computeIfAbsent(connection.getNode(), (n) -> newConcurrentSet());
ids.add(testSearchPhaseResult.getContextId());
}
if (randomBoolean()) {
listener.onResponse(testSearchPhaseResult);
} else {
new Thread(() -> listener.onResponse(testSearchPhaseResult)).start();
}
}

@Override
protected SearchPhase getNextPhase(SearchPhaseResults<TestSearchPhaseResult> results,
SearchPhaseContext context) {
return new SearchPhase("test") {
@Override
public void run() {
throw new RuntimeException("boom");
}
};
}
};
asyncAction.start();
latch.await();
assertNotNull(failure.get());
assertThat(failure.get().getCause().getMessage(), containsString("boom"));
assertFalse(nodeToContextMap.isEmpty());
assertTrue(nodeToContextMap.toString(), nodeToContextMap.containsKey(primaryNode) || nodeToContextMap.containsKey(replicaNode));
assertEquals(shardsIter.size()-1, numFreedContext.get());
if (nodeToContextMap.containsKey(primaryNode)) {
assertTrue(nodeToContextMap.get(primaryNode).toString(), nodeToContextMap.get(primaryNode).isEmpty());
} else {
assertTrue(nodeToContextMap.get(replicaNode).toString(), nodeToContextMap.get(replicaNode).isEmpty());
}
executor.shutdown();
}

public void testAllowPartialResults() throws InterruptedException {
SearchRequest request = new SearchRequest();
request.allowPartialSearchResults(false);
Expand Down

0 comments on commit 4c6bfe3

Please sign in to comment.