Skip to content

Commit

Permalink
Add new query_vector_builder option to knn search clause (#93331)
Browse files Browse the repository at this point in the history
This adds a new option to the knn search clause called query_vector_builder. This is a pluggable configuration that allows the query_vector created or retrieved.
  • Loading branch information
benwtrent committed Feb 1, 2023
1 parent 93ecc4d commit 7f9f3bc
Show file tree
Hide file tree
Showing 10 changed files with 675 additions and 11 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/93331.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 93331
summary: Add new `query_vector_builder` option to knn search clause
area: Search
type: enhancement
issues: []
8 changes: 7 additions & 1 deletion docs/reference/search/search.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,14 @@ include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-k]
include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-num-candidates]
`query_vector`::
(Required, array of floats)
(Optional, array of floats)
include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-query-vector]
`query_vector_builder`::
(Optional, object)
A configuration object indicating how to build a query_vector before executing the request. You must provide
a `query_vector_builder` or `query_vector`, but not both.
====

[[search-api-min-score]]
Expand Down
42 changes: 42 additions & 0 deletions server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.elasticsearch.search.suggest.Suggest;
import org.elasticsearch.search.suggest.Suggester;
import org.elasticsearch.search.suggest.SuggestionBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.xcontent.ContextParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContent;
Expand Down Expand Up @@ -73,6 +74,14 @@ default List<SignificanceHeuristicSpec<?>> getSignificanceHeuristics() {
return emptyList();
}

/**
* The new {@link QueryVectorBuilder}s defined by this plugin. {@linkplain QueryVectorBuilder}s can be used within a kNN
* search to build the query vector instead of having the user provide the vector directly
*/
default List<QueryVectorBuilderSpec<?>> getQueryVectorBuilders() {
return emptyList();
}

/**
* The new {@link FetchSubPhase}s defined by this plugin.
*/
Expand Down Expand Up @@ -592,4 +601,37 @@ public Map<String, Highlighter> getHighlighters() {
return highlighters;
}
}

/**
* Specification of custom {@link QueryVectorBuilder}.
*/
class QueryVectorBuilderSpec<T extends QueryVectorBuilder> extends SearchExtensionSpec<T, BiFunction<XContentParser, Void, T>> {
/**
* Specification of custom {@link QueryVectorBuilder}.
*
* @param name holds the names by which this query vector builder might be parsed.
* The {@link ParseField#getPreferredName()} is special as it
* is the name by under which the reader is registered. So it is the name that the builder should use as its
* {@link NamedWriteable#getWriteableName()} too.
* @param reader the reader registered for this query vector builder. Typically a reference to a constructor that takes a
* {@link StreamInput}
* @param parser the parser the reads the query vector builder from xcontent
*/
public QueryVectorBuilderSpec(ParseField name, Writeable.Reader<T> reader, BiFunction<XContentParser, Void, T> parser) {
super(name, reader, parser);
}

/**
* Specification of custom {@link QueryVectorBuilder}.
*
* @param name the name by which this query vector builder might be parsed or deserialized.
* Make sure that the query builder returns this name for {@link NamedWriteable#getWriteableName()}.
* @param reader the reader registered for this query vector builder. Typically a reference to a constructor that takes a
* {@link StreamInput}
* @param parser the parser the reads the query vector builder from xcontent
*/
public QueryVectorBuilderSpec(String name, Writeable.Reader<T> reader, BiFunction<XContentParser, Void, T> parser) {
super(name, reader, parser);
}
}
}
14 changes: 14 additions & 0 deletions server/src/main/java/org/elasticsearch/search/SearchModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
import org.elasticsearch.plugins.SearchPlugin.FetchPhaseConstructionContext;
import org.elasticsearch.plugins.SearchPlugin.PipelineAggregationSpec;
import org.elasticsearch.plugins.SearchPlugin.QuerySpec;
import org.elasticsearch.plugins.SearchPlugin.QueryVectorBuilderSpec;
import org.elasticsearch.plugins.SearchPlugin.RescorerSpec;
import org.elasticsearch.plugins.SearchPlugin.ScoreFunctionSpec;
import org.elasticsearch.plugins.SearchPlugin.SearchExtSpec;
Expand Down Expand Up @@ -244,6 +245,7 @@
import org.elasticsearch.search.suggest.term.TermSuggestionBuilder;
import org.elasticsearch.search.vectors.KnnScoreDocQueryBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;
Expand Down Expand Up @@ -305,6 +307,7 @@ public SearchModule(Settings settings, List<SearchPlugin> plugins) {
registerSorts();
registerValueFormats();
registerSignificanceHeuristics(plugins);
registerQueryVectorBuilders(plugins);
this.valuesSourceRegistry = registerAggregations(plugins);
registerPipelineAggregations(plugins);
registerFetchSubPhases(plugins);
Expand Down Expand Up @@ -980,6 +983,17 @@ private <T extends SignificanceHeuristic> void registerSignificanceHeuristic(Sig
);
}

private void registerQueryVectorBuilders(List<SearchPlugin> plugins) {
registerFromPlugin(plugins, SearchPlugin::getQueryVectorBuilders, this::registerQueryVectorBuilder);
}

private <T extends QueryVectorBuilder> void registerQueryVectorBuilder(QueryVectorBuilderSpec<?> spec) {
namedXContents.add(new NamedXContentRegistry.Entry(QueryVectorBuilder.class, spec.getName(), p -> spec.getParser().apply(p, null)));
namedWriteables.add(
new NamedWriteableRegistry.Entry(QueryVectorBuilder.class, spec.getName().getPreferredName(), spec.getReader())
);
}

private void registerFetchSubPhases(List<SearchPlugin> plugins) {
registerFetchSubPhase(new ExplainPhase());
registerFetchSubPhase(new StoredFieldsPhase());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

package org.elasticsearch.search.vectors;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand All @@ -27,8 +30,11 @@
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.common.Strings.format;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

/**
* Defines a kNN search to run in the search request.
Expand All @@ -39,25 +45,36 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates");
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField BOOST_FIELD = AbstractQueryBuilder.BOOST_FIELD;

private static final ConstructingObjectParser<KnnSearchBuilder, Void> PARSER = new ConstructingObjectParser<>("knn", args -> {
@SuppressWarnings("unchecked")
// TODO optimize parsing for when BYTE values are provided
List<Float> vector = (List<Float>) args[1];
float[] vectorArray = new float[vector.size()];
for (int i = 0; i < vector.size(); i++) {
vectorArray[i] = vector.get(i);
final float[] vectorArray;
if (vector != null) {
vectorArray = new float[vector.size()];
for (int i = 0; i < vector.size(); i++) {
vectorArray[i] = vector.get(i);
}
} else {
vectorArray = null;
}
return new KnnSearchBuilder((String) args[0], vectorArray, (int) args[2], (int) args[3]);
return new KnnSearchBuilder((String) args[0], vectorArray, (QueryVectorBuilder) args[4], (int) args[2], (int) args[3]);
});

static {
PARSER.declareString(constructorArg(), FIELD_FIELD);
PARSER.declareFloatArray(constructorArg(), QUERY_VECTOR_FIELD);
PARSER.declareFloatArray(optionalConstructorArg(), QUERY_VECTOR_FIELD);
PARSER.declareInt(constructorArg(), K_FIELD);
PARSER.declareInt(constructorArg(), NUM_CANDS_FIELD);
PARSER.declareNamedObject(
optionalConstructorArg(),
(p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c),
QUERY_VECTOR_BUILDER_FIELD
);
PARSER.declareFieldArray(
KnnSearchBuilder::addFilterQueries,
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
Expand All @@ -73,6 +90,8 @@ public static KnnSearchBuilder fromXContent(XContentParser parser) throws IOExce

final String field;
final float[] queryVector;
final QueryVectorBuilder queryVectorBuilder;
private final Supplier<float[]> querySupplier;
final int k;
final int numCands;
final List<QueryBuilder> filterQueries;
Expand All @@ -87,6 +106,27 @@ public static KnnSearchBuilder fromXContent(XContentParser parser) throws IOExce
* @param numCands the number of nearest neighbor candidates to consider per shard
*/
public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands) {
this(field, Objects.requireNonNull(queryVector, format("[%s] cannot be null", QUERY_VECTOR_FIELD)), null, k, numCands);
}

/**
* Defines a kNN search where the query vector will be provided by the queryVectorBuilder
* @param field the name of the vector field to search against
* @param queryVectorBuilder the query vector builder
* @param k the final number of nearest neighbors to return as top hits
* @param numCands the number of nearest neighbor candidates to consider per shard
*/
public KnnSearchBuilder(String field, QueryVectorBuilder queryVectorBuilder, int k, int numCands) {
this(
field,
null,
Objects.requireNonNull(queryVectorBuilder, format("[%s] cannot be null", QUERY_VECTOR_BUILDER_FIELD.getPreferredName())),
k,
numCands
);
}

private KnnSearchBuilder(String field, float[] queryVector, QueryVectorBuilder queryVectorBuilder, int k, int numCands) {
if (k < 1) {
throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
}
Expand All @@ -98,11 +138,41 @@ public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands)
if (numCands > NUM_CANDS_LIMIT) {
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
}
if (queryVector == null && queryVectorBuilder == null) {
throw new IllegalArgumentException(
format(
"either [%s] or [%s] must be provided",
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
QUERY_VECTOR_FIELD.getPreferredName()
)
);
}
if (queryVector != null && queryVectorBuilder != null) {
throw new IllegalArgumentException(
format(
"cannot provide both [%s] and [%s]",
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
QUERY_VECTOR_FIELD.getPreferredName()
)
);
}
this.field = field;
this.queryVector = queryVector;
this.queryVector = queryVector == null ? new float[0] : queryVector;
this.queryVectorBuilder = queryVectorBuilder;
this.k = k;
this.numCands = numCands;
this.filterQueries = new ArrayList<>();
this.querySupplier = null;
}

private KnnSearchBuilder(String field, Supplier<float[]> querySupplier, int k, int numCands, List<QueryBuilder> filterQueries) {
this.field = field;
this.queryVector = new float[0];
this.queryVectorBuilder = null;
this.k = k;
this.numCands = numCands;
this.filterQueries = filterQueries;
this.querySupplier = querySupplier;
}

public KnnSearchBuilder(StreamInput in) throws IOException {
Expand All @@ -112,6 +182,12 @@ public KnnSearchBuilder(StreamInput in) throws IOException {
this.queryVector = in.readFloatArray();
this.filterQueries = in.readNamedWriteableList(QueryBuilder.class);
this.boost = in.readFloat();
if (in.getTransportVersion().onOrAfter(TransportVersion.V_8_7_0)) {
this.queryVectorBuilder = in.readOptionalNamedWriteable(QueryVectorBuilder.class);
} else {
this.queryVectorBuilder = null;
}
this.querySupplier = null;
}

public int k() {
Expand Down Expand Up @@ -140,6 +216,32 @@ public KnnSearchBuilder boost(float boost) {

@Override
public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
if (querySupplier != null) {
if (querySupplier.get() == null) {
return this;
}
return new KnnSearchBuilder(field, querySupplier.get(), k, numCands).boost(boost).addFilterQueries(filterQueries);
}
if (queryVectorBuilder != null) {
SetOnce<float[]> toSet = new SetOnce<>();
ctx.registerAsyncAction((c, l) -> queryVectorBuilder.buildVector(c, ActionListener.wrap(v -> {
toSet.set(v);
if (v == null) {
l.onFailure(
new IllegalArgumentException(
format(
"[%s] with name [%s] returned null query_vector",
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
queryVectorBuilder.getWriteableName()
)
)
);
return;
}
l.onResponse(null);
}, l::onFailure)));
return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries).boost(boost);
}
boolean changed = false;
List<QueryBuilder> rewrittenQueries = new ArrayList<>(filterQueries.size());
for (QueryBuilder query : filterQueries) {
Expand All @@ -156,6 +258,9 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
}

public KnnVectorQueryBuilder toQueryBuilder() {
if (queryVectorBuilder != null) {
throw new IllegalArgumentException("missing rewrite");
}
return new KnnVectorQueryBuilder(field, queryVector, numCands).boost(boost).addFilterQueries(filterQueries);
}

Expand All @@ -168,21 +273,38 @@ public boolean equals(Object o) {
&& numCands == that.numCands
&& Objects.equals(field, that.field)
&& Arrays.equals(queryVector, that.queryVector)
&& Objects.equals(queryVectorBuilder, that.queryVectorBuilder)
&& Objects.equals(querySupplier, that.querySupplier)
&& Objects.equals(filterQueries, that.filterQueries)
&& boost == that.boost;
}

@Override
public int hashCode() {
return Objects.hash(field, k, numCands, Arrays.hashCode(queryVector), Objects.hashCode(filterQueries), boost);
return Objects.hash(
field,
k,
numCands,
querySupplier,
queryVectorBuilder,
Arrays.hashCode(queryVector),
Objects.hashCode(filterQueries),
boost
);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(FIELD_FIELD.getPreferredName(), field)
.field(K_FIELD.getPreferredName(), k)
.field(NUM_CANDS_FIELD.getPreferredName(), numCands)
.array(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
.field(NUM_CANDS_FIELD.getPreferredName(), numCands);
if (queryVectorBuilder != null) {
builder.startObject(QUERY_VECTOR_BUILDER_FIELD.getPreferredName());
builder.field(queryVectorBuilder.getWriteableName(), queryVectorBuilder);
builder.endObject();
} else {
builder.array(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
}

if (filterQueries.isEmpty() == false) {
builder.startArray(FILTER_FIELD.getPreferredName());
Expand All @@ -201,11 +323,26 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

@Override
public void writeTo(StreamOutput out) throws IOException {
if (querySupplier != null) {
throw new IllegalStateException("missing a rewriteAndFetch?");
}
out.writeString(field);
out.writeVInt(k);
out.writeVInt(numCands);
out.writeFloatArray(queryVector);
out.writeNamedWriteableList(filterQueries);
out.writeFloat(boost);
if (out.getTransportVersion().before(TransportVersion.V_8_7_0) && queryVectorBuilder != null) {
throw new IllegalArgumentException(
format(
"cannot serialize [%s] to older node of version [%s]",
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
out.getTransportVersion()
)
);
}
if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_7_0)) {
out.writeOptionalNamedWriteable(queryVectorBuilder);
}
}
}

0 comments on commit 7f9f3bc

Please sign in to comment.