Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine SearchProgressListener internal API #53373

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
this.searchPhaseController = searchPhaseController;
SearchProgressListener progressListener = task.getProgressListener();
SearchSourceBuilder sourceBuilder = request.source();
progressListener.notifyListShards(progressListener.searchShards(this.shardsIts),
progressListener.searchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0);
progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts),
SearchProgressListener.buildSearchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,8 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
numReducePhases++;
index = 1;
if (hasAggs || hasTopDocs) {
progressListener.notifyPartialReduce(progressListener.searchShards(processedShards),
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0] : null, numReducePhases);
progressListener.notifyPartialReduce(SearchProgressListener.buildSearchShards(processedShards),
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0] : null, numReducePhases);
}
}
final int i = index++;
Expand Down Expand Up @@ -695,7 +695,7 @@ private synchronized List<TopDocs> getRemainingTopDocs() {
public ReducedQueryPhase reduce() {
ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(),
getRemainingAggs(), getRemainingTopDocs(), topDocsStats, numReducePhases, false, performFinalReduce);
progressListener.notifyReduce(progressListener.searchShards(results.asList()),
progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()),
reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
return reducePhase;
}
Expand Down Expand Up @@ -751,8 +751,8 @@ ReducedQueryPhase reduce() {
List<SearchPhaseResult> resultList = results.asList();
final ReducedQueryPhase reducePhase =
reducedQueryPhase(resultList, isScrollRequest, trackTotalHitsUpTo, request.isFinalReduce());
listener.notifyReduce(listener.searchShards(resultList), reducePhase.totalHits,
reducePhase.aggregations, reducePhase.numReducePhases);
listener.notifyFinalReduce(SearchProgressListener.buildSearchShards(resultList),
reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
return reducePhase;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ abstract class SearchProgressListener {
* @param clusters The statistics for remote clusters included in the search.
* @param fetchPhase <code>true</code> if the search needs a fetch phase, <code>false</code> otherwise.
**/
public void onListShards(List<SearchShard> shards, List<SearchShard> skippedShards, Clusters clusters, boolean fetchPhase) {}
protected void onListShards(List<SearchShard> shards, List<SearchShard> skippedShards, Clusters clusters, boolean fetchPhase) {}

/**
* Executed when a shard returns a query result.
*
* @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards} )}.
*/
public void onQueryResult(int shardIndex) {}
protected void onQueryResult(int shardIndex) {}

/**
* Executed when a shard reports a query failure.
Expand All @@ -69,7 +69,7 @@ public void onQueryResult(int shardIndex) {}
* @param shardTarget The last shard target that thrown an exception.
* @param exc The cause of the failure.
*/
public void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {}
protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {}

/**
* Executed when a partial reduce is created. The number of partial reduce can be controlled via
Expand All @@ -80,7 +80,7 @@ public void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Except
* @param aggs The partial result for aggregations.
* @param reducePhase The version number for this reduce.
*/
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}
protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}

/**
* Executed once when the final reduce is created.
Expand All @@ -90,22 +90,22 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
* @param aggs The final result for aggregations.
* @param reducePhase The version number for this reduce.
*/
public void onReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}
protected void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}

/**
* Executed when a shard returns a fetch result.
*
* @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}.
*/
public void onFetchResult(int shardIndex) {}
protected void onFetchResult(int shardIndex) {}

/**
* Executed when a shard reports a fetch failure.
*
* @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}.
* @param exc The cause of the failure.
*/
public void onFetchFailure(int shardIndex, Exception exc) {}
protected void onFetchFailure(int shardIndex, Exception exc) {}

final void notifyListShards(List<SearchShard> shards, List<SearchShard> skippedShards, Clusters clusters, boolean fetchPhase) {
this.shards = shards;
Expand Down Expand Up @@ -142,9 +142,9 @@ final void notifyPartialReduce(List<SearchShard> shards, TotalHits totalHits, In
}
}

final void notifyReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
protected final void notifyFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
try {
onReduce(shards, totalHits, aggs, reducePhase);
onFinalReduce(shards, totalHits, aggs, reducePhase);
} catch (Exception e) {
logger.warn(() -> new ParameterizedMessage("Failed to execute progress listener on reduce"), e);
}
Expand All @@ -168,22 +168,22 @@ final void notifyFetchFailure(int shardIndex, Exception exc) {
}
}

final List<SearchShard> searchShards(List<? extends SearchPhaseResult> results) {
static List<SearchShard> buildSearchShards(List<? extends SearchPhaseResult> results) {
return results.stream()
.filter(Objects::nonNull)
.map(SearchPhaseResult::getSearchShardTarget)
.map(e -> new SearchShard(e.getClusterAlias(), e.getShardId()))
.collect(Collectors.toUnmodifiableList());
}

final List<SearchShard> searchShards(SearchShardTarget[] results) {
static List<SearchShard> buildSearchShards(SearchShardTarget[] results) {
return Arrays.stream(results)
.filter(Objects::nonNull)
.map(e -> new SearchShard(e.getClusterAlias(), e.getShardId()))
.collect(Collectors.toUnmodifiableList());
}

final List<SearchShard> searchShards(GroupShardsIterator<SearchShardIterator> its) {
static List<SearchShard> buildSearchShards(GroupShardsIterator<SearchShardIterator> its) {
return StreamSupport.stream(its.spliterator(), false)
.map(e -> new SearchShard(e.getClusterAlias(), e.shardId()))
.collect(Collectors.toUnmodifiableList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ final class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<Se
this.progressListener = task.getProgressListener();
final SearchProgressListener progressListener = task.getProgressListener();
final SearchSourceBuilder sourceBuilder = request.source();
progressListener.notifyListShards(progressListener.searchShards(this.shardsIts),
progressListener.searchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0);
progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts),
SearchProgressListener.buildSearchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0);
}

protected void executePhaseOnShard(final SearchShardIterator shardIt, final ShardRouting shard,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
* A class that encapsulates the {@link ShardId} and the cluster alias
* of a shard used during the search action.
*/
public class SearchShard implements Comparable<SearchShard> {
public final class SearchShard implements Comparable<SearchShard> {
@Nullable
private final String clusterAlias;
private final ShardId shardId;
Expand All @@ -40,8 +40,7 @@ public SearchShard(@Nullable String clusterAlias, ShardId shardId) {
}

/**
* Return the cluster alias if the shard is on a remote cluster and <code>null</code>
* otherwise (local).
* Return the cluster alias if we are executing a cross cluster search request, <code>null</code> otherwise.
*/
@Nullable
public String getClusterAlias() {
Expand All @@ -51,7 +50,6 @@ public String getClusterAlias() {
/**
* Return the {@link ShardId} of this shard.
*/
@Nullable
public ShardId getShardId() {
return shardId;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
* Task storing information about a currently running search shard request.
* See {@link ShardSearchRequest}, {@link ShardFetchSearchRequest}, ...
*/
public class SearchShardTask extends CancellableTask {
public final class SearchShardTask extends CancellableTask {

public SearchShardTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
super(id, type, action, description, parentTaskId, headers);
Expand All @@ -40,5 +40,4 @@ public SearchShardTask(long id, String type, String action, String description,
public boolean shouldCancelChildrenOnCancellation() {
return false;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ public SearchTask(long id, String type, String action, String description, TaskI
/**
* Attach a {@link SearchProgressListener} to this task.
*/
public void setProgressListener(SearchProgressListener progressListener) {
public final void setProgressListener(SearchProgressListener progressListener) {
this.progressListener = progressListener;
}

/**
* Return the {@link SearchProgressListener} attached to this task.
*/
public SearchProgressListener getProgressListener() {
public final SearchProgressListener getProgressListener() {
return progressListener;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
}

@Override
public void onReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
totalHitsListener.set(totalHits);
finalAggsListener.set(aggs);
numReduceListener.incrementAndGet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
}

@Override
public void onReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
numReduces.incrementAndGet();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.action.search;

import org.elasticsearch.index.Index;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.EqualsHashCodeTestUtils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class SearchShardTests extends ESTestCase {

public void testEqualsAndHashcode() {
String index = randomAlphaOfLengthBetween(5, 10);
SearchShard searchShard = new SearchShard(randomBoolean() ? null : randomAlphaOfLengthBetween(3, 10),
new ShardId(index, index + "-uuid", randomIntBetween(0, 1024)));
EqualsHashCodeTestUtils.checkEqualsAndHashCode(searchShard,
s -> new SearchShard(s.getClusterAlias(), s.getShardId()),
s -> {
if (randomBoolean()) {
return new SearchShard(s.getClusterAlias() == null ? randomAlphaOfLengthBetween(3, 10) : null, s.getShardId());
} else {
String indexName = s.getShardId().getIndexName();
int shardId = s.getShardId().getId();
if (randomBoolean()) {
indexName += randomAlphaOfLength(5);
} else {
shardId += randomIntBetween(1, 1024);
}
return new SearchShard(s.getClusterAlias(), new ShardId(indexName, indexName + "-uuid", shardId));
}
});
}

public void testCompareTo() {
List<SearchShard> searchShards = new ArrayList<>();
Index index0 = new Index("index0", "index0-uuid");
Index index1 = new Index("index1", "index1-uuid");
searchShards.add(new SearchShard(null, new ShardId(index0, 0)));
searchShards.add(new SearchShard(null, new ShardId(index1, 0)));
searchShards.add(new SearchShard(null, new ShardId(index0, 1)));
searchShards.add(new SearchShard(null, new ShardId(index1, 1)));
searchShards.add(new SearchShard(null, new ShardId(index0, 2)));
searchShards.add(new SearchShard(null, new ShardId(index1, 2)));
searchShards.add(new SearchShard("", new ShardId(index0, 0)));
searchShards.add(new SearchShard("", new ShardId(index1, 0)));
searchShards.add(new SearchShard("", new ShardId(index0, 1)));
searchShards.add(new SearchShard("", new ShardId(index1, 1)));

searchShards.add(new SearchShard("remote0", new ShardId(index0, 0)));
searchShards.add(new SearchShard("remote0", new ShardId(index1, 0)));
searchShards.add(new SearchShard("remote0", new ShardId(index0, 1)));
searchShards.add(new SearchShard("remote0", new ShardId(index0, 2)));
searchShards.add(new SearchShard("remote1", new ShardId(index0, 0)));
searchShards.add(new SearchShard("remote1", new ShardId(index1, 0)));
searchShards.add(new SearchShard("remote1", new ShardId(index0, 1)));
searchShards.add(new SearchShard("remote1", new ShardId(index1, 1)));

List<SearchShard> sorted = new ArrayList<>(searchShards);
Collections.sort(sorted);
assertEquals(searchShards, sorted);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
/**
* Task that tracks the progress of a currently running {@link SearchRequest}.
*/
class AsyncSearchTask extends SearchTask {
final class AsyncSearchTask extends SearchTask {
private final AsyncSearchId searchId;
private final Client client;
private final ThreadPool threadPool;
Expand Down Expand Up @@ -111,8 +111,7 @@ AsyncSearchId getSearchId() {
return searchId;
}

@Override
public SearchProgressActionListener getProgressListener() {
Listener getSearchProgressActionListener() {
return progressListener;
}

Expand Down Expand Up @@ -193,7 +192,7 @@ public void addCompletionListener(Consumer<AsyncSearchResponse> listener) {
if (hasCompleted) {
executeImmediately = true;
} else {
completionListeners.put(completionId++, resp -> listener.accept(resp));
completionListeners.put(completionId++, listener::accept);
}
}
if (executeImmediately) {
Expand Down Expand Up @@ -300,31 +299,31 @@ private void checkExpiration() {
}
}

private class Listener extends SearchProgressActionListener {
class Listener extends SearchProgressActionListener {
@Override
public void onQueryResult(int shardIndex) {
protected void onQueryResult(int shardIndex) {
checkExpiration();
}

@Override
public void onFetchResult(int shardIndex) {
protected void onFetchResult(int shardIndex) {
checkExpiration();
}

@Override
public void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
// best effort to cancel expired tasks
checkExpiration();
searchResponse.get().addShardFailure(shardIndex, new ShardSearchFailure(exc, shardTarget));
}

@Override
public void onFetchFailure(int shardIndex, Exception exc) {
protected void onFetchFailure(int shardIndex, Exception exc) {
checkExpiration();
}

@Override
public void onListShards(List<SearchShard> shards, List<SearchShard> skipped, Clusters clusters, boolean fetchPhase) {
protected void onListShards(List<SearchShard> shards, List<SearchShard> skipped, Clusters clusters, boolean fetchPhase) {
// best effort to cancel expired tasks
checkExpiration();
searchResponse.compareAndSet(null,
Expand All @@ -342,7 +341,7 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
}

@Override
public void onReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
// best effort to cancel expired tasks
checkExpiration();
searchResponse.get().updatePartialResponse(shards.size(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ protected void doExecute(Task task, SubmitAsyncSearchRequest request, ActionList
CancellableTask submitTask = (CancellableTask) task;
final SearchRequest searchRequest = createSearchRequest(request, submitTask.getId(), request.getKeepAlive());
AsyncSearchTask searchTask = (AsyncSearchTask) taskManager.register("transport", SearchAction.INSTANCE.name(), searchRequest);
searchAction.execute(searchTask, searchRequest, searchTask.getProgressListener());
searchAction.execute(searchTask, searchRequest, searchTask.getSearchProgressActionListener());
searchTask.addCompletionListener(
new ActionListener<>() {
@Override
Expand Down
Loading