Skip to content

Commit

Permalink
Fix for from parameter when using sub_searches and rank (#106253)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmpailis committed Jun 17, 2024
1 parent 60f6ba3 commit 0c14873
Show file tree
Hide file tree
Showing 23 changed files with 1,271 additions and 79 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/106253.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 106253
summary: Fix for from parameter when using `sub_searches` and rank
area: Ranking
type: bug
issues:
- 99011
4 changes: 2 additions & 2 deletions docs/reference/search/retriever.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ GET /index/_search
}
],
"rank_constant": ...
"window_size": ...
"rank_window_size": ...
}
}
}
Expand All @@ -207,7 +207,7 @@ The <<search-from-param, `from`>> and <<search-size-param, `size`>>
parameters are provided globally as part of the general
<<search-search, search API>>. They are applied to all retrievers in a
retriever tree unless a specific retriever overrides the `size` parameter
using a different parameter such as `window_size`. Though, the final
using a different parameter such as `rank_window_size`. Though, the final
search hits are always limited to `size`.

==== Using aggregations with a retriever tree
Expand Down
86 changes: 80 additions & 6 deletions docs/reference/search/rrf.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ GET example-index/_search
}
}
],
"window_size": 50,
"rank_window_size": 50,
"rank_constant": 20
}
}
Expand All @@ -94,8 +94,8 @@ its global top 50 results.
the query top documents and rank them based on the RRF formula using parameters from
the `rrf` retriever to get the combined top documents using the default `size` of `10`.

Note that if `k` from a knn search is larger than `window_size`, the results are
truncated to `window_size`. If `k` is smaller than `window_size`, the results are
Note that if `k` from a knn search is larger than `rank_window_size`, the results are
truncated to `rank_window_size`. If `k` is smaller than `rank_window_size`, the results are
`k` size.

[[rrf-supported-features]]
Expand Down Expand Up @@ -160,7 +160,7 @@ GET example-index/_search
}
}
],
"window_size": 50,
"rank_window_size": 50,
"rank_constant": 20
}
}
Expand Down Expand Up @@ -289,7 +289,7 @@ GET example-index/_search
}
}
],
"window_size": 5,
"rank_window_size": 5,
"rank_constant": 1
}
},
Expand Down Expand Up @@ -510,8 +510,82 @@ _id: 5 = 1.0/(1+4) = 0.2000
----
// NOTCONSOLE

We rank the documents based on the RRF formula with a `window_size` of `5`
We rank the documents based on the RRF formula with a `rank_window_size` of `5`
truncating the bottom `2` docs in our RRF result set with a `size` of `3`.
We end with `_id: 3` as `_rank: 1`, `_id: 2` as `_rank: 2`, and
`_id: 4` as `_rank: 3`. This ranking matches the result set from the
original RRF search as expected.


==== Pagination in RRF

When using `rrf` you can paginate through the results using the `from` parameter.
As the final ranking is solely dependent on the original query ranks, to ensure
consistency when paginating, we have to make sure that while `from` changes, the order
of what we have already seen remains intact. To that end, we're using a fixed `rank_window_size`
as the whole available result set upon which we can paginate.
This essentially means that if:

* `from + size` &le; `rank_window_size` : we could get `results[from: from+size]` documents back from
the final `rrf` ranked result set

* `from + size` &gt; `rank_window_size` : we would get 0 results back, as the request would fall outside the
available `rank_window_size`-sized result set.

An important thing to note here is that since `rank_window_size` is all the results that we'll get to see
from the individual query components, pagination guarantees consistency, i.e. no documents are skipped
or duplicated in multiple pages, iff `rank_window_size` remains the same. If `rank_window_size` changes, then the order
of the results might change as well, even for the same ranks.

To illustrate all of the above, let's consider the following simplified example where we have
two queries, `queryA` and `queryB` and their ranked documents:
[source,python]
----
| queryA | queryB |
_id: | 1 | 5 |
_id: | 2 | 4 |
_id: | 3 | 3 |
_id: | 4 | 1 |
_id: | | 2 |
----
// NOTCONSOLE

For `rank_window_size=5` we would get to see all documents from both `queryA` and `queryB`.
Assuming a `rank_constant=1`, the `rrf` scores would be:
[source,python]
----
# doc | queryA | queryB | score
_id: 1 = 1.0/(1+1) + 1.0/(1+4) = 0.7
_id: 2 = 1.0/(1+2) + 1.0/(1+5) = 0.5
_id: 3 = 1.0/(1+3) + 1.0/(1+3) = 0.5
_id: 4 = 1.0/(1+4) + 1.0/(1+2) = 0.533
_id: 5 = 0 + 1.0/(1+1) = 0.5
----
// NOTCONSOLE

So the final ranked result set would be [`1`, `4`, `2`, `3`, `5`] and we would paginate over that, since
`rank_window_size == len(results)`. In this scenario, we would have:

* `from=0, size=2` would return documents [`1`, `4`] with ranks `[1, 2]`
* `from=2, size=2` would return documents [`2`, `3`] with ranks `[3, 4]`
* `from=4, size=2` would return document [`5`] with rank `[5]`
* `from=6, size=2` would return an empty result set as it there are no more results to iterate over

Now, if we had a `rank_window_size=2`, we would only get to see `[1, 2]` and `[5, 4]` documents
for queries `queryA` and `queryB` respectively. Working out the math, we would see that the results would now
be slightly different, because we would have no knowledge of the documents in positions `[3: end]` for either query.
[source,python]
----
# doc | queryA | queryB | score
_id: 1 = 1.0/(1+1) + 0 = 0.5
_id: 2 = 1.0/(1+2) + 0 = 0.33
_id: 4 = 0 + 1.0/(1+2) = 0.33
_id: 5 = 0 + 1.0/(1+1) = 0.5
----
// NOTCONSOLE

The final ranked result set would be [`1`, `5`, `2`, `4`], and we would be able to paginate
on the top `rank_window_size` results, i.e. [`1`, `5`]. So for the same params as above, we would now have:

* `from=0, size=2` would return [`1`, `5`] with ranks `[1, 2]`
* `from=2, size=2` would return an empty result set as it would fall outside the available `rank_window_size` results.
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,10 @@ static ReducedQueryPhase reducedQueryPhase(
);
sortedTopDocs = new SortedTopDocs(rankedDocs, false, null, null, null, 0);
size = sortedTopDocs.scoreDocs.length;
// we need to reset from here as pagination and result trimming has already taken place
// within the `QueryPhaseRankCoordinatorContext#rankQueryPhaseResults` and we don't want
// to apply it again in the `getHits` method.
from = 0;
}
final TotalHits totalHits = topDocsStats.getTotalHits();
return new ReducedQueryPhase(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,10 @@ public ActionRequestValidationException validate() {
if (size == 0) {
validationException = addValidationError("[rank] requires [size] greater than [0]", validationException);
}
if (size > source.rankBuilder().windowSize()) {
if (size > source.rankBuilder().rankWindowSize()) {
validationException = addValidationError(
"[rank] requires [window_size: "
+ source.rankBuilder().windowSize()
"[rank] requires [rank_window_size: "
+ source.rankBuilder().rankWindowSize()
+ "]"
+ " be greater than or equal to [size: "
+ size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ static void executeRank(SearchContext searchContext) throws QueryPhaseExecutionE
RankSearchContext rankSearchContext = new RankSearchContext(
searchContext,
rankQuery,
queryPhaseRankShardContext.windowSize()
queryPhaseRankShardContext.rankWindowSize()
)
) {
QueryPhase.addCollectorsAndSearch(rankSearchContext);
Expand Down
22 changes: 11 additions & 11 deletions server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@
*/
public abstract class RankBuilder implements VersionedNamedWriteable, ToXContentObject {

public static final ParseField WINDOW_SIZE_FIELD = new ParseField("window_size");
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");

public static final int DEFAULT_WINDOW_SIZE = SearchService.DEFAULT_SIZE;

private final int windowSize;
private final int rankWindowSize;

public RankBuilder(int windowSize) {
this.windowSize = windowSize;
public RankBuilder(int rankWindowSize) {
this.rankWindowSize = rankWindowSize;
}

public RankBuilder(StreamInput in) throws IOException {
windowSize = in.readVInt();
rankWindowSize = in.readVInt();
}

public final void writeTo(StreamOutput out) throws IOException {
out.writeVInt(windowSize);
out.writeVInt(rankWindowSize);
doWriteTo(out);
}

Expand All @@ -55,7 +55,7 @@ public final void writeTo(StreamOutput out) throws IOException {
public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startObject(getWriteableName());
builder.field(WINDOW_SIZE_FIELD.getPreferredName(), windowSize);
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
doXContent(builder, params);
builder.endObject();
builder.endObject();
Expand All @@ -64,8 +64,8 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params)

protected abstract void doXContent(XContentBuilder builder, Params params) throws IOException;

public int windowSize() {
return windowSize;
public int rankWindowSize() {
return rankWindowSize;
}

/**
Expand All @@ -88,14 +88,14 @@ public final boolean equals(Object obj) {
}
@SuppressWarnings("unchecked")
RankBuilder other = (RankBuilder) obj;
return Objects.equals(windowSize, other.windowSize()) && doEquals(other);
return Objects.equals(rankWindowSize, other.rankWindowSize()) && doEquals(other);
}

protected abstract boolean doEquals(RankBuilder other);

@Override
public final int hashCode() {
return Objects.hash(getClass(), windowSize, doHashCode());
return Objects.hash(getClass(), rankWindowSize, doHashCode());
}

protected abstract int doHashCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
*/
public abstract class QueryPhaseRankCoordinatorContext {

protected final int windowSize;
protected final int rankWindowSize;

public QueryPhaseRankCoordinatorContext(int windowSize) {
this.windowSize = windowSize;
public QueryPhaseRankCoordinatorContext(int rankWindowSize) {
this.rankWindowSize = rankWindowSize;
}

/**
* This is used to pull information passed back from the shards as part of {@link QuerySearchResult#getRankShardResult()}
* and return a {@link ScoreDoc[]} of the `window_size` ranked results. Note that {@link TopDocsStats} is included so that
* and return a {@link ScoreDoc[]} of the `rank_window_size` ranked results. Note that {@link TopDocsStats} is included so that
* appropriate stats may be updated based on rank results.
* This is called when reducing query results through {@code SearchPhaseController#reducedQueryPhase()}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,27 @@
public abstract class QueryPhaseRankShardContext {

protected final List<Query> queries;
protected final int windowSize;
protected final int rankWindowSize;

public QueryPhaseRankShardContext(List<Query> queries, int windowSize) {
public QueryPhaseRankShardContext(List<Query> queries, int rankWindowSize) {
this.queries = queries;
this.windowSize = windowSize;
this.rankWindowSize = rankWindowSize;
}

public List<Query> queries() {
return queries;
}

public int windowSize() {
return windowSize;
public int rankWindowSize() {
return rankWindowSize;
}

/**
* This is used to reduce the number of required results that are serialized
* to the coordinating node. Normally we would have to serialize {@code queries * window_size}
* to the coordinating node. Normally we would have to serialize {@code queries * rank_window_size}
* results, but we can infer that there will likely be overlap of document results. Given that we
* know any searches that match the same document must be on the same shard, we can sort on the shard
* instead for a top window_size set of results and reduce the amount of data we serialize.
* instead for a top rank_window_size set of results and reduce the amount of data we serialize.
*/
public abstract RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults);
}
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ public void testValidate() throws IOException {
assertNotNull(validationErrors);
assertEquals(1, validationErrors.validationErrors().size());
assertEquals(
"[rank] requires [window_size: 1] be greater than or equal to [size: 2]",
"[rank] requires [rank_window_size: 1] be greater than or equal to [size: 2]",
validationErrors.validationErrors().get(0)
);
}
Expand Down Expand Up @@ -437,10 +437,21 @@ public void testValidate() throws IOException {
assertNotNull(validationErrors);
assertEquals(1, validationErrors.validationErrors().size());
assertEquals(
"[rank] requires [window_size: 9] be greater than or equal to [size: 10]",
"[rank] requires [rank_window_size: 9] be greater than or equal to [size: 10]",
validationErrors.validationErrors().get(0)
);
}
{
SearchRequest searchRequest = new SearchRequest().source(
new SearchSourceBuilder().rankBuilder(new TestRankBuilder(3))
.query(QueryBuilders.termQuery("field", "term"))
.knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null)))
.size(3)
.from(4)
);
ActionRequestValidationException validationErrors = searchRequest.validate();
assertNull(validationErrors);
}
{
SearchRequest searchRequest = new SearchRequest().source(
new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class TestRankBuilder extends RankBuilder {
);

static {
PARSER.declareInt(optionalConstructorArg(), WINDOW_SIZE_FIELD);
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
}

public static TestRankBuilder fromXContent(XContentParser parser) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public ScoreDoc[] rankQueryPhaseResults(List<QuerySearchResult> querySearchResul

for (int qi = 0; qi < queryCount; ++qi) {
final int fqi = qi;
queues.add(new PriorityQueue<>(windowSize + from) {
queues.add(new PriorityQueue<>(rankWindowSize) {
@Override
protected boolean lessThan(RRFRankDoc a, RRFRankDoc b) {
float score1 = a.scores[fqi];
Expand Down Expand Up @@ -105,7 +105,7 @@ protected boolean lessThan(RRFRankDoc a, RRFRankDoc b) {
// score if we already saw it as part of a previous query's
// doc set, otherwise we make a new doc and calculate the
// initial score
Map<RankKey, RRFRankDoc> results = Maps.newMapWithExpectedSize(queryCount * windowSize);
Map<RankKey, RRFRankDoc> results = Maps.newMapWithExpectedSize(queryCount * rankWindowSize);
final int fqc = queryCount;
for (int qi = 0; qi < queryCount; ++qi) {
PriorityQueue<RRFRankDoc> queue = queues.get(qi);
Expand All @@ -127,6 +127,11 @@ protected boolean lessThan(RRFRankDoc a, RRFRankDoc b) {
}
}

// return if pagination requested is outside the results
if (results.values().size() - from <= 0) {
return new ScoreDoc[0];
}

// sort the results based on rrf score, tiebreaker based on
// larger individual query score from 1 to n, smaller shard then smaller doc id
RRFRankDoc[] sortedResults = results.values().toArray(RRFRankDoc[]::new);
Expand All @@ -151,9 +156,10 @@ protected boolean lessThan(RRFRankDoc a, RRFRankDoc b) {
}
return rrf1.doc < rrf2.doc ? -1 : 1;
});
// trim results to size
RRFRankDoc[] topResults = new RRFRankDoc[Math.min(size, sortedResults.length - from)];
for (int rank = 0; rank < topResults.length; ++rank) {
topResults[rank] = sortedResults[rank];
topResults[rank] = sortedResults[from + rank];
topResults[rank].rank = rank + 1 + from;
}
// update fetch hits for the fetch phase, so we gather any additional
Expand Down
Loading

0 comments on commit 0c14873

Please sign in to comment.