-
Notifications
You must be signed in to change notification settings - Fork 24.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add modelId and modelText to KnnVectorQueryBuilder #106068
Changes from 19 commits
b53f756
d81c4e1
32e701e
1eb1286
98e9b21
b58d1d2
dca4e99
840ff1c
2e35f29
1f9ac72
f2c998e
b164d25
905bc31
d359256
9fa135c
1e58587
301705d
303431e
7f09dc1
0f9d59f
28e2136
3c22e9c
ce232be
f359b89
3013a1b
45b5a00
0de580a
9193cf7
8f97ea1
73f300d
983c7e9
f682fb7
61671b6
50e882d
822a32d
166e3c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 106068 | ||
summary: Add `modelId` and `modelText` to `KnnVectorQueryBuilder` | ||
area: Search | ||
type: enhancement | ||
issues: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -600,6 +600,13 @@ Query vector. Must have the same number of dimensions as the vector field you | |
are searching against. | ||
end::knn-query-vector[] | ||
|
||
tag::knn-query-vector-builder[] | ||
A configuration object indicating how to build a query_vector before executing | ||
the request. You must provide either a `query_vector_builder` or `query_vector`, | ||
but not both. Refer to <<knn-semantic-search>> to learn more. | ||
end::knn-query-vector-builder[] | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you need to update There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can I use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should be able to use |
||
tag::knn-similarity[] | ||
The minimum similarity required for a document to be considered a match. The similarity | ||
value calculated relates to the raw <<dense-vector-similarity, `similarity`>> used. Not the | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
import org.apache.lucene.search.Query; | ||
import org.apache.lucene.search.join.BitSetProducer; | ||
import org.apache.lucene.search.join.ToChildBlockJoinQuery; | ||
import org.apache.lucene.util.SetOnce; | ||
import org.elasticsearch.TransportVersion; | ||
import org.elasticsearch.TransportVersions; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
|
@@ -40,15 +41,16 @@ | |
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Objects; | ||
import java.util.function.Supplier; | ||
|
||
import static org.elasticsearch.common.Strings.format; | ||
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; | ||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; | ||
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; | ||
|
||
/** | ||
* A query that performs kNN search using Lucene's {@link org.apache.lucene.search.KnnFloatVectorQuery} or | ||
* {@link org.apache.lucene.search.KnnByteVectorQuery}. | ||
* | ||
*/ | ||
public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBuilder> { | ||
public static final String NAME = "knn"; | ||
|
@@ -60,6 +62,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu | |
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector"); | ||
public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField("similarity"); | ||
public static final ParseField FILTER_FIELD = new ParseField("filter"); | ||
public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder"); | ||
|
||
@SuppressWarnings("unchecked") | ||
public static final ConstructingObjectParser<KnnVectorQueryBuilder, Void> PARSER = new ConstructingObjectParser<>("knn", args -> { | ||
|
@@ -73,14 +76,23 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu | |
} else { | ||
vectorArray = null; | ||
} | ||
return new KnnVectorQueryBuilder((String) args[0], vectorArray, (Integer) args[2], (Float) args[3]); | ||
String fn = (String) args[0]; | ||
Integer numCands = (Integer) args[2]; | ||
Float vs = (Float) args[3]; | ||
QueryVectorBuilder qvb = (QueryVectorBuilder) args[4]; | ||
return new KnnVectorQueryBuilder(fn, vectorArray, qvb, null, numCands, vs); | ||
}); | ||
|
||
static { | ||
PARSER.declareString(constructorArg(), FIELD_FIELD); | ||
PARSER.declareFloatArray(constructorArg(), QUERY_VECTOR_FIELD); | ||
PARSER.declareFloatArray(optionalConstructorArg(), QUERY_VECTOR_FIELD); | ||
PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD); | ||
PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY_FIELD); | ||
PARSER.declareNamedObject( | ||
optionalConstructorArg(), | ||
(p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c), | ||
QUERY_VECTOR_BUILDER_FIELD | ||
); | ||
PARSER.declareFieldArray( | ||
KnnVectorQueryBuilder::addFilterQueries, | ||
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), | ||
|
@@ -99,18 +111,51 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) { | |
private Integer numCands; | ||
private final List<QueryBuilder> filterQueries = new ArrayList<>(); | ||
private final Float vectorSimilarity; | ||
private final QueryVectorBuilder queryVectorBuilder; | ||
private final Supplier<float[]> queryVectorSupplier; | ||
|
||
public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer numCands, Float vectorSimilarity) { | ||
this(fieldName, queryVector, null, null, numCands, vectorSimilarity); | ||
} | ||
|
||
protected KnnVectorQueryBuilder(String fieldName, QueryVectorBuilder queryVectorBuilder, Integer numCands, Float vectorSimilarity) { | ||
this(fieldName, null, queryVectorBuilder, null, numCands, vectorSimilarity); | ||
} | ||
|
||
private KnnVectorQueryBuilder( | ||
String fieldName, | ||
float[] queryVector, | ||
QueryVectorBuilder queryVectorBuilder, | ||
Supplier<float[]> queryVectorSupplier, | ||
Integer numCands, | ||
Float vectorSimilarity | ||
) { | ||
if (numCands != null && numCands > NUM_CANDS_LIMIT) { | ||
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); | ||
} | ||
if (queryVector == null) { | ||
throw new IllegalArgumentException("[" + QUERY_VECTOR_FIELD.getPreferredName() + "] must be provided"); | ||
if (queryVector == null && queryVectorBuilder == null) { | ||
throw new IllegalArgumentException( | ||
format( | ||
"either [%s] or [%s] must be provided", | ||
QUERY_VECTOR_FIELD.getPreferredName(), | ||
QUERY_VECTOR_BUILDER_FIELD.getPreferredName() | ||
) | ||
); | ||
} else if (queryVector != null && queryVectorBuilder != null) { | ||
throw new IllegalArgumentException( | ||
format( | ||
"only one of [%s] and [%s] must be provided", | ||
QUERY_VECTOR_FIELD.getPreferredName(), | ||
QUERY_VECTOR_BUILDER_FIELD.getPreferredName() | ||
) | ||
); | ||
tteofili marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
this.fieldName = fieldName; | ||
this.queryVector = queryVector; | ||
this.numCands = numCands; | ||
this.vectorSimilarity = vectorSimilarity; | ||
this.queryVectorBuilder = queryVectorBuilder; | ||
this.queryVectorSupplier = queryVectorSupplier; | ||
} | ||
|
||
public KnnVectorQueryBuilder(StreamInput in) throws IOException { | ||
|
@@ -136,6 +181,12 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException { | |
} else { | ||
this.vectorSimilarity = null; | ||
} | ||
if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_VECTOR_BUILDER)) { | ||
this.queryVectorBuilder = in.readOptionalNamedWriteable(QueryVectorBuilder.class); | ||
} else { | ||
this.queryVectorBuilder = null; | ||
} | ||
this.queryVectorSupplier = null; | ||
} | ||
|
||
public String getFieldName() { | ||
|
@@ -160,6 +211,11 @@ public List<QueryBuilder> filterQueries() { | |
return filterQueries; | ||
} | ||
|
||
@Nullable | ||
public QueryVectorBuilder queryVectorBuilder() { | ||
return queryVectorBuilder; | ||
} | ||
|
||
public KnnVectorQueryBuilder addFilterQuery(QueryBuilder filterQuery) { | ||
Objects.requireNonNull(filterQuery); | ||
this.filterQueries.add(filterQuery); | ||
|
@@ -204,19 +260,41 @@ protected void doWriteTo(StreamOutput out) throws IOException { | |
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { | ||
out.writeOptionalFloat(vectorSimilarity); | ||
} | ||
if (out.getTransportVersion().before(TransportVersions.KNN_QUERY_VECTOR_BUILDER) && queryVectorBuilder != null) { | ||
throw new IllegalArgumentException( | ||
format( | ||
"cannot serialize [%s] to older node of version [%s]", | ||
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), | ||
out.getTransportVersion() | ||
) | ||
); | ||
} | ||
if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_VECTOR_BUILDER)) { | ||
out.writeOptionalNamedWriteable(queryVectorBuilder); | ||
} | ||
benwtrent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
@Override | ||
protected void doXContent(XContentBuilder builder, Params params) throws IOException { | ||
if (queryVectorSupplier != null) { | ||
throw new IllegalStateException("missing a rewriteAndFetch?"); | ||
Comment on lines
296
to
+298
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be done in wire serialization (StreamOutput) as well. |
||
} | ||
builder.startObject(NAME); | ||
builder.field(FIELD_FIELD.getPreferredName(), fieldName); | ||
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); | ||
if (queryVector != null) { | ||
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); | ||
} | ||
if (numCands != null) { | ||
builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands); | ||
} | ||
if (vectorSimilarity != null) { | ||
builder.field(VECTOR_SIMILARITY_FIELD.getPreferredName(), vectorSimilarity); | ||
} | ||
if (queryVectorBuilder != null) { | ||
benwtrent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
builder.startObject(QUERY_VECTOR_BUILDER_FIELD.getPreferredName()); | ||
builder.field(queryVectorBuilder.getWriteableName(), queryVectorBuilder); | ||
builder.endObject(); | ||
} | ||
if (filterQueries.isEmpty() == false) { | ||
builder.startArray(FILTER_FIELD.getPreferredName()); | ||
for (QueryBuilder filterQuery : filterQueries) { | ||
|
@@ -235,6 +313,36 @@ public String getWriteableName() { | |
|
||
@Override | ||
protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { | ||
if (queryVectorSupplier != null) { | ||
if (queryVectorSupplier.get() == null) { | ||
return this; | ||
} | ||
return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), numCands, vectorSimilarity).boost(boost) | ||
.queryName(queryName) | ||
.addFilterQueries(filterQueries); | ||
} | ||
if (queryVectorBuilder != null) { | ||
SetOnce<float[]> toSet = new SetOnce<>(); | ||
ctx.registerAsyncAction((c, l) -> queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> { | ||
toSet.set(v); | ||
if (v == null) { | ||
ll.onFailure( | ||
new IllegalArgumentException( | ||
format( | ||
"[%s] with name [%s] returned null query_vector", | ||
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), | ||
queryVectorBuilder.getWriteableName() | ||
) | ||
) | ||
); | ||
return; | ||
} | ||
ll.onResponse(null); | ||
}))); | ||
return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, toSet::get, numCands, vectorSimilarity).boost( | ||
boost | ||
).queryName(queryName).addFilterQueries(filterQueries); | ||
} | ||
if (ctx.convertToInnerHitsRewriteContext() != null) { | ||
return new ExactKnnQueryBuilder(queryVector, fieldName).boost(boost).queryName(queryName); | ||
} | ||
|
@@ -251,7 +359,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { | |
rewrittenQueries.add(rewrittenQuery); | ||
} | ||
if (changed) { | ||
return new KnnVectorQueryBuilder(fieldName, queryVector, numCands, vectorSimilarity).boost(boost) | ||
return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, queryVectorSupplier, numCands, vectorSimilarity) | ||
tteofili marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.boost(boost) | ||
.queryName(queryName) | ||
.addFilterQueries(rewrittenQueries); | ||
} | ||
|
@@ -326,7 +435,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { | |
|
||
@Override | ||
protected int doHashCode() { | ||
return Objects.hash(fieldName, Arrays.hashCode(queryVector), numCands, filterQueries, vectorSimilarity); | ||
return Objects.hash(fieldName, Arrays.hashCode(queryVector), numCands, filterQueries, vectorSimilarity, queryVectorBuilder); | ||
} | ||
|
||
@Override | ||
|
@@ -335,7 +444,8 @@ protected boolean doEquals(KnnVectorQueryBuilder other) { | |
&& Arrays.equals(queryVector, other.queryVector) | ||
&& Objects.equals(numCands, other.numCands) | ||
&& Objects.equals(filterQueries, other.filterQueries) | ||
&& Objects.equals(vectorSimilarity, other.vectorSimilarity); | ||
&& Objects.equals(vectorSimilarity, other.vectorSimilarity) | ||
&& Objects.equals(queryVectorBuilder, other.queryVectorBuilder); | ||
} | ||
|
||
@Override | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.