diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java new file mode 100644 index 0000000000000..1b2e5672cfed4 --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java @@ -0,0 +1,230 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.benchmark.search.aggregations; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.search.QueryPhaseResultConsumer; +import org.elasticsearch.action.search.SearchPhaseController; +import org.elasticsearch.action.search.SearchProgressListener; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.BucketOrder; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; +import org.elasticsearch.search.aggregations.MultiBucketConsumerService; +import org.elasticsearch.search.aggregations.bucket.terms.StringTerms; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.query.QuerySearchResult; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.AbstractList; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static java.util.Collections.emptyList; + +@Warmup(iterations = 5) +@Measurement(iterations = 7) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Thread) +@Fork(value = 1) +public class TermsReduceBenchmark { + private final SearchModule searchModule = new SearchModule(Settings.EMPTY, false, emptyList()); + private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(searchModule.getNamedWriteables()); + private final SearchPhaseController controller = new SearchPhaseController( + namedWriteableRegistry, + req -> new InternalAggregation.ReduceContextBuilder() { + @Override + public InternalAggregation.ReduceContext forPartialReduction() { + return InternalAggregation.ReduceContext.forPartialReduction(null, null, () -> PipelineAggregator.PipelineTree.EMPTY); + } + + @Override + public InternalAggregation.ReduceContext forFinalReduction() { + final MultiBucketConsumerService.MultiBucketConsumer bucketConsumer = new MultiBucketConsumerService.MultiBucketConsumer( + Integer.MAX_VALUE, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ); + return InternalAggregation.ReduceContext.forFinalReduction( + null, + null, + bucketConsumer, + PipelineAggregator.PipelineTree.EMPTY + ); + } + } + ); + + @State(Scope.Benchmark) + public static class TermsList extends AbstractList { + @Param({ "1600172297" }) + long seed; + + @Param({ "64", "128", "512" }) + int numShards; + + @Param({ "100" }) + int topNSize; + + @Param({ "1", "10", "100" }) + int cardinalityFactor; + + List aggsList; + + @Setup + public void setup() { + this.aggsList = new ArrayList<>(); + Random rand = new Random(seed); + int cardinality = cardinalityFactor * topNSize; + BytesRef[] dict = new BytesRef[cardinality]; + for (int i = 0; i < dict.length; i++) { + dict[i] = new BytesRef(Long.toString(rand.nextLong())); + } + for (int i = 0; i < numShards; i++) { + aggsList.add(InternalAggregations.from(Collections.singletonList(newTerms(rand, dict, true)))); + } + } + + private StringTerms newTerms(Random rand, BytesRef[] dict, boolean withNested) { + Set randomTerms = new HashSet<>(); + for (int i = 0; i < topNSize; i++) { + randomTerms.add(dict[rand.nextInt(dict.length)]); + } + List buckets = new ArrayList<>(); + for (BytesRef term : randomTerms) { + InternalAggregations subAggs; + if (withNested) { + subAggs = InternalAggregations.from(Collections.singletonList(newTerms(rand, dict, false))); + } else { + subAggs = InternalAggregations.EMPTY; + } + buckets.add(new StringTerms.Bucket(term, rand.nextInt(10000), subAggs, true, 0L, DocValueFormat.RAW)); + } + + Collections.sort(buckets, (a, b) -> a.compareKey(b)); + return new StringTerms( + "terms", + BucketOrder.key(true), + BucketOrder.count(false), + topNSize, + 1, + Collections.emptyMap(), + DocValueFormat.RAW, + numShards, + true, + 0, + buckets, + 0 + ); + } + + @Override + public InternalAggregations get(int index) { + return aggsList.get(index); + } + + @Override + public int size() { + return aggsList.size(); + } + } + + @Param({ "32", "512" }) + private int bufferSize; + + @Benchmark + public SearchPhaseController.ReducedQueryPhase reduceAggs(TermsList candidateList) throws Exception { + List shards = new ArrayList<>(); + for (int i = 0; i < candidateList.size(); i++) { + QuerySearchResult result = new QuerySearchResult(); + result.setShardIndex(i); + result.from(0); + result.size(0); + result.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1000, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]), + Float.NaN + ), + new DocValueFormat[] { DocValueFormat.RAW } + ); + result.aggregations(candidateList.get(i)); + result.setSearchShardTarget( + new SearchShardTarget("node", new ShardId(new Index("index", "index"), i), null, OriginalIndices.NONE) + ); + shards.add(result); + } + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder().size(0).aggregation(AggregationBuilders.terms("test"))); + request.setBatchedReduceSize(bufferSize); + ExecutorService executor = Executors.newFixedThreadPool(1); + QueryPhaseResultConsumer consumer = new QueryPhaseResultConsumer( + request, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + controller, + SearchProgressListener.NOOP, + namedWriteableRegistry, + shards.size(), + exc -> {} + ); + CountDownLatch latch = new CountDownLatch(shards.size()); + for (int i = 0; i < shards.size(); i++) { + consumer.consumeResult(shards.get(i), () -> latch.countDown()); + } + latch.await(); + SearchPhaseController.ReducedQueryPhase phase = consumer.reduce(); + executor.shutdownNow(); + return phase; + } +} diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java index de4d4e66f884a..5fe61b9848a45 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java @@ -23,7 +23,6 @@ import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse; import org.elasticsearch.client.Client; import org.elasticsearch.client.node.NodeClient; -import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.AggregationBuilders; @@ -174,8 +173,7 @@ public void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Except } @Override - public void onPartialReduce(List shards, TotalHits totalHits, - DelayableWriteable.Serialized aggs, int reducePhase) { + public void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { numReduces.incrementAndGet(); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java index 8b707092141e0..e267903b42e97 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java @@ -19,22 +19,251 @@ package org.elasticsearch.action.search; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.ScoreMode; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.index.query.RangeQueryBuilder; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.indices.IndicesService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorBase; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.CardinalityUpperBound; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.aggregations.bucket.terms.LongTerms; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.InternalMax; +import org.elasticsearch.search.aggregations.support.ValueType; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.FetchSubPhase; +import org.elasticsearch.search.fetch.FetchSubPhaseProcessor; +import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.test.ESIntegTestCase; +import java.io.IOException; +import java.util.Collection; import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; public class TransportSearchIT extends ESIntegTestCase { + public static class TestPlugin extends Plugin implements SearchPlugin { + @Override + public List getAggregations() { + return Collections.singletonList( + new AggregationSpec(TestAggregationBuilder.NAME, TestAggregationBuilder::new, TestAggregationBuilder.PARSER) + .addResultReader(InternalMax::new) + ); + } + + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + /** + * Set up a fetch sub phase that throws an exception on indices whose name that start with "boom". + */ + return Collections.singletonList(fetchContext -> new FetchSubPhaseProcessor() { + @Override + public void setNextReader(LeafReaderContext readerContext) { + } + + @Override + public void process(FetchSubPhase.HitContext hitContext) { + if (fetchContext.getIndexName().startsWith("boom")) { + throw new RuntimeException("boom"); + } + } + }); + } + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put("indices.breaker.request.type", "memory") + .build(); + } + + @Override + protected Collection> nodePlugins() { + return Collections.singletonList(TestPlugin.class); + } + + public void testLocalClusterAlias() { + long nowInMillis = randomLongBetween(0, Long.MAX_VALUE); + IndexRequest indexRequest = new IndexRequest("test"); + indexRequest.id("1"); + indexRequest.source("field", "value"); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client().index(indexRequest).actionGet(); + assertEquals(RestStatus.CREATED, indexResponse.status()); + + { + SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY, + "local", nowInMillis, randomBoolean()); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(1, searchResponse.getHits().getTotalHits().value); + SearchHit[] hits = searchResponse.getHits().getHits(); + assertEquals(1, hits.length); + SearchHit hit = hits[0]; + assertEquals("local", hit.getClusterAlias()); + assertEquals("test", hit.getIndex()); + assertEquals("1", hit.getId()); + } + { + SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY, + "", nowInMillis, randomBoolean()); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(1, searchResponse.getHits().getTotalHits().value); + SearchHit[] hits = searchResponse.getHits().getHits(); + assertEquals(1, hits.length); + SearchHit hit = hits[0]; + assertEquals("", hit.getClusterAlias()); + assertEquals("test", hit.getIndex()); + assertEquals("1", hit.getId()); + } + } + + public void testAbsoluteStartMillis() { + { + IndexRequest indexRequest = new IndexRequest("test-1970.01.01"); + indexRequest.id("1"); + indexRequest.source("date", "1970-01-01"); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client().index(indexRequest).actionGet(); + assertEquals(RestStatus.CREATED, indexResponse.status()); + } + { + IndexRequest indexRequest = new IndexRequest("test-1982.01.01"); + indexRequest.id("1"); + indexRequest.source("date", "1982-01-01"); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client().index(indexRequest).actionGet(); + assertEquals(RestStatus.CREATED, indexResponse.status()); + } + { + SearchRequest searchRequest = new SearchRequest(); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + } + { + SearchRequest searchRequest = new SearchRequest(""); + searchRequest.indicesOptions(IndicesOptions.fromOptions(true, true, true, true)); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(0, searchResponse.getTotalShards()); + } + { + SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), + Strings.EMPTY_ARRAY, "", 0, randomBoolean()); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + } + { + SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), + Strings.EMPTY_ARRAY, "", 0, randomBoolean()); + searchRequest.indices(""); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(1, searchResponse.getHits().getTotalHits().value); + assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex()); + } + { + SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), + Strings.EMPTY_ARRAY, "", 0, randomBoolean()); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + RangeQueryBuilder rangeQuery = new RangeQueryBuilder("date"); + rangeQuery.gte("1970-01-01"); + rangeQuery.lt("1982-01-01"); + sourceBuilder.query(rangeQuery); + searchRequest.source(sourceBuilder); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(1, searchResponse.getHits().getTotalHits().value); + assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex()); + } + } + + public void testFinalReduce() { + long nowInMillis = randomLongBetween(0, Long.MAX_VALUE); + { + IndexRequest indexRequest = new IndexRequest("test"); + indexRequest.id("1"); + indexRequest.source("price", 10); + IndexResponse indexResponse = client().index(indexRequest).actionGet(); + assertEquals(RestStatus.CREATED, indexResponse.status()); + } + { + IndexRequest indexRequest = new IndexRequest("test"); + indexRequest.id("2"); + indexRequest.source("price", 100); + IndexResponse indexResponse = client().index(indexRequest).actionGet(); + assertEquals(RestStatus.CREATED, indexResponse.status()); + } + client().admin().indices().prepareRefresh("test").get(); + + SearchRequest originalRequest = new SearchRequest(); + SearchSourceBuilder source = new SearchSourceBuilder(); + source.size(0); + originalRequest.source(source); + TermsAggregationBuilder terms = new TermsAggregationBuilder("terms").userValueTypeHint(ValueType.NUMERIC); + terms.field("price"); + terms.size(1); + source.aggregation(terms); + + { + SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(originalRequest, + Strings.EMPTY_ARRAY, "remote", nowInMillis, true); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + Aggregations aggregations = searchResponse.getAggregations(); + LongTerms longTerms = aggregations.get("terms"); + assertEquals(1, longTerms.getBuckets().size()); + } + { + SearchRequest searchRequest = SearchRequest.subSearchRequest(originalRequest, + Strings.EMPTY_ARRAY, "remote", nowInMillis, false); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + Aggregations aggregations = searchResponse.getAggregations(); + LongTerms longTerms = aggregations.get("terms"); + assertEquals(2, longTerms.getBuckets().size()); + } + } public void testShardCountLimit() throws Exception { try { @@ -103,4 +332,276 @@ public void testSearchIdle() throws Exception { assertThat(resp.getHits().getTotalHits().value, equalTo(2L)); }); } + + public void testCircuitBreakerReduceFail() throws Exception { + int numShards = randomIntBetween(1, 10); + indexSomeDocs("test", numShards, numShards*3); + + { + final AtomicArray responses = new AtomicArray<>(10); + final CountDownLatch latch = new CountDownLatch(10); + for (int i = 0; i < 10; i++) { + int batchReduceSize = randomIntBetween(2, Math.max(numShards + 1, 3)); + SearchRequest request = client().prepareSearch("test") + .addAggregation(new TestAggregationBuilder("test")) + .setBatchedReduceSize(batchReduceSize) + .request(); + final int index = i; + client().search(request, new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + responses.set(index, true); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + responses.set(index, false); + latch.countDown(); + } + }); + } + latch.await(); + assertThat(responses.asList().size(), equalTo(10)); + for (boolean resp : responses.asList()) { + assertTrue(resp); + } + assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L))); + } + + try { + Settings settings = Settings.builder() + .put("indices.breaker.request.limit", "1b") + .build(); + assertAcked(client().admin().cluster().prepareUpdateSettings().setTransientSettings(settings)); + final Client client = client(); + assertBusy(() -> { + SearchPhaseExecutionException exc = expectThrows(SearchPhaseExecutionException.class, () -> client.prepareSearch("test") + .addAggregation(new TestAggregationBuilder("test")) + .get()); + assertThat(ExceptionsHelper.unwrapCause(exc).getCause().getMessage(), containsString("")); + }); + + final AtomicArray exceptions = new AtomicArray<>(10); + final CountDownLatch latch = new CountDownLatch(10); + for (int i = 0; i < 10; i++) { + int batchReduceSize = randomIntBetween(2, Math.max(numShards + 1, 3)); + SearchRequest request = client().prepareSearch("test") + .addAggregation(new TestAggregationBuilder("test")) + .setBatchedReduceSize(batchReduceSize) + .request(); + final int index = i; + client().search(request, new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + latch.countDown(); + } + + @Override + public void onFailure(Exception exc) { + exceptions.set(index, exc); + latch.countDown(); + } + }); + } + latch.await(); + assertThat(exceptions.asList().size(), equalTo(10)); + for (Exception exc : exceptions.asList()) { + assertThat(ExceptionsHelper.unwrapCause(exc).getCause().getMessage(), containsString("")); + } + assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L))); + } finally { + Settings settings = Settings.builder() + .putNull("indices.breaker.request.limit") + .build(); + assertAcked(client().admin().cluster().prepareUpdateSettings().setTransientSettings(settings)); + } + } + + public void testCircuitBreakerFetchFail() throws Exception { + int numShards = randomIntBetween(1, 10); + int numDocs = numShards*10; + indexSomeDocs("boom", numShards, numDocs); + + final AtomicArray exceptions = new AtomicArray<>(10); + final CountDownLatch latch = new CountDownLatch(10); + for (int i = 0; i < 10; i++) { + int batchReduceSize = randomIntBetween(2, Math.max(numShards + 1, 3)); + SearchRequest request = client().prepareSearch("boom") + .setBatchedReduceSize(batchReduceSize) + .setAllowPartialSearchResults(false) + .request(); + final int index = i; + client().search(request, new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + latch.countDown(); + } + + @Override + public void onFailure(Exception exc) { + exceptions.set(index, exc); + latch.countDown(); + } + }); + } + latch.await(); + assertThat(exceptions.asList().size(), equalTo(10)); + for (Exception exc : exceptions.asList()) { + assertThat(ExceptionsHelper.unwrapCause(exc).getCause().getMessage(), containsString("boom")); + } + assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L))); + } + + private void indexSomeDocs(String indexName, int numberOfShards, int numberOfDocs) { + createIndex(indexName, Settings.builder().put("index.number_of_shards", numberOfShards).build()); + + for (int i = 0; i < numberOfDocs; i++) { + IndexResponse indexResponse = client().prepareIndex(indexName, "_doc") + .setSource("number", randomInt()) + .get(); + assertEquals(RestStatus.CREATED, indexResponse.status()); + } + client().admin().indices().prepareRefresh(indexName).get(); + } + + private long requestBreakerUsed() { + NodesStatsResponse stats = client().admin().cluster().prepareNodesStats() + .addMetric(NodesStatsRequest.Metric.BREAKER.metricName()) + .get(); + long estimated = 0; + for (NodeStats nodeStats : stats.getNodes()) { + estimated += nodeStats.getBreaker().getStats(CircuitBreaker.REQUEST).getEstimated(); + } + return estimated; + } + + /** + * A test aggregation that doesn't consume circuit breaker memory when running on shards. + * It is used to test the behavior of the circuit breaker when reducing multiple aggregations + * together (coordinator node). + */ + private static class TestAggregationBuilder extends AbstractAggregationBuilder { + static final String NAME = "test"; + + private static final ObjectParser PARSER = + ObjectParser.fromBuilder(NAME, TestAggregationBuilder::new); + + TestAggregationBuilder(String name) { + super(name); + } + + TestAggregationBuilder(StreamInput input) throws IOException { + super(input); + } + + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + // noop + } + + @Override + protected AggregatorFactory doBuild(QueryShardContext queryShardContext, + AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder) throws IOException { + return new AggregatorFactory(name, queryShardContext, parent, subFactoriesBuilder, metadata) { + @Override + protected Aggregator createInternal(SearchContext searchContext, + Aggregator parent, + CardinalityUpperBound cardinality, + Map metadata) throws IOException { + return new TestAggregator(name, parent, searchContext); + } + }; + } + + @Override + protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + return builder; + } + + @Override + protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map metadata) { + return new TestAggregationBuilder(name); + } + + @Override + public BucketCardinality bucketCardinality() { + return BucketCardinality.NONE; + } + + @Override + public String getType() { + return "test"; + } + } + + /** + * A test aggregator that extends {@link Aggregator} instead of {@link AggregatorBase} + * to avoid tripping the circuit breaker when executing on a shard. + */ + private static class TestAggregator extends Aggregator { + private final String name; + private final Aggregator parent; + private final SearchContext context; + + private TestAggregator(String name, Aggregator parent, SearchContext context) { + this.name = name; + this.parent = parent; + this.context = context; + } + + + @Override + public String name() { + return name; + } + + @Override + public SearchContext context() { + return context; + } + + @Override + public Aggregator parent() { + return parent; + } + + @Override + public Aggregator subAggregator(String name) { + return null; + } + + @Override + public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException { + return new InternalAggregation[] { + new InternalMax(name(), Double.NaN, DocValueFormat.RAW, Collections.emptyMap()) + }; + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return new InternalMax(name(), Double.NaN, DocValueFormat.RAW, Collections.emptyMap()); + } + + @Override + public void close() {} + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException { + throw new CollectionTerminatedException(); + } + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE_NO_SCORES; + } + + @Override + public void preCollection() throws IOException {} + + @Override + public void postCollection() throws IOException {} + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 31c805e0b222f..0caca6976c02f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -33,6 +33,8 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.index.shard.ShardId; @@ -77,7 +79,7 @@ abstract class AbstractSearchAsyncAction exten **/ private final BiFunction nodeIdToConnection; private final SearchTask task; - final SearchPhaseResults results; + protected final SearchPhaseResults results; private final ClusterState clusterState; private final Map aliasFilter; private final Map concreteIndexBoosts; @@ -98,6 +100,8 @@ abstract class AbstractSearchAsyncAction exten private final Map pendingExecutionsPerNode = new ConcurrentHashMap<>(); private final boolean throttleConcurrentRequests; + private final List releasables = new ArrayList<>(); + AbstractSearchAsyncAction(String name, Logger logger, SearchTransportService searchTransportService, BiFunction nodeIdToConnection, Map aliasFilter, Map concreteIndexBoosts, @@ -133,7 +137,7 @@ abstract class AbstractSearchAsyncAction exten this.executor = executor; this.request = request; this.task = task; - this.listener = listener; + this.listener = ActionListener.runAfter(listener, this::releaseContext); this.nodeIdToConnection = nodeIdToConnection; this.clusterState = clusterState; this.concreteIndexBoosts = concreteIndexBoosts; @@ -143,6 +147,15 @@ abstract class AbstractSearchAsyncAction exten this.clusters = clusters; } + @Override + public void addReleasable(Releasable releasable) { + releasables.add(releasable); + } + + public void releaseContext() { + Releasables.close(releasables); + } + /** * Builds how long it took to execute the search. */ @@ -529,7 +542,7 @@ public void sendSearchResponse(InternalSearchResponse internalSearchResponse, At ShardSearchFailure[] failures = buildShardFailures(); Boolean allowPartialResults = request.allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; - if (request.pointInTimeBuilder() == null && allowPartialResults == false && failures.length > 0) { + if (allowPartialResults == false && failures.length > 0) { raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures)); } else { final Version minNodeVersion = clusterState.nodes().getMinNodeVersion(); @@ -567,6 +580,7 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) { } }); } + Releasables.close(releasables); listener.onFailure(exception); } diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index db59c39559ed2..663cb861cc047 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -23,6 +23,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.routing.GroupShardsIterator; +import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.search.SearchService.CanMatchResponse; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -76,6 +77,11 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction listener) { @@ -84,8 +90,7 @@ protected void executePhaseOnShard(SearchShardIterator shardIt, SearchShardTarge } @Override - protected SearchPhase getNextPhase(SearchPhaseResults results, - SearchPhaseContext context) { + protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { return phaseFactory.apply(getIterator((CanMatchSearchPhaseResults) results, shardsIts)); } diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index 980049e99afc4..e0fe285b730ec 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -29,7 +29,6 @@ import java.io.IOException; import java.util.List; -import java.util.function.Consumer; import java.util.function.Function; /** @@ -50,18 +49,21 @@ final class DfsQueryPhase extends SearchPhase { DfsQueryPhase(List searchResults, AggregatedDfs dfs, - SearchPhaseController searchPhaseController, + QueryPhaseResultConsumer queryResult, Function, SearchPhase> nextPhaseFactory, - SearchPhaseContext context, Consumer onPartialMergeFailure) { + SearchPhaseContext context) { super("dfs_query"); this.progressListener = context.getTask().getProgressListener(); - this.queryResult = searchPhaseController.newSearchPhaseResults(context, progressListener, - context.getRequest(), context.getNumShards(), onPartialMergeFailure); + this.queryResult = queryResult; this.searchResults = searchResults; this.dfs = dfs; this.nextPhaseFactory = nextPhaseFactory; this.context = context; this.searchTransportService = context.getSearchTransport(); + + // register the release of the query consumer to free up the circuit breaker memory + // at the end of the search + context.addReleasable(queryResult); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java index 860cf46645f2b..24531f90b6fd0 100644 --- a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java @@ -23,8 +23,11 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TopDocs; import org.elasticsearch.action.search.SearchPhaseController.TopDocsStats; -import org.elasticsearch.common.io.stream.DelayableWriteable; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.search.SearchPhaseResult; @@ -51,13 +54,16 @@ /** * A {@link ArraySearchPhaseResults} implementation that incrementally reduces aggregation results * as shard results are consumed. - * This implementation can be configured to batch up a certain amount of results and reduce - * them asynchronously in the provided {@link Executor} iff the buffer is exhausted. + * This implementation adds the memory that it used to save and reduce the results of shard aggregations + * in the {@link CircuitBreaker#REQUEST} circuit breaker. Before any partial or final reduce, the memory + * needed to reduce the aggregations is estimated and a {@link CircuitBreakingException} is thrown if it + * exceeds the maximum memory allowed in this breaker. */ -public class QueryPhaseResultConsumer extends ArraySearchPhaseResults { +public class QueryPhaseResultConsumer extends ArraySearchPhaseResults implements Releasable { private static final Logger logger = LogManager.getLogger(QueryPhaseResultConsumer.class); private final Executor executor; + private final CircuitBreaker circuitBreaker; private final SearchPhaseController controller; private final SearchProgressListener progressListener; private final ReduceContextBuilder aggReduceContextBuilder; @@ -71,15 +77,13 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults onPartialMergeFailure; - private volatile long aggsMaxBufferSize; - private volatile long aggsCurrentBufferSize; - /** * Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results * as shard results are consumed. */ public QueryPhaseResultConsumer(SearchRequest request, Executor executor, + CircuitBreaker circuitBreaker, SearchPhaseController controller, SearchProgressListener progressListener, NamedWriteableRegistry namedWriteableRegistry, @@ -87,6 +91,7 @@ public QueryPhaseResultConsumer(SearchRequest request, Consumer onPartialMergeFailure) { super(expectedResultSize); this.executor = executor; + this.circuitBreaker = circuitBreaker; this.controller = controller; this.progressListener = progressListener; this.aggReduceContextBuilder = controller.getReduceContext(request); @@ -94,11 +99,17 @@ public QueryPhaseResultConsumer(SearchRequest request, this.topNSize = getTopDocsSize(request); this.performFinalReduce = request.isFinalReduce(); this.onPartialMergeFailure = onPartialMergeFailure; + SearchSourceBuilder source = request.source(); this.hasTopDocs = source == null || source.size() != 0; this.hasAggs = source != null && source.aggregations() != null; - int bufferSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize; - this.pendingMerges = new PendingMerges(bufferSize, request.resolveTrackTotalHitsUpTo()); + int batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize; + this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo()); + } + + @Override + public void close() { + Releasables.close(pendingMerges); } @Override @@ -117,28 +128,35 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { throw pendingMerges.getFailure(); } - logger.trace("aggs final reduction [{}] max [{}]", aggsCurrentBufferSize, aggsMaxBufferSize); // ensure consistent ordering pendingMerges.sortBuffer(); final TopDocsStats topDocsStats = pendingMerges.consumeTopDocsStats(); final List topDocsList = pendingMerges.consumeTopDocs(); final List aggsList = pendingMerges.consumeAggs(); + long breakerSize = pendingMerges.circuitBreakerBytes; + if (hasAggs) { + // Add an estimate of the final reduce size + breakerSize = pendingMerges.addEstimateAndMaybeBreak(pendingMerges.estimateRamBytesUsedForReduce(breakerSize)); + } SearchPhaseController.ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), aggsList, topDocsList, topDocsStats, pendingMerges.numReducePhases, false, aggReduceContextBuilder, performFinalReduce); + if (hasAggs) { + // Update the circuit breaker to replace the estimation with the serialized size of the newly reduced result + long finalSize = reducePhase.aggregations.getSerializedSize() - breakerSize; + pendingMerges.addWithoutBreaking(finalSize); + logger.trace("aggs final reduction [{}] max [{}]", + pendingMerges.aggsCurrentBufferSize, pendingMerges.maxAggsCurrentBufferSize); + } progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()), reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases); return reducePhase; } - private MergeResult partialReduce(MergeTask task, + private MergeResult partialReduce(QuerySearchResult[] toConsume, + List emptyResults, TopDocsStats topDocsStats, MergeResult lastMerge, int numReducePhases) { - final QuerySearchResult[] toConsume = task.consumeBuffer(); - if (toConsume == null) { - // the task is cancelled - return null; - } // ensure consistent ordering Arrays.sort(toConsume, Comparator.comparingInt(QuerySearchResult::getShardIndex)); @@ -164,27 +182,20 @@ private MergeResult partialReduce(MergeTask task, newTopDocs = null; } - final DelayableWriteable.Serialized newAggs; + final InternalAggregations newAggs; if (hasAggs) { List aggsList = new ArrayList<>(); if (lastMerge != null) { - aggsList.add(lastMerge.reducedAggs.expand()); + aggsList.add(lastMerge.reducedAggs); } for (QuerySearchResult result : toConsume) { aggsList.add(result.consumeAggs().expand()); } - InternalAggregations result = InternalAggregations.topLevelReduce(aggsList, - aggReduceContextBuilder.forPartialReduction()); - newAggs = DelayableWriteable.referencing(result).asSerialized(InternalAggregations::readFrom, namedWriteableRegistry); - long previousBufferSize = aggsCurrentBufferSize; - aggsCurrentBufferSize = newAggs.ramBytesUsed(); - aggsMaxBufferSize = Math.max(aggsCurrentBufferSize, aggsMaxBufferSize); - logger.trace("aggs partial reduction [{}->{}] max [{}]", - previousBufferSize, aggsCurrentBufferSize, aggsMaxBufferSize); + newAggs = InternalAggregations.topLevelReduce(aggsList, aggReduceContextBuilder.forPartialReduction()); } else { newAggs = null; } - List processedShards = new ArrayList<>(task.emptyResults); + List processedShards = new ArrayList<>(emptyResults); if (lastMerge != null) { processedShards.addAll(lastMerge.processedShards); } @@ -193,49 +204,109 @@ private MergeResult partialReduce(MergeTask task, processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId())); } progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases); - return new MergeResult(processedShards, newTopDocs, newAggs); + // we leave the results un-serialized because serializing is slow but we compute the serialized + // size as an estimate of the memory used by the newly reduced aggregations. + long serializedSize = hasAggs ? newAggs.getSerializedSize() : 0; + return new MergeResult(processedShards, newTopDocs, newAggs, hasAggs ? serializedSize : 0); } public int getNumReducePhases() { return pendingMerges.numReducePhases; } - private class PendingMerges { - private final int bufferSize; - - private int index; - private final QuerySearchResult[] buffer; + private class PendingMerges implements Releasable { + private final int batchReduceSize; + private final List buffer = new ArrayList<>(); private final List emptyResults = new ArrayList<>(); + // the memory that is accounted in the circuit breaker for this consumer + private volatile long circuitBreakerBytes; + // the memory that is currently used in the buffer + private volatile long aggsCurrentBufferSize; + private volatile long maxAggsCurrentBufferSize = 0; - private final TopDocsStats topDocsStats; - private MergeResult mergeResult; private final ArrayDeque queue = new ArrayDeque<>(); private final AtomicReference runningTask = new AtomicReference<>(); private final AtomicReference failure = new AtomicReference<>(); - private boolean hasPartialReduce; - private int numReducePhases; + private final TopDocsStats topDocsStats; + private volatile MergeResult mergeResult; + private volatile boolean hasPartialReduce; + private volatile int numReducePhases; - PendingMerges(int bufferSize, int trackTotalHitsUpTo) { - this.bufferSize = bufferSize; + PendingMerges(int batchReduceSize, int trackTotalHitsUpTo) { + this.batchReduceSize = batchReduceSize; this.topDocsStats = new TopDocsStats(trackTotalHitsUpTo); - this.buffer = new QuerySearchResult[bufferSize]; } - public boolean hasFailure() { + @Override + public synchronized void close() { + assert hasPendingMerges() == false : "cannot close with partial reduce in-flight"; + if (hasFailure()) { + assert circuitBreakerBytes == 0; + return; + } + assert circuitBreakerBytes >= 0; + circuitBreaker.addWithoutBreaking(-circuitBreakerBytes); + circuitBreakerBytes = 0; + } + + synchronized Exception getFailure() { + return failure.get(); + } + + boolean hasFailure() { return failure.get() != null; } - public synchronized boolean hasPendingMerges() { + boolean hasPendingMerges() { return queue.isEmpty() == false || runningTask.get() != null; } - public synchronized void sortBuffer() { - if (index > 0) { - Arrays.sort(buffer, 0, index, Comparator.comparingInt(QuerySearchResult::getShardIndex)); + void sortBuffer() { + if (buffer.size() > 0) { + Collections.sort(buffer, Comparator.comparingInt(QuerySearchResult::getShardIndex)); } } + synchronized long addWithoutBreaking(long size) { + circuitBreaker.addWithoutBreaking(size); + circuitBreakerBytes += size; + maxAggsCurrentBufferSize = Math.max(maxAggsCurrentBufferSize, circuitBreakerBytes); + return circuitBreakerBytes; + } + + synchronized long addEstimateAndMaybeBreak(long estimatedSize) { + circuitBreaker.addEstimateBytesAndMaybeBreak(estimatedSize, ""); + circuitBreakerBytes += estimatedSize; + maxAggsCurrentBufferSize = Math.max(maxAggsCurrentBufferSize, circuitBreakerBytes); + return circuitBreakerBytes; + } + + /** + * Returns the size of the serialized aggregation that is contained in the + * provided {@link QuerySearchResult}. + */ + long ramBytesUsedQueryResult(QuerySearchResult result) { + if (hasAggs == false) { + return 0; + } + return result.aggregations() + .asSerialized(InternalAggregations::readFrom, namedWriteableRegistry) + .ramBytesUsed(); + } + + /** + * Returns an estimation of the size that a reduce of the provided size + * would take on memory. + * This size is estimated as roughly 1.5 times the size of the serialized + * aggregations that need to be reduced. This estimation can be completely + * off for some aggregations but it is corrected with the real size after + * the reduce completes. + */ + long estimateRamBytesUsedForReduce(long size) { + return Math.round(1.5d * size - size); + } + public void consume(QuerySearchResult result, Runnable next) { boolean executeNextImmediately = true; synchronized (this) { @@ -247,20 +318,24 @@ public void consume(QuerySearchResult result, Runnable next) { } } else { // add one if a partial merge is pending - int size = index + (hasPartialReduce ? 1 : 0); - if (size >= bufferSize) { + int size = buffer.size() + (hasPartialReduce ? 1 : 0); + if (size >= batchReduceSize) { hasPartialReduce = true; executeNextImmediately = false; - QuerySearchResult[] clone = new QuerySearchResult[index]; - System.arraycopy(buffer, 0, clone, 0, index); - MergeTask task = new MergeTask(clone, new ArrayList<>(emptyResults), next); - Arrays.fill(buffer, null); + QuerySearchResult[] clone = buffer.stream().toArray(QuerySearchResult[]::new); + MergeTask task = new MergeTask(clone, aggsCurrentBufferSize, new ArrayList<>(emptyResults), next); + aggsCurrentBufferSize = 0; + buffer.clear(); emptyResults.clear(); - index = 0; queue.add(task); tryExecuteNext(); } - buffer[index++] = result; + if (hasAggs) { + long aggsSize = ramBytesUsedQueryResult(result); + addWithoutBreaking(aggsSize); + aggsCurrentBufferSize += aggsSize; + } + buffer.add(result); } } if (executeNextImmediately) { @@ -268,56 +343,85 @@ public void consume(QuerySearchResult result, Runnable next) { } } - private void onMergeFailure(Exception exc) { - synchronized (this) { - if (failure.get() != null) { - return; - } - failure.compareAndSet(null, exc); - MergeTask task = runningTask.get(); - runningTask.compareAndSet(task, null); - onPartialMergeFailure.accept(exc); - List toCancel = new ArrayList<>(); - if (task != null) { - toCancel.add(task); - } - toCancel.addAll(queue); - queue.clear(); - mergeResult = null; - toCancel.stream().forEach(MergeTask::cancel); + private synchronized void onMergeFailure(Exception exc) { + if (hasFailure()) { + assert circuitBreakerBytes == 0; + return; + } + assert circuitBreakerBytes >= 0; + if (circuitBreakerBytes > 0) { + // make sure that we reset the circuit breaker + circuitBreaker.addWithoutBreaking(-circuitBreakerBytes); + circuitBreakerBytes = 0; + } + failure.compareAndSet(null, exc); + MergeTask task = runningTask.get(); + runningTask.compareAndSet(task, null); + onPartialMergeFailure.accept(exc); + List toCancels = new ArrayList<>(); + if (task != null) { + toCancels.add(task); + } + queue.stream().forEach(toCancels::add); + queue.clear(); + mergeResult = null; + for (MergeTask toCancel : toCancels) { + toCancel.cancel(); } } - private void onAfterMerge(MergeTask task, MergeResult newResult) { + private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedSize) { synchronized (this) { + if (hasFailure()) { + return; + } runningTask.compareAndSet(task, null); mergeResult = newResult; + if (hasAggs) { + // Update the circuit breaker to remove the size of the source aggregations + // and replace the estimation with the serialized size of the newly reduced result. + long newSize = mergeResult.estimatedSize - estimatedSize; + addWithoutBreaking(newSize); + logger.trace("aggs partial reduction [{}->{}] max [{}]", + estimatedSize, mergeResult.estimatedSize, maxAggsCurrentBufferSize); + } + task.consumeListener(); } - task.consumeListener(); } private void tryExecuteNext() { final MergeTask task; synchronized (this) { if (queue.isEmpty() - || failure.get() != null + || hasFailure() || runningTask.get() != null) { return; } task = queue.poll(); runningTask.compareAndSet(null, task); } + executor.execute(new AbstractRunnable() { @Override protected void doRun() { + final MergeResult thisMergeResult = mergeResult; + long estimatedTotalSize = (thisMergeResult != null ? thisMergeResult.estimatedSize : 0) + task.aggsBufferSize; final MergeResult newMerge; try { - newMerge = partialReduce(task, topDocsStats, mergeResult, ++numReducePhases); + final QuerySearchResult[] toConsume = task.consumeBuffer(); + if (toConsume == null) { + return; + } + long estimatedMergeSize = estimateRamBytesUsedForReduce(estimatedTotalSize); + addEstimateAndMaybeBreak(estimatedMergeSize); + estimatedTotalSize += estimatedMergeSize; + ++ numReducePhases; + newMerge = partialReduce(toConsume, task.emptyResults, topDocsStats, thisMergeResult, numReducePhases); } catch (Exception t) { onMergeFailure(t); return; } - onAfterMerge(task, newMerge); + onAfterMerge(task, newMerge, estimatedTotalSize); tryExecuteNext(); } @@ -328,15 +432,14 @@ public void onFailure(Exception exc) { }); } - public TopDocsStats consumeTopDocsStats() { - for (int i = 0; i < index; i++) { - QuerySearchResult result = buffer[i]; + public synchronized TopDocsStats consumeTopDocsStats() { + for (QuerySearchResult result : buffer) { topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); } return topDocsStats; } - public List consumeTopDocs() { + public synchronized List consumeTopDocs() { if (hasTopDocs == false) { return Collections.emptyList(); } @@ -344,8 +447,7 @@ public List consumeTopDocs() { if (mergeResult != null) { topDocsList.add(mergeResult.reducedTopDocs); } - for (int i = 0; i < index; i++) { - QuerySearchResult result = buffer[i]; + for (QuerySearchResult result : buffer) { TopDocsAndMaxScore topDocs = result.consumeTopDocs(); setShardIndex(topDocs.topDocs, result.getShardIndex()); topDocsList.add(topDocs.topDocs); @@ -353,46 +455,45 @@ public List consumeTopDocs() { return topDocsList; } - public List consumeAggs() { + public synchronized List consumeAggs() { if (hasAggs == false) { return Collections.emptyList(); } List aggsList = new ArrayList<>(); if (mergeResult != null) { - aggsList.add(mergeResult.reducedAggs.expand()); + aggsList.add(mergeResult.reducedAggs); } - for (int i = 0; i < index; i++) { - QuerySearchResult result = buffer[i]; + for (QuerySearchResult result : buffer) { aggsList.add(result.consumeAggs().expand()); } return aggsList; } - - public Exception getFailure() { - return failure.get(); - } } private static class MergeResult { private final List processedShards; private final TopDocs reducedTopDocs; - private final DelayableWriteable.Serialized reducedAggs; + private final InternalAggregations reducedAggs; + private final long estimatedSize; private MergeResult(List processedShards, TopDocs reducedTopDocs, - DelayableWriteable.Serialized reducedAggs) { + InternalAggregations reducedAggs, long estimatedSize) { this.processedShards = processedShards; this.reducedTopDocs = reducedTopDocs; this.reducedAggs = reducedAggs; + this.estimatedSize = estimatedSize; } } private static class MergeTask { private final List emptyResults; private QuerySearchResult[] buffer; + private long aggsBufferSize; private Runnable next; - private MergeTask(QuerySearchResult[] buffer, List emptyResults, Runnable next) { + private MergeTask(QuerySearchResult[] buffer, long aggsBufferSize, List emptyResults, Runnable next) { this.buffer = buffer; + this.aggsBufferSize = aggsBufferSize; this.emptyResults = emptyResults; this.next = next; } @@ -403,7 +504,7 @@ public synchronized QuerySearchResult[] consumeBuffer() { return toRet; } - public synchronized void consumeListener() { + public void consumeListener() { if (next != null) { next.run(); next = null; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 0762d70dc5cbf..b53e635d9866a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -35,29 +35,29 @@ import java.util.Set; import java.util.concurrent.Executor; import java.util.function.BiFunction; -import java.util.function.Consumer; final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction { private final SearchPhaseController searchPhaseController; - private final Consumer onPartialMergeFailure; + + private final QueryPhaseResultConsumer queryPhaseResultConsumer; SearchDfsQueryThenFetchAsyncAction(final Logger logger, final SearchTransportService searchTransportService, final BiFunction nodeIdToConnection, final Map aliasFilter, final Map concreteIndexBoosts, final Map> indexRoutings, final SearchPhaseController searchPhaseController, final Executor executor, + final QueryPhaseResultConsumer queryPhaseResultConsumer, final SearchRequest request, final ActionListener listener, final GroupShardsIterator shardsIts, final TransportSearchAction.SearchTimeProvider timeProvider, - final ClusterState clusterState, final SearchTask task, SearchResponse.Clusters clusters, - Consumer onPartialMergeFailure) { + final ClusterState clusterState, final SearchTask task, SearchResponse.Clusters clusters) { super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, executor, request, listener, shardsIts, timeProvider, clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()), request.getMaxConcurrentShardRequests(), clusters); + this.queryPhaseResultConsumer = queryPhaseResultConsumer; this.searchPhaseController = searchPhaseController; - this.onPartialMergeFailure = onPartialMergeFailure; SearchProgressListener progressListener = task.getProgressListener(); SearchSourceBuilder sourceBuilder = request.source(); progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts), @@ -72,11 +72,12 @@ protected void executePhaseOnShard(final SearchShardIterator shardIt, final Sear } @Override - protected SearchPhase getNextPhase(final SearchPhaseResults results, final SearchPhaseContext context) { + protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { final List dfsSearchResults = results.getAtomicArray().asList(); final AggregatedDfs aggregatedDfs = searchPhaseController.aggregateDfs(dfsSearchResults); - return new DfsQueryPhase(dfsSearchResults, aggregatedDfs, searchPhaseController, (queryResults) -> - new FetchSearchPhase(queryResults, searchPhaseController, aggregatedDfs, context), context, onPartialMergeFailure); + return new DfsQueryPhase(dfsSearchResults, aggregatedDfs, queryPhaseResultConsumer, + (queryResults) -> new FetchSearchPhase(queryResults, searchPhaseController, aggregatedDfs, context), + context); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java index 75ce64dc264eb..e56100dc5287f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java @@ -21,6 +21,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; @@ -123,4 +124,9 @@ default void sendReleaseSearchContext(ShardSearchContextId contextId, * a response is returned to the user indicating that all shards have failed. */ void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase); + + /** + * Registers a {@link Releasable} that will be closed when the search request finishes or fails. + */ + void addReleasable(Releasable releasable); } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index a612e09a549f9..21dc1589c6579 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -34,6 +34,7 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits.Relation; import org.apache.lucene.search.grouping.CollapseTopFieldDocs; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.collect.HppcMaps; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; @@ -563,14 +564,16 @@ InternalAggregation.ReduceContextBuilder getReduceContext(SearchRequest request) } /** - * Returns a new {@link QueryPhaseResultConsumer} instance. This might return an instance that reduces search responses incrementally. + * Returns a new {@link QueryPhaseResultConsumer} instance that reduces search responses incrementally. */ QueryPhaseResultConsumer newSearchPhaseResults(Executor executor, + CircuitBreaker circuitBreaker, SearchProgressListener listener, SearchRequest request, int numShards, Consumer onPartialMergeFailure) { - return new QueryPhaseResultConsumer(request, executor, this, listener, namedWriteableRegistry, numShards, onPartialMergeFailure); + return new QueryPhaseResultConsumer(request, executor, circuitBreaker, + this, listener, namedWriteableRegistry, numShards, onPartialMergeFailure); } static final class TopDocsStats { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java index bbb9a5ad02388..f6670eb5e2f5c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java @@ -25,7 +25,6 @@ import org.apache.lucene.search.TotalHits; import org.elasticsearch.action.search.SearchResponse.Clusters; import org.elasticsearch.cluster.routing.GroupShardsIterator; -import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.InternalAggregations; @@ -78,11 +77,10 @@ protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exc * * @param shards The list of shards that are part of this reduce. * @param totalHits The total number of hits in this reduce. - * @param aggs The partial result for aggregations stored in serialized form. + * @param aggs The partial result for aggregations. * @param reducePhase The version number for this reduce. */ - protected void onPartialReduce(List shards, TotalHits totalHits, - DelayableWriteable.Serialized aggs, int reducePhase) {} + protected void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {} /** * Executed once when the final reduce is created. @@ -137,8 +135,7 @@ final void notifyQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exc } } - final void notifyPartialReduce(List shards, TotalHits totalHits, - DelayableWriteable.Serialized aggs, int reducePhase) { + final void notifyPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { try { onPartialReduce(shards, totalHits, aggs, reducePhase); } catch (Exception e) { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index f841c6e55f44b..79f5e5ca9571e 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -26,7 +26,6 @@ import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; -import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchRequest; @@ -37,7 +36,6 @@ import java.util.Set; import java.util.concurrent.Executor; import java.util.function.BiFunction; -import java.util.function.Consumer; import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize; @@ -56,22 +54,26 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction aliasFilter, final Map concreteIndexBoosts, final Map> indexRoutings, final SearchPhaseController searchPhaseController, final Executor executor, - final SearchRequest request, final ActionListener listener, + final QueryPhaseResultConsumer resultConsumer, final SearchRequest request, + final ActionListener listener, final GroupShardsIterator shardsIts, final TransportSearchAction.SearchTimeProvider timeProvider, - ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters, - Consumer onPartialMergeFailure) { + ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters) { super("query", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, executor, request, listener, shardsIts, timeProvider, clusterState, task, - searchPhaseController.newSearchPhaseResults(executor, task.getProgressListener(), - request, shardsIts.size(), onPartialMergeFailure), request.getMaxConcurrentShardRequests(), clusters); + resultConsumer, request.getMaxConcurrentShardRequests(), clusters); this.topDocsSize = getTopDocsSize(request); this.trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo(); this.searchPhaseController = searchPhaseController; this.progressListener = task.getProgressListener(); - final SearchSourceBuilder sourceBuilder = request.source(); + + // register the release of the query consumer to free up the circuit breaker memory + // at the end of the search + addReleasable(resultConsumer); + + boolean hasFetchPhase = request.source() == null ? true : request.source().size() > 0; progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts), - SearchProgressListener.buildSearchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0); + SearchProgressListener.buildSearchShards(toSkipShardsIts), clusters, hasFetchPhase); } protected void executePhaseOnShard(final SearchShardIterator shardIt, @@ -108,8 +110,8 @@ && getRequest().scroll() == null } @Override - protected SearchPhase getNextPhase(final SearchPhaseResults results, final SearchPhaseContext context) { - return new FetchSearchPhase(results, searchPhaseController, null, context); + protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { + return new FetchSearchPhase(results, searchPhaseController, null, this); } private ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) { diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index aab586fa47e65..2c9a5f9e37e53 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -43,6 +43,7 @@ import org.elasticsearch.cluster.routing.ShardIterator; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.Strings; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -55,6 +56,7 @@ import org.elasticsearch.index.Index; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; @@ -115,10 +117,12 @@ public class TransportSearchAction extends HandledTransportAction) SearchRequest::new); this.client = client; this.threadPool = threadPool; + this.circuitBreaker = circuitBreakerService.getBreaker(CircuitBreaker.REQUEST); this.searchPhaseController = searchPhaseController; this.searchTransportService = searchTransportService; this.remoteClusterService = searchTransportService.getRemoteClusterService(); @@ -796,17 +801,19 @@ public void run() { }; }, clusters); } else { + final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults(executor, + circuitBreaker, task.getProgressListener(), searchRequest, shardIterators.size(), exc -> cancelTask(task, exc)); AbstractSearchAsyncAction searchAsyncAction; switch (searchRequest.searchType()) { case DFS_QUERY_THEN_FETCH: searchAsyncAction = new SearchDfsQueryThenFetchAsyncAction(logger, searchTransportService, connectionLookup, - aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, searchRequest, listener, - shardIterators, timeProvider, clusterState, task, clusters, exc -> cancelTask(task, exc)); + aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, + executor, queryResultConsumer, searchRequest, listener, shardIterators, timeProvider, clusterState, task, clusters); break; case QUERY_THEN_FETCH: searchAsyncAction = new SearchQueryThenFetchAsyncAction(logger, searchTransportService, connectionLookup, - aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, searchRequest, listener, - shardIterators, timeProvider, clusterState, task, clusters, exc -> cancelTask(task, exc)); + aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, queryResultConsumer, + searchRequest, listener, shardIterators, timeProvider, clusterState, task, clusters); break; default: throw new IllegalStateException("Unknown search type: [" + searchRequest.searchType() + "]"); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java b/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java index b16f5cdd3b1f8..1522a789cd7c1 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java @@ -272,4 +272,47 @@ public static InternalAggregations reduce(List aggregation public static InternalAggregations reduce(List aggregationsList, ReduceContext context) { return reduce(aggregationsList, context, InternalAggregations::from); } + + /** + * Returns the number of bytes required to serialize these aggregations in binary form. + */ + public long getSerializedSize() { + try (CountingStreamOutput out = new CountingStreamOutput()) { + out.setVersion(Version.CURRENT); + writeTo(out); + return out.size; + } catch (IOException exc) { + // should never happen + throw new RuntimeException(exc); + } + } + + private static class CountingStreamOutput extends StreamOutput { + long size = 0; + + @Override + public void writeByte(byte b) throws IOException { + ++ size; + } + + @Override + public void writeBytes(byte[] b, int offset, int length) throws IOException { + size += length; + } + + @Override + public void flush() throws IOException {} + + @Override + public void close() throws IOException {} + + @Override + public void reset() throws IOException { + size = 0; + } + + public long length() { + return size; + } + } } diff --git a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java index 916f9111517a5..9f1199f774a63 100644 --- a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java @@ -96,7 +96,7 @@ private AbstractSearchAsyncAction createAction(SearchRequest results, request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY) { @Override - protected SearchPhase getNextPhase(final SearchPhaseResults results, final SearchPhaseContext context) { + protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { return null; } diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index 140d28c47fd9a..d71b14f3d12f3 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -25,8 +25,11 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.store.MockDirectoryWrapper; import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchPhaseResult; @@ -86,15 +89,19 @@ public void sendExecuteQuery(Transport.Connection connection, QuerySearchRequest } } }; + SearchPhaseController searchPhaseController = searchPhaseController(); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); mockSearchPhaseContext.searchTransport = searchTransportService; - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, searchPhaseController(), + QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.searchRequest, + results.length(), exc -> {}); + DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, consumer, (response) -> new SearchPhase("test") { @Override public void run() throws IOException { responseRef.set(response.results); } - }, mockSearchPhaseContext, exc -> {}); + }, mockSearchPhaseContext); assertEquals("dfs_query", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -141,15 +148,19 @@ public void sendExecuteQuery(Transport.Connection connection, QuerySearchRequest } } }; + SearchPhaseController searchPhaseController = searchPhaseController(); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); mockSearchPhaseContext.searchTransport = searchTransportService; - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, searchPhaseController(), + QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.searchRequest, + results.length(), exc -> {}); + DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, consumer, (response) -> new SearchPhase("test") { @Override public void run() throws IOException { responseRef.set(response.results); } - }, mockSearchPhaseContext, exc -> {}); + }, mockSearchPhaseContext); assertEquals("dfs_query", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -198,15 +209,19 @@ public void sendExecuteQuery(Transport.Connection connection, QuerySearchRequest } } }; + SearchPhaseController searchPhaseController = searchPhaseController(); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); mockSearchPhaseContext.searchTransport = searchTransportService; - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, searchPhaseController(), + QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.searchRequest, + results.length(), exc -> {}); + DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, consumer, (response) -> new SearchPhase("test") { @Override public void run() throws IOException { responseRef.set(response.results); } - }, mockSearchPhaseContext, exc -> {}); + }, mockSearchPhaseContext); assertEquals("dfs_query", phase.getName()); expectThrows(UncheckedIOException.class, phase::run); assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); // phase execution will clean up on the contexts diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java index 9dc544091ae85..32d8e0d724686 100644 --- a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java @@ -24,6 +24,8 @@ import org.apache.lucene.store.MockDirectoryWrapper; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.index.shard.ShardId; @@ -43,8 +45,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; -import static org.elasticsearch.action.search.SearchProgressListener.NOOP; - public class FetchSearchPhaseTests extends ESTestCase { public void testShortcutQueryAndFetchOptimization() { @@ -52,7 +52,8 @@ public void testShortcutQueryAndFetchOptimization() { writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), - NOOP, mockSearchPhaseContext.getRequest(), 1, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), 1, exc -> {}); boolean hasHits = randomBoolean(); final int numHits; if (hasHits) { @@ -96,7 +97,8 @@ public void testFetchTwoDocument() { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), - NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), 2, exc -> {}); int resultSetSize = randomIntBetween(2, 10); ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); QuerySearchResult queryResult = new QuerySearchResult(ctx1, new SearchShardTarget("node1", new ShardId("test", "na", 0), @@ -157,7 +159,8 @@ public void testFailFetchOneDoc() { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), - NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), 2, exc -> {}); int resultSetSize = randomIntBetween(2, 10); final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); QuerySearchResult queryResult = new QuerySearchResult(ctx, @@ -220,7 +223,8 @@ public void testFetchDocsConcurrently() throws InterruptedException { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits); - QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), NOOP, + QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), numHits, exc -> {}); for (int i = 0; i < numHits; i++) { QuerySearchResult queryResult = new QuerySearchResult(new ShardSearchContextId("", i), @@ -279,7 +283,8 @@ public void testExceptionFailsPhase() { writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), - NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), 2, exc -> {}); int resultSetSize = randomIntBetween(2, 10); QuerySearchResult queryResult = new QuerySearchResult(new ShardSearchContextId("", 123), new SearchShardTarget("node1", new ShardId("test", "na", 0), @@ -337,7 +342,8 @@ public void testCleanupIrrelevantContexts() { // contexts that are not fetched s SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), - NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), 2, exc -> {}); int resultSetSize = 1; final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); QuerySearchResult queryResult = new QuerySearchResult(ctx1, diff --git a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java index cf1a96dde422c..96e9fe7a61c1f 100644 --- a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java +++ b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java @@ -23,6 +23,7 @@ import org.elasticsearch.Version; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; @@ -131,6 +132,11 @@ public void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase) { } } + @Override + public void addReleasable(Releasable releasable) { + // Noop + } + @Override public void execute(Runnable command) { command.run(); diff --git a/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java b/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java index 2e9f7f7af6f41..f44a0cf292d5e 100644 --- a/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java @@ -23,7 +23,8 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.elasticsearch.action.OriginalIndices; -import org.elasticsearch.common.io.stream.DelayableWriteable; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.EsExecutors; @@ -93,8 +94,9 @@ public void testProgressListenerExceptionsAreCaught() throws Exception { SearchRequest searchRequest = new SearchRequest("index"); searchRequest.setBatchedReduceSize(2); AtomicReference onPartialMergeFailure = new AtomicReference<>(); - QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(searchRequest, executor, searchPhaseController, - searchProgressListener, writableRegistry(), 10, e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(searchRequest, executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), searchPhaseController, searchProgressListener, + writableRegistry(), 10, e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { curr.addSuppressed(prev); return curr; })); @@ -140,7 +142,7 @@ protected void onQueryResult(int shardIndex) { @Override protected void onPartialReduce(List shards, TotalHits totalHits, - DelayableWriteable.Serialized aggs, int reducePhase) { + InternalAggregations aggs, int reducePhase) { onPartialReduce.incrementAndGet(); throw new UnsupportedOperationException(); } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java index 99181976fce65..676da3da9e63b 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java @@ -460,8 +460,7 @@ protected void executePhaseOnShard(SearchShardIterator shardIt, } @Override - protected SearchPhase getNextPhase(SearchPhaseResults results, - SearchPhaseContext context) { + protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { return new SearchPhase("test") { @Override public void run() { diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index 7971a7d831106..2898e203a13a1 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -33,10 +33,10 @@ import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; -import org.elasticsearch.common.io.stream.DelayableWriteable; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.settings.Settings; @@ -45,7 +45,6 @@ import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor; -import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchHit; @@ -77,7 +76,6 @@ import org.junit.After; import org.junit.Before; -import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -95,7 +93,6 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; -import static org.elasticsearch.action.search.SearchProgressListener.NOOP; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -111,9 +108,9 @@ public class SearchPhaseControllerTests extends ESTestCase { @Override protected NamedWriteableRegistry writableRegistry() { - List entries = - new ArrayList<>(new SearchModule(Settings.EMPTY, false, emptyList()).getNamedWriteables()); - entries.add(new NamedWriteableRegistry.Entry(InternalAggregation.class, "throwing", InternalThrowing::new)); + List entries = new ArrayList<>( + new SearchModule(Settings.EMPTY, false, emptyList()).getNamedWriteables() + ); return new NamedWriteableRegistry(entries); } @@ -419,7 +416,8 @@ private void consumerTestCase(int numEmptyResponses) throws Exception { request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo"))); request.setBatchedReduceSize(bufferSize); ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, 3+numEmptyResponses, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, 3+numEmptyResponses, exc -> {}); if (numEmptyResponses == 0) { assertEquals(0, reductions.size()); } @@ -506,7 +504,8 @@ public void testConsumerConcurrently() throws Exception { request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo"))); request.setBatchedReduceSize(bufferSize); ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); Thread[] threads = new Thread[expectedNumResults]; CountDownLatch latch = new CountDownLatch(expectedNumResults); @@ -556,7 +555,8 @@ public void testConsumerOnlyAggs() throws Exception { request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0)); request.setBatchedReduceSize(bufferSize); QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); CountDownLatch latch = new CountDownLatch(expectedNumResults); for (int i = 0; i < expectedNumResults; i++) { @@ -597,7 +597,8 @@ public void testConsumerOnlyHits() throws Exception { } request.setBatchedReduceSize(bufferSize); QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); CountDownLatch latch = new CountDownLatch(expectedNumResults); for (int i = 0; i < expectedNumResults; i++) { @@ -640,7 +641,8 @@ public void testReduceTopNWithFromOffset() throws Exception { request.source(new SearchSourceBuilder().size(5).from(5)); request.setBatchedReduceSize(randomIntBetween(2, 4)); QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, 4, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP + , request, 4, exc -> {}); int score = 100; CountDownLatch latch = new CountDownLatch(4); for (int i = 0; i < 4; i++) { @@ -678,7 +680,8 @@ public void testConsumerSortByField() throws Exception { int size = randomIntBetween(1, 10); request.setBatchedReduceSize(bufferSize); QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); SortField[] sortFields = {new SortField("field", SortField.Type.INT, true)}; DocValueFormat[] docValueFormats = {DocValueFormat.RAW}; @@ -716,7 +719,8 @@ public void testConsumerFieldCollapsing() throws Exception { int size = randomIntBetween(5, 10); request.setBatchedReduceSize(bufferSize); QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, expectedNumResults, exc -> {}); SortField[] sortFields = {new SortField("field", SortField.Type.STRING)}; BytesRef a = new BytesRef("a"); BytesRef b = new BytesRef("b"); @@ -757,7 +761,8 @@ public void testConsumerSuggestions() throws Exception { SearchRequest request = randomSearchRequest(); request.setBatchedReduceSize(bufferSize); QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, expectedNumResults, exc -> {}); int maxScoreTerm = -1; int maxScorePhrase = -1; int maxScoreCompletion = -1; @@ -871,7 +876,7 @@ public void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Except @Override public void onPartialReduce(List shards, TotalHits totalHits, - DelayableWriteable.Serialized aggs, int reducePhase) { + InternalAggregations aggs, int reducePhase) { assertEquals(numReduceListener.incrementAndGet(), reducePhase); } @@ -883,7 +888,7 @@ public void onFinalReduce(List shards, TotalHits totalHits, Interna } }; QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - progressListener, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), progressListener, request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); Thread[] threads = new Thread[expectedNumResults]; CountDownLatch latch = new CountDownLatch(expectedNumResults); @@ -932,7 +937,19 @@ public void onFinalReduce(List shards, TotalHits totalHits, Interna } } - public void testPartialMergeFailure() throws InterruptedException { + public void testPartialReduce() throws Exception { + for (int i = 0; i < 10; i++) { + testReduceCase(false); + } + } + + public void testPartialReduceWithFailure() throws Exception { + for (int i = 0; i < 10; i++) { + testReduceCase(true); + } + } + + private void testReduceCase(boolean shouldFail) throws Exception { int expectedNumResults = randomIntBetween(20, 200); int bufferSize = randomIntBetween(2, expectedNumResults - 1); SearchRequest request = new SearchRequest(); @@ -940,11 +957,16 @@ public void testPartialMergeFailure() throws InterruptedException { request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0)); request.setBatchedReduceSize(bufferSize); AtomicBoolean hasConsumedFailure = new AtomicBoolean(); + AssertingCircuitBreaker circuitBreaker = new AssertingCircuitBreaker(CircuitBreaker.REQUEST); + boolean shouldFailPartial = shouldFail && randomBoolean(); + if (shouldFailPartial) { + circuitBreaker.shouldBreak.set(true); + } QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> hasConsumedFailure.set(true)); + circuitBreaker, SearchProgressListener.NOOP, + request, expectedNumResults, exc -> hasConsumedFailure.set(true)); CountDownLatch latch = new CountDownLatch(expectedNumResults); Thread[] threads = new Thread[expectedNumResults]; - int failedIndex = randomIntBetween(0, expectedNumResults-1); for (int i = 0; i < expectedNumResults; i++) { final int index = i; threads[index] = new Thread(() -> { @@ -955,7 +977,7 @@ public void testPartialMergeFailure() throws InterruptedException { new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), Lucene.EMPTY_SCORE_DOCS), Float.NaN), new DocValueFormat[0]); InternalAggregations aggs = InternalAggregations.from( - Collections.singletonList(new InternalThrowing("test", (failedIndex == index), Collections.emptyMap()))); + Collections.singletonList(new InternalMax("test", 0d, DocValueFormat.RAW, Collections.emptyMap()))); result.aggregations(aggs); result.setShardIndex(index); result.size(1); @@ -967,65 +989,44 @@ public void testPartialMergeFailure() throws InterruptedException { threads[i].join(); } latch.await(); - IllegalStateException exc = expectThrows(IllegalStateException.class, () -> consumer.reduce()); - if (exc.getMessage().contains("partial reduce")) { - assertTrue(hasConsumedFailure.get()); + if (shouldFail) { + if (shouldFailPartial == false) { + circuitBreaker.shouldBreak.set(true); + } + CircuitBreakingException exc = expectThrows(CircuitBreakingException.class, () -> consumer.reduce()); + assertEquals(shouldFailPartial, hasConsumedFailure.get()); + assertThat(exc.getMessage(), containsString("")); + circuitBreaker.shouldBreak.set(false); } else { - assertThat(exc.getMessage(), containsString("final reduce")); + SearchPhaseController.ReducedQueryPhase phase = consumer.reduce(); } + consumer.close(); + assertThat(circuitBreaker.allocated, equalTo(0L)); } - private static class InternalThrowing extends InternalAggregation { - private final boolean shouldThrow; - - protected InternalThrowing(String name, boolean shouldThrow, Map metadata) { - super(name, metadata); - this.shouldThrow = shouldThrow; - } + private static class AssertingCircuitBreaker extends NoopCircuitBreaker { + private final AtomicBoolean shouldBreak = new AtomicBoolean(false); - protected InternalThrowing(StreamInput in) throws IOException { - super(in); - this.shouldThrow = in.readBoolean(); - } + private volatile long allocated; - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeBoolean(shouldThrow); + AssertingCircuitBreaker(String name) { + super(name); } @Override - public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { - if (aggregations.stream() - .map(agg -> (InternalThrowing) agg) - .anyMatch(agg -> agg.shouldThrow)) { - if (reduceContext.isFinalReduce()) { - throw new IllegalStateException("final reduce"); - } else { - throw new IllegalStateException("partial reduce"); - } + public double addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException { + assert bytes >= 0; + if (shouldBreak.get()) { + throw new CircuitBreakingException(label, getDurability()); } - return new InternalThrowing(name, false, metadata); - } - - @Override - protected boolean mustReduceOnSingleInternalAgg() { - return true; + allocated += bytes; + return allocated; } @Override - public Object getProperty(List path) { - return null; - } - - @Override - public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { - throw new IllegalStateException("not implemented"); - } - - @Override - public String getWriteableName() { - return "throwing"; + public long addWithoutBreaking(long bytes) { + allocated += bytes; + return allocated; } } - } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java index cbdc5b56c85b8..9c1d4bf3448df 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -29,6 +29,8 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.EsExecutors; @@ -51,6 +53,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -144,15 +147,19 @@ public void sendExecuteQuery(Transport.Connection connection, ShardSearchRequest searchRequest.source().collapse(new CollapseBuilder("collapse_field")); } searchRequest.allowPartialSearchResults(false); + Executor executor = EsExecutors.newDirectExecutorService(); SearchPhaseController controller = new SearchPhaseController( writableRegistry(), r -> InternalAggregationTestCase.emptyReduceContextBuilder()); SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.emptyMap()); + QueryPhaseResultConsumer resultConsumer = new QueryPhaseResultConsumer(searchRequest, executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), controller, task.getProgressListener(), writableRegistry(), + shardsIter.size(), exc -> {}); SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction(logger, searchTransportService, (clusterAlias, node) -> lookup.get(node), Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)), - Collections.emptyMap(), Collections.emptyMap(), controller, EsExecutors.newDirectExecutorService(), searchRequest, - null, shardsIter, timeProvider, null, task, - SearchResponse.Clusters.EMPTY, exc -> {}) { + Collections.emptyMap(), Collections.emptyMap(), controller, executor, + resultConsumer, searchRequest, null, shardsIter, timeProvider, null, + task, SearchResponse.Clusters.EMPTY) { @Override protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { return new SearchPhase("test") { diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionSingleNodeTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionSingleNodeTests.java deleted file mode 100644 index 7e27aaac59ecb..0000000000000 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionSingleNodeTests.java +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.action.search; - -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.index.IndexResponse; -import org.elasticsearch.action.support.IndicesOptions; -import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.common.Strings; -import org.elasticsearch.index.query.RangeQueryBuilder; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.search.aggregations.Aggregations; -import org.elasticsearch.search.aggregations.bucket.terms.LongTerms; -import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; -import org.elasticsearch.search.aggregations.support.ValueType; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.test.ESSingleNodeTestCase; - -public class TransportSearchActionSingleNodeTests extends ESSingleNodeTestCase { - - public void testLocalClusterAlias() { - long nowInMillis = randomLongBetween(0, Long.MAX_VALUE); - IndexRequest indexRequest = new IndexRequest("test"); - indexRequest.id("1"); - indexRequest.source("field", "value"); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client().index(indexRequest).actionGet(); - assertEquals(RestStatus.CREATED, indexResponse.status()); - - { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY, - "local", nowInMillis, randomBoolean()); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(1, searchResponse.getHits().getTotalHits().value); - SearchHit[] hits = searchResponse.getHits().getHits(); - assertEquals(1, hits.length); - SearchHit hit = hits[0]; - assertEquals("local", hit.getClusterAlias()); - assertEquals("test", hit.getIndex()); - assertEquals("1", hit.getId()); - } - { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY, - "", nowInMillis, randomBoolean()); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(1, searchResponse.getHits().getTotalHits().value); - SearchHit[] hits = searchResponse.getHits().getHits(); - assertEquals(1, hits.length); - SearchHit hit = hits[0]; - assertEquals("", hit.getClusterAlias()); - assertEquals("test", hit.getIndex()); - assertEquals("1", hit.getId()); - } - } - - public void testAbsoluteStartMillis() { - { - IndexRequest indexRequest = new IndexRequest("test-1970.01.01"); - indexRequest.id("1"); - indexRequest.source("date", "1970-01-01"); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client().index(indexRequest).actionGet(); - assertEquals(RestStatus.CREATED, indexResponse.status()); - } - { - IndexRequest indexRequest = new IndexRequest("test-1982.01.01"); - indexRequest.id("1"); - indexRequest.source("date", "1982-01-01"); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client().index(indexRequest).actionGet(); - assertEquals(RestStatus.CREATED, indexResponse.status()); - } - { - SearchRequest searchRequest = new SearchRequest(); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(2, searchResponse.getHits().getTotalHits().value); - } - { - SearchRequest searchRequest = new SearchRequest(""); - searchRequest.indicesOptions(IndicesOptions.fromOptions(true, true, true, true)); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(0, searchResponse.getTotalShards()); - } - { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), - Strings.EMPTY_ARRAY, "", 0, randomBoolean()); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(2, searchResponse.getHits().getTotalHits().value); - } - { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), - Strings.EMPTY_ARRAY, "", 0, randomBoolean()); - searchRequest.indices(""); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(1, searchResponse.getHits().getTotalHits().value); - assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex()); - } - { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), - Strings.EMPTY_ARRAY, "", 0, randomBoolean()); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - RangeQueryBuilder rangeQuery = new RangeQueryBuilder("date"); - rangeQuery.gte("1970-01-01"); - rangeQuery.lt("1982-01-01"); - sourceBuilder.query(rangeQuery); - searchRequest.source(sourceBuilder); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(1, searchResponse.getHits().getTotalHits().value); - assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex()); - } - } - - public void testFinalReduce() { - long nowInMillis = randomLongBetween(0, Long.MAX_VALUE); - { - IndexRequest indexRequest = new IndexRequest("test"); - indexRequest.id("1"); - indexRequest.source("price", 10); - IndexResponse indexResponse = client().index(indexRequest).actionGet(); - assertEquals(RestStatus.CREATED, indexResponse.status()); - } - { - IndexRequest indexRequest = new IndexRequest("test"); - indexRequest.id("2"); - indexRequest.source("price", 100); - IndexResponse indexResponse = client().index(indexRequest).actionGet(); - assertEquals(RestStatus.CREATED, indexResponse.status()); - } - client().admin().indices().prepareRefresh("test").get(); - - SearchRequest originalRequest = new SearchRequest(); - SearchSourceBuilder source = new SearchSourceBuilder(); - source.size(0); - originalRequest.source(source); - TermsAggregationBuilder terms = new TermsAggregationBuilder("terms").userValueTypeHint(ValueType.NUMERIC); - terms.field("price"); - terms.size(1); - source.aggregation(terms); - - { - SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(originalRequest, - Strings.EMPTY_ARRAY, "remote", nowInMillis, true); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(2, searchResponse.getHits().getTotalHits().value); - Aggregations aggregations = searchResponse.getAggregations(); - LongTerms longTerms = aggregations.get("terms"); - assertEquals(1, longTerms.getBuckets().size()); - } - { - SearchRequest searchRequest = SearchRequest.subSearchRequest(originalRequest, - Strings.EMPTY_ARRAY, "remote", nowInMillis, false); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(2, searchResponse.getHits().getTotalHits().value); - Aggregations aggregations = searchResponse.getAggregations(); - LongTerms longTerms = aggregations.get("terms"); - assertEquals(2, longTerms.getBuckets().size()); - } - } -} diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java index 96072fca36b91..4fcf0255f6203 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java @@ -18,6 +18,8 @@ */ package org.elasticsearch.search.aggregations; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -44,6 +46,7 @@ import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; +import static org.hamcrest.Matchers.equalTo; public class InternalAggregationsTests extends ESTestCase { @@ -126,19 +129,32 @@ private static PipelineAggregator.PipelineTree randomPipelineTree() { public void testSerialization() throws Exception { InternalAggregations aggregations = createTestInstance(); - writeToAndReadFrom(aggregations, 0); + writeToAndReadFrom(aggregations, Version.CURRENT, 0); } - private void writeToAndReadFrom(InternalAggregations aggregations, int iteration) throws IOException { - try (BytesStreamOutput out = new BytesStreamOutput()) { - aggregations.writeTo(out); - try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(out.bytes().toBytesRef().bytes), registry)) { - InternalAggregations deserialized = InternalAggregations.readFrom(in); - assertEquals(aggregations.aggregations, deserialized.aggregations); - if (iteration < 2) { - writeToAndReadFrom(deserialized, iteration + 1); - } + public void testSerializedSize() throws Exception { + InternalAggregations aggregations = createTestInstance(); + assertThat(aggregations.getSerializedSize(), + equalTo((long) serialize(aggregations, Version.CURRENT).length)); + } + + private void writeToAndReadFrom(InternalAggregations aggregations, Version version, int iteration) throws IOException { + BytesRef serializedAggs = serialize(aggregations, version); + try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(serializedAggs.bytes), registry)) { + in.setVersion(version); + InternalAggregations deserialized = InternalAggregations.readFrom(in); + assertEquals(aggregations.aggregations, deserialized.aggregations); + if (iteration < 2) { + writeToAndReadFrom(deserialized, version, iteration + 1); } } } + + private BytesRef serialize(InternalAggregations aggs, Version version) throws IOException { + try (BytesStreamOutput out = new BytesStreamOutput()) { + out.setVersion(version); + aggs.writeTo(out); + return out.bytes().toBytesRef(); + } + } } diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index affa437759cb3..5541c679b8f4c 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -1616,7 +1616,7 @@ clusterService, indicesService, threadPool, shardStateAction, mappingUpdatedActi SearchPhaseController searchPhaseController = new SearchPhaseController( writableRegistry(), searchService::aggReduceContextBuilder); actions.put(SearchAction.INSTANCE, - new TransportSearchAction(client, threadPool, transportService, searchService, + new TransportSearchAction(client, threadPool, new NoneCircuitBreakerService(), transportService, searchService, searchTransportService, searchPhaseController, clusterService, actionFilters, indexNameExpressionResolver, namedWriteableRegistry)); actions.put(RestoreSnapshotAction.INSTANCE, diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java index 6ae6655ef6b6b..7bd4553776169 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java @@ -20,7 +20,6 @@ import org.elasticsearch.action.search.SearchTask; import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.client.Client; -import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.InternalAggregation; @@ -391,7 +390,7 @@ protected void onListShards(List shards, List skipped, @Override public void onPartialReduce(List shards, TotalHits totalHits, - DelayableWriteable.Serialized aggregations, int reducePhase) { + InternalAggregations aggregations, int reducePhase) { // best effort to cancel expired tasks checkCancellation(); // The way that the MutableSearchResponse will build the aggs. @@ -401,16 +400,15 @@ public void onPartialReduce(List shards, TotalHits totalHits, reducedAggs = () -> null; } else { /* - * Keep a reference to the serialized form of the partially - * reduced aggs and reduce it on the fly when someone asks + * Keep a reference to the partially reduced aggs and reduce it on the fly when someone asks * for it. It's important that we wait until someone needs * the result so we don't perform the final reduce only to * throw it away. And it is important that we keep the reference - * to the serialized aggregations because SearchPhaseController + * to the aggregations because SearchPhaseController * *already* has that reference so we're not creating more garbage. */ reducedAggs = () -> - InternalAggregations.topLevelReduce(singletonList(aggregations.expand()), aggReduceContextSupplier.get()); + InternalAggregations.topLevelReduce(singletonList(aggregations), aggReduceContextSupplier.get()); } searchResponse.get().updatePartialResponse(shards.size(), totalHits, reducedAggs, reducePhase); } diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java index 01f57a07ee817..d06b47d9cf5d6 100644 --- a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java @@ -16,8 +16,6 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.io.stream.DelayableWriteable; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.shard.ShardId; @@ -25,7 +23,6 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.BucketOrder; -import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.bucket.terms.StringTerms; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -155,56 +152,14 @@ public void onFailure(Exception e) { latch.await(); } - public void testGetResponseFailureDuringReduction() throws InterruptedException { - AsyncSearchTask task = createAsyncSearchTask(); - task.getSearchProgressActionListener().onListShards(Collections.emptyList(), Collections.emptyList(), - SearchResponse.Clusters.EMPTY, false); - InternalAggregations aggs = InternalAggregations.from(Collections.singletonList(new StringTerms("name", BucketOrder.key(true), - BucketOrder.key(true), 1, 1, Collections.emptyMap(), DocValueFormat.RAW, 1, false, 1, Collections.emptyList(), 0))); - //providing an empty named writeable registry will make the expansion fail, hence the delayed reduction will fail too - //causing an exception when executing getResponse as part of the completion listener callback - DelayableWriteable.Serialized serializedAggs = DelayableWriteable.referencing(aggs) - .asSerialized(InternalAggregations::readFrom, new NamedWriteableRegistry(Collections.emptyList())); - task.getSearchProgressActionListener().onPartialReduce(Collections.emptyList(), new TotalHits(0, TotalHits.Relation.EQUAL_TO), - serializedAggs, 1); - AtomicReference response = new AtomicReference<>(); - CountDownLatch latch = new CountDownLatch(1); - task.addCompletionListener(new ActionListener() { - @Override - public void onResponse(AsyncSearchResponse asyncSearchResponse) { - assertTrue(response.compareAndSet(null, asyncSearchResponse)); - latch.countDown(); - } - - @Override - public void onFailure(Exception e) { - throw new AssertionError("onFailure should not be called"); - } - }, TimeValue.timeValueMillis(10L)); - assertTrue(latch.await(1, TimeUnit.SECONDS)); - assertNotNull(response.get().getSearchResponse()); - assertEquals(0, response.get().getSearchResponse().getTotalShards()); - assertEquals(0, response.get().getSearchResponse().getSuccessfulShards()); - assertEquals(0, response.get().getSearchResponse().getFailedShards()); - assertThat(response.get().getFailure(), instanceOf(ElasticsearchException.class)); - assertEquals("Async search: error while reducing partial results", response.get().getFailure().getMessage()); - assertThat(response.get().getFailure().getCause(), instanceOf(IllegalArgumentException.class)); - assertEquals("Unknown NamedWriteable category [" + InternalAggregation.class.getName() + "]", - response.get().getFailure().getCause().getMessage()); - } - public void testWithFailureAndGetResponseFailureDuringReduction() throws InterruptedException { AsyncSearchTask task = createAsyncSearchTask(); task.getSearchProgressActionListener().onListShards(Collections.emptyList(), Collections.emptyList(), SearchResponse.Clusters.EMPTY, false); InternalAggregations aggs = InternalAggregations.from(Collections.singletonList(new StringTerms("name", BucketOrder.key(true), BucketOrder.key(true), 1, 1, Collections.emptyMap(), DocValueFormat.RAW, 1, false, 1, Collections.emptyList(), 0))); - //providing an empty named writeable registry will make the expansion fail, hence the delayed reduction will fail too - //causing an exception when executing getResponse as part of the completion listener callback - DelayableWriteable.Serialized serializedAggs = DelayableWriteable.referencing(aggs) - .asSerialized(InternalAggregations::readFrom, new NamedWriteableRegistry(Collections.emptyList())); task.getSearchProgressActionListener().onPartialReduce(Collections.emptyList(), new TotalHits(0, TotalHits.Relation.EQUAL_TO), - serializedAggs, 1); + aggs, 1); task.getSearchProgressActionListener().onFailure(new CircuitBreakingException("boom", CircuitBreaker.Durability.TRANSIENT)); AtomicReference response = new AtomicReference<>(); CountDownLatch latch = new CountDownLatch(1); @@ -229,9 +184,6 @@ public void onFailure(Exception e) { Exception failure = asyncSearchResponse.getFailure(); assertThat(failure, instanceOf(ElasticsearchException.class)); assertEquals("Async search: error while reducing partial results", failure.getMessage()); - assertThat(failure.getCause(), instanceOf(IllegalArgumentException.class)); - assertEquals("Unknown NamedWriteable category [" + InternalAggregation.class.getName() + - "]", failure.getCause().getMessage()); assertEquals(1, failure.getSuppressed().length); assertThat(failure.getSuppressed()[0], instanceOf(ElasticsearchException.class)); assertEquals("error while executing search", failure.getSuppressed()[0].getMessage());