Skip to content
Merged
7 changes: 7 additions & 0 deletions docs/changelog/107645.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pr: 107645
summary: Add `_name` support for top level `knn` clauses
area: Search
type: enhancement
issues:
- 106254
- 107448
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,40 @@ setup:
- length: {hits.hits: 1}
- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 33686.29, error: 0.01}}
---
"Knn search with _name":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rassyan I added a yaml test for you

- skip:
version: ' - 8.14.99'
reason: 'support for _name in knn was added in 8.15'
features: close_to

- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
k: 3
num_candidates: 3
_name: "my_knn_query"
query:
term:
name:
term: cow.jpg
_name: "my_query"


- match: {hits.hits.0._id: "1"}
- match: {hits.hits.0.fields.name.0: "cow.jpg"}
- match: {hits.hits.0.matched_queries.0: "my_knn_query"}
- match: {hits.hits.0.matched_queries.1: "my_query"}

- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
- match: {hits.hits.1.matched_queries.0: "my_knn_query"}

- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2.fields.name.0: "rabbit.jpg"}
- match: {hits.hits.2.matched_queries.0: "my_knn_query"}
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ENRICH_CACHE_ADDITIONAL_STATS = def(8_638_00_0);
public static final TransportVersion ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED = def(8_639_00_0);
public static final TransportVersion ML_TRAINED_MODEL_CACHE_METADATA_ADDED = def(8_640_00_0);
public static final TransportVersion TOP_LEVEL_KNN_SUPPORT_QUERY_NAME = def(8_641_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
scoreDocs.toArray(new ScoreDoc[0]),
source.knnSearch().get(i).getField(),
source.knnSearch().get(i).getQueryVector()
).boost(source.knnSearch().get(i).boost());
).boost(source.knnSearch().get(i).boost()).queryName(source.knnSearch().get(i).queryName());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

if (nestedPath != null) {
query = new NestedQueryBuilder(nestedPath, query, ScoreMode.Max).innerHit(source.knnSearch().get(i).innerHit());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");
public static final ParseField VECTOR_SIMILARITY = new ParseField("similarity");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField NAME_FIELD = AbstractQueryBuilder.NAME_FIELD;
public static final ParseField BOOST_FIELD = AbstractQueryBuilder.BOOST_FIELD;
public static final ParseField INNER_HITS_FIELD = new ParseField("inner_hits");

Expand Down Expand Up @@ -89,6 +90,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
FILTER_FIELD,
ObjectParser.ValueType.OBJECT_ARRAY
);
PARSER.declareString(KnnSearchBuilder.Builder::queryName, NAME_FIELD);
PARSER.declareFloat(KnnSearchBuilder.Builder::boost, BOOST_FIELD);
PARSER.declareField(
KnnSearchBuilder.Builder::innerHit,
Expand All @@ -110,6 +112,7 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw
final int numCands;
final Float similarity;
final List<QueryBuilder> filterQueries;
String queryName;
float boost = DEFAULT_BOOST;
InnerHitBuilder innerHitBuilder;

Expand Down Expand Up @@ -171,7 +174,7 @@ public KnnSearchBuilder(
int numCands,
Float similarity
) {
this(field, queryVectorBuilder, queryVector, new ArrayList<>(), k, numCands, similarity, null, DEFAULT_BOOST);
this(field, queryVectorBuilder, queryVector, new ArrayList<>(), k, numCands, similarity, null, null, DEFAULT_BOOST);
}

private KnnSearchBuilder(
Expand Down Expand Up @@ -201,6 +204,7 @@ private KnnSearchBuilder(
int numCandidates,
Float similarity,
InnerHitBuilder innerHitBuilder,
String queryName,
float boost
) {
if (k < 1) {
Expand Down Expand Up @@ -239,6 +243,7 @@ private KnnSearchBuilder(
this.numCands = numCandidates;
this.innerHitBuilder = innerHitBuilder;
this.similarity = similarity;
this.queryName = queryName;
this.boost = boost;
this.filterQueries = filterQueries;
this.querySupplier = null;
Expand All @@ -255,6 +260,11 @@ public KnnSearchBuilder(StreamInput in) throws IOException {
}
this.filterQueries = in.readNamedWriteableCollectionAsList(QueryBuilder.class);
this.boost = in.readFloat();
if (in.getTransportVersion().onOrAfter(TransportVersions.TOP_LEVEL_KNN_SUPPORT_QUERY_NAME)) {
this.queryName = in.readOptionalString();
} else {
this.queryName = null;
}
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_7_0)) {
this.queryVectorBuilder = in.readOptionalNamedWriteable(QueryVectorBuilder.class);
} else {
Expand Down Expand Up @@ -300,6 +310,18 @@ public KnnSearchBuilder addFilterQueries(List<QueryBuilder> filterQueries) {
return this;
}

/**
* Sets a query name for the kNN search query.
*/
public KnnSearchBuilder queryName(String queryName) {
this.queryName = queryName;
return this;
}

public String queryName() {
return queryName;
}

/**
* Set a boost to apply to the kNN search scores.
*/
Expand Down Expand Up @@ -328,6 +350,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
return this;
}
return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, similarity).boost(boost)
.queryName(queryName)
.addFilterQueries(filterQueries)
.innerHit(innerHitBuilder);
}
Expand All @@ -349,7 +372,9 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
}
ll.onResponse(null);
})));
return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries, similarity).boost(boost).innerHit(innerHitBuilder);
return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries, similarity).boost(boost)
.queryName(queryName)
.innerHit(innerHitBuilder);
}
boolean changed = false;
List<QueryBuilder> rewrittenQueries = new ArrayList<>(filterQueries.size());
Expand All @@ -362,6 +387,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
}
if (changed) {
return new KnnSearchBuilder(field, queryVector, k, numCands, similarity).boost(boost)
.queryName(queryName)
.addFilterQueries(rewrittenQueries)
.innerHit(innerHitBuilder);
}
Expand All @@ -372,7 +398,9 @@ public KnnVectorQueryBuilder toQueryBuilder() {
if (queryVectorBuilder != null) {
throw new IllegalArgumentException("missing rewrite");
}
return new KnnVectorQueryBuilder(field, queryVector, numCands, similarity).boost(boost).addFilterQueries(filterQueries);
return new KnnVectorQueryBuilder(field, queryVector, numCands, similarity).boost(boost)
.queryName(queryName)
.addFilterQueries(filterQueries);
}

@Override
Expand All @@ -389,6 +417,7 @@ public boolean equals(Object o) {
&& Objects.equals(filterQueries, that.filterQueries)
&& Objects.equals(similarity, that.similarity)
&& Objects.equals(innerHitBuilder, that.innerHitBuilder)
&& Objects.equals(queryName, that.queryName)
&& boost == that.boost;
}

Expand All @@ -404,6 +433,7 @@ public int hashCode() {
Objects.hashCode(queryVector),
Objects.hashCode(filterQueries),
innerHitBuilder,
queryName,
boost
);
}
Expand Down Expand Up @@ -440,6 +470,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (boost != DEFAULT_BOOST) {
builder.field(BOOST_FIELD.getPreferredName(), boost);
}
if (queryName != null) {
builder.field(NAME_FIELD.getPreferredName(), queryName);
}

return builder;
}
Expand All @@ -459,6 +492,9 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeNamedWriteableCollection(filterQueries);
out.writeFloat(boost);
if (out.getTransportVersion().onOrAfter(TransportVersions.TOP_LEVEL_KNN_SUPPORT_QUERY_NAME)) {
out.writeOptionalString(queryName);
}
if (out.getTransportVersion().before(TransportVersions.V_8_7_0) && queryVectorBuilder != null) {
throw new IllegalArgumentException(
format(
Expand Down Expand Up @@ -488,6 +524,7 @@ public static class Builder {
private Integer numCandidates;
private Float similarity;
private final List<QueryBuilder> filterQueries = new ArrayList<>();
private String queryName;
private float boost = DEFAULT_BOOST;
private InnerHitBuilder innerHitBuilder;

Expand All @@ -502,6 +539,11 @@ public Builder field(String field) {
return this;
}

public Builder queryName(String queryName) {
this.queryName = queryName;
return this;
}

public Builder boost(float boost) {
this.boost = boost;
return this;
Expand Down Expand Up @@ -552,6 +594,7 @@ public KnnSearchBuilder build(int size) {
adjustedNumCandidates,
similarity,
innerHitBuilder,
queryName,
boost
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ public void testKnnWithQuery() throws IOException {
indicesAdmin().prepareRefresh("index").get();

float[] queryVector = randomVector();
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f);
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f).queryName("knn");
assertResponse(
client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch))
.setQuery(QueryBuilders.matchQuery("text", "goodnight"))
.setQuery(QueryBuilders.matchQuery("text", "goodnight").queryName("query"))
.addFetchField("*")
.setSize(10),
response -> {
Expand All @@ -121,6 +121,8 @@ public void testKnnWithQuery() throws IOException {

// Because of the boost, vector results should appear first
assertNotNull(response.getHits().getAt(0).field("vector"));
assertEquals(response.getHits().getAt(0).getMatchedQueries()[0], "knn");
assertEquals(response.getHits().getAt(9).getMatchedQueries()[0], "query");
}
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public final void testKnnSearchBuilderWireSerialization() throws IOException {
10,
randomBoolean() ? null : randomFloat()
);
searchBuilder.queryName(randomAlphaOfLengthBetween(5, 10));
KnnSearchBuilder serialized = copyWriteable(
searchBuilder,
getNamedWriteableRegistry(),
Expand Down