Skip to content

Commit

Permalink
Adds a consistent shard index to ShardSearchRequest
Browse files Browse the repository at this point in the history
This change ensures that the shard index that is used to tiebreak documents with identical sort
remains consistent between two requests that target the same shards. The index is now always computed from the
natural order of the shards in the search request.
This change also adds the consistent shard index to the ShardSearchRequest. That allows the slice builder
to use this information to build more balanced slice query.

Relates elastic#56828
  • Loading branch information
jimczi committed Dec 1, 2020
1 parent 3c3a432 commit f666fa7
Show file tree
Hide file tree
Showing 22 changed files with 175 additions and 269 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.CollectionUtil;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
Expand Down Expand Up @@ -48,10 +49,9 @@

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -83,7 +83,6 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
private final ClusterState clusterState;
private final Map<String, AliasFilter> aliasFilter;
private final Map<String, Float> concreteIndexBoosts;
private final Map<String, Set<String>> indexRoutings;
private final SetOnce<AtomicArray<ShardSearchFailure>> shardFailures = new SetOnce<>();
private final Object shardFailuresMutex = new Object();
private final AtomicBoolean hasShardResponse = new AtomicBoolean(false);
Expand All @@ -94,6 +93,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten

protected final GroupShardsIterator<SearchShardIterator> toSkipShardsIts;
protected final GroupShardsIterator<SearchShardIterator> shardsIts;
private final Map<SearchShardIterator, Integer> shardItIndexMap;
private final int expectedTotalOps;
private final AtomicInteger totalOps = new AtomicInteger();
private final int maxConcurrentRequestsPerNode;
Expand All @@ -106,7 +106,6 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
AbstractSearchAsyncAction(String name, Logger logger, SearchTransportService searchTransportService,
BiFunction<String, String, Transport.Connection> nodeIdToConnection,
Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts,
Map<String, Set<String>> indexRoutings,
Executor executor, SearchRequest request,
ActionListener<SearchResponse> listener, GroupShardsIterator<SearchShardIterator> shardsIts,
SearchTimeProvider timeProvider, ClusterState clusterState,
Expand All @@ -124,6 +123,17 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
}
this.toSkipShardsIts = new GroupShardsIterator<>(toSkipIterators);
this.shardsIts = new GroupShardsIterator<>(iterators);
this.shardItIndexMap = new HashMap<>();

// we compute the shard index based on the natural order of the shards
// that participate in the search request. This means that this number is
// consistent between two requests that target the same shards.
List<SearchShardIterator> naturalOrder = new ArrayList<>(iterators);
CollectionUtil.timSort(naturalOrder);
for (int i = 0; i < naturalOrder.size(); i++) {
shardItIndexMap.put(naturalOrder.get(i), i);
}

// we need to add 1 for non active partition, since we count it in the total. This means for each shard in the iterator we sum up
// it's number of active shards but use 1 as the default if no replica of a shard is active at this point.
// on a per shards level we use shardIt.remaining() to increment the totalOps pointer but add 1 for the current shard result
Expand All @@ -143,7 +153,6 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
this.clusterState = clusterState;
this.concreteIndexBoosts = concreteIndexBoosts;
this.aliasFilter = aliasFilter;
this.indexRoutings = indexRoutings;
this.results = resultConsumer;
this.clusters = clusters;
}
Expand Down Expand Up @@ -210,10 +219,13 @@ public final void run() {
throw new SearchPhaseExecutionException(getName(), msg, null, ShardSearchFailure.EMPTY_ARRAY);
}
}
for (int index = 0; index < shardsIts.size(); index++) {
final SearchShardIterator shardRoutings = shardsIts.get(index);

for (int i = 0; i < shardsIts.size(); i++) {
final SearchShardIterator shardRoutings = shardsIts.get(i);
assert shardRoutings.skip() == false;
performPhaseOnShard(index, shardRoutings, shardRoutings.nextOrNull());
assert shardItIndexMap.containsKey(shardRoutings);
int shardIndex = shardItIndexMap.get(shardRoutings);
performPhaseOnShard(shardIndex, shardRoutings, shardRoutings.nextOrNull());
}
}
}
Expand Down Expand Up @@ -651,15 +663,12 @@ public final void onFailure(Exception e) {
}

@Override
public final ShardSearchRequest buildShardSearchRequest(SearchShardIterator shardIt) {
public final ShardSearchRequest buildShardSearchRequest(SearchShardIterator shardIt, int shardIndex) {
AliasFilter filter = aliasFilter.get(shardIt.shardId().getIndex().getUUID());
assert filter != null;
float indexBoost = concreteIndexBoosts.getOrDefault(shardIt.shardId().getIndex().getUUID(), DEFAULT_INDEX_BOOST);
String indexName = shardIt.shardId().getIndex().getName();
final String[] routings = indexRoutings.getOrDefault(indexName, Collections.emptySet())
.toArray(new String[0]);
ShardSearchRequest shardRequest = new ShardSearchRequest(shardIt.getOriginalIndices(), request, shardIt.shardId(), getNumShards(),
filter, indexBoost, timeProvider.getAbsoluteStartMillis(), shardIt.getClusterAlias(), routings,
ShardSearchRequest shardRequest = new ShardSearchRequest(shardIt.getOriginalIndices(), request, shardIt.shardId(), shardIndex,
getNumShards(), filter, indexBoost, timeProvider.getAbsoluteStartMillis(), shardIt.getClusterAlias(),
shardIt.getSearchContextId(), shardIt.getSearchContextKeepAlive());
// if we already received a search result we can inform the shard that it
// can return a null response if the request rewrites to match none rather
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.function.BiFunction;
import java.util.function.Function;
Expand All @@ -63,14 +62,13 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMa
CanMatchPreFilterSearchPhase(Logger logger, SearchTransportService searchTransportService,
BiFunction<String, String, Transport.Connection> nodeIdToConnection,
Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts,
Map<String, Set<String>> indexRoutings,
Executor executor, SearchRequest request,
ActionListener<SearchResponse> listener, GroupShardsIterator<SearchShardIterator> shardsIts,
TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState,
SearchTask task, Function<GroupShardsIterator<SearchShardIterator>, SearchPhase> phaseFactory,
SearchResponse.Clusters clusters) {
//We set max concurrent shard requests to the number of shards so no throttling happens for can_match requests
super("can_match", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings,
super("can_match", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts,
executor, request, listener, shardsIts, timeProvider, clusterState, task,
new CanMatchSearchPhaseResults(shardsIts.size()), shardsIts.size(), clusters);
this.phaseFactory = phaseFactory;
Expand All @@ -86,7 +84,7 @@ public void addReleasable(Releasable releasable) {
protected void executePhaseOnShard(SearchShardIterator shardIt, SearchShardTarget shard,
SearchActionListener<CanMatchResponse> listener) {
getSearchTransport().sendCanMatch(getConnection(shard.getClusterAlias(), shard.getNodeId()),
buildShardSearchRequest(shardIt), getTask(), listener);
buildShardSearchRequest(shardIt, listener.requestIndex), getTask(), listener);
}

@Override
Expand Down Expand Up @@ -149,7 +147,7 @@ private static Comparator<Integer> shardComparator(GroupShardsIterator<SearchSha
MinAndMax<?>[] minAndMaxes,
SortOrder order) {
final Comparator<Integer> comparator = Comparator.comparing(index -> minAndMaxes[index], MinAndMax.getComparator(order));
return comparator.thenComparing(index -> shardsIts.get(index).shardId());
return comparator.thenComparing(index -> shardsIts.get(index));
}

private static final class CanMatchSearchPhaseResults extends SearchPhaseResults<CanMatchResponse> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ private void innerRun() throws Exception {
ShardFetchSearchRequest fetchSearchRequest = createFetchRequest(queryResult.queryResult().getContextId(), i, entry,
lastEmittedDocPerShard, searchShardTarget.getOriginalIndices(), queryResult.getShardSearchRequest(),
queryResult.getRescoreDocIds());
executeFetch(i, searchShardTarget, counter, fetchSearchRequest, queryResult.queryResult(),
executeFetch(queryResult.getShardIndex(), searchShardTarget, counter, fetchSearchRequest, queryResult.queryResult(),
connection);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.function.BiFunction;

Expand All @@ -45,14 +44,14 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
SearchDfsQueryThenFetchAsyncAction(final Logger logger, final SearchTransportService searchTransportService,
final BiFunction<String, String, Transport.Connection> nodeIdToConnection,
final Map<String, AliasFilter> aliasFilter,
final Map<String, Float> concreteIndexBoosts, final Map<String, Set<String>> indexRoutings,
final Map<String, Float> concreteIndexBoosts,
final SearchPhaseController searchPhaseController, final Executor executor,
final QueryPhaseResultConsumer queryPhaseResultConsumer,
final SearchRequest request, final ActionListener<SearchResponse> listener,
final GroupShardsIterator<SearchShardIterator> shardsIts,
final TransportSearchAction.SearchTimeProvider timeProvider,
final ClusterState clusterState, final SearchTask task, SearchResponse.Clusters clusters) {
super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings,
super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts,
executor, request, listener,
shardsIts, timeProvider, clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()),
request.getMaxConcurrentShardRequests(), clusters);
Expand All @@ -68,7 +67,7 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
protected void executePhaseOnShard(final SearchShardIterator shardIt, final SearchShardTarget shard,
final SearchActionListener<DfsSearchResult> listener) {
getSearchTransport().sendExecuteDfs(getConnection(shard.getClusterAlias(), shard.getNodeId()),
buildShardSearchRequest(shardIt) , getTask(), listener);
buildShardSearchRequest(shardIt, listener.requestIndex) , getTask(), listener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,12 @@ default void sendReleaseSearchContext(ShardSearchContextId contextId,

/**
* Builds an request for the initial search phase.
*
* @param shardIt the target {@link SearchShardIterator}
* @param shardIndex the index of the shard that is used in the coordinator node to
* tiebreak results with identical sort values
*/
ShardSearchRequest buildShardSearchRequest(SearchShardIterator shardIt);
ShardSearchRequest buildShardSearchRequest(SearchShardIterator shardIt, int shardIndex);

/**
* Processes the phase transition from on phase to another. This method handles all errors that happen during the initial run execution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,21 +428,25 @@ ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> quer
throw new IllegalStateException(errorMsg);
}
validateMergeSortValueFormats(queryResults);
final QuerySearchResult firstResult = queryResults.stream().findFirst().get().queryResult();
final boolean hasSuggest = firstResult.suggest() != null;
final boolean hasProfileResults = firstResult.hasProfileResults();
final boolean hasSuggest = queryResults.stream().anyMatch(res -> res.queryResult().suggest() != null);
final boolean hasProfileResults = queryResults.stream().anyMatch(res -> res.queryResult().hasProfileResults());

// count the total (we use the query result provider here, since we might not get any hits (we scrolled past them))
final Map<String, List<Suggestion>> groupedSuggestions = hasSuggest ? new HashMap<>() : Collections.emptyMap();
final Map<String, ProfileShardResult> profileResults = hasProfileResults ? new HashMap<>(queryResults.size())
: Collections.emptyMap();
int from = 0;
int size = 0;
DocValueFormat[] sortValueFormats = null;
for (SearchPhaseResult entry : queryResults) {
QuerySearchResult result = entry.queryResult();
from = result.from();
// sorted queries can set the size to 0 if they have enough competitive hits.
size = Math.max(result.size(), size);
if (result.sortValueFormats() != null) {
sortValueFormats = result.sortValueFormats();
}

if (hasSuggest) {
assert result.suggest() != null;
for (Suggestion<? extends Suggestion.Entry<? extends Suggestion.Entry.Option>> suggestion : result.suggest()) {
Expand Down Expand Up @@ -477,7 +481,7 @@ ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> quer
final TotalHits totalHits = topDocsStats.getTotalHits();
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.getMaxScore(),
topDocsStats.timedOut, topDocsStats.terminatedEarly, reducedSuggest, aggregations, shardResults, sortedTopDocs,
firstResult.sortValueFormats(), numReducePhases, size, from, false);
sortValueFormats, numReducePhases, size, from, false);
}

private static InternalAggregations reduceAggs(InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.elasticsearch.transport.Transport;

import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.function.BiFunction;

Expand All @@ -52,14 +51,14 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
SearchQueryThenFetchAsyncAction(final Logger logger, final SearchTransportService searchTransportService,
final BiFunction<String, String, Transport.Connection> nodeIdToConnection,
final Map<String, AliasFilter> aliasFilter,
final Map<String, Float> concreteIndexBoosts, final Map<String, Set<String>> indexRoutings,
final Map<String, Float> concreteIndexBoosts,
final SearchPhaseController searchPhaseController, final Executor executor,
final QueryPhaseResultConsumer resultConsumer, final SearchRequest request,
final ActionListener<SearchResponse> listener,
final GroupShardsIterator<SearchShardIterator> shardsIts,
final TransportSearchAction.SearchTimeProvider timeProvider,
ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters) {
super("query", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings,
super("query", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts,
executor, request, listener, shardsIts, timeProvider, clusterState, task,
resultConsumer, request.getMaxConcurrentShardRequests(), clusters);
this.topDocsSize = getTopDocsSize(request);
Expand All @@ -79,7 +78,7 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
protected void executePhaseOnShard(final SearchShardIterator shardIt,
final SearchShardTarget shard,
final SearchActionListener<SearchPhaseResult> listener) {
ShardSearchRequest request = rewriteShardSearchRequest(super.buildShardSearchRequest(shardIt));
ShardSearchRequest request = rewriteShardSearchRequest(super.buildShardSearchRequest(shardIt, listener.requestIndex));
getSearchTransport().sendExecuteQuery(getConnection(shard.getClusterAlias(), shard.getNodeId()), request, getTask(), listener);
}

Expand Down
Loading

0 comments on commit f666fa7

Please sign in to comment.