diff --git a/docs/changelog/137637.yaml b/docs/changelog/137637.yaml new file mode 100644 index 0000000000000..88976ea28b5aa --- /dev/null +++ b/docs/changelog/137637.yaml @@ -0,0 +1,5 @@ +pr: 137637 +summary: Fix Bug in `RankDocRetrieverBuilder` when `from` is set to Default (-1) +area: Search +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java index 0cdd5ab35adcd..12552457fa3b6 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java @@ -13,6 +13,7 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.RankDocsQueryBuilder; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.xcontent.XContentBuilder; @@ -135,6 +136,11 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder if (sourceHasMinScore()) { searchSourceBuilder.minScore(this.minScore == null ? Float.MIN_VALUE : this.minScore); } + + if (searchSourceBuilder.from() < 0) { + searchSourceBuilder.from(SearchService.DEFAULT_FROM); + } + if (searchSourceBuilder.size() + searchSourceBuilder.from() > rankDocResults.length) { searchSourceBuilder.size(Math.max(0, rankDocResults.length - searchSourceBuilder.from())); } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java index 165ad9b2de183..117561f98296e 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.index.query.RandomQueryBuilder; import org.elasticsearch.index.query.RankDocsQueryBuilder; import org.elasticsearch.index.query.Rewriteable; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; @@ -123,6 +124,9 @@ public void testExtractToSearchSourceBuilder() throws IOException { } } assertNull(source.postFilter()); + + // the default `from` is -1, when `extractToSearchSourceBuilder` is run, it should modify this to the default + assertEquals(SearchService.DEFAULT_FROM, source.from()); } public void testTopDocsQuery() throws IOException { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml index 98e392ed1ccee..6c6ad62d0fbff 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml @@ -373,7 +373,7 @@ setup: rank_window_size: 10 inference_id: my-rerank-model inference_text: "How often does the moon hide the sun?" - field: inference_text_field + field: text size: 10 - match: { hits.total.value: 1 } @@ -452,7 +452,7 @@ setup: rank_window_size: 10 inference_id: my-rerank-model inference_text: "How often does the moon hide the sun?" - field: inference_text_field + field: text min_score: 0 size: 10 diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java index 6854fc436038f..3f13c63bb3b8c 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java @@ -52,7 +52,6 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.lessThanOrEqualTo; @ESIntegTestCase.ClusterScope(minNumDataNodes = 3) public class RRFRetrieverBuilderIT extends ESIntegTestCase { @@ -161,62 +160,82 @@ public void testRRFPagination() { for (int i = 0; i < randomIntBetween(1, 5); i++) { int from = randomIntBetween(0, totalDocs - 1); int size = randomIntBetween(1, totalDocs - from); - for (int docs_to_fetch = from; docs_to_fetch < totalDocs; docs_to_fetch += size) { + for (int from_value = from; from_value < totalDocs; from_value += size) { SearchSourceBuilder source = new SearchSourceBuilder(); - source.from(docs_to_fetch); + source.from(from_value); source.size(size); - // this one retrieves docs 1, 2, 4, 6, and 7 - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( - QueryBuilders.boolQuery() - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) - ); - // this one retrieves docs 2 and 6 due to prefilter - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( - QueryBuilders.boolQuery() - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) - ); - standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 2.0f }, - null, - 10, - 100, - null, - null - ); - source.retriever( - new RRFRetrieverBuilder( - Arrays.asList( - new CompoundRetrieverBuilder.RetrieverSource(standard0, null), - new CompoundRetrieverBuilder.RetrieverSource(standard1, null), - new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) - ), - rankWindowSize, - rankConstant - ) - ); - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - int fDocs_to_fetch = docs_to_fetch; - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); - assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getHits().length, lessThanOrEqualTo(size)); - for (int k = 0; k < Math.min(size, resp.getHits().getHits().length); k++) { - assertThat(resp.getHits().getAt(k).getId(), equalTo(expectedDocIds.get(k + fDocs_to_fetch))); - } - }); + assertRRFPagination(source, from_value, size, rankWindowSize, rankConstant, expectedDocIds); } } + + // test with `from` as the default (-1) + for (int i = 0; i < randomIntBetween(5, 20); i++) { + int size = randomIntBetween(1, totalDocs); + SearchSourceBuilder source = new SearchSourceBuilder(); + source.size(size); + assertRRFPagination(source, source.from(), size, rankWindowSize, rankConstant, expectedDocIds); + } + + // and finally test with from = default, and size > {total docs} to be sure + SearchSourceBuilder source = new SearchSourceBuilder(); + source.size(totalDocs + 2); + assertRRFPagination(source, source.from(), totalDocs, rankWindowSize, rankConstant, expectedDocIds); + } + + private void assertRRFPagination( + SearchSourceBuilder source, + int from, + int maxExpectedSize, + int rankWindowSize, + int rankConstant, + List expectedDocIds + ) { + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + rankConstant + ) + ); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + + int innerFrom = Math.max(from, 0); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + + int expectedSize = innerFrom + maxExpectedSize > 6 ? 6 - innerFrom : maxExpectedSize; + assertThat(resp.getHits().getHits().length, equalTo(expectedSize)); + + for (int k = 0; k < expectedSize; k++) { + assertThat(resp.getHits().getAt(k).getId(), equalTo(expectedDocIds.get(k + innerFrom))); + } + }); } public void testRRFWithAggs() {