Skip to content

Commit

Permalink
Look up connection using the right cluster alias when releasing conte…
Browse files Browse the repository at this point in the history
…xts (elastic#38570)

Whenever phase failure is raised in AbstractSearchAsyncAction, we go and
release search contexts of shards that successfully returned their
results, prior to notifying the listener of the failure. In case we are
executing a CCS request, it's important to look-up the connection to
send the release context request to.

This commit makes sure that the lookup takes the cluster alias into
account. We used to use `null` at all times instead which is not correct
and was not caught as any exception is caught without re-throwing it.
  • Loading branch information
javanna authored and dimitris-athanasiou committed Feb 12, 2019
1 parent 0d7b0c4 commit 69769da
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 14 deletions.
Expand Up @@ -228,7 +228,7 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) {
results.getSuccessfulResults().forEach((entry) -> {
try {
SearchShardTarget searchShardTarget = entry.getSearchShardTarget();
Transport.Connection connection = getConnection(null, searchShardTarget.getNodeId());
Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId());
sendReleaseSearchContext(entry.getRequestId(), connection, searchShardTarget.getOriginalIndices());
} catch (Exception inner) {
inner.addSuppressed(exception);
Expand Down Expand Up @@ -281,14 +281,12 @@ public final SearchRequest getRequest() {

@Override
public final SearchResponse buildSearchResponse(InternalSearchResponse internalSearchResponse, String scrollId) {

ShardSearchFailure[] failures = buildShardFailures();
Boolean allowPartialResults = request.allowPartialSearchResults();
assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults";
if (allowPartialResults == false && failures.length > 0){
raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures));
}

return new SearchResponse(internalSearchResponse, scrollId, getNumShards(), successfulOps.get(),
skippedOps.get(), buildTookInMillis(), failures, clusters);
}
Expand Down
Expand Up @@ -19,32 +19,49 @@

package org.elasticsearch.action.search;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.search.internal.ShardSearchTransportRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.Transport;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;

public class AbstractSearchAsyncActionTests extends ESTestCase {

private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
final boolean controlled,
final AtomicLong expected) {
private final List<Tuple<String, String>> resolvedNodes = new ArrayList<>();
private final Set<Long> releasedContexts = new CopyOnWriteArraySet<>();

private AbstractSearchAsyncAction<SearchPhaseResult> createAction(SearchRequest request,
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> results,
ActionListener<SearchResponse> listener,
final boolean controlled,
final AtomicLong expected) {
final Runnable runnable;
final TransportSearchAction.SearchTimeProvider timeProvider;
if (controlled) {
Expand All @@ -61,18 +78,20 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
System::nanoTime);
}

final SearchRequest request = new SearchRequest();
request.allowPartialSearchResults(true);
request.preference("_shards:1,3");
return new AbstractSearchAsyncAction<SearchPhaseResult>("test", null, null, null,
BiFunction<String, String, Transport.Connection> nodeIdToConnection = (cluster, node) -> {
resolvedNodes.add(Tuple.tuple(cluster, node));
return null;
};

return new AbstractSearchAsyncAction<SearchPhaseResult>("test", null, null, nodeIdToConnection,
Collections.singletonMap("foo", new AliasFilter(new MatchAllQueryBuilder())), Collections.singletonMap("foo", 2.0f),
Collections.singletonMap("name", Sets.newHashSet("bar", "baz")),null, request, null,
Collections.singletonMap("name", Sets.newHashSet("bar", "baz")), null, request, listener,
new GroupShardsIterator<>(
Collections.singletonList(
new SearchShardIterator(null, null, Collections.emptyList(), null)
)
), timeProvider, 0, null,
new InitialSearchPhase.ArraySearchPhaseResults<>(10), request.getMaxConcurrentShardRequests(),
results, request.getMaxConcurrentShardRequests(),
SearchResponse.Clusters.EMPTY) {
@Override
protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, final SearchPhaseContext context) {
Expand All @@ -89,6 +108,11 @@ long buildTookInMillis() {
runnable.run();
return super.buildTookInMillis();
}

@Override
public void sendReleaseSearchContext(long contextId, Transport.Connection connection, OriginalIndices originalIndices) {
releasedContexts.add(contextId);
}
};
}

Expand All @@ -102,7 +126,8 @@ public void testTookWithRealClock() {

private void runTestTook(final boolean controlled) {
final AtomicLong expected = new AtomicLong();
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(controlled, expected);
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(new SearchRequest(),
new InitialSearchPhase.ArraySearchPhaseResults<>(10), null, controlled, expected);
final long actual = action.buildTookInMillis();
if (controlled) {
// with a controlled clock, we can assert the exact took time
Expand All @@ -114,8 +139,10 @@ private void runTestTook(final boolean controlled) {
}

public void testBuildShardSearchTransportRequest() {
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(randomBoolean()).preference("_shards:1,3");
final AtomicLong expected = new AtomicLong();
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(false, expected);
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(searchRequest,
new InitialSearchPhase.ArraySearchPhaseResults<>(10), null, false, expected);
String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10);
SearchShardIterator iterator = new SearchShardIterator(clusterAlias, new ShardId(new Index("name", "foo"), 1),
Collections.emptyList(), new OriginalIndices(new String[] {"name", "name1"}, IndicesOptions.strictExpand()));
Expand All @@ -129,4 +156,114 @@ public void testBuildShardSearchTransportRequest() {
assertEquals("_shards:1,3", shardSearchTransportRequest.preference());
assertEquals(clusterAlias, shardSearchTransportRequest.getClusterAlias());
}

public void testBuildSearchResponse() {
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(randomBoolean());
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(searchRequest,
new InitialSearchPhase.ArraySearchPhaseResults<>(10), null, false, new AtomicLong());
String scrollId = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10);
InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, scrollId);
assertEquals(scrollId, searchResponse.getScrollId());
assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations());
assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest());
assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile());
assertSame(searchResponse.getHits(), internalSearchResponse.hits());
}

public void testBuildSearchResponseAllowPartialFailures() {
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(searchRequest,
new InitialSearchPhase.ArraySearchPhaseResults<>(10), null, false, new AtomicLong());
action.onShardFailure(0, new SearchShardTarget("node", new ShardId("index", "index-uuid", 0), null, OriginalIndices.NONE),
new IllegalArgumentException());
String scrollId = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10);
InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, scrollId);
assertEquals(scrollId, searchResponse.getScrollId());
assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations());
assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest());
assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile());
assertSame(searchResponse.getHits(), internalSearchResponse.hits());
}

public void testBuildSearchResponseDisallowPartialFailures() {
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false);
AtomicReference<Exception> exception = new AtomicReference<>();
ActionListener<SearchResponse> listener = ActionListener.wrap(response -> fail("onResponse should not be called"), exception::set);
Set<Long> requestIds = new HashSet<>();
List<Tuple<String, String>> nodeLookups = new ArrayList<>();
int numFailures = randomIntBetween(1, 5);
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> phaseResults = phaseResults(requestIds, nodeLookups, numFailures);
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong());
for (int i = 0; i < numFailures; i++) {
ShardId failureShardId = new ShardId("index", "index-uuid", i);
String failureClusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10);
String failureNodeId = randomAlphaOfLengthBetween(5, 10);
action.onShardFailure(i, new SearchShardTarget(failureNodeId, failureShardId, failureClusterAlias, OriginalIndices.NONE),
new IllegalArgumentException());
}
action.buildSearchResponse(InternalSearchResponse.empty(), randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10));
assertThat(exception.get(), instanceOf(SearchPhaseExecutionException.class));
SearchPhaseExecutionException searchPhaseExecutionException = (SearchPhaseExecutionException)exception.get();
assertEquals(0, searchPhaseExecutionException.getSuppressed().length);
assertEquals(numFailures, searchPhaseExecutionException.shardFailures().length);
for (ShardSearchFailure shardSearchFailure : searchPhaseExecutionException.shardFailures()) {
assertThat(shardSearchFailure.getCause(), instanceOf(IllegalArgumentException.class));
}
assertEquals(nodeLookups, resolvedNodes);
assertEquals(requestIds, releasedContexts);
}

public void testOnPhaseFailure() {
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false);
AtomicReference<Exception> exception = new AtomicReference<>();
ActionListener<SearchResponse> listener = ActionListener.wrap(response -> fail("onResponse should not be called"), exception::set);
Set<Long> requestIds = new HashSet<>();
List<Tuple<String, String>> nodeLookups = new ArrayList<>();
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> phaseResults = phaseResults(requestIds, nodeLookups, 0);
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong());
action.onPhaseFailure(new SearchPhase("test") {
@Override
public void run() {

}
}, "message", null);
assertThat(exception.get(), instanceOf(SearchPhaseExecutionException.class));
SearchPhaseExecutionException searchPhaseExecutionException = (SearchPhaseExecutionException)exception.get();
assertEquals("message", searchPhaseExecutionException.getMessage());
assertEquals("test", searchPhaseExecutionException.getPhaseName());
assertEquals(0, searchPhaseExecutionException.shardFailures().length);
assertEquals(0, searchPhaseExecutionException.getSuppressed().length);
assertEquals(nodeLookups, resolvedNodes);
assertEquals(requestIds, releasedContexts);
}

private static InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> phaseResults(Set<Long> requestIds,
List<Tuple<String, String>> nodeLookups,
int numFailures) {
int numResults = randomIntBetween(1, 10);
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> phaseResults =
new InitialSearchPhase.ArraySearchPhaseResults<>(numResults + numFailures);

for (int i = 0; i < numResults; i++) {
long requestId = randomLong();
requestIds.add(requestId);
SearchPhaseResult phaseResult = new PhaseResult(requestId);
String resultClusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10);
String resultNodeId = randomAlphaOfLengthBetween(5, 10);
ShardId resultShardId = new ShardId("index", "index-uuid", i);
nodeLookups.add(Tuple.tuple(resultClusterAlias, resultNodeId));
phaseResult.setSearchShardTarget(new SearchShardTarget(resultNodeId, resultShardId, resultClusterAlias, OriginalIndices.NONE));
phaseResult.setShardIndex(i);
phaseResults.consumeResult(phaseResult);
}
return phaseResults;
}

private static final class PhaseResult extends SearchPhaseResult {
PhaseResult(long requestId) {
this.requestId = requestId;
}
}
}

0 comments on commit 69769da

Please sign in to comment.