Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/137637.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 137637
summary: Fix Bug in `RankDocRetrieverBuilder` when `from` is set to Default (-1)
area: Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -135,6 +136,11 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
if (sourceHasMinScore()) {
searchSourceBuilder.minScore(this.minScore == null ? Float.MIN_VALUE : this.minScore);
}

if (searchSourceBuilder.from() < 0) {
searchSourceBuilder.from(SearchService.DEFAULT_FROM);
}

if (searchSourceBuilder.size() + searchSourceBuilder.from() > rankDocResults.length) {
searchSourceBuilder.size(Math.max(0, rankDocResults.length - searchSourceBuilder.from()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.index.query.RandomQueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
Expand Down Expand Up @@ -123,6 +124,9 @@ public void testExtractToSearchSourceBuilder() throws IOException {
}
}
assertNull(source.postFilter());

// the default `from` is -1, when `extractToSearchSourceBuilder` is run, it should modify this to the default
assertEquals(SearchService.DEFAULT_FROM, source.from());
}

public void testTopDocsQuery() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ setup:
rank_window_size: 10
inference_id: my-rerank-model
inference_text: "How often does the moon hide the sun?"
field: inference_text_field
field: text
size: 10

- match: { hits.total.value: 1 }
Expand Down Expand Up @@ -452,7 +452,7 @@ setup:
rank_window_size: 10
inference_id: my-rerank-model
inference_text: "How often does the moon hide the sun?"
field: inference_text_field
field: text
min_score: 0
size: 10

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

@ESIntegTestCase.ClusterScope(minNumDataNodes = 3)
public class RRFRetrieverBuilderIT extends ESIntegTestCase {
Expand Down Expand Up @@ -161,62 +160,82 @@ public void testRRFPagination() {
for (int i = 0; i < randomIntBetween(1, 5); i++) {
int from = randomIntBetween(0, totalDocs - 1);
int size = randomIntBetween(1, totalDocs - from);
for (int docs_to_fetch = from; docs_to_fetch < totalDocs; docs_to_fetch += size) {
for (int from_value = from; from_value < totalDocs; from_value += size) {
SearchSourceBuilder source = new SearchSourceBuilder();
source.from(docs_to_fetch);
source.from(from_value);
source.size(size);
// this one retrieves docs 1, 2, 4, 6, and 7
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
// this one retrieves docs 2 and 6 due to prefilter
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(
VECTOR_FIELD,
new float[] { 2.0f },
null,
10,
100,
null,
null
);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
),
rankWindowSize,
rankConstant
)
);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
int fDocs_to_fetch = docs_to_fetch;
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getHits().length, lessThanOrEqualTo(size));
for (int k = 0; k < Math.min(size, resp.getHits().getHits().length); k++) {
assertThat(resp.getHits().getAt(k).getId(), equalTo(expectedDocIds.get(k + fDocs_to_fetch)));
}
});
assertRRFPagination(source, from_value, size, rankWindowSize, rankConstant, expectedDocIds);
}
}

// test with `from` as the default (-1)
for (int i = 0; i < randomIntBetween(5, 20); i++) {
int size = randomIntBetween(1, totalDocs);
SearchSourceBuilder source = new SearchSourceBuilder();
source.size(size);
assertRRFPagination(source, source.from(), size, rankWindowSize, rankConstant, expectedDocIds);
}

// and finally test with from = default, and size > {total docs} to be sure
SearchSourceBuilder source = new SearchSourceBuilder();
source.size(totalDocs + 2);
assertRRFPagination(source, source.from(), totalDocs, rankWindowSize, rankConstant, expectedDocIds);
}

private void assertRRFPagination(
SearchSourceBuilder source,
int from,
int maxExpectedSize,
int rankWindowSize,
int rankConstant,
List<String> expectedDocIds
) {
// this one retrieves docs 1, 2, 4, 6, and 7
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
// this one retrieves docs 2 and 6 due to prefilter
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
),
rankWindowSize,
rankConstant
)
);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);

int innerFrom = Math.max(from, 0);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));

int expectedSize = innerFrom + maxExpectedSize > 6 ? 6 - innerFrom : maxExpectedSize;
assertThat(resp.getHits().getHits().length, equalTo(expectedSize));

for (int k = 0; k < expectedSize; k++) {
assertThat(resp.getHits().getAt(k).getId(), equalTo(expectedDocIds.get(k + innerFrom)));
}
});
}

public void testRRFWithAggs() {
Expand Down