Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add modelId and modelText to KnnVectorQueryBuilder #106068

Merged
merged 36 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b53f756
Add modelId and modelText to KnnVectorQueryBuilder
tteofili Mar 7, 2024
d81c4e1
Added missing files
tteofili Mar 7, 2024
32e701e
Fix query in yaml test
tteofili Mar 7, 2024
1eb1286
Update docs/changelog/106068.yaml
tteofili Mar 7, 2024
98e9b21
Fixed docs, bumped TransportVersion
tteofili Mar 7, 2024
b58d1d2
Merge branch 'knn_dsl_modeltext_modelid' of github.com:tteofili/elast…
tteofili Mar 7, 2024
dca4e99
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili Mar 8, 2024
840ff1c
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili Mar 8, 2024
2e35f29
correct assertion for yml test
tteofili Mar 8, 2024
1f9ac72
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili Mar 11, 2024
f2c998e
Fix ctor visibility
tteofili Mar 11, 2024
b164d25
Improved input validation
tteofili Mar 11, 2024
905bc31
validation should not check for query supplier
tteofili Mar 11, 2024
d359256
dropped unused ctor
tteofili Mar 11, 2024
9fa135c
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili Mar 11, 2024
1e58587
formatting fix
tteofili Mar 11, 2024
301705d
revert deprecated api doc change, use include for qvb
tteofili Mar 12, 2024
303431e
validation in KNNVQB#doXContent
tteofili Mar 12, 2024
7f09dc1
IT refactoring to avoid duplicate tests
tteofili Mar 12, 2024
0f9d59f
fixed docs error
tteofili Mar 12, 2024
28e2136
skip yml test for versions <= 8.14
tteofili Mar 12, 2024
3c22e9c
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili Mar 12, 2024
ce232be
serialization validation for supplier
tteofili Mar 13, 2024
f359b89
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili Mar 13, 2024
3013a1b
rewrite test expected vector should be float[]
tteofili Mar 13, 2024
45b5a00
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili Mar 14, 2024
0de580a
nested vectors with inner hits test
tteofili Mar 14, 2024
9193cf7
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili Mar 14, 2024
8f97ea1
add required privileges for indexing nested vectors
tteofili Mar 14, 2024
73f300d
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili Mar 14, 2024
983c7e9
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili Mar 15, 2024
f682fb7
add required privileges for indexing nested vectors
tteofili Mar 15, 2024
61671b6
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili Mar 18, 2024
50e882d
index refresh needs token too
tteofili Mar 18, 2024
822a32d
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili Mar 18, 2024
166e3c5
added skip version for nested vectors test
tteofili Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/106068.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 106068
summary: Add `modelId` and `modelText` to `KnnVectorQueryBuilder`
area: Search
type: enhancement
issues: []
11 changes: 10 additions & 1 deletion docs/reference/query-dsl/knn-query.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,19 @@ the top `size` results.
`query_vector`::
+
--
(Required, array of floats or string) Query vector. Must have the same number of dimensions
(Optional, array of floats or string) Query vector. Must have the same number of dimensions
as the vector field you are searching against. Must be either an array of floats or a hex-encoded byte vector.
Either this or `query_vector_builder` must be provided.
--

`query_vector_builder`::
+
--
(Optional, object) Query vector builder.
include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-query-vector-builder]
--


`num_candidates`::
+
--
Expand Down
7 changes: 7 additions & 0 deletions docs/reference/rest-api/common-parms.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,13 @@ Query vector. Must have the same number of dimensions as the vector field you
are searching against. Must be either an array of floats or a hex-encoded byte vector.
end::knn-query-vector[]

tag::knn-query-vector-builder[]
A configuration object indicating how to build a query_vector before executing
the request. You must provide either a `query_vector_builder` or `query_vector`,
but not both. Refer to <<knn-semantic-search>> to learn more.
end::knn-query-vector-builder[]


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to update docs/reference/query-dsl/knn-query.asciidoc as well to include knn-query-vector-builder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can I use include directive inside docs/reference/query-dsl/knn-query.asciidoc ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to use include

tag::knn-similarity[]
The minimum similarity required for a document to be considered a match. The similarity
value calculated relates to the raw <<dense-vector-similarity, `similarity`>> used. Not the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ALLOCATION_STATS = def(8_604_00_0);
public static final TransportVersion ESQL_EXTENDED_ENRICH_TYPES = def(8_605_00_0);
public static final TransportVersion KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING = def(8_606_00_0);
public static final TransportVersion KNN_QUERY_VECTOR_BUILDER = def(8_607_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
Expand All @@ -39,15 +40,16 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.common.Strings.format;
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

/**
* A query that performs kNN search using Lucene's {@link org.apache.lucene.search.KnnFloatVectorQuery} or
* {@link org.apache.lucene.search.KnnByteVectorQuery}.
*
*/
public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBuilder> {
public static final String NAME = "knn";
Expand All @@ -59,11 +61,19 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField("similarity");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<KnnVectorQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
"knn",
args -> new KnnVectorQueryBuilder((String) args[0], (VectorData) args[1], (Integer) args[2], (Float) args[3])
args -> new KnnVectorQueryBuilder(
(String) args[0],
(VectorData) args[1],
(QueryVectorBuilder) args[4],
null,
(Integer) args[2],
(Float) args[3]
)
);

static {
Expand All @@ -76,6 +86,11 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
);
PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD);
PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY_FIELD);
PARSER.declareNamedObject(
optionalConstructorArg(),
(p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c),
QUERY_VECTOR_BUILDER_FIELD
);
PARSER.declareFieldArray(
KnnVectorQueryBuilder::addFilterQueries,
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
Expand All @@ -94,26 +109,59 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) {
private Integer numCands;
private final List<QueryBuilder> filterQueries = new ArrayList<>();
private final Float vectorSimilarity;
private final QueryVectorBuilder queryVectorBuilder;
private final Supplier<float[]> queryVectorSupplier;

public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer numCands, Float vectorSimilarity) {
this(fieldName, VectorData.fromFloats(queryVector), numCands, vectorSimilarity);
this(fieldName, VectorData.fromFloats(queryVector), null, null, numCands, vectorSimilarity);
}

protected KnnVectorQueryBuilder(String fieldName, QueryVectorBuilder queryVectorBuilder, Integer numCands, Float vectorSimilarity) {
this(fieldName, null, queryVectorBuilder, null, numCands, vectorSimilarity);
}

public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer numCands, Float vectorSimilarity) {
this(fieldName, VectorData.fromBytes(queryVector), numCands, vectorSimilarity);
this(fieldName, VectorData.fromBytes(queryVector), null, null, numCands, vectorSimilarity);
}

public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer numCands, Float vectorSimilarity) {
this(fieldName, queryVector, null, null, numCands, vectorSimilarity);
}

private KnnVectorQueryBuilder(
String fieldName,
VectorData queryVector,
QueryVectorBuilder queryVectorBuilder,
Supplier<float[]> queryVectorSupplier,
Integer numCands,
Float vectorSimilarity
) {
if (numCands != null && numCands > NUM_CANDS_LIMIT) {
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
}
if (queryVector == null) {
throw new IllegalArgumentException("[" + QUERY_VECTOR_FIELD.getPreferredName() + "] must be provided");
if (queryVector == null && queryVectorBuilder == null) {
throw new IllegalArgumentException(
format(
"either [%s] or [%s] must be provided",
QUERY_VECTOR_FIELD.getPreferredName(),
QUERY_VECTOR_BUILDER_FIELD.getPreferredName()
)
);
} else if (queryVector != null && queryVectorBuilder != null) {
throw new IllegalArgumentException(
format(
"only one of [%s] and [%s] must be provided",
QUERY_VECTOR_FIELD.getPreferredName(),
QUERY_VECTOR_BUILDER_FIELD.getPreferredName()
)
);
tteofili marked this conversation as resolved.
Show resolved Hide resolved
}
this.fieldName = fieldName;
this.queryVector = queryVector;
this.numCands = numCands;
this.vectorSimilarity = vectorSimilarity;
this.queryVectorBuilder = queryVectorBuilder;
this.queryVectorSupplier = queryVectorSupplier;
}

public KnnVectorQueryBuilder(StreamInput in) throws IOException {
Expand Down Expand Up @@ -144,6 +192,12 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException {
} else {
this.vectorSimilarity = null;
}
if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_VECTOR_BUILDER)) {
this.queryVectorBuilder = in.readOptionalNamedWriteable(QueryVectorBuilder.class);
} else {
this.queryVectorBuilder = null;
}
this.queryVectorSupplier = null;
}

public String getFieldName() {
Expand All @@ -168,6 +222,11 @@ public List<QueryBuilder> filterQueries() {
return filterQueries;
}

@Nullable
public QueryVectorBuilder queryVectorBuilder() {
return queryVectorBuilder;
}

public KnnVectorQueryBuilder addFilterQuery(QueryBuilder filterQuery) {
Objects.requireNonNull(filterQuery);
this.filterQueries.add(filterQuery);
Expand All @@ -182,6 +241,9 @@ public KnnVectorQueryBuilder addFilterQueries(List<QueryBuilder> filterQueries)

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
if (queryVectorSupplier != null) {
throw new IllegalStateException("missing a rewriteAndFetch?");
}
out.writeString(fieldName);
if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_NUMCANDS_AS_OPTIONAL_PARAM)) {
out.writeOptionalVInt(numCands);
Expand Down Expand Up @@ -216,19 +278,41 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) {
out.writeOptionalFloat(vectorSimilarity);
}
if (out.getTransportVersion().before(TransportVersions.KNN_QUERY_VECTOR_BUILDER) && queryVectorBuilder != null) {
throw new IllegalArgumentException(
format(
"cannot serialize [%s] to older node of version [%s]",
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
out.getTransportVersion()
)
);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_VECTOR_BUILDER)) {
out.writeOptionalNamedWriteable(queryVectorBuilder);
}
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
if (queryVectorSupplier != null) {
throw new IllegalStateException("missing a rewriteAndFetch?");
Comment on lines 296 to +298
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be done in wire serialization (StreamOutput) as well.

}
builder.startObject(NAME);
builder.field(FIELD_FIELD.getPreferredName(), fieldName);
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
if (queryVector != null) {
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
}
if (numCands != null) {
builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands);
}
if (vectorSimilarity != null) {
builder.field(VECTOR_SIMILARITY_FIELD.getPreferredName(), vectorSimilarity);
}
if (queryVectorBuilder != null) {
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
builder.startObject(QUERY_VECTOR_BUILDER_FIELD.getPreferredName());
builder.field(queryVectorBuilder.getWriteableName(), queryVectorBuilder);
builder.endObject();
}
if (filterQueries.isEmpty() == false) {
builder.startArray(FILTER_FIELD.getPreferredName());
for (QueryBuilder filterQuery : filterQueries) {
Expand All @@ -247,6 +331,36 @@ public String getWriteableName() {

@Override
protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
if (queryVectorSupplier != null) {
if (queryVectorSupplier.get() == null) {
return this;
}
return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), numCands, vectorSimilarity).boost(boost)
.queryName(queryName)
.addFilterQueries(filterQueries);
}
if (queryVectorBuilder != null) {
SetOnce<float[]> toSet = new SetOnce<>();
ctx.registerAsyncAction((c, l) -> queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> {
toSet.set(v);
if (v == null) {
ll.onFailure(
new IllegalArgumentException(
format(
"[%s] with name [%s] returned null query_vector",
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
queryVectorBuilder.getWriteableName()
)
)
);
return;
}
ll.onResponse(null);
})));
return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, toSet::get, numCands, vectorSimilarity).boost(
boost
).queryName(queryName).addFilterQueries(filterQueries);
}
if (ctx.convertToInnerHitsRewriteContext() != null) {
return new ExactKnnQueryBuilder(queryVector, fieldName).boost(boost).queryName(queryName);
}
Expand All @@ -263,7 +377,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
rewrittenQueries.add(rewrittenQuery);
}
if (changed) {
return new KnnVectorQueryBuilder(fieldName, queryVector, numCands, vectorSimilarity).boost(boost)
return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, queryVectorSupplier, numCands, vectorSimilarity)
tteofili marked this conversation as resolved.
Show resolved Hide resolved
.boost(boost)
.queryName(queryName)
.addFilterQueries(rewrittenQueries);
}
Expand Down Expand Up @@ -338,7 +453,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {

@Override
protected int doHashCode() {
return Objects.hash(fieldName, Objects.hashCode(queryVector), numCands, filterQueries, vectorSimilarity);
return Objects.hash(fieldName, Objects.hashCode(queryVector), numCands, filterQueries, vectorSimilarity, queryVectorBuilder);
}

@Override
Expand All @@ -347,7 +462,8 @@ protected boolean doEquals(KnnVectorQueryBuilder other) {
&& Objects.equals(queryVector, other.queryVector)
&& Objects.equals(numCands, other.numCands)
&& Objects.equals(filterQueries, other.filterQueries)
&& Objects.equals(vectorSimilarity, other.vectorSimilarity);
&& Objects.equals(vectorSimilarity, other.vectorSimilarity)
&& Objects.equals(queryVectorBuilder, other.queryVectorBuilder);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.search.Query;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.compress.CompressedXContent;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
Expand All @@ -23,6 +24,8 @@
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.test.AbstractBuilderTestCase;
Expand All @@ -38,7 +41,9 @@
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.nullValue;

abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase<KnnVectorQueryBuilder> {
private static final String VECTOR_FIELD = "vector";
Expand Down Expand Up @@ -248,4 +253,39 @@ private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery
}
}
}

public void testRewriteWithQueryVectorBuilder() throws Exception {
int dims = randomInt(1024);
float[] expectedArray = new float[dims];
for (int i = 0; i < dims; i++) {
expectedArray[i] = randomFloat();
}
KnnVectorQueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder(
"field",
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(expectedArray),
5,
1f
);
knnVectorQueryBuilder.boost(randomFloat());
List<QueryBuilder> filters = new ArrayList<>();
int numFilters = randomIntBetween(1, 5);
for (int i = 0; i < numFilters; i++) {
String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME;
filters.add(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10)));
}
knnVectorQueryBuilder.addFilterQueries(filters);

QueryRewriteContext context = new QueryRewriteContext(null, null, null);
PlainActionFuture<QueryBuilder> knnFuture = new PlainActionFuture<>();
Rewriteable.rewriteAndFetch(knnVectorQueryBuilder, context, knnFuture);
KnnVectorQueryBuilder rewritten = (KnnVectorQueryBuilder) knnFuture.get();

assertThat(rewritten.getFieldName(), equalTo(knnVectorQueryBuilder.getFieldName()));
assertThat(rewritten.boost(), equalTo(knnVectorQueryBuilder.boost()));
assertThat(rewritten.queryVector().asFloatVector(), equalTo(expectedArray));
assertThat(rewritten.queryVectorBuilder(), nullValue());
assertThat(rewritten.getVectorSimilarity(), equalTo(1f));
assertThat(rewritten.filterQueries(), hasSize(numFilters));
assertThat(rewritten.filterQueries(), equalTo(filters));
}
}