diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 742fa60d098d1..0eecfce9e1e56 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -54,8 +54,8 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction this.searchPhaseController = searchPhaseController; SearchProgressListener progressListener = task.getProgressListener(); SearchSourceBuilder sourceBuilder = request.source(); - progressListener.notifyListShards(progressListener.searchShards(this.shardsIts), - progressListener.searchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0); + progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts), + SearchProgressListener.buildSearchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index f26b0fc80ccbe..423f930b78fbe 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -665,8 +665,8 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) { numReducePhases++; index = 1; if (hasAggs || hasTopDocs) { - progressListener.notifyPartialReduce(progressListener.searchShards(processedShards), - topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0] : null, numReducePhases); + progressListener.notifyPartialReduce(SearchProgressListener.buildSearchShards(processedShards), + topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0] : null, numReducePhases); } } final int i = index++; @@ -695,7 +695,7 @@ private synchronized List getRemainingTopDocs() { public ReducedQueryPhase reduce() { ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), getRemainingAggs(), getRemainingTopDocs(), topDocsStats, numReducePhases, false, performFinalReduce); - progressListener.notifyReduce(progressListener.searchShards(results.asList()), + progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()), reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases); return reducePhase; } @@ -751,8 +751,8 @@ ReducedQueryPhase reduce() { List resultList = results.asList(); final ReducedQueryPhase reducePhase = reducedQueryPhase(resultList, isScrollRequest, trackTotalHitsUpTo, request.isFinalReduce()); - listener.notifyReduce(listener.searchShards(resultList), reducePhase.totalHits, - reducePhase.aggregations, reducePhase.numReducePhases); + listener.notifyFinalReduce(SearchProgressListener.buildSearchShards(resultList), + reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases); return reducePhase; } }; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java index 80eda195ad7e9..997151160f96b 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java @@ -53,14 +53,14 @@ abstract class SearchProgressListener { * @param clusters The statistics for remote clusters included in the search. * @param fetchPhase true if the search needs a fetch phase, false otherwise. **/ - public void onListShards(List shards, List skippedShards, Clusters clusters, boolean fetchPhase) {} + protected void onListShards(List shards, List skippedShards, Clusters clusters, boolean fetchPhase) {} /** * Executed when a shard returns a query result. * * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards} )}. */ - public void onQueryResult(int shardIndex) {} + protected void onQueryResult(int shardIndex) {} /** * Executed when a shard reports a query failure. @@ -69,7 +69,7 @@ public void onQueryResult(int shardIndex) {} * @param shardTarget The last shard target that thrown an exception. * @param exc The cause of the failure. */ - public void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {} + protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {} /** * Executed when a partial reduce is created. The number of partial reduce can be controlled via @@ -80,7 +80,7 @@ public void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Except * @param aggs The partial result for aggregations. * @param reducePhase The version number for this reduce. */ - public void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {} + protected void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {} /** * Executed once when the final reduce is created. @@ -90,14 +90,14 @@ public void onPartialReduce(List shards, TotalHits totalHits, Inter * @param aggs The final result for aggregations. * @param reducePhase The version number for this reduce. */ - public void onReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {} + protected void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {} /** * Executed when a shard returns a fetch result. * * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}. */ - public void onFetchResult(int shardIndex) {} + protected void onFetchResult(int shardIndex) {} /** * Executed when a shard reports a fetch failure. @@ -105,7 +105,7 @@ public void onFetchResult(int shardIndex) {} * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}. * @param exc The cause of the failure. */ - public void onFetchFailure(int shardIndex, Exception exc) {} + protected void onFetchFailure(int shardIndex, Exception exc) {} final void notifyListShards(List shards, List skippedShards, Clusters clusters, boolean fetchPhase) { this.shards = shards; @@ -142,9 +142,9 @@ final void notifyPartialReduce(List shards, TotalHits totalHits, In } } - final void notifyReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + protected final void notifyFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { try { - onReduce(shards, totalHits, aggs, reducePhase); + onFinalReduce(shards, totalHits, aggs, reducePhase); } catch (Exception e) { logger.warn(() -> new ParameterizedMessage("Failed to execute progress listener on reduce"), e); } @@ -168,7 +168,7 @@ final void notifyFetchFailure(int shardIndex, Exception exc) { } } - final List searchShards(List results) { + static List buildSearchShards(List results) { return results.stream() .filter(Objects::nonNull) .map(SearchPhaseResult::getSearchShardTarget) @@ -176,14 +176,14 @@ final List searchShards(List results) .collect(Collectors.toUnmodifiableList()); } - final List searchShards(SearchShardTarget[] results) { + static List buildSearchShards(SearchShardTarget[] results) { return Arrays.stream(results) .filter(Objects::nonNull) .map(e -> new SearchShard(e.getClusterAlias(), e.getShardId())) .collect(Collectors.toUnmodifiableList()); } - final List searchShards(GroupShardsIterator its) { + static List buildSearchShards(GroupShardsIterator its) { return StreamSupport.stream(its.spliterator(), false) .map(e -> new SearchShard(e.getClusterAlias(), e.shardId())) .collect(Collectors.toUnmodifiableList()); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index 3573e3ce36e24..24345c606003a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -57,8 +57,8 @@ final class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction { +public final class SearchShard implements Comparable { @Nullable private final String clusterAlias; private final ShardId shardId; @@ -40,8 +40,7 @@ public SearchShard(@Nullable String clusterAlias, ShardId shardId) { } /** - * Return the cluster alias if the shard is on a remote cluster and null - * otherwise (local). + * Return the cluster alias if we are executing a cross cluster search request, null otherwise. */ @Nullable public String getClusterAlias() { @@ -51,7 +50,6 @@ public String getClusterAlias() { /** * Return the {@link ShardId} of this shard. */ - @Nullable public ShardId getShardId() { return shardId; } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchShardTask.java b/server/src/main/java/org/elasticsearch/action/search/SearchShardTask.java index 4719c1fda9d53..abfc876ad6000 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchShardTask.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchShardTask.java @@ -40,5 +40,4 @@ public SearchShardTask(long id, String type, String action, String description, public boolean shouldCancelChildrenOnCancellation() { return false; } - } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTask.java b/server/src/main/java/org/elasticsearch/action/search/SearchTask.java index 97247e443bb64..c5a918c06f1bb 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTask.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTask.java @@ -37,14 +37,14 @@ public SearchTask(long id, String type, String action, String description, TaskI /** * Attach a {@link SearchProgressListener} to this task. */ - public void setProgressListener(SearchProgressListener progressListener) { + public final void setProgressListener(SearchProgressListener progressListener) { this.progressListener = progressListener; } /** * Return the {@link SearchProgressListener} attached to this task. */ - public SearchProgressListener getProgressListener() { + public final SearchProgressListener getProgressListener() { return progressListener; } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index 0d4913896b848..5e65487fe5927 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -829,7 +829,7 @@ public void onPartialReduce(List shards, TotalHits totalHits, Inter } @Override - public void onReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + public void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { totalHitsListener.set(totalHits); finalAggsListener.set(aggs); numReduceListener.incrementAndGet(); diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java b/server/src/test/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java index 931fdd506ed97..3f9269afd9e78 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java @@ -178,7 +178,7 @@ public void onPartialReduce(List shards, TotalHits totalHits, Inter } @Override - public void onReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + public void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { numReduces.incrementAndGet(); } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchShardTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchShardTests.java new file mode 100644 index 0000000000000..8ee1718023e69 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/search/SearchShardTests.java @@ -0,0 +1,83 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.action.search; + +import org.elasticsearch.index.Index; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.EqualsHashCodeTestUtils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class SearchShardTests extends ESTestCase { + + public void testEqualsAndHashcode() { + String index = randomAlphaOfLengthBetween(5, 10); + SearchShard searchShard = new SearchShard(randomBoolean() ? null : randomAlphaOfLengthBetween(3, 10), + new ShardId(index, index + "-uuid", randomIntBetween(0, 1024))); + EqualsHashCodeTestUtils.checkEqualsAndHashCode(searchShard, + s -> new SearchShard(s.getClusterAlias(), s.getShardId()), + s -> { + if (randomBoolean()) { + return new SearchShard(s.getClusterAlias() == null ? randomAlphaOfLengthBetween(3, 10) : null, s.getShardId()); + } else { + String indexName = s.getShardId().getIndexName(); + int shardId = s.getShardId().getId(); + if (randomBoolean()) { + indexName += randomAlphaOfLength(5); + } else { + shardId += randomIntBetween(1, 1024); + } + return new SearchShard(s.getClusterAlias(), new ShardId(indexName, indexName + "-uuid", shardId)); + } + }); + } + + public void testCompareTo() { + List searchShards = new ArrayList<>(); + Index index0 = new Index("index0", "index0-uuid"); + Index index1 = new Index("index1", "index1-uuid"); + searchShards.add(new SearchShard(null, new ShardId(index0, 0))); + searchShards.add(new SearchShard(null, new ShardId(index1, 0))); + searchShards.add(new SearchShard(null, new ShardId(index0, 1))); + searchShards.add(new SearchShard(null, new ShardId(index1, 1))); + searchShards.add(new SearchShard(null, new ShardId(index0, 2))); + searchShards.add(new SearchShard(null, new ShardId(index1, 2))); + searchShards.add(new SearchShard("", new ShardId(index0, 0))); + searchShards.add(new SearchShard("", new ShardId(index1, 0))); + searchShards.add(new SearchShard("", new ShardId(index0, 1))); + searchShards.add(new SearchShard("", new ShardId(index1, 1))); + + searchShards.add(new SearchShard("remote0", new ShardId(index0, 0))); + searchShards.add(new SearchShard("remote0", new ShardId(index1, 0))); + searchShards.add(new SearchShard("remote0", new ShardId(index0, 1))); + searchShards.add(new SearchShard("remote0", new ShardId(index0, 2))); + searchShards.add(new SearchShard("remote1", new ShardId(index0, 0))); + searchShards.add(new SearchShard("remote1", new ShardId(index1, 0))); + searchShards.add(new SearchShard("remote1", new ShardId(index0, 1))); + searchShards.add(new SearchShard("remote1", new ShardId(index1, 1))); + + List sorted = new ArrayList<>(searchShards); + Collections.sort(sorted); + assertEquals(searchShards, sorted); + } +} diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java index 3a1b4f63f271d..cb2862a087774 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java @@ -41,7 +41,7 @@ /** * Task that tracks the progress of a currently running {@link SearchRequest}. */ -class AsyncSearchTask extends SearchTask { +final class AsyncSearchTask extends SearchTask { private final AsyncSearchId searchId; private final Client client; private final ThreadPool threadPool; @@ -111,8 +111,7 @@ AsyncSearchId getSearchId() { return searchId; } - @Override - public SearchProgressActionListener getProgressListener() { + Listener getSearchProgressActionListener() { return progressListener; } @@ -193,7 +192,7 @@ public void addCompletionListener(Consumer listener) { if (hasCompleted) { executeImmediately = true; } else { - completionListeners.put(completionId++, resp -> listener.accept(resp)); + completionListeners.put(completionId++, listener::accept); } } if (executeImmediately) { @@ -300,31 +299,31 @@ private void checkExpiration() { } } - private class Listener extends SearchProgressActionListener { + class Listener extends SearchProgressActionListener { @Override - public void onQueryResult(int shardIndex) { + protected void onQueryResult(int shardIndex) { checkExpiration(); } @Override - public void onFetchResult(int shardIndex) { + protected void onFetchResult(int shardIndex) { checkExpiration(); } @Override - public void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { + protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { // best effort to cancel expired tasks checkExpiration(); searchResponse.get().addShardFailure(shardIndex, new ShardSearchFailure(exc, shardTarget)); } @Override - public void onFetchFailure(int shardIndex, Exception exc) { + protected void onFetchFailure(int shardIndex, Exception exc) { checkExpiration(); } @Override - public void onListShards(List shards, List skipped, Clusters clusters, boolean fetchPhase) { + protected void onListShards(List shards, List skipped, Clusters clusters, boolean fetchPhase) { // best effort to cancel expired tasks checkExpiration(); searchResponse.compareAndSet(null, @@ -342,7 +341,7 @@ public void onPartialReduce(List shards, TotalHits totalHits, Inter } @Override - public void onReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + public void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { // best effort to cancel expired tasks checkExpiration(); searchResponse.get().updatePartialResponse(shards.size(), diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java index 72e93575c977a..fb66e61e0bd11 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java @@ -67,7 +67,7 @@ protected void doExecute(Task task, SubmitAsyncSearchRequest request, ActionList CancellableTask submitTask = (CancellableTask) task; final SearchRequest searchRequest = createSearchRequest(request, submitTask.getId(), request.getKeepAlive()); AsyncSearchTask searchTask = (AsyncSearchTask) taskManager.register("transport", SearchAction.INSTANCE.name(), searchRequest); - searchAction.execute(searchTask, searchRequest, searchTask.getProgressListener()); + searchAction.execute(searchTask, searchRequest, searchTask.getSearchProgressActionListener()); searchTask.addCompletionListener( new ActionListener<>() { @Override diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchActionTests.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchActionTests.java index 7ee2282b33ced..9ac847828d91e 100644 --- a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchActionTests.java +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchActionTests.java @@ -41,7 +41,6 @@ public class AsyncSearchActionTests extends AsyncSearchIntegTestCase { private String indexName; private int numShards; - private int numDocs; private int numKeywords; private Map keywordFreqs; @@ -52,7 +51,7 @@ public class AsyncSearchActionTests extends AsyncSearchIntegTestCase { public void indexDocuments() throws InterruptedException { indexName = "test-async"; numShards = randomIntBetween(internalCluster().numDataNodes(), internalCluster().numDataNodes()*10); - numDocs = randomIntBetween(numShards, numShards*10); + int numDocs = randomIntBetween(numShards, numShards*10); createIndex(indexName, Settings.builder().put("index.number_of_shards", numShards).build()); numKeywords = randomIntBetween(1, 100); keywordFreqs = new HashMap<>(); @@ -143,7 +142,7 @@ public void testTermsAggregation() throws Exception { StringTerms terms = response.getSearchResponse().getAggregations().get("terms"); assertThat(terms.getBuckets().size(), greaterThanOrEqualTo(0)); assertThat(terms.getBuckets().size(), lessThanOrEqualTo(numKeywords)); - for (InternalTerms.Bucket bucket : terms.getBuckets()) { + for (InternalTerms.Bucket bucket : terms.getBuckets()) { long count = keywordFreqs.getOrDefault(bucket.getKeyAsString(), new AtomicInteger(0)).get(); assertThat(bucket.getDocCount(), lessThanOrEqualTo(count)); } @@ -158,7 +157,7 @@ public void testTermsAggregation() throws Exception { StringTerms terms = response.getSearchResponse().getAggregations().get("terms"); assertThat(terms.getBuckets().size(), greaterThanOrEqualTo(0)); assertThat(terms.getBuckets().size(), lessThanOrEqualTo(numKeywords)); - for (InternalTerms.Bucket bucket : terms.getBuckets()) { + for (InternalTerms.Bucket bucket : terms.getBuckets()) { long count = keywordFreqs.getOrDefault(bucket.getKeyAsString(), new AtomicInteger(0)).get(); if (numFailures > 0) { assertThat(bucket.getDocCount(), lessThanOrEqualTo(count)); @@ -239,14 +238,14 @@ public void testInvalidId() throws Exception { } public void testNoIndex() throws Exception { - SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest(new String[] { "invalid-*" }); + SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest("invalid-*"); request.setWaitForCompletion(TimeValue.timeValueMillis(1)); AsyncSearchResponse response = submitAsyncSearch(request); assertNotNull(response.getSearchResponse()); assertFalse(response.isRunning()); assertThat(response.getSearchResponse().getTotalShards(), equalTo(0)); - request = new SubmitAsyncSearchRequest(new String[] { "invalid" }); + request = new SubmitAsyncSearchRequest("invalid"); request.setWaitForCompletion(TimeValue.timeValueMillis(1)); response = submitAsyncSearch(request); assertNull(response.getSearchResponse()); diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchIntegTestCase.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchIntegTestCase.java index 5c229a5de1d80..651c3eee5365d 100644 --- a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchIntegTestCase.java +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchIntegTestCase.java @@ -152,7 +152,7 @@ protected SearchResponseIterator assertBlockingIterator(String indexName, .collect( Collectors.toMap( Function.identity(), - id -> new ShardIdLatch(id, new CountDownLatch(1), failures.decrementAndGet() >= 0 ? true : false) + id -> new ShardIdLatch(id, new CountDownLatch(1), failures.decrementAndGet() >= 0) ) ); ShardIdLatch[] shardLatchArray = shardLatchMap.values().stream() @@ -174,7 +174,6 @@ protected SearchResponseIterator assertBlockingIterator(String indexName, private int lastVersion = initial.getVersion(); private int shardIndex = 0; private boolean isFirst = true; - private int shardFailures = 0; @Override public boolean hasNext() { @@ -201,8 +200,6 @@ private AsyncSearchResponse doNext() throws Exception { while (index < step && shardIndex < shardLatchArray.length) { if (shardLatchArray[shardIndex].shouldFail == false) { ++index; - } else { - ++shardFailures; } shardLatchArray[shardIndex++].countDown(); } @@ -219,13 +216,13 @@ private AsyncSearchResponse doNext() throws Exception { if (newResponse.isRunning()) { assertThat(newResponse.status(), equalTo(RestStatus.OK)); assertTrue(newResponse.isPartial()); - assertFalse(newResponse.getFailure() != null); + assertNull(newResponse.getFailure()); assertNotNull(newResponse.getSearchResponse()); assertThat(newResponse.getSearchResponse().getTotalShards(), equalTo(shardLatchArray.length)); assertThat(newResponse.getSearchResponse().getShardFailures().length, lessThanOrEqualTo(numFailures)); } else if (numFailures == shardLatchArray.length) { assertThat(newResponse.status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR)); - assertTrue(newResponse.getFailure() != null); + assertNotNull(newResponse.getFailure()); assertTrue(newResponse.isPartial()); assertNotNull(newResponse.getSearchResponse()); assertThat(newResponse.getSearchResponse().getTotalShards(), equalTo(shardLatchArray.length)); @@ -306,7 +303,7 @@ private BlockQueryBuilder() { } @Override - protected void doWriteTo(StreamOutput out) throws IOException {} + protected void doWriteTo(StreamOutput out) {} @Override protected void doXContent(XContentBuilder builder, Params params) throws IOException { diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java index 1f5f4c406db61..8c3f57883ec14 100644 --- a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java @@ -60,7 +60,6 @@ public void testWaitForInit() throws InterruptedException { skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1))); } - List threads = new ArrayList<>(); int numThreads = randomIntBetween(1, 10); CountDownLatch latch = new CountDownLatch(numThreads); for (int i = 0; i < numThreads; i++) { @@ -79,11 +78,10 @@ public void onFailure(Exception e) { } }, TimeValue.timeValueMillis(1))); - threads.add(thread); thread.start(); } assertFalse(latch.await(numThreads*2, TimeUnit.MILLISECONDS)); - task.getProgressListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false); + task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false); latch.await(); } @@ -91,18 +89,6 @@ public void testWithFailure() throws InterruptedException { AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), TimeValue.timeValueHours(1), Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)), new NoOpClient(threadPool), threadPool, null); - int numShards = randomIntBetween(0, 10); - List shards = new ArrayList<>(); - for (int i = 0; i < numShards; i++) { - shards.add(new SearchShard(null, new ShardId("0", "0", 1))); - } - List skippedShards = new ArrayList<>(); - int numSkippedShards = randomIntBetween(0, 10); - for (int i = 0; i < numSkippedShards; i++) { - skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1))); - } - - List threads = new ArrayList<>(); int numThreads = randomIntBetween(1, 10); CountDownLatch latch = new CountDownLatch(numThreads); for (int i = 0; i < numThreads; i++) { @@ -120,11 +106,10 @@ public void onFailure(Exception e) { throw new AssertionError(e); } }, TimeValue.timeValueMillis(1))); - threads.add(thread); thread.start(); } assertFalse(latch.await(numThreads*2, TimeUnit.MILLISECONDS)); - task.getProgressListener().onFailure(new Exception("boom")); + task.getSearchProgressActionListener().onFailure(new Exception("boom")); latch.await(); } @@ -144,22 +129,23 @@ public void testWaitForCompletion() throws InterruptedException { } int numShardFailures = 0; - task.getProgressListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false); + task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false); for (int i = 0; i < numShards; i++) { - task.getProgressListener().onPartialReduce(shards.subList(i, i+1), + task.getSearchProgressActionListener().onPartialReduce(shards.subList(i, i+1), new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0); assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numShardFailures, true); } - task.getProgressListener().onReduce(shards, + task.getSearchProgressActionListener().onFinalReduce(shards, new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0); assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numShardFailures, true); - task.getProgressListener().onResponse(newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards)); + ((AsyncSearchTask.Listener)task.getProgressListener()).onResponse( + newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards)); assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numShardFailures, false); threadPool.shutdownNow(); } - private SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards) { + private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards) { InternalSearchResponse response = new InternalSearchResponse(SearchHits.empty(), InternalAggregations.EMPTY, null, null, false, null, 1); return new SearchResponse(response, null, totalShards, successfulShards, skippedShards, @@ -171,7 +157,6 @@ private void assertCompletionListeners(AsyncSearchTask task, int expectedSkippedShards, int expectedShardFailures, boolean isPartial) throws InterruptedException { - List threads = new ArrayList<>(); int numThreads = randomIntBetween(1, 10); CountDownLatch latch = new CountDownLatch(numThreads); for (int i = 0; i < numThreads; i++) { @@ -190,7 +175,6 @@ public void onFailure(Exception e) { throw new AssertionError(e); } }, TimeValue.timeValueMillis(1))); - threads.add(thread); thread.start(); } latch.await();