Skip to content

Commit

Permalink
Add modelId and modelText to KnnVectorQueryBuilder (#106068)
Browse files Browse the repository at this point in the history
* Add modelId and modelText to KnnVectorQueryBuilder

Use QueryVectorBuilder within KnnVectorQueryBuilder to make it
possible to perform knn queries also when a query vector is not
immediately available. Supplying a text_embedding query_vector_builder
with model_text and model_id instead of the query_vector will result
in the generation of a query_vector by calling inference on the
specified model_id with the supplied model_text (during query
rewrite). This is consistent with the way query vectors are built
from model_id / model_text in KnnSearchBuilder (DFS phase).
  • Loading branch information
tteofili committed Mar 18, 2024
1 parent bc249ef commit 7bff3b3
Show file tree
Hide file tree
Showing 8 changed files with 680 additions and 140 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/106068.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 106068
summary: Add `modelId` and `modelText` to `KnnVectorQueryBuilder`
area: Search
type: enhancement
issues: []
11 changes: 10 additions & 1 deletion docs/reference/query-dsl/knn-query.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,19 @@ the top `size` results.
`query_vector`::
+
--
(Required, array of floats or string) Query vector. Must have the same number of dimensions
(Optional, array of floats or string) Query vector. Must have the same number of dimensions
as the vector field you are searching against. Must be either an array of floats or a hex-encoded byte vector.
Either this or `query_vector_builder` must be provided.
--

`query_vector_builder`::
+
--
(Optional, object) Query vector builder.
include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-query-vector-builder]
--


`num_candidates`::
+
--
Expand Down
7 changes: 7 additions & 0 deletions docs/reference/rest-api/common-parms.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,13 @@ Query vector. Must have the same number of dimensions as the vector field you
are searching against. Must be either an array of floats or a hex-encoded byte vector.
end::knn-query-vector[]

tag::knn-query-vector-builder[]
A configuration object indicating how to build a query_vector before executing
the request. You must provide either a `query_vector_builder` or `query_vector`,
but not both. Refer to <<knn-semantic-search>> to learn more.
end::knn-query-vector-builder[]


tag::knn-similarity[]
The minimum similarity required for a document to be considered a match. The similarity
value calculated relates to the raw <<dense-vector-similarity, `similarity`>> used. Not the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ static TransportVersion def(int id) {
public static final TransportVersion AGGS_EXCLUDED_DELETED_DOCS = def(8_609_00_0);
public static final TransportVersion ESQL_SERIALIZE_BIG_ARRAY = def(8_610_00_0);
public static final TransportVersion AUTO_SHARDING_ROLLOVER_CONDITION = def(8_611_00_0);
public static final TransportVersion KNN_QUERY_VECTOR_BUILDER = def(8_612_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
Expand All @@ -39,15 +40,16 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.common.Strings.format;
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

/**
* A query that performs kNN search using Lucene's {@link org.apache.lucene.search.KnnFloatVectorQuery} or
* {@link org.apache.lucene.search.KnnByteVectorQuery}.
*
*/
public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBuilder> {
public static final String NAME = "knn";
Expand All @@ -59,11 +61,19 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField("similarity");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<KnnVectorQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
"knn",
args -> new KnnVectorQueryBuilder((String) args[0], (VectorData) args[1], (Integer) args[2], (Float) args[3])
args -> new KnnVectorQueryBuilder(
(String) args[0],
(VectorData) args[1],
(QueryVectorBuilder) args[4],
null,
(Integer) args[2],
(Float) args[3]
)
);

static {
Expand All @@ -76,6 +86,11 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
);
PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD);
PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY_FIELD);
PARSER.declareNamedObject(
optionalConstructorArg(),
(p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c),
QUERY_VECTOR_BUILDER_FIELD
);
PARSER.declareFieldArray(
KnnVectorQueryBuilder::addFilterQueries,
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
Expand All @@ -94,26 +109,59 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) {
private Integer numCands;
private final List<QueryBuilder> filterQueries = new ArrayList<>();
private final Float vectorSimilarity;
private final QueryVectorBuilder queryVectorBuilder;
private final Supplier<float[]> queryVectorSupplier;

public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer numCands, Float vectorSimilarity) {
this(fieldName, VectorData.fromFloats(queryVector), numCands, vectorSimilarity);
this(fieldName, VectorData.fromFloats(queryVector), null, null, numCands, vectorSimilarity);
}

protected KnnVectorQueryBuilder(String fieldName, QueryVectorBuilder queryVectorBuilder, Integer numCands, Float vectorSimilarity) {
this(fieldName, null, queryVectorBuilder, null, numCands, vectorSimilarity);
}

public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer numCands, Float vectorSimilarity) {
this(fieldName, VectorData.fromBytes(queryVector), numCands, vectorSimilarity);
this(fieldName, VectorData.fromBytes(queryVector), null, null, numCands, vectorSimilarity);
}

public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer numCands, Float vectorSimilarity) {
this(fieldName, queryVector, null, null, numCands, vectorSimilarity);
}

private KnnVectorQueryBuilder(
String fieldName,
VectorData queryVector,
QueryVectorBuilder queryVectorBuilder,
Supplier<float[]> queryVectorSupplier,
Integer numCands,
Float vectorSimilarity
) {
if (numCands != null && numCands > NUM_CANDS_LIMIT) {
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
}
if (queryVector == null) {
throw new IllegalArgumentException("[" + QUERY_VECTOR_FIELD.getPreferredName() + "] must be provided");
if (queryVector == null && queryVectorBuilder == null) {
throw new IllegalArgumentException(
format(
"either [%s] or [%s] must be provided",
QUERY_VECTOR_FIELD.getPreferredName(),
QUERY_VECTOR_BUILDER_FIELD.getPreferredName()
)
);
} else if (queryVector != null && queryVectorBuilder != null) {
throw new IllegalArgumentException(
format(
"only one of [%s] and [%s] must be provided",
QUERY_VECTOR_FIELD.getPreferredName(),
QUERY_VECTOR_BUILDER_FIELD.getPreferredName()
)
);
}
this.fieldName = fieldName;
this.queryVector = queryVector;
this.numCands = numCands;
this.vectorSimilarity = vectorSimilarity;
this.queryVectorBuilder = queryVectorBuilder;
this.queryVectorSupplier = queryVectorSupplier;
}

public KnnVectorQueryBuilder(StreamInput in) throws IOException {
Expand Down Expand Up @@ -144,6 +192,12 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException {
} else {
this.vectorSimilarity = null;
}
if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_VECTOR_BUILDER)) {
this.queryVectorBuilder = in.readOptionalNamedWriteable(QueryVectorBuilder.class);
} else {
this.queryVectorBuilder = null;
}
this.queryVectorSupplier = null;
}

public String getFieldName() {
Expand All @@ -168,6 +222,11 @@ public List<QueryBuilder> filterQueries() {
return filterQueries;
}

@Nullable
public QueryVectorBuilder queryVectorBuilder() {
return queryVectorBuilder;
}

public KnnVectorQueryBuilder addFilterQuery(QueryBuilder filterQuery) {
Objects.requireNonNull(filterQuery);
this.filterQueries.add(filterQuery);
Expand All @@ -182,6 +241,9 @@ public KnnVectorQueryBuilder addFilterQueries(List<QueryBuilder> filterQueries)

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
if (queryVectorSupplier != null) {
throw new IllegalStateException("missing a rewriteAndFetch?");
}
out.writeString(fieldName);
if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_NUMCANDS_AS_OPTIONAL_PARAM)) {
out.writeOptionalVInt(numCands);
Expand Down Expand Up @@ -216,19 +278,41 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) {
out.writeOptionalFloat(vectorSimilarity);
}
if (out.getTransportVersion().before(TransportVersions.KNN_QUERY_VECTOR_BUILDER) && queryVectorBuilder != null) {
throw new IllegalArgumentException(
format(
"cannot serialize [%s] to older node of version [%s]",
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
out.getTransportVersion()
)
);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_VECTOR_BUILDER)) {
out.writeOptionalNamedWriteable(queryVectorBuilder);
}
}

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
if (queryVectorSupplier != null) {
throw new IllegalStateException("missing a rewriteAndFetch?");
}
builder.startObject(NAME);
builder.field(FIELD_FIELD.getPreferredName(), fieldName);
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
if (queryVector != null) {
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
}
if (numCands != null) {
builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands);
}
if (vectorSimilarity != null) {
builder.field(VECTOR_SIMILARITY_FIELD.getPreferredName(), vectorSimilarity);
}
if (queryVectorBuilder != null) {
builder.startObject(QUERY_VECTOR_BUILDER_FIELD.getPreferredName());
builder.field(queryVectorBuilder.getWriteableName(), queryVectorBuilder);
builder.endObject();
}
if (filterQueries.isEmpty() == false) {
builder.startArray(FILTER_FIELD.getPreferredName());
for (QueryBuilder filterQuery : filterQueries) {
Expand All @@ -247,6 +331,36 @@ public String getWriteableName() {

@Override
protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
if (queryVectorSupplier != null) {
if (queryVectorSupplier.get() == null) {
return this;
}
return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), numCands, vectorSimilarity).boost(boost)
.queryName(queryName)
.addFilterQueries(filterQueries);
}
if (queryVectorBuilder != null) {
SetOnce<float[]> toSet = new SetOnce<>();
ctx.registerAsyncAction((c, l) -> queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> {
toSet.set(v);
if (v == null) {
ll.onFailure(
new IllegalArgumentException(
format(
"[%s] with name [%s] returned null query_vector",
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
queryVectorBuilder.getWriteableName()
)
)
);
return;
}
ll.onResponse(null);
})));
return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, toSet::get, numCands, vectorSimilarity).boost(
boost
).queryName(queryName).addFilterQueries(filterQueries);
}
if (ctx.convertToInnerHitsRewriteContext() != null) {
return new ExactKnnQueryBuilder(queryVector, fieldName).boost(boost).queryName(queryName);
}
Expand All @@ -263,7 +377,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
rewrittenQueries.add(rewrittenQuery);
}
if (changed) {
return new KnnVectorQueryBuilder(fieldName, queryVector, numCands, vectorSimilarity).boost(boost)
return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, queryVectorSupplier, numCands, vectorSimilarity)
.boost(boost)
.queryName(queryName)
.addFilterQueries(rewrittenQueries);
}
Expand Down Expand Up @@ -338,7 +453,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {

@Override
protected int doHashCode() {
return Objects.hash(fieldName, Objects.hashCode(queryVector), numCands, filterQueries, vectorSimilarity);
return Objects.hash(fieldName, Objects.hashCode(queryVector), numCands, filterQueries, vectorSimilarity, queryVectorBuilder);
}

@Override
Expand All @@ -347,7 +462,8 @@ protected boolean doEquals(KnnVectorQueryBuilder other) {
&& Objects.equals(queryVector, other.queryVector)
&& Objects.equals(numCands, other.numCands)
&& Objects.equals(filterQueries, other.filterQueries)
&& Objects.equals(vectorSimilarity, other.vectorSimilarity);
&& Objects.equals(vectorSimilarity, other.vectorSimilarity)
&& Objects.equals(queryVectorBuilder, other.queryVectorBuilder);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.search.Query;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.compress.CompressedXContent;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
Expand All @@ -23,6 +24,8 @@
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.test.AbstractBuilderTestCase;
Expand All @@ -38,7 +41,9 @@
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.nullValue;

abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase<KnnVectorQueryBuilder> {
private static final String VECTOR_FIELD = "vector";
Expand Down Expand Up @@ -248,4 +253,39 @@ private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery
}
}
}

public void testRewriteWithQueryVectorBuilder() throws Exception {
int dims = randomInt(1024);
float[] expectedArray = new float[dims];
for (int i = 0; i < dims; i++) {
expectedArray[i] = randomFloat();
}
KnnVectorQueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder(
"field",
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(expectedArray),
5,
1f
);
knnVectorQueryBuilder.boost(randomFloat());
List<QueryBuilder> filters = new ArrayList<>();
int numFilters = randomIntBetween(1, 5);
for (int i = 0; i < numFilters; i++) {
String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME;
filters.add(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10)));
}
knnVectorQueryBuilder.addFilterQueries(filters);

QueryRewriteContext context = new QueryRewriteContext(null, null, null);
PlainActionFuture<QueryBuilder> knnFuture = new PlainActionFuture<>();
Rewriteable.rewriteAndFetch(knnVectorQueryBuilder, context, knnFuture);
KnnVectorQueryBuilder rewritten = (KnnVectorQueryBuilder) knnFuture.get();

assertThat(rewritten.getFieldName(), equalTo(knnVectorQueryBuilder.getFieldName()));
assertThat(rewritten.boost(), equalTo(knnVectorQueryBuilder.boost()));
assertThat(rewritten.queryVector().asFloatVector(), equalTo(expectedArray));
assertThat(rewritten.queryVectorBuilder(), nullValue());
assertThat(rewritten.getVectorSimilarity(), equalTo(1f));
assertThat(rewritten.filterQueries(), hasSize(numFilters));
assertThat(rewritten.filterQueries(), equalTo(filters));
}
}

0 comments on commit 7bff3b3

Please sign in to comment.