Skip to content

Commit

Permalink
Add support for Reciprocal Rank Fusion to the search API (#93396)
Browse files Browse the repository at this point in the history
This change at a high level adds global ranking on the coordinating node at the end of query reduction 
prior to the fetch phase. Individual rank methods are defined in plugins.

The first rank plugin added as part of this change is reciprocal rank fusion (RRF). RRF uses a relatively 
simple formula for merging 1...n results sets together with sum(1/(k+d)) where k is a ranking constant 
and d is a document's scored position within a result set from a query.
  • Loading branch information
jdconrad committed Apr 24, 2023
1 parent 3bd0032 commit 5314e5d
Show file tree
Hide file tree
Showing 59 changed files with 4,141 additions and 102 deletions.
48 changes: 48 additions & 0 deletions docs/changelog/93396.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
pr: 93396
summary: Add support for Reciprocal Rank Fusion to the search API
area: Ranking
type: feature
issues: []
highlight:
title: Add support for Reciprocal Rank Fusion (RRF) to the search API
body: |-
This change adds reciprocal rank fusion (RRF) which follows the basic formula
for merging `1...n` sets of results sets together with `sum(1/(k+d))` where `k`
is a ranking constant and `d` is a document's scored position within a result set
from a query. The main advantage of ranking this way is the scores for the sets
of results do not have to be normalized relative to each other because RRF only
relies upon positions within each result set.
The API for this change adds a `rank` top-level element to the search
endpoint. An example:
[source,Java]
----
{
"query": {
"match": {
"product": {
"query": "brown shoes"
}
}
},
"knn": {
"field": "product-vector",
"query_vector": [54, 10, -2],
"k": 20,
"num_candidates": 75
},
"rank": {
"rrf": {
"window_size": 100,
"rank_constant": 20
}
}
}
----
The above example will execute the search query and the knn search separately.
It will preserve separate result sets up to the point where the queries are
ranked on the coordinating node using RRF.
notable: true
2 changes: 2 additions & 0 deletions docs/reference/search.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ include::search/point-in-time-api.asciidoc[]

include::search/knn-search.asciidoc[]

include::search/rrf.asciidoc[]

include::search/scroll-api.asciidoc[]

include::search/clear-scroll-api.asciidoc[]
Expand Down
6 changes: 6 additions & 0 deletions docs/reference/search/rrf.asciidoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[[rrf]]
=== Reciprocal rank fusion

Reciprocal Rank Fusion (RRF) is a simple method to combine document result sets
from multiple queries where the queries' document scores may be unrelated.

12 changes: 12 additions & 0 deletions docs/reference/search/search.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,18 @@ Period of time used to extend the life of the PIT.
(Optional, <<query-dsl,query object>>) Defines the search definition using the
<<query-dsl,Query DSL>>.

[[request-body-rank]]
`rank`::
(Optional, object) Defines a method for combining and ranking `1` standard query
with `1..n` knn searches or no standard query with `2..n` knn searches.
+
.Ranking methods
[%collapsible%open]
====
`rrf`::
(Optional, object) Sets the ranking method to <<rrf, reciprocal rank fusion (RRF)>>.
====

[[search-api-body-runtime]]
// tag::runtime-mappings-def[]
`runtime_mappings`::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
public class BulkByScrollParallelizationHelperTests extends ESTestCase {
public void testSliceIntoSubRequests() throws IOException {
SearchRequest searchRequest = randomSearchRequest(
() -> randomSearchSourceBuilder(() -> null, () -> null, () -> null, () -> emptyList(), () -> null, () -> null)
() -> randomSearchSourceBuilder(() -> null, () -> null, () -> null, () -> null, () -> emptyList(), () -> null, () -> null)
);
if (searchRequest.source() != null) {
// Clear the slice builder if there is one set. We can't call sliceIntoSubRequests if it is.
Expand Down
1 change: 1 addition & 0 deletions server/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@
exports org.elasticsearch.search.profile.dfs;
exports org.elasticsearch.search.profile.query;
exports org.elasticsearch.search.query;
exports org.elasticsearch.search.rank;
exports org.elasticsearch.search.rescore;
exports org.elasticsearch.search.runtime;
exports org.elasticsearch.search.searchafter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -130,46 +131,82 @@ public void onFailure(Exception exception) {
}
}

private ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
// package private for testing
ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
SearchSourceBuilder source = request.source();
if (source == null || source.knnSearch().isEmpty()) {
return request;
}

List<ScoreDoc> scoreDocs = new ArrayList<>();
for (DfsKnnResults dfsKnnResults : knnResults) {
for (ScoreDoc scoreDoc : dfsKnnResults.scoreDocs()) {
if (scoreDoc.shardIndex == request.shardRequestIndex()) {
scoreDocs.add(scoreDoc);
if (source.rankBuilder() == null) {
// this path will use linear combination if there are
// multiple knn queries to combine all knn queries into
// a single query per shard

List<ScoreDoc> scoreDocs = new ArrayList<>();
for (DfsKnnResults dfsKnnResults : knnResults) {
for (ScoreDoc scoreDoc : dfsKnnResults.scoreDocs()) {
if (scoreDoc.shardIndex == request.shardRequestIndex()) {
scoreDocs.add(scoreDoc);
}
}
}
}
scoreDocs.sort(Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
// It is possible that the different results refer to the same doc.
for (int i = 0; i < scoreDocs.size() - 1; i++) {
ScoreDoc scoreDoc = scoreDocs.get(i);
int j = i + 1;
for (; j < scoreDocs.size(); j++) {
ScoreDoc otherScoreDoc = scoreDocs.get(j);
if (otherScoreDoc.doc != scoreDoc.doc) {
break;
scoreDocs.sort(Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
// It is possible that the different results refer to the same doc.
for (int i = 0; i < scoreDocs.size() - 1; i++) {
ScoreDoc scoreDoc = scoreDocs.get(i);
int j = i + 1;
for (; j < scoreDocs.size(); j++) {
ScoreDoc otherScoreDoc = scoreDocs.get(j);
if (otherScoreDoc.doc != scoreDoc.doc) {
break;
}
scoreDoc.score += otherScoreDoc.score;
}
if (j > i + 1) {
scoreDocs.subList(i + 1, j).clear();
}
scoreDoc.score += otherScoreDoc.score;
}
if (j > i + 1) {
scoreDocs.subList(i + 1, j).clear();
}
}
KnnScoreDocQueryBuilder knnQuery = new KnnScoreDocQueryBuilder(scoreDocs.toArray(new ScoreDoc[0]));

SearchSourceBuilder newSource = source.shallowCopy().knnSearch(List.of());
if (source.query() == null) {
newSource.query(knnQuery);
KnnScoreDocQueryBuilder knnQuery = new KnnScoreDocQueryBuilder(scoreDocs.toArray(new ScoreDoc[0]));
SearchSourceBuilder newSource = source.shallowCopy().knnSearch(List.of());
if (source.query() == null) {
newSource.query(knnQuery);
} else {
newSource.query(new BoolQueryBuilder().should(knnQuery).should(source.query()));
}
request.source(newSource);
} else {
newSource.query(new BoolQueryBuilder().should(knnQuery).should(source.query()));
// this path will keep knn queries separate for ranking per shard
// if there are multiple knn queries

List<QueryBuilder> rankQueryBuilders = new ArrayList<>();
if (source.query() != null) {
rankQueryBuilders.add(source.query());
}

for (DfsKnnResults dfsKnnResults : knnResults) {
List<ScoreDoc> scoreDocs = new ArrayList<>();
for (ScoreDoc scoreDoc : dfsKnnResults.scoreDocs()) {
if (scoreDoc.shardIndex == request.shardRequestIndex()) {
scoreDocs.add(scoreDoc);
}
}
scoreDocs.sort(Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
KnnScoreDocQueryBuilder knnQuery = new KnnScoreDocQueryBuilder(scoreDocs.toArray(new ScoreDoc[0]));
rankQueryBuilders.add(knnQuery);
}

BoolQueryBuilder searchQuery = new BoolQueryBuilder();
for (QueryBuilder queryBuilder : rankQueryBuilders) {
searchQuery.should(queryBuilder);
}

SearchSourceBuilder newSource = source.shallowCopy().query(searchQuery).knnSearch(List.of());
request.source(newSource);
request.rankQueryBuilders(rankQueryBuilders);
}

request.source(newSource);
return request;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.rank.RankCoordinatorContext;

import java.util.ArrayDeque;
import java.util.ArrayList;
Expand Down Expand Up @@ -57,6 +59,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
private final CircuitBreaker circuitBreaker;
private final SearchProgressListener progressListener;
private final AggregationReduceContext.Builder aggReduceContextBuilder;
private final RankCoordinatorContext rankCoordinatorContext;

private final int topNSize;
private final boolean hasTopDocs;
Expand Down Expand Up @@ -90,7 +93,12 @@ public QueryPhaseResultConsumer(
this.onPartialMergeFailure = onPartialMergeFailure;

SearchSourceBuilder source = request.source();
this.hasTopDocs = source == null || source.size() != 0;
int size = source == null || source.size() == -1 ? SearchService.DEFAULT_SIZE : source.size();
int from = source == null || source.from() == -1 ? SearchService.DEFAULT_FROM : source.from();
this.rankCoordinatorContext = source == null || source.rankBuilder() == null
? null
: source.rankBuilder().buildRankCoordinatorContext(size, from);
this.hasTopDocs = (source == null || size != 0) && rankCoordinatorContext == null;
this.hasAggs = source != null && source.aggregations() != null;
int batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize;
this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo());
Expand Down Expand Up @@ -135,6 +143,7 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
pendingMerges.numReducePhases,
false,
aggReduceContextBuilder,
rankCoordinatorContext,
performFinalReduce
);
if (hasAggs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import org.elasticsearch.search.profile.SearchProfileResults;
import org.elasticsearch.search.profile.SearchProfileResultsBuilder;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.rank.RankCoordinatorContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.suggest.Suggest;
import org.elasticsearch.search.suggest.Suggest.Suggestion;
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
Expand Down Expand Up @@ -425,7 +427,10 @@ private static SearchHits getHits(
: "not enough hits fetched. index [" + index + "] length: " + fetchResult.hits().getHits().length;
SearchHit searchHit = fetchResult.hits().getHits()[index];
searchHit.shard(fetchResult.getSearchShardTarget());
if (sortedTopDocs.isSortedByField) {
if (reducedQueryPhase.rankCoordinatorContext != null) {
assert shardDoc instanceof RankDoc;
searchHit.setRank(((RankDoc) shardDoc).rank);
} else if (sortedTopDocs.isSortedByField) {
FieldDoc fieldDoc = (FieldDoc) shardDoc;
searchHit.sortValues(fieldDoc.fields, reducedQueryPhase.sortValueFormats);
if (sortScoreIndex != -1) {
Expand Down Expand Up @@ -476,7 +481,17 @@ public AggregationReduceContext forFinalReduction() {
topDocs.add(td.topDocs);
}
}
return reducedQueryPhase(queryResults, Collections.emptyList(), topDocs, topDocsStats, 0, true, aggReduceContextBuilder, true);
return reducedQueryPhase(
queryResults,
Collections.emptyList(),
topDocs,
topDocsStats,
0,
true,
aggReduceContextBuilder,
null,
true
);
}

/**
Expand All @@ -496,6 +511,7 @@ static ReducedQueryPhase reducedQueryPhase(
int numReducePhases,
boolean isScrollRequest,
AggregationReduceContext.Builder aggReduceContextBuilder,
RankCoordinatorContext rankCoordinatorContext,
boolean performFinalReduce
) {
assert numReducePhases >= 0 : "num reduce phases must be >= 0 but was: " + numReducePhases;
Expand All @@ -513,6 +529,7 @@ static ReducedQueryPhase reducedQueryPhase(
null,
SortedTopDocs.EMPTY,
null,
null,
numReducePhases,
0,
0,
Expand Down Expand Up @@ -578,7 +595,12 @@ static ReducedQueryPhase reducedQueryPhase(
final SearchProfileResultsBuilder profileBuilder = profileShardResults.isEmpty()
? null
: new SearchProfileResultsBuilder(profileShardResults);
final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, bufferedTopDocs, from, size, reducedCompletionSuggestions);
final SortedTopDocs sortedTopDocs = rankCoordinatorContext == null
? sortDocs(isScrollRequest, bufferedTopDocs, from, size, reducedCompletionSuggestions)
: rankCoordinatorContext.rank(queryResults.stream().map(SearchPhaseResult::queryResult).toList(), topDocsStats);
if (rankCoordinatorContext != null) {
size = sortedTopDocs.scoreDocs.length;
}
final TotalHits totalHits = topDocsStats.getTotalHits();
return new ReducedQueryPhase(
totalHits,
Expand All @@ -591,6 +613,7 @@ static ReducedQueryPhase reducedQueryPhase(
profileBuilder,
sortedTopDocs,
sortValueFormats,
rankCoordinatorContext,
numReducePhases,
size,
from,
Expand Down Expand Up @@ -677,6 +700,8 @@ public record ReducedQueryPhase(
SortedTopDocs sortedTopDocs,
// sort value formats used to sort / format the result
DocValueFormat[] sortValueFormats,
// the rank context if ranking is used
RankCoordinatorContext rankCoordinatorContext,
// the number of reduces phases
int numReducePhases,
// the size of the top hits to return
Expand Down Expand Up @@ -750,16 +775,16 @@ QueryPhaseResultConsumer newSearchPhaseResults(
);
}

static final class TopDocsStats {
public static final class TopDocsStats {
final int trackTotalHitsUpTo;
long totalHits;
private TotalHits.Relation totalHitsRelation;
long fetchHits;
public long fetchHits;
private float maxScore = Float.NEGATIVE_INFINITY;
boolean timedOut;
Boolean terminatedEarly;
public boolean timedOut;
public Boolean terminatedEarly;

TopDocsStats(int trackTotalHitsUpTo) {
public TopDocsStats(int trackTotalHitsUpTo) {
this.trackTotalHitsUpTo = trackTotalHitsUpTo;
this.totalHits = 0;
this.totalHitsRelation = Relation.EQUAL_TO;
Expand Down Expand Up @@ -814,7 +839,7 @@ void add(TopDocsAndMaxScore topDocs, boolean timedOut, Boolean terminatedEarly)
}
}

record SortedTopDocs(
public record SortedTopDocs(
// the searches merged top docs
ScoreDoc[] scoreDocs,
// <code>true</code> iff the result score docs is sorted by a field (not score), this implies that <code>sortField</code> is set.
Expand All @@ -825,6 +850,6 @@ record SortedTopDocs(
Object[] collapseValues,
int numberOfCompletionsSuggestions
) {
static final SortedTopDocs EMPTY = new SortedTopDocs(EMPTY_DOCS, false, null, null, null, 0);
public static final SortedTopDocs EMPTY = new SortedTopDocs(EMPTY_DOCS, false, null, null, null, 0);
}
}

0 comments on commit 5314e5d

Please sign in to comment.