-
Notifications
You must be signed in to change notification settings - Fork 24.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add modelId and modelText to KnnVectorQueryBuilder #106068
Merged
Merged
Changes from 25 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
b53f756
Add modelId and modelText to KnnVectorQueryBuilder
tteofili d81c4e1
Added missing files
tteofili 32e701e
Fix query in yaml test
tteofili 1eb1286
Update docs/changelog/106068.yaml
tteofili 98e9b21
Fixed docs, bumped TransportVersion
tteofili b58d1d2
Merge branch 'knn_dsl_modeltext_modelid' of github.com:tteofili/elast…
tteofili dca4e99
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili 840ff1c
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili 2e35f29
correct assertion for yml test
tteofili 1f9ac72
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili f2c998e
Fix ctor visibility
tteofili b164d25
Improved input validation
tteofili 905bc31
validation should not check for query supplier
tteofili d359256
dropped unused ctor
tteofili 9fa135c
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili 1e58587
formatting fix
tteofili 301705d
revert deprecated api doc change, use include for qvb
tteofili 303431e
validation in KNNVQB#doXContent
tteofili 7f09dc1
IT refactoring to avoid duplicate tests
tteofili 0f9d59f
fixed docs error
tteofili 28e2136
skip yml test for versions <= 8.14
tteofili 3c22e9c
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili ce232be
serialization validation for supplier
tteofili f359b89
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili 3013a1b
rewrite test expected vector should be float[]
tteofili 45b5a00
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili 0de580a
nested vectors with inner hits test
tteofili 9193cf7
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili 8f97ea1
add required privileges for indexing nested vectors
tteofili 73f300d
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili 983c7e9
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili f682fb7
add required privileges for indexing nested vectors
tteofili 61671b6
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili 50e882d
index refresh needs token too
tteofili 822a32d
Merge branch 'main' of github.com:elastic/elasticsearch into knn_dsl_…
tteofili 166e3c5
added skip version for nested vectors test
tteofili File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 106068 | ||
summary: Add `modelId` and `modelText` to `KnnVectorQueryBuilder` | ||
area: Search | ||
type: enhancement | ||
issues: [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
import org.apache.lucene.search.Query; | ||
import org.apache.lucene.search.join.BitSetProducer; | ||
import org.apache.lucene.search.join.ToChildBlockJoinQuery; | ||
import org.apache.lucene.util.SetOnce; | ||
import org.elasticsearch.TransportVersion; | ||
import org.elasticsearch.TransportVersions; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
|
@@ -39,15 +40,16 @@ | |
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Objects; | ||
import java.util.function.Supplier; | ||
|
||
import static org.elasticsearch.common.Strings.format; | ||
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; | ||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; | ||
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; | ||
|
||
/** | ||
* A query that performs kNN search using Lucene's {@link org.apache.lucene.search.KnnFloatVectorQuery} or | ||
* {@link org.apache.lucene.search.KnnByteVectorQuery}. | ||
* | ||
*/ | ||
public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBuilder> { | ||
public static final String NAME = "knn"; | ||
|
@@ -59,11 +61,19 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu | |
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector"); | ||
public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField("similarity"); | ||
public static final ParseField FILTER_FIELD = new ParseField("filter"); | ||
public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder"); | ||
|
||
@SuppressWarnings("unchecked") | ||
public static final ConstructingObjectParser<KnnVectorQueryBuilder, Void> PARSER = new ConstructingObjectParser<>( | ||
"knn", | ||
args -> new KnnVectorQueryBuilder((String) args[0], (VectorData) args[1], (Integer) args[2], (Float) args[3]) | ||
args -> new KnnVectorQueryBuilder( | ||
(String) args[0], | ||
(VectorData) args[1], | ||
(QueryVectorBuilder) args[4], | ||
null, | ||
(Integer) args[2], | ||
(Float) args[3] | ||
) | ||
); | ||
|
||
static { | ||
|
@@ -76,6 +86,11 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu | |
); | ||
PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD); | ||
PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY_FIELD); | ||
PARSER.declareNamedObject( | ||
optionalConstructorArg(), | ||
(p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c), | ||
QUERY_VECTOR_BUILDER_FIELD | ||
); | ||
PARSER.declareFieldArray( | ||
KnnVectorQueryBuilder::addFilterQueries, | ||
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), | ||
|
@@ -94,26 +109,59 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) { | |
private Integer numCands; | ||
private final List<QueryBuilder> filterQueries = new ArrayList<>(); | ||
private final Float vectorSimilarity; | ||
private final QueryVectorBuilder queryVectorBuilder; | ||
private final Supplier<float[]> queryVectorSupplier; | ||
|
||
public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer numCands, Float vectorSimilarity) { | ||
this(fieldName, VectorData.fromFloats(queryVector), numCands, vectorSimilarity); | ||
this(fieldName, VectorData.fromFloats(queryVector), null, null, numCands, vectorSimilarity); | ||
} | ||
|
||
protected KnnVectorQueryBuilder(String fieldName, QueryVectorBuilder queryVectorBuilder, Integer numCands, Float vectorSimilarity) { | ||
this(fieldName, null, queryVectorBuilder, null, numCands, vectorSimilarity); | ||
} | ||
|
||
public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer numCands, Float vectorSimilarity) { | ||
this(fieldName, VectorData.fromBytes(queryVector), numCands, vectorSimilarity); | ||
this(fieldName, VectorData.fromBytes(queryVector), null, null, numCands, vectorSimilarity); | ||
} | ||
|
||
public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer numCands, Float vectorSimilarity) { | ||
this(fieldName, queryVector, null, null, numCands, vectorSimilarity); | ||
} | ||
|
||
private KnnVectorQueryBuilder( | ||
String fieldName, | ||
VectorData queryVector, | ||
QueryVectorBuilder queryVectorBuilder, | ||
Supplier<float[]> queryVectorSupplier, | ||
Integer numCands, | ||
Float vectorSimilarity | ||
) { | ||
if (numCands != null && numCands > NUM_CANDS_LIMIT) { | ||
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); | ||
} | ||
if (queryVector == null) { | ||
throw new IllegalArgumentException("[" + QUERY_VECTOR_FIELD.getPreferredName() + "] must be provided"); | ||
if (queryVector == null && queryVectorBuilder == null) { | ||
throw new IllegalArgumentException( | ||
format( | ||
"either [%s] or [%s] must be provided", | ||
QUERY_VECTOR_FIELD.getPreferredName(), | ||
QUERY_VECTOR_BUILDER_FIELD.getPreferredName() | ||
) | ||
); | ||
} else if (queryVector != null && queryVectorBuilder != null) { | ||
throw new IllegalArgumentException( | ||
format( | ||
"only one of [%s] and [%s] must be provided", | ||
QUERY_VECTOR_FIELD.getPreferredName(), | ||
QUERY_VECTOR_BUILDER_FIELD.getPreferredName() | ||
) | ||
); | ||
tteofili marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
this.fieldName = fieldName; | ||
this.queryVector = queryVector; | ||
this.numCands = numCands; | ||
this.vectorSimilarity = vectorSimilarity; | ||
this.queryVectorBuilder = queryVectorBuilder; | ||
this.queryVectorSupplier = queryVectorSupplier; | ||
} | ||
|
||
public KnnVectorQueryBuilder(StreamInput in) throws IOException { | ||
|
@@ -144,6 +192,12 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException { | |
} else { | ||
this.vectorSimilarity = null; | ||
} | ||
if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_VECTOR_BUILDER)) { | ||
this.queryVectorBuilder = in.readOptionalNamedWriteable(QueryVectorBuilder.class); | ||
} else { | ||
this.queryVectorBuilder = null; | ||
} | ||
this.queryVectorSupplier = null; | ||
} | ||
|
||
public String getFieldName() { | ||
|
@@ -168,6 +222,11 @@ public List<QueryBuilder> filterQueries() { | |
return filterQueries; | ||
} | ||
|
||
@Nullable | ||
public QueryVectorBuilder queryVectorBuilder() { | ||
return queryVectorBuilder; | ||
} | ||
|
||
public KnnVectorQueryBuilder addFilterQuery(QueryBuilder filterQuery) { | ||
Objects.requireNonNull(filterQuery); | ||
this.filterQueries.add(filterQuery); | ||
|
@@ -182,6 +241,9 @@ public KnnVectorQueryBuilder addFilterQueries(List<QueryBuilder> filterQueries) | |
|
||
@Override | ||
protected void doWriteTo(StreamOutput out) throws IOException { | ||
if (queryVectorSupplier != null) { | ||
throw new IllegalStateException("missing a rewriteAndFetch?"); | ||
} | ||
out.writeString(fieldName); | ||
if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_NUMCANDS_AS_OPTIONAL_PARAM)) { | ||
out.writeOptionalVInt(numCands); | ||
|
@@ -216,19 +278,41 @@ protected void doWriteTo(StreamOutput out) throws IOException { | |
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { | ||
out.writeOptionalFloat(vectorSimilarity); | ||
} | ||
if (out.getTransportVersion().before(TransportVersions.KNN_QUERY_VECTOR_BUILDER) && queryVectorBuilder != null) { | ||
throw new IllegalArgumentException( | ||
format( | ||
"cannot serialize [%s] to older node of version [%s]", | ||
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), | ||
out.getTransportVersion() | ||
) | ||
); | ||
} | ||
if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_VECTOR_BUILDER)) { | ||
out.writeOptionalNamedWriteable(queryVectorBuilder); | ||
} | ||
benwtrent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
@Override | ||
protected void doXContent(XContentBuilder builder, Params params) throws IOException { | ||
if (queryVectorSupplier != null) { | ||
throw new IllegalStateException("missing a rewriteAndFetch?"); | ||
Comment on lines
296
to
+298
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be done in wire serialization (StreamOutput) as well. |
||
} | ||
builder.startObject(NAME); | ||
builder.field(FIELD_FIELD.getPreferredName(), fieldName); | ||
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); | ||
if (queryVector != null) { | ||
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); | ||
} | ||
if (numCands != null) { | ||
builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands); | ||
} | ||
if (vectorSimilarity != null) { | ||
builder.field(VECTOR_SIMILARITY_FIELD.getPreferredName(), vectorSimilarity); | ||
} | ||
if (queryVectorBuilder != null) { | ||
benwtrent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
builder.startObject(QUERY_VECTOR_BUILDER_FIELD.getPreferredName()); | ||
builder.field(queryVectorBuilder.getWriteableName(), queryVectorBuilder); | ||
builder.endObject(); | ||
} | ||
if (filterQueries.isEmpty() == false) { | ||
builder.startArray(FILTER_FIELD.getPreferredName()); | ||
for (QueryBuilder filterQuery : filterQueries) { | ||
|
@@ -247,6 +331,36 @@ public String getWriteableName() { | |
|
||
@Override | ||
protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { | ||
if (queryVectorSupplier != null) { | ||
if (queryVectorSupplier.get() == null) { | ||
return this; | ||
} | ||
return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), numCands, vectorSimilarity).boost(boost) | ||
.queryName(queryName) | ||
.addFilterQueries(filterQueries); | ||
} | ||
if (queryVectorBuilder != null) { | ||
SetOnce<float[]> toSet = new SetOnce<>(); | ||
ctx.registerAsyncAction((c, l) -> queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> { | ||
toSet.set(v); | ||
if (v == null) { | ||
ll.onFailure( | ||
new IllegalArgumentException( | ||
format( | ||
"[%s] with name [%s] returned null query_vector", | ||
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), | ||
queryVectorBuilder.getWriteableName() | ||
) | ||
) | ||
); | ||
return; | ||
} | ||
ll.onResponse(null); | ||
}))); | ||
return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, toSet::get, numCands, vectorSimilarity).boost( | ||
boost | ||
).queryName(queryName).addFilterQueries(filterQueries); | ||
} | ||
if (ctx.convertToInnerHitsRewriteContext() != null) { | ||
return new ExactKnnQueryBuilder(queryVector, fieldName).boost(boost).queryName(queryName); | ||
} | ||
|
@@ -263,7 +377,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { | |
rewrittenQueries.add(rewrittenQuery); | ||
} | ||
if (changed) { | ||
return new KnnVectorQueryBuilder(fieldName, queryVector, numCands, vectorSimilarity).boost(boost) | ||
return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, queryVectorSupplier, numCands, vectorSimilarity) | ||
tteofili marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.boost(boost) | ||
.queryName(queryName) | ||
.addFilterQueries(rewrittenQueries); | ||
} | ||
|
@@ -338,7 +453,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { | |
|
||
@Override | ||
protected int doHashCode() { | ||
return Objects.hash(fieldName, Objects.hashCode(queryVector), numCands, filterQueries, vectorSimilarity); | ||
return Objects.hash(fieldName, Objects.hashCode(queryVector), numCands, filterQueries, vectorSimilarity, queryVectorBuilder); | ||
} | ||
|
||
@Override | ||
|
@@ -347,7 +462,8 @@ protected boolean doEquals(KnnVectorQueryBuilder other) { | |
&& Objects.equals(queryVector, other.queryVector) | ||
&& Objects.equals(numCands, other.numCands) | ||
&& Objects.equals(filterQueries, other.filterQueries) | ||
&& Objects.equals(vectorSimilarity, other.vectorSimilarity); | ||
&& Objects.equals(vectorSimilarity, other.vectorSimilarity) | ||
&& Objects.equals(queryVectorBuilder, other.queryVectorBuilder); | ||
} | ||
|
||
@Override | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to update
docs/reference/query-dsl/knn-query.asciidoc
as well to includeknn-query-vector-builder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can I use
include
directive insidedocs/reference/query-dsl/knn-query.asciidoc
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to use
include