diff --git a/docs/changelog/107645.yaml b/docs/changelog/107645.yaml new file mode 100644 index 0000000000000..93fc0f2a89b3a --- /dev/null +++ b/docs/changelog/107645.yaml @@ -0,0 +1,7 @@ +pr: 107645 +summary: Add `_name` support for top level `knn` clauses +area: Search +type: enhancement +issues: + - 106254 + - 107448 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml index c8cbf499cf8b2..8471bd8cb5a9a 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml @@ -448,3 +448,40 @@ setup: - length: {hits.hits: 1} - match: {hits.hits.0._id: "2"} - close_to: {hits.hits.0._score: {value: 33686.29, error: 0.01}} +--- +"Knn search with _name": + - skip: + version: ' - 8.14.99' + reason: 'support for _name in knn was added in 8.15' + features: close_to + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + _name: "my_knn_query" + query: + term: + name: + term: cow.jpg + _name: "my_query" + + + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.0.fields.name.0: "cow.jpg"} + - match: {hits.hits.0.matched_queries.0: "my_knn_query"} + - match: {hits.hits.0.matched_queries.1: "my_query"} + + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.1.fields.name.0: "moose.jpg"} + - match: {hits.hits.1.matched_queries.0: "my_knn_query"} + + - match: {hits.hits.2._id: "3"} + - match: {hits.hits.2.fields.name.0: "rabbit.jpg"} + - match: {hits.hits.2.matched_queries.0: "my_knn_query"} diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 456497c167294..d0a091337342e 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -179,6 +179,7 @@ static TransportVersion def(int id) { public static final TransportVersion ENRICH_CACHE_ADDITIONAL_STATS = def(8_638_00_0); public static final TransportVersion ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED = def(8_639_00_0); public static final TransportVersion ML_TRAINED_MODEL_CACHE_METADATA_ADDED = def(8_640_00_0); + public static final TransportVersion TOP_LEVEL_KNN_SUPPORT_QUERY_NAME = def(8_641_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index 0c9d6ba12a27a..c5c35b1980a5d 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -155,7 +155,7 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) { scoreDocs.toArray(new ScoreDoc[0]), source.knnSearch().get(i).getField(), source.knnSearch().get(i).getQueryVector() - ).boost(source.knnSearch().get(i).boost()); + ).boost(source.knnSearch().get(i).boost()).queryName(source.knnSearch().get(i).queryName()); if (nestedPath != null) { query = new NestedQueryBuilder(nestedPath, query, ScoreMode.Max).innerHit(source.knnSearch().get(i).innerHit()); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 8dfc740e9df74..3c03d3258ebab 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -52,6 +52,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder"); public static final ParseField VECTOR_SIMILARITY = new ParseField("similarity"); public static final ParseField FILTER_FIELD = new ParseField("filter"); + public static final ParseField NAME_FIELD = AbstractQueryBuilder.NAME_FIELD; public static final ParseField BOOST_FIELD = AbstractQueryBuilder.BOOST_FIELD; public static final ParseField INNER_HITS_FIELD = new ParseField("inner_hits"); @@ -89,6 +90,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea FILTER_FIELD, ObjectParser.ValueType.OBJECT_ARRAY ); + PARSER.declareString(KnnSearchBuilder.Builder::queryName, NAME_FIELD); PARSER.declareFloat(KnnSearchBuilder.Builder::boost, BOOST_FIELD); PARSER.declareField( KnnSearchBuilder.Builder::innerHit, @@ -110,6 +112,7 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw final int numCands; final Float similarity; final List filterQueries; + String queryName; float boost = DEFAULT_BOOST; InnerHitBuilder innerHitBuilder; @@ -171,7 +174,7 @@ public KnnSearchBuilder( int numCands, Float similarity ) { - this(field, queryVectorBuilder, queryVector, new ArrayList<>(), k, numCands, similarity, null, DEFAULT_BOOST); + this(field, queryVectorBuilder, queryVector, new ArrayList<>(), k, numCands, similarity, null, null, DEFAULT_BOOST); } private KnnSearchBuilder( @@ -201,6 +204,7 @@ private KnnSearchBuilder( int numCandidates, Float similarity, InnerHitBuilder innerHitBuilder, + String queryName, float boost ) { if (k < 1) { @@ -239,6 +243,7 @@ private KnnSearchBuilder( this.numCands = numCandidates; this.innerHitBuilder = innerHitBuilder; this.similarity = similarity; + this.queryName = queryName; this.boost = boost; this.filterQueries = filterQueries; this.querySupplier = null; @@ -255,6 +260,11 @@ public KnnSearchBuilder(StreamInput in) throws IOException { } this.filterQueries = in.readNamedWriteableCollectionAsList(QueryBuilder.class); this.boost = in.readFloat(); + if (in.getTransportVersion().onOrAfter(TransportVersions.TOP_LEVEL_KNN_SUPPORT_QUERY_NAME)) { + this.queryName = in.readOptionalString(); + } else { + this.queryName = null; + } if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_7_0)) { this.queryVectorBuilder = in.readOptionalNamedWriteable(QueryVectorBuilder.class); } else { @@ -300,6 +310,18 @@ public KnnSearchBuilder addFilterQueries(List filterQueries) { return this; } + /** + * Sets a query name for the kNN search query. + */ + public KnnSearchBuilder queryName(String queryName) { + this.queryName = queryName; + return this; + } + + public String queryName() { + return queryName; + } + /** * Set a boost to apply to the kNN search scores. */ @@ -328,6 +350,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { return this; } return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, similarity).boost(boost) + .queryName(queryName) .addFilterQueries(filterQueries) .innerHit(innerHitBuilder); } @@ -349,7 +372,9 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { } ll.onResponse(null); }))); - return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries, similarity).boost(boost).innerHit(innerHitBuilder); + return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries, similarity).boost(boost) + .queryName(queryName) + .innerHit(innerHitBuilder); } boolean changed = false; List rewrittenQueries = new ArrayList<>(filterQueries.size()); @@ -362,6 +387,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { } if (changed) { return new KnnSearchBuilder(field, queryVector, k, numCands, similarity).boost(boost) + .queryName(queryName) .addFilterQueries(rewrittenQueries) .innerHit(innerHitBuilder); } @@ -372,7 +398,9 @@ public KnnVectorQueryBuilder toQueryBuilder() { if (queryVectorBuilder != null) { throw new IllegalArgumentException("missing rewrite"); } - return new KnnVectorQueryBuilder(field, queryVector, numCands, similarity).boost(boost).addFilterQueries(filterQueries); + return new KnnVectorQueryBuilder(field, queryVector, numCands, similarity).boost(boost) + .queryName(queryName) + .addFilterQueries(filterQueries); } @Override @@ -389,6 +417,7 @@ public boolean equals(Object o) { && Objects.equals(filterQueries, that.filterQueries) && Objects.equals(similarity, that.similarity) && Objects.equals(innerHitBuilder, that.innerHitBuilder) + && Objects.equals(queryName, that.queryName) && boost == that.boost; } @@ -404,6 +433,7 @@ public int hashCode() { Objects.hashCode(queryVector), Objects.hashCode(filterQueries), innerHitBuilder, + queryName, boost ); } @@ -440,6 +470,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (boost != DEFAULT_BOOST) { builder.field(BOOST_FIELD.getPreferredName(), boost); } + if (queryName != null) { + builder.field(NAME_FIELD.getPreferredName(), queryName); + } return builder; } @@ -459,6 +492,9 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeNamedWriteableCollection(filterQueries); out.writeFloat(boost); + if (out.getTransportVersion().onOrAfter(TransportVersions.TOP_LEVEL_KNN_SUPPORT_QUERY_NAME)) { + out.writeOptionalString(queryName); + } if (out.getTransportVersion().before(TransportVersions.V_8_7_0) && queryVectorBuilder != null) { throw new IllegalArgumentException( format( @@ -488,6 +524,7 @@ public static class Builder { private Integer numCandidates; private Float similarity; private final List filterQueries = new ArrayList<>(); + private String queryName; private float boost = DEFAULT_BOOST; private InnerHitBuilder innerHitBuilder; @@ -502,6 +539,11 @@ public Builder field(String field) { return this; } + public Builder queryName(String queryName) { + this.queryName = queryName; + return this; + } + public Builder boost(float boost) { this.boost = boost; return this; @@ -552,6 +594,7 @@ public KnnSearchBuilder build(int size) { adjustedNumCandidates, similarity, innerHitBuilder, + queryName, boost ); } diff --git a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java index a678956b20e59..818f74da5853a 100644 --- a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java @@ -106,11 +106,11 @@ public void testKnnWithQuery() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f).queryName("knn"); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) - .setQuery(QueryBuilders.matchQuery("text", "goodnight")) + .setQuery(QueryBuilders.matchQuery("text", "goodnight").queryName("query")) .addFetchField("*") .setSize(10), response -> { @@ -121,6 +121,8 @@ public void testKnnWithQuery() throws IOException { // Because of the boost, vector results should appear first assertNotNull(response.getHits().getAt(0).field("vector")); + assertEquals(response.getHits().getAt(0).getMatchedQueries()[0], "knn"); + assertEquals(response.getHits().getAt(9).getMatchedQueries()[0], "query"); } ); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java index b327aee0931f9..45f13ed9ef319 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java @@ -98,6 +98,7 @@ public final void testKnnSearchBuilderWireSerialization() throws IOException { 10, randomBoolean() ? null : randomFloat() ); + searchBuilder.queryName(randomAlphaOfLengthBetween(5, 10)); KnnSearchBuilder serialized = copyWriteable( searchBuilder, getNamedWriteableRegistry(),