Skip to content

Commit

Permalink
[7.15] Possible source of leaked delayable writables (#80166) (#80293)
Browse files Browse the repository at this point in the history
* Possible source of leaked delayable writables  (#80166)

In a couple of error paths, it's possible that QueryPhaseResultsConsumer#PendingMerges might not have consumed all the aggregations it's trying to merge. It is never the less important to release those aggregations so that we don't leak memory or hold references to them. This PR achieves that by using the Releasables.close() mechanism, which will execute each close action, even if earlier actions had exceptions. This ensures that all of the aggregations get released and that the circuit breaker gets cleaned up.

Co-authored-by: Henning Andersen <henning.andersen@elastic.co>
Co-authored-by: David Turner <david.turner@elastic.co>

* fix unused imports

Co-authored-by: Henning Andersen <henning.andersen@elastic.co>
Co-authored-by: David Turner <david.turner@elastic.co>
  • Loading branch information
3 people committed Nov 3, 2021
1 parent e981063 commit ccba167
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize;
import static org.elasticsearch.action.search.SearchPhaseController.mergeTopDocs;
Expand Down Expand Up @@ -71,14 +72,16 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
* Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results
* as shard results are consumed.
*/
public QueryPhaseResultConsumer(SearchRequest request,
Executor executor,
CircuitBreaker circuitBreaker,
SearchPhaseController controller,
SearchProgressListener progressListener,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
Consumer<Exception> onPartialMergeFailure) {
public QueryPhaseResultConsumer(
SearchRequest request,
Executor executor,
CircuitBreaker circuitBreaker,
SearchPhaseController controller,
SearchProgressListener progressListener,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
Consumer<Exception> onPartialMergeFailure
) {
super(expectedResultSize);
this.executor = executor;
this.circuitBreaker = circuitBreaker;
Expand All @@ -93,7 +96,7 @@ public QueryPhaseResultConsumer(SearchRequest request,
SearchSourceBuilder source = request.source();
this.hasTopDocs = source == null || source.size() != 0;
this.hasAggs = source != null && source.aggregations() != null;
int batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize;
int batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize;
this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo());
}

Expand Down Expand Up @@ -128,28 +131,41 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
// Add an estimate of the final reduce size
breakerSize = pendingMerges.addEstimateAndMaybeBreak(pendingMerges.estimateRamBytesUsedForReduce(breakerSize));
}
SearchPhaseController.ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), aggsList,
topDocsList, topDocsStats, pendingMerges.numReducePhases, false, aggReduceContextBuilder, performFinalReduce);
SearchPhaseController.ReducedQueryPhase reducePhase = controller.reducedQueryPhase(
results.asList(),
aggsList,
topDocsList,
topDocsStats,
pendingMerges.numReducePhases,
false,
aggReduceContextBuilder,
performFinalReduce
);
if (hasAggs
// reduced aggregations can be null if all shards failed
&& reducePhase.aggregations != null) {
// reduced aggregations can be null if all shards failed
&& reducePhase.aggregations != null) {

// Update the circuit breaker to replace the estimation with the serialized size of the newly reduced result
long finalSize = DelayableWriteable.getSerializedSize(reducePhase.aggregations) - breakerSize;
pendingMerges.addWithoutBreaking(finalSize);
logger.trace("aggs final reduction [{}] max [{}]",
pendingMerges.aggsCurrentBufferSize, pendingMerges.maxAggsCurrentBufferSize);
logger.trace("aggs final reduction [{}] max [{}]", pendingMerges.aggsCurrentBufferSize, pendingMerges.maxAggsCurrentBufferSize);
}
progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()),
reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
progressListener.notifyFinalReduce(
SearchProgressListener.buildSearchShards(results.asList()),
reducePhase.totalHits,
reducePhase.aggregations,
reducePhase.numReducePhases
);
return reducePhase;
}

private MergeResult partialReduce(QuerySearchResult[] toConsume,
List<SearchShard> emptyResults,
TopDocsStats topDocsStats,
MergeResult lastMerge,
int numReducePhases) {
private MergeResult partialReduce(
QuerySearchResult[] toConsume,
List<SearchShard> emptyResults,
TopDocsStats topDocsStats,
MergeResult lastMerge,
int numReducePhases
) {
// ensure consistent ordering
Arrays.sort(toConsume, Comparator.comparingInt(QuerySearchResult::getShardIndex));

Expand All @@ -168,9 +184,12 @@ private MergeResult partialReduce(QuerySearchResult[] toConsume,
setShardIndex(topDocs.topDocs, result.getShardIndex());
topDocsList.add(topDocs.topDocs);
}
newTopDocs = mergeTopDocs(topDocsList,
newTopDocs = mergeTopDocs(
topDocsList,
// we have to merge here in the same way we collect on a shard
topNSize, 0);
topNSize,
0
);
} else {
newTopDocs = null;
}
Expand Down Expand Up @@ -233,14 +252,24 @@ private class PendingMerges implements Releasable {

@Override
public synchronized void close() {
assert hasPendingMerges() == false : "cannot close with partial reduce in-flight";
if (hasFailure()) {
assert circuitBreakerBytes == 0;
return;
} else {
assert circuitBreakerBytes >= 0;
}

List<Releasable> toRelease = new ArrayList<>(buffer.stream().<Releasable>map(b -> b::releaseAggs).collect(Collectors.toList()));
toRelease.add(() -> {
circuitBreaker.addWithoutBreaking(-circuitBreakerBytes);
circuitBreakerBytes = 0;
});

Releasables.close(toRelease);

if (hasPendingMerges()) {
// This is a theoretically unreachable exception.
throw new IllegalStateException("Attempted to close with partial reduce in-flight");
}
assert circuitBreakerBytes >= 0;
circuitBreaker.addWithoutBreaking(-circuitBreakerBytes);
circuitBreakerBytes = 0;
}

synchronized Exception getFailure() {
Expand Down Expand Up @@ -378,8 +407,12 @@ private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedS
// and replace the estimation with the serialized size of the newly reduced result.
long newSize = mergeResult.estimatedSize - estimatedSize;
addWithoutBreaking(newSize);
logger.trace("aggs partial reduction [{}->{}] max [{}]",
estimatedSize, mergeResult.estimatedSize, maxAggsCurrentBufferSize);
logger.trace(
"aggs partial reduction [{}->{}] max [{}]",
estimatedSize,
mergeResult.estimatedSize,
maxAggsCurrentBufferSize
);
}
task.consumeListener();
}
Expand All @@ -388,9 +421,7 @@ private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedS
private void tryExecuteNext() {
final MergeTask task;
synchronized (this) {
if (queue.isEmpty()
|| hasFailure()
|| runningTask.get() != null) {
if (queue.isEmpty() || hasFailure() || runningTask.get() != null) {
return;
}
task = queue.poll();
Expand All @@ -411,7 +442,7 @@ protected void doRun() {
long estimatedMergeSize = estimateRamBytesUsedForReduce(estimatedTotalSize);
addEstimateAndMaybeBreak(estimatedMergeSize);
estimatedTotalSize += estimatedMergeSize;
++ numReducePhases;
++numReducePhases;
newMerge = partialReduce(toConsume, task.emptyResults, topDocsStats, thisMergeResult, numReducePhases);
} catch (Exception t) {
for (QuerySearchResult result : toConsume) {
Expand Down Expand Up @@ -475,8 +506,12 @@ private static class MergeResult {
private final InternalAggregations reducedAggs;
private final long estimatedSize;

private MergeResult(List<SearchShard> processedShards, TopDocs reducedTopDocs,
InternalAggregations reducedAggs, long estimatedSize) {
private MergeResult(
List<SearchShard> processedShards,
TopDocs reducedTopDocs,
InternalAggregations reducedAggs,
long estimatedSize
) {
this.processedShards = processedShards;
this.reducedTopDocs = reducedTopDocs;
this.reducedAggs = reducedAggs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,9 @@ private void readFromWithId(ShardSearchContextId id, StreamInput in, boolean del
@Override
public void writeTo(StreamOutput out) throws IOException {
// we do not know that it is being sent over transport, but this at least protects all writes from happening, including sending.
assert aggregations == null || aggregations.isSerialized() == false : "cannot send serialized version since it will leak";
if (aggregations != null && aggregations.isSerialized()) {
throw new IllegalStateException("cannot send serialized version since it will leak");
}
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeBoolean(isNull);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import java.util.concurrent.atomic.AtomicInteger;

import static org.elasticsearch.search.SearchService.MAX_ASYNC_SEARCH_RESPONSE_SIZE_SETTING;
import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -462,6 +463,27 @@ public void testSearchPhaseFailure() throws Exception {
ensureTaskNotRunning(response.getId());
}

public void testSearchPhaseFailureLeak() throws Exception {
SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest(indexName);
request.setKeepOnCompletion(true);
request.setWaitForCompletionTimeout(TimeValue.timeValueMinutes(10));
request.getSearchRequest().allowPartialSearchResults(false);
request.getSearchRequest()
.source(
new SearchSourceBuilder().query(
new ThrowingQueryBuilder(randomLong(), new AlreadyClosedException("boom"), between(0, numShards - 1))
)
);
request.getSearchRequest().source().aggregation(terms("f").field("f").size(between(1, 10)));

AsyncSearchResponse response = submitAsyncSearch(request);
assertFalse(response.isRunning());
assertTrue(response.isPartial());
assertThat(response.status(), equalTo(RestStatus.SERVICE_UNAVAILABLE));
assertNotNull(response.getFailure());
ensureTaskNotRunning(response.getId());
}

public void testFinalResponseLargerMaxSize() throws Exception {
SearchSourceBuilder source = new SearchSourceBuilder()
.query(new MatchAllQueryBuilder())
Expand Down

0 comments on commit ccba167

Please sign in to comment.