Skip to content

Commit

Permalink
Ref count fetch phase results (#102324)
Browse files Browse the repository at this point in the history
Making the fetch result ref counted as the second to last step to making `SearchHit` ref counted.
* Fix bug in existing ref counting
* Make search results collector ref counted
* Misc. small adjustments to counting
* Closing/decrementing a bunch of things in tests
  • Loading branch information
original-brownbear authored and gmarouli committed Nov 22, 2023
1 parent 34d041c commit 7b64d9a
Show file tree
Hide file tree
Showing 26 changed files with 1,785 additions and 1,496 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,16 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
this.executor = executor;
this.request = request;
this.task = task;
this.listener = ActionListener.runAfter(listener, this::releaseContext);
this.listener = ActionListener.runAfter(listener, () -> Releasables.close(releasables));
this.nodeIdToConnection = nodeIdToConnection;
this.concreteIndexBoosts = concreteIndexBoosts;
this.clusterStateVersion = clusterState.version();
this.minTransportVersion = clusterState.getMinTransportVersion();
this.aliasFilter = aliasFilter;
this.results = resultConsumer;
// register the release of the query consumer to free up the circuit breaker memory
// at the end of the search
addReleasable(resultConsumer::decRef);
this.clusters = clusters;
}

Expand All @@ -189,10 +192,6 @@ public void addReleasable(Releasable releasable) {
releasables.add(releasable);
}

public void releaseContext() {
Releasables.close(releasables);
}

/**
* Builds how long it took to execute the search.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
package org.elasticsearch.action.search;

import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.transport.LeakTracker;

import java.util.stream.Stream;

Expand All @@ -19,6 +22,8 @@
class ArraySearchPhaseResults<Result extends SearchPhaseResult> extends SearchPhaseResults<Result> {
final AtomicArray<Result> results;

private final RefCounted refCounted = LeakTracker.wrap(AbstractRefCounted.of(this::doClose));

ArraySearchPhaseResults(int size) {
super(size);
this.results = new AtomicArray<>(size);
Expand All @@ -32,9 +37,16 @@ Stream<Result> getSuccessfulResults() {
void consumeResult(Result result, Runnable next) {
assert results.get(result.getShardIndex()) == null : "shardIndex: " + result.getShardIndex() + " is already set";
results.set(result.getShardIndex(), result);
result.incRef();
next.run();
}

protected void doClose() {
for (Result result : getAtomicArray().asList()) {
result.decRef();
}
}

boolean hasResult(int shardIndex) {
return results.get(shardIndex) != null;
}
Expand All @@ -43,4 +55,24 @@ boolean hasResult(int shardIndex) {
AtomicArray<Result> getAtomicArray() {
return results;
}

@Override
public void incRef() {
refCounted.incRef();
}

@Override
public boolean tryIncRef() {
return refCounted.tryIncRef();
}

@Override
public boolean decRef() {
return refCounted.decRef();
}

@Override
public boolean hasReferences() {
return refCounted.hasReferences();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,26 @@ synchronized FixedBitSet getPossibleMatches() {
Stream<CanMatchShardResponse> getSuccessfulResults() {
return Stream.empty();
}

@Override
public void incRef() {

}

@Override
public boolean tryIncRef() {
return false;
}

@Override
public boolean decRef() {
return false;
}

@Override
public boolean hasReferences() {
return false;
}
}

private GroupShardsIterator<SearchShardIterator> getIterator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ final class CountedCollector<R extends SearchPhaseResult> {

CountedCollector(ArraySearchPhaseResults<R> resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
this.resultConsumer = resultConsumer;
resultConsumer.incRef();
this.counter = new CountDown(expectedOps);
this.onFinish = onFinish;
this.context = context;
Expand All @@ -37,7 +38,11 @@ final class CountedCollector<R extends SearchPhaseResult> {
void countDown() {
assert counter.isCountedDown() == false : "more operations executed than specified";
if (counter.countDown()) {
onFinish.run();
try {
onFinish.run();
} finally {
resultConsumer.decRef();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ final class DfsQueryPhase extends SearchPhase {

// register the release of the query consumer to free up the circuit breaker memory
// at the end of the search
context.addReleasable(queryResult);
context.addReleasable(queryResult::decRef);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ final class FetchSearchPhase extends SearchPhase {
);
}
this.fetchResults = new ArraySearchPhaseResults<>(resultConsumer.getNumShards());
context.addReleasable(fetchResults::decRef);
this.queryResults = resultConsumer.getAtomicArray();
this.aggregatedDfs = aggregatedDfs;
this.nextPhaseFactory = nextPhaseFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
* needed to reduce the aggregations is estimated and a {@link CircuitBreakingException} is thrown if it
* exceeds the maximum memory allowed in this breaker.
*/
public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> implements Releasable {
public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
private static final Logger logger = LogManager.getLogger(QueryPhaseResultConsumer.class);

private final Executor executor;
Expand Down Expand Up @@ -104,8 +104,12 @@ public QueryPhaseResultConsumer(
}

@Override
public void close() {
pendingMerges.close();
protected void doClose() {
try {
super.doClose();
} finally {
pendingMerges.close();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
clusters
);
this.queryPhaseResultConsumer = queryPhaseResultConsumer;
addReleasable(queryPhaseResultConsumer::decRef);
this.progressListener = task.getProgressListener();
// don't build the SearchShard list (can be expensive) if the SearchProgressListener won't use it
if (progressListener != SearchProgressListener.NOOP) {
Expand All @@ -90,7 +91,7 @@ protected SearchPhase getNextPhase(final SearchPhaseResults<DfsSearchResult> res
final List<DfsSearchResult> dfsSearchResults = results.getAtomicArray().asList();
final AggregatedDfs aggregatedDfs = SearchPhaseController.aggregateDfs(dfsSearchResults);
final List<DfsKnnResults> mergedKnnResults = SearchPhaseController.mergeKnnResults(getRequest(), dfsSearchResults);

queryPhaseResultConsumer.incRef();
return new DfsQueryPhase(
dfsSearchResults,
aggregatedDfs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
package org.elasticsearch.action.search;

import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.search.SearchPhaseResult;

import java.util.stream.Stream;

/**
* This class acts as a basic result collection that can be extended to do on-the-fly reduction or result processing
*/
abstract class SearchPhaseResults<Result extends SearchPhaseResult> {
abstract class SearchPhaseResults<Result extends SearchPhaseResult> implements RefCounted {
private final int numShards;

SearchPhaseResults(int numShards) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
this.trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo();
this.progressListener = task.getProgressListener();

// register the release of the query consumer to free up the circuit breaker memory
// at the end of the search
addReleasable(resultConsumer);
// don't build the SearchShard list (can be expensive) if the SearchProgressListener won't use it
if (progressListener != SearchProgressListener.NOOP) {
notifyListShards(progressListener, clusters, request.source());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ public void run() {
@Override
protected void innerOnResponse(FetchSearchResult response) {
fetchResults.setOnce(response.getShardIndex(), response);
if (counter.countDown()) {
sendResponse(reducedQueryPhase, fetchResults);
}
response.incRef();
consumeResponse(counter, reducedQueryPhase);
}

@Override
Expand All @@ -124,13 +123,20 @@ public void onFailure(Exception t) {
} else {
// the counter is set to the total size of docIdsToLoad
// which can have null values so we have to count them down too
if (counter.countDown()) {
sendResponse(reducedQueryPhase, fetchResults);
}
consumeResponse(counter, reducedQueryPhase);
}
}
}
};
}

private void consumeResponse(CountDown counter, SearchPhaseController.ReducedQueryPhase reducedQueryPhase) {
if (counter.countDown()) {
sendResponse(reducedQueryPhase, fetchResults);
for (FetchSearchResult fetchSearchResult : fetchResults.asList()) {
fetchSearchResult.decRef();
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ final class DefaultSearchContext extends SearchContext {
@Override
public void addFetchResult() {
this.fetchResult = new FetchSearchResult(this.readerContext.id(), this.shardTarget);
addReleasable(fetchResult::decRef);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ private QueryFetchSearchResult executeFetchPhase(ReaderContext reader, SearchCon
executor.success();
}
// This will incRef the QuerySearchResult when it gets created
return new QueryFetchSearchResult(context.queryResult(), context.fetchResult());
return QueryFetchSearchResult.of(context.queryResult(), context.fetchResult());
}

public void executeQueryPhase(
Expand Down Expand Up @@ -865,7 +865,9 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A
}
executor.success();
}
return searchContext.fetchResult();
var fetchResult = searchContext.fetchResult();
fetchResult.incRef();
return fetchResult;
} catch (Exception e) {
assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
// we handle the failure in the failure listener below
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.ShardSearchContextId;
import org.elasticsearch.search.profile.ProfileResult;
import org.elasticsearch.transport.LeakTracker;

import java.io.IOException;

Expand All @@ -28,6 +31,8 @@ public final class FetchSearchResult extends SearchPhaseResult {

private ProfileResult profileResult;

private final RefCounted refCounted = LeakTracker.wrap(AbstractRefCounted.of(() -> hits = null));

public FetchSearchResult() {}

public FetchSearchResult(ShardSearchContextId id, SearchShardTarget shardTarget) {
Expand Down Expand Up @@ -90,4 +95,24 @@ public int counterGetAndIncrement() {
public ProfileResult profileResult() {
return profileResult;
}

@Override
public void incRef() {
refCounted.incRef();
}

@Override
public boolean tryIncRef() {
return refCounted.tryIncRef();
}

@Override
public boolean decRef() {
return refCounted.decRef();
}

@Override
public boolean hasReferences() {
return refCounted.hasReferences();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,21 @@ public final class QueryFetchSearchResult extends SearchPhaseResult {
private final FetchSearchResult fetchResult;
private final RefCounted refCounted;

public static QueryFetchSearchResult of(QuerySearchResult queryResult, FetchSearchResult fetchResult) {
// We're acquiring a copy, we should incRef it
queryResult.incRef();
fetchResult.incRef();
return new QueryFetchSearchResult(queryResult, fetchResult);
}

public QueryFetchSearchResult(StreamInput in) throws IOException {
// These get a ref count of 1 when we create them, so we don't need to incRef here
this(new QuerySearchResult(in), new FetchSearchResult(in));
}

public QueryFetchSearchResult(QuerySearchResult queryResult, FetchSearchResult fetchResult) {
private QueryFetchSearchResult(QuerySearchResult queryResult, FetchSearchResult fetchResult) {
this.queryResult = queryResult;
this.fetchResult = fetchResult;
// We're acquiring a copy, we should incRef it
this.queryResult.incRef();
this.fetchResult.incRef();
refCounted = LeakTracker.wrap(AbstractRefCounted.of(() -> {
queryResult.decRef();
fetchResult.decRef();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public SubSearchContext(SearchContext context) {
super(context);
context.addReleasable(this);
this.fetchSearchResult = new FetchSearchResult();
addReleasable(fetchSearchResult::decRef);
this.querySearchResult = new QuerySearchResult();
}

Expand Down

0 comments on commit 7b64d9a

Please sign in to comment.