Skip to content

Commit

Permalink
Make SearchResponseSections RefCounted (#104060)
Browse files Browse the repository at this point in the history
We want to make `SearchHits` ref-counted so we need this thing referring to the hits
ref-counted.
  • Loading branch information
original-brownbear committed Jan 9, 2024
1 parent f14d87b commit 834d1a8
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchContextMissingException;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.builder.PointInTimeBuilder;
Expand Down Expand Up @@ -212,9 +211,7 @@ public final void start() {
// total hits is null in the response if the tracking of total hits is disabled
boolean withTotalHits = trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED;
sendSearchResponse(
withTotalHits
? new SearchResponseSections(SearchHits.EMPTY_WITH_TOTAL_HITS, null, null, false, null, null, 1)
: new SearchResponseSections(SearchHits.EMPTY_WITHOUT_TOTAL_HITS, null, null, false, null, null, 1),
withTotalHits ? SearchResponseSections.EMPTY_WITH_TOTAL_HITS : SearchResponseSections.EMPTY_WITHOUT_TOTAL_HITS,
new AtomicArray<>(0)
);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,11 @@ final class FetchSearchPhase extends SearchPhase {
private final AggregatedDfs aggregatedDfs;

FetchSearchPhase(SearchPhaseResults<SearchPhaseResult> resultConsumer, AggregatedDfs aggregatedDfs, SearchPhaseContext context) {
this(
resultConsumer,
aggregatedDfs,
context,
(response, queryPhaseResults) -> new ExpandSearchPhase(
context,
response.hits,
() -> new FetchLookupFieldsPhase(context, response, queryPhaseResults)
)
);
this(resultConsumer, aggregatedDfs, context, (response, queryPhaseResults) -> {
response.mustIncRef();
context.addReleasable(response::decRef);
return new ExpandSearchPhase(context, response.hits, () -> new FetchLookupFieldsPhase(context, response, queryPhaseResults));
});
}

FetchSearchPhase(
Expand Down Expand Up @@ -229,12 +224,11 @@ private void moveToNextPhase(
SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
AtomicArray<? extends SearchPhaseResult> fetchResultsArr
) {
context.executeNextPhase(
this,
nextPhaseFactory.apply(
SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr),
queryResults
)
);
var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr);
try {
context.executeNextPhase(this, nextPhaseFactory.apply(resp, queryResults));
} finally {
resp.decRef();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ public static SearchResponseSections merge(
AtomicArray<? extends SearchPhaseResult> fetchResultsArray
) {
if (reducedQueryPhase.isEmptyResult) {
return new SearchResponseSections(SearchHits.EMPTY_WITH_TOTAL_HITS, null, null, false, null, null, 1);
return SearchResponseSections.EMPTY_WITH_TOTAL_HITS;
}
ScoreDoc[] sortedDocs = reducedQueryPhase.sortedTopDocs.scoreDocs;
var fetchResults = fetchResultsArray.asList();
Expand Down Expand Up @@ -465,7 +465,7 @@ private static SearchHits getHits(
}
}
return new SearchHits(
hits.toArray(new SearchHit[0]),
hits.toArray(SearchHits.EMPTY),
reducedQueryPhase.totalHits,
reducedQueryPhase.maxScore,
sortedTopDocs.sortFields,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@

package org.elasticsearch.action.search;

import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.profile.SearchProfileResults;
import org.elasticsearch.search.profile.SearchProfileShardResult;
import org.elasticsearch.search.suggest.Suggest;
import org.elasticsearch.transport.LeakTracker;

import java.util.Collections;
import java.util.Map;
Expand All @@ -21,8 +24,26 @@
* Holds some sections that a search response is composed of (hits, aggs, suggestions etc.) during some steps of the search response
* building.
*/
public class SearchResponseSections {
public class SearchResponseSections implements RefCounted {

public static final SearchResponseSections EMPTY_WITH_TOTAL_HITS = new SearchResponseSections(
SearchHits.EMPTY_WITH_TOTAL_HITS,
null,
null,
false,
null,
null,
1
);
public static final SearchResponseSections EMPTY_WITHOUT_TOTAL_HITS = new SearchResponseSections(
SearchHits.EMPTY_WITHOUT_TOTAL_HITS,
null,
null,
false,
null,
null,
1
);
protected final SearchHits hits;
protected final Aggregations aggregations;
protected final Suggest suggest;
Expand All @@ -31,6 +52,8 @@ public class SearchResponseSections {
protected final Boolean terminatedEarly;
protected final int numReducePhases;

private final RefCounted refCounted;

public SearchResponseSections(
SearchHits hits,
Aggregations aggregations,
Expand All @@ -47,6 +70,12 @@ public SearchResponseSections(
this.timedOut = timedOut;
this.terminatedEarly = terminatedEarly;
this.numReducePhases = numReducePhases;
refCounted = hits.getHits().length > 0 ? LeakTracker.wrap(new AbstractRefCounted() {
@Override
protected void closeInternal() {
// TODO: noop until hits are ref counted
}
}) : ALWAYS_REFERENCED;
}

public final boolean timedOut() {
Expand Down Expand Up @@ -88,4 +117,24 @@ public final Map<String, SearchProfileShardResult> profile() {
}
return profileResults.getShardResults();
}

@Override
public void incRef() {
refCounted.incRef();
}

@Override
public boolean tryIncRef() {
return refCounted.tryIncRef();
}

@Override
public boolean decRef() {
return refCounted.decRef();
}

@Override
public boolean hasReferences() {
return refCounted.hasReferences();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,25 @@ protected final void sendResponse(
if (request.scroll() != null) {
scrollId = request.scrollId();
}
ActionListener.respondAndRelease(
listener,
new SearchResponse(
SearchPhaseController.merge(true, queryPhase, fetchResults),
scrollId,
this.scrollId.getContext().length,
successfulOps.get(),
0,
buildTookInMillis(),
buildShardFailures(),
SearchResponse.Clusters.EMPTY,
null
)
);
var sections = SearchPhaseController.merge(true, queryPhase, fetchResults);
try {
ActionListener.respondAndRelease(
listener,
new SearchResponse(
sections,
scrollId,
this.scrollId.getContext().length,
successfulOps.get(),
0,
buildTookInMillis(),
buildShardFailures(),
SearchResponse.Clusters.EMPTY,
null
)
);
} finally {
sections.decRef();
}
} catch (Exception e) {
listener.onFailure(new ReduceSearchPhaseException("fetch", "inner finish failed", e, buildShardFailures()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.SearchShardTarget;
Expand Down Expand Up @@ -252,10 +251,7 @@ public void onFailure(Exception e) {

@Override
protected void doRun() {
sendSearchResponse(
new SearchResponseSections(SearchHits.EMPTY_WITH_TOTAL_HITS, null, null, false, null, null, 1),
results.getAtomicArray()
);
sendSearchResponse(SearchResponseSections.EMPTY_WITH_TOTAL_HITS, results.getAtomicArray());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.elasticsearch.index.Index;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.AliasFilter;
Expand Down Expand Up @@ -194,10 +193,7 @@ public void testSendSearchResponseDisallowPartialFailures() {
new IllegalArgumentException()
);
}
action.sendSearchResponse(
new SearchResponseSections(SearchHits.EMPTY_WITH_TOTAL_HITS, null, null, false, null, null, 1),
phaseResults.results
);
action.sendSearchResponse(SearchResponseSections.EMPTY_WITH_TOTAL_HITS, phaseResults.results);
assertThat(exception.get(), instanceOf(SearchPhaseExecutionException.class));
SearchPhaseExecutionException searchPhaseExecutionException = (SearchPhaseExecutionException) exception.get();
assertEquals(0, searchPhaseExecutionException.getSuppressed().length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,12 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL

List<MultiSearchResponse.Item> mSearchResponses = new ArrayList<>(numInnerHits);
for (int innerHitNum = 0; innerHitNum < numInnerHits; innerHitNum++) {
mockSearchPhaseContext.sendSearchResponse(
new SearchResponseSections(collapsedHits.get(innerHitNum), null, null, false, null, null, 1),
null
);
var sections = new SearchResponseSections(collapsedHits.get(innerHitNum), null, null, false, null, null, 1);
try {
mockSearchPhaseContext.sendSearchResponse(sections, null);
} finally {
sections.decRef();
}
mSearchResponses.add(new MultiSearchResponse.Item(mockSearchPhaseContext.searchResponse.get(), null));
}

Expand All @@ -111,7 +113,12 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL
ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") {
@Override
public void run() {
mockSearchPhaseContext.sendSearchResponse(new SearchResponseSections(hits, null, null, false, null, null, 1), null);
var sections = new SearchResponseSections(hits, null, null, false, null, null, 1);
try {
mockSearchPhaseContext.sendSearchResponse(sections, null);
} finally {
sections.decRef();
}
}
});

Expand Down Expand Up @@ -194,7 +201,12 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL
ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") {
@Override
public void run() {
mockSearchPhaseContext.sendSearchResponse(new SearchResponseSections(hits, null, null, false, null, null, 1), null);
var sections = new SearchResponseSections(hits, null, null, false, null, null, 1);
try {
mockSearchPhaseContext.sendSearchResponse(sections, null);
} finally {
sections.decRef();
}
}
});
phase.run();
Expand Down Expand Up @@ -222,7 +234,12 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL
ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") {
@Override
public void run() {
mockSearchPhaseContext.sendSearchResponse(new SearchResponseSections(hits, null, null, false, null, null, 1), null);
var sections = new SearchResponseSections(hits, null, null, false, null, null, 1);
try {
mockSearchPhaseContext.sendSearchResponse(sections, null);
} finally {
sections.decRef();
}
}
});
phase.run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL
searchHits[i] = SearchHitTests.createTestItem(randomBoolean(), randomBoolean());
}
SearchHits hits = new SearchHits(searchHits, new TotalHits(numHits, TotalHits.Relation.EQUAL_TO), 1.0f);
FetchLookupFieldsPhase phase = new FetchLookupFieldsPhase(
searchPhaseContext,
new SearchResponseSections(hits, null, null, false, null, null, 1),
null
);
phase.run();
var sections = new SearchResponseSections(hits, null, null, false, null, null, 1);
try {
FetchLookupFieldsPhase phase = new FetchLookupFieldsPhase(searchPhaseContext, sections, null);
phase.run();
} finally {
sections.decRef();
}
searchPhaseContext.assertNoFailure();
assertNotNull(searchPhaseContext.searchResponse.get());
} finally {
Expand Down Expand Up @@ -185,12 +186,13 @@ void sendExecuteMultiSearch(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
1.0f
);
FetchLookupFieldsPhase phase = new FetchLookupFieldsPhase(
searchPhaseContext,
new SearchResponseSections(searchHits, null, null, false, null, null, 1),
null
);
phase.run();
var sections = new SearchResponseSections(searchHits, null, null, false, null, null, 1);
try {
FetchLookupFieldsPhase phase = new FetchLookupFieldsPhase(searchPhaseContext, sections, null);
phase.run();
} finally {
sections.decRef();
}
assertTrue(requestSent.get());
searchPhaseContext.assertNoFailure();
assertNotNull(searchPhaseContext.searchResponse.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,7 @@ public void run() {
assertTrue(searchPhaseDidRun.get());
assertEquals(shardsIter.size() - numSkipped, numRequests.get());

asyncAction.sendSearchResponse(
new SearchResponseSections(SearchHits.EMPTY_WITH_TOTAL_HITS, null, null, false, null, null, 1),
null
);
asyncAction.sendSearchResponse(SearchResponseSections.EMPTY_WITH_TOTAL_HITS, null);
assertNotNull(searchResponse.get());
assertEquals(0, searchResponse.get().getFailedShards());
assertEquals(numSkipped, searchResponse.get().getSkippedShards());
Expand Down Expand Up @@ -698,10 +695,7 @@ public void run() {
assertThat(latch.await(4, TimeUnit.SECONDS), equalTo(true));
assertThat(searchPhaseDidRun.get(), equalTo(true));

asyncAction.sendSearchResponse(
new SearchResponseSections(SearchHits.EMPTY_WITH_TOTAL_HITS, null, null, false, null, null, 1),
null
);
asyncAction.sendSearchResponse(SearchResponseSections.EMPTY_WITH_TOTAL_HITS, null);
assertNotNull(searchResponse.get());
assertThat(searchResponse.get().getSkippedShards(), equalTo(numUnavailableSkippedShards));
assertThat(searchResponse.get().getFailedShards(), equalTo(0));
Expand Down

0 comments on commit 834d1a8

Please sign in to comment.