From aea374b5339de710975db07a3df18e325b5eafb3 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Thu, 20 Nov 2025 17:23:33 +0200 Subject: [PATCH 1/2] Intercept filters to knn queries `knn` queries can have filter queries. Those filters may contain semantic queries (e.g. `knn`, `sparse_vector`, or `match` queries targetting `semantic_text` fields). We also need to intercept those in order to perform inference for their `query` field during the coordinator node rewrite. This commit achieves this with the following changes: - `SemanticKnnVectorQueryRewriteInterceptor` attempts to rewrite filter queries. This means we rely on a rewrite cycle to attempt to intercept every query down the tree of each filter query. - `InterceptedInferenceQueryBuilder` now has a `customCoordinatorNodeRewrite` method that subclasses can implement to implement additional rewriting logic needed, e.g. rewrite inner queries. - `InterceptedInferenceKnnVectorQueryBuilder` implements `customCoordinatorNodeRewrite` so that the filter queries are rewritten. This commit fixes the exceptions throws in #138410. However, searches that contain semantic text queries as filters to a semantic text knn query will return `0` hits due to another issue that is captured in #138184. Closes #138410 --- .../rewriter/QueryRewriteInterceptor.java | 5 +- .../search/vectors/KnnVectorQueryBuilder.java | 7 + ...AbstractKnnVectorQueryBuilderTestCase.java | 8 + ...rceptedInferenceKnnVectorQueryBuilder.java | 32 +++- ...InterceptedInferenceMatchQueryBuilder.java | 2 +- .../InterceptedInferenceQueryBuilder.java | 21 ++- ...ptedInferenceSparseVectorQueryBuilder.java | 2 +- ...anticKnnVectorQueryRewriteInterceptor.java | 24 ++- ...erceptedInferenceQueryBuilderTestCase.java | 9 +- ...edInferenceKnnVectorQueryBuilderTests.java | 154 +++++++++++++++++- ...ceptedInferenceMatchQueryBuilderTests.java | 5 +- ...nferenceSparseVectorQueryBuilderTests.java | 4 +- .../test/inference/47_semantic_text_knn.yml | 93 +++++++++++ 13 files changed, 341 insertions(+), 25 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/plugins/internal/rewriter/QueryRewriteInterceptor.java b/server/src/main/java/org/elasticsearch/plugins/internal/rewriter/QueryRewriteInterceptor.java index 8f4fb2ce7491a..a2e8fb7a55c9f 100644 --- a/server/src/main/java/org/elasticsearch/plugins/internal/rewriter/QueryRewriteInterceptor.java +++ b/server/src/main/java/org/elasticsearch/plugins/internal/rewriter/QueryRewriteInterceptor.java @@ -12,6 +12,7 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; +import java.io.IOException; import java.util.Map; /** @@ -27,7 +28,7 @@ public interface QueryRewriteInterceptor { * @param queryBuilder the original {@link QueryBuilder} to potentially rewrite * @return the rewritten {@link QueryBuilder}, or the original instance if no rewrite was needed */ - QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder); + QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) throws IOException; /** * Name of the query to be intercepted and rewritten. @@ -52,7 +53,7 @@ public String getQueryName() { } @Override - public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) { + public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) throws IOException { QueryRewriteInterceptor interceptor = interceptors.get(queryBuilder.getName()); if (interceptor != null) { return interceptor.interceptAndRewrite(context, queryBuilder); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index c85ffcea2c46b..809d111fee168 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -357,6 +357,13 @@ public KnnVectorQueryBuilder addFilterQueries(List filterQueries) return this; } + public KnnVectorQueryBuilder setFilterQueries(List filterQueries) { + Objects.requireNonNull(filterQueries); + this.filterQueries.clear(); + this.filterQueries.addAll(filterQueries); + return this; + } + @Override protected void doWriteTo(StreamOutput out) throws IOException { if (queryVectorSupplier != null) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 65644e0dc7f50..f19bf9556aa25 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -32,6 +32,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.QueryShardException; +import org.elasticsearch.index.query.RandomQueryBuilder; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.query.TermQueryBuilder; @@ -565,4 +566,11 @@ public void testRewriteWithQueryVectorBuilder() throws Exception { assertThat(rewritten.filterQueries(), hasSize(numFilters)); assertThat(rewritten.filterQueries(), equalTo(filters)); } + + public void testSetFilterQueries() { + KnnVectorQueryBuilder knnQueryBuilder = doCreateTestQueryBuilder(); + List newFilters = randomList(5, () -> RandomQueryBuilder.createQuery(random())); + knnQueryBuilder.setFilterQueries(newFilters); + assertThat(knnQueryBuilder.filterQueries(), equalTo(newFilters)); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java index 808afeb6b3c33..2279f60f6520f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java @@ -32,7 +32,9 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; +import java.util.List; import java.util.Map; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; @@ -91,6 +93,34 @@ protected FullyQualifiedInferenceId getInferenceIdOverride() { return modelId != null ? new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, modelId) : null; } + @Override + protected QueryBuilder customCoordinatorNodeRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + // knn query may contain filters that are also intercepted. + // We need to rewrite those here so that we can get inference results for them too. + QueryBuilder rewritten = rewriteFilterQueries(queryRewriteContext); + if (rewritten != this) { + return rewritten; + } + return super.customCoordinatorNodeRewrite(queryRewriteContext); + } + + private QueryBuilder rewriteFilterQueries(QueryRewriteContext queryRewriteContext) throws IOException { + boolean filtersChanged = false; + List rewrittenFilters = new ArrayList<>(originalQuery.filterQueries().size()); + for (QueryBuilder filter : originalQuery.filterQueries()) { + QueryBuilder rewrittenFilter = filter.rewrite(queryRewriteContext); + if (rewrittenFilter != filter) { + filtersChanged = true; + } + rewrittenFilters.add(rewrittenFilter); + } + if (filtersChanged) { + originalQuery.setFilterQueries(rewrittenFilters); + return copy(inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); + } + return this; + } + @Override protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) { if (originalQuery.queryVector() == null && originalQuery.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder == false) { @@ -119,7 +149,7 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) { } @Override - protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { + protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException { QueryBuilder rewritten = this; if (queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS) == false) { rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java index 018fdca7fabdb..13ec526d3b4bc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java @@ -65,7 +65,7 @@ protected String getQuery() { } @Override - protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { + protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException { QueryBuilder rewritten = this; if (queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS) == false) { rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java index 8267643108bcf..397bcacd8ca6e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java @@ -152,7 +152,7 @@ protected InterceptedInferenceQueryBuilder( * @param queryRewriteContext The query rewrite context * @return The query builder rewritten to a backwards-compatible form */ - protected abstract QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext); + protected abstract QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException; /** * Generate a copy of {@code this}. @@ -209,6 +209,10 @@ protected FullyQualifiedInferenceId getInferenceIdOverride() { */ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {} + protected QueryBuilder customCoordinatorNodeRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + return this; + } + @Override protected void doWriteTo(StreamOutput out) throws IOException { if (inferenceResultsMapSupplier != null) { @@ -304,7 +308,7 @@ private QueryBuilder doRewriteBuildQuery(QueryRewriteContext indexMetadataContex return queryFields(inferenceFieldsToQuery, nonInferenceFieldsToQuery, indexMetadataContext); } - private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) { + private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) throws IOException { QueryBuilder rewrittenBwC = doRewriteBwC(queryRewriteContext); if (rewrittenBwC != this) { return rewrittenBwC; @@ -344,6 +348,18 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri ); } + QueryBuilder rewritten = customCoordinatorNodeRewrite(queryRewriteContext); + if (this != rewritten) { + return rewritten; + } + return coordinatorNodeRewrite(queryRewriteContext, inferenceIds, ccsRequest); + } + + private QueryBuilder coordinatorNodeRewrite( + QueryRewriteContext queryRewriteContext, + Set inferenceIds, + boolean ccsRequest + ) { if (inferenceResultsMapSupplier != null) { // Additional inference results have already been requested, and we are waiting for them to continue the rewrite process return getNewInferenceResultsFromSupplier(inferenceResultsMapSupplier, this, m -> copy(m, null, ccsRequest)); @@ -376,7 +392,6 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri } else { rewritten = copy(inferenceResultsMap, newInferenceResultsMapSupplier, ccsRequest); } - return rewritten; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java index 48a9d3910b01e..7b80478ce57be 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java @@ -106,7 +106,7 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) { } @Override - protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { + protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException { QueryBuilder rewritten = this; if (queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS) == false) { rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java index 6943d812eb4f7..c692632fd2e25 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java @@ -9,19 +9,39 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import java.io.IOException; + public class SemanticKnnVectorQueryRewriteInterceptor implements QueryRewriteInterceptor { @Override - public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) { + public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) throws IOException { if (queryBuilder instanceof KnnVectorQueryBuilder knnVectorQueryBuilder) { - return new InterceptedInferenceKnnVectorQueryBuilder(knnVectorQueryBuilder); + return interceptKnnQuery(context, knnVectorQueryBuilder); } else { throw new IllegalStateException("Unexpected query builder type: " + queryBuilder.getClass()); } } + private static InterceptedInferenceKnnVectorQueryBuilder interceptKnnQuery( + QueryRewriteContext context, + KnnVectorQueryBuilder knnVectorQueryBuilder + ) throws IOException { + boolean changed = false; + for (QueryBuilder filter : knnVectorQueryBuilder.filterQueries()) { + QueryBuilder rewritten = filter.rewrite(context); + if (rewritten != filter) { + changed = true; + } + } + if (changed) { + knnVectorQueryBuilder.setFilterQueries(Rewriteable.rewrite(knnVectorQueryBuilder.filterQueries(), context)); + } + return new InterceptedInferenceKnnVectorQueryBuilder(knnVectorQueryBuilder); + } + @Override public String getQueryName() { return KnnVectorQueryBuilder.NAME; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java index 9ae5e3dcc591d..b804033383c07 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java @@ -65,7 +65,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.Collectors; import static org.elasticsearch.TransportVersions.V_8_15_0; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; @@ -335,7 +337,7 @@ protected abstract InterceptedInferenceQueryBuilder createInterceptedQueryBui Map inferenceResultsMap ); - protected abstract QueryRewriteInterceptor createQueryRewriteInterceptor(); + protected abstract List createQueryRewriteInterceptors(); protected abstract TransportVersion getMinimalSupportedVersion(); @@ -427,8 +429,9 @@ protected QueryRewriteContext createQueryRewriteContext( indexMetadata ); - QueryRewriteInterceptor interceptor = createQueryRewriteInterceptor(); - Map interceptorMap = Map.of(interceptor.getQueryName(), interceptor); + List interceptors = createQueryRewriteInterceptors(); + Map interceptorMap = interceptors.stream() + .collect(Collectors.toMap(QueryRewriteInterceptor::getQueryName, Function.identity())); return new QueryRewriteContext( null, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java index de8583132319b..d0fb06e9a048e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java @@ -9,8 +9,11 @@ import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.IndexFieldMapper; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -22,7 +25,11 @@ import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.search.vectors.VectorData; +import org.elasticsearch.xpack.core.XPackPlugin; import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.TokenPruningConfigTests; import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; @@ -33,6 +40,7 @@ import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; @@ -45,6 +53,7 @@ public class InterceptedInferenceKnnVectorQueryBuilderTests extends AbstractInte protected Collection getPlugins() { List plugins = new ArrayList<>(super.getPlugins()); plugins.add(new FakeMlPlugin()); + plugins.add(new XPackPlugin(Settings.EMPTY)); return plugins; } @@ -66,8 +75,12 @@ protected InterceptedInferenceQueryBuilder createIntercep } @Override - protected QueryRewriteInterceptor createQueryRewriteInterceptor() { - return new SemanticKnnVectorQueryRewriteInterceptor(); + protected List createQueryRewriteInterceptors() { + return List.of( + new SemanticKnnVectorQueryRewriteInterceptor(), + new SemanticMatchQueryRewriteInterceptor(), + new SemanticSparseVectorQueryRewriteInterceptor() + ); } @Override @@ -167,13 +180,8 @@ public void testInterceptAndRewrite() throws Exception { coordinatorRewritten = copyNamedWriteable(coordinatorRewritten, writableRegistry(), QueryBuilder.class); assertThat(coordinatorRewritten, instanceOf(InterceptedInferenceKnnVectorQueryBuilder.class)); InterceptedInferenceKnnVectorQueryBuilder coordinatorIntercepted = (InterceptedInferenceKnnVectorQueryBuilder) coordinatorRewritten; - assertThat(coordinatorIntercepted.originalQuery, equalTo(knnQuery)); - assertThat(coordinatorIntercepted.inferenceResultsMap, notNullValue()); - assertThat(coordinatorIntercepted.inferenceResultsMap.size(), equalTo(1)); - InferenceResults inferenceResults = coordinatorIntercepted.inferenceResultsMap.get( - new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, DENSE_INFERENCE_ID) - ); + InferenceResults inferenceResults = assertQueryIsInterceptedKnnWithValidResults(coordinatorIntercepted); assertThat(inferenceResults, notNullValue()); assertThat(inferenceResults, instanceOf(MlDenseEmbeddingResults.class)); VectorData queryVector = new VectorData(((MlDenseEmbeddingResults) inferenceResults).getInferenceAsFloat()); @@ -203,6 +211,136 @@ public void testInterceptAndRewrite() throws Exception { assertThat(dataRewrittenTestIndex2, equalTo(expectedDataRewrittenTestIndex2)); } + public void testCoordinatorNodeRewrite_GivenKnnQueryWithSemanticFilters_ShouldInterceptFilters() throws Exception { + final String denseField1 = "dense_field_1"; + final String denseField2 = "dense_field_2"; + final String sparseField = "sparse_field"; + final TestIndex testIndex = new TestIndex( + "test-index", + Map.of(denseField1, DENSE_INFERENCE_ID, denseField2, DENSE_INFERENCE_ID, sparseField, SPARSE_INFERENCE_ID), + Map.of() + ); + final KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder( + denseField1, + new TextEmbeddingQueryVectorBuilder(DENSE_INFERENCE_ID, "foo"), + 50, + 500, + 50f, + null + ).boost(3.0f) + .queryName("bar") + .addFilterQuery( + QueryBuilders.boolQuery() + .filter( + new KnnVectorQueryBuilder( + denseField2, + new TextEmbeddingQueryVectorBuilder(DENSE_INFERENCE_ID, "some query"), + 50, + 500, + 50f, + null + ) + ) + ) + .addFilterQuery( + new KnnVectorQueryBuilder( + denseField2, + new TextEmbeddingQueryVectorBuilder(DENSE_INFERENCE_ID, "some query"), + 50, + 500, + 50f, + null + ).addFilterQuery( + new SparseVectorQueryBuilder( + sparseField, + null, + SPARSE_INFERENCE_ID, + "some other query", + randomBoolean(), + TokenPruningConfigTests.testInstance() + ) + ) + ) + .addFilterQuery(new MatchQueryBuilder(sparseField, "some other query")); + + // Perform coordinator node rewrite + final QueryRewriteContext queryRewriteContext = createQueryRewriteContext( + Map.of(testIndex.name(), testIndex.semanticTextFields()), + Map.of(), + TransportVersion.current(), + null + ); + QueryBuilder coordinatorRewritten = rewriteAndFetch(knnQuery, queryRewriteContext); + + // Use a serialization cycle to strip InterceptedQueryBuilderWrapper + coordinatorRewritten = copyNamedWriteable(coordinatorRewritten, writableRegistry(), QueryBuilder.class); + QueryBuilder serializedKnnQuery = copyNamedWriteable(knnQuery, writableRegistry(), QueryBuilder.class); + + assertQueryIsInterceptedKnnWithValidResults(coordinatorRewritten); + InterceptedInferenceKnnVectorQueryBuilder coordinatorIntercepted = (InterceptedInferenceKnnVectorQueryBuilder) coordinatorRewritten; + assertThat(coordinatorIntercepted.originalQuery, equalTo(serializedKnnQuery)); + assertQueryIsInterceptedKnnWithValidResults(coordinatorIntercepted); + + assertThat(coordinatorIntercepted.originalQuery.filterQueries(), hasSize(3)); + + // Assertions on first filter + { + assertThat(coordinatorIntercepted.originalQuery.filterQueries().get(0), instanceOf(BoolQueryBuilder.class)); + BoolQueryBuilder filter = (BoolQueryBuilder) coordinatorIntercepted.originalQuery.filterQueries().get(0); + assertThat(filter.filter(), hasSize(1)); + assertQueryIsInterceptedKnnWithValidResults(filter.filter().get(0)); + } + + // Assertions on second filter + { + assertQueryIsInterceptedKnnWithValidResults(coordinatorIntercepted.originalQuery.filterQueries().get(1)); + InterceptedInferenceKnnVectorQueryBuilder filter = + (InterceptedInferenceKnnVectorQueryBuilder) coordinatorIntercepted.originalQuery.filterQueries().get(1); + assertThat(filter.originalQuery.filterQueries(), hasSize(1)); + assertQueryIsInterceptedSparseVectorWithValidResults(filter.originalQuery.filterQueries().get(0)); + } + + // Assertions on third filter + { + assertThat( + coordinatorIntercepted.originalQuery.filterQueries().get(2), + instanceOf(InterceptedInferenceMatchQueryBuilder.class) + ); + InterceptedInferenceMatchQueryBuilder filter = (InterceptedInferenceMatchQueryBuilder) coordinatorIntercepted.originalQuery + .filterQueries() + .get(2); + assertInterceptedQueryHasValidResultsForSparseVector(filter); + } + } + + private static InferenceResults assertQueryIsInterceptedKnnWithValidResults(QueryBuilder query) { + assertThat(query, instanceOf(InterceptedInferenceKnnVectorQueryBuilder.class)); + InterceptedInferenceKnnVectorQueryBuilder interceptedKnn = (InterceptedInferenceKnnVectorQueryBuilder) query; + assertThat(interceptedKnn.inferenceResultsMap, notNullValue()); + assertThat(interceptedKnn.inferenceResultsMap.size(), equalTo(1)); + InferenceResults inferenceResults = interceptedKnn.inferenceResultsMap.get( + new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, DENSE_INFERENCE_ID) + ); + assertThat(inferenceResults, notNullValue()); + assertThat(inferenceResults, instanceOf(MlDenseEmbeddingResults.class)); + return inferenceResults; + } + + private static void assertQueryIsInterceptedSparseVectorWithValidResults(QueryBuilder query) { + assertThat(query, instanceOf(InterceptedInferenceSparseVectorQueryBuilder.class)); + assertInterceptedQueryHasValidResultsForSparseVector((InterceptedInferenceSparseVectorQueryBuilder) query); + } + + private static void assertInterceptedQueryHasValidResultsForSparseVector(InterceptedInferenceQueryBuilder intercepted) { + assertThat(intercepted.inferenceResultsMap, notNullValue()); + assertThat(intercepted.inferenceResultsMap.size(), equalTo(1)); + InferenceResults inferenceResults = intercepted.inferenceResultsMap.get( + new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, SPARSE_INFERENCE_ID) + ); + assertThat(inferenceResults, notNullValue()); + assertThat(inferenceResults, instanceOf(TextExpansionResults.class)); + } + private static NestedQueryBuilder buildExpectedNestedQuery( KnnVectorQueryBuilder knnQuery, VectorData queryVector, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java index ed87d5adda0b6..5b584ee322a6f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import java.util.List; import java.util.Map; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; @@ -39,8 +40,8 @@ protected InterceptedInferenceQueryBuilder createInterceptedQ } @Override - protected QueryRewriteInterceptor createQueryRewriteInterceptor() { - return new SemanticMatchQueryRewriteInterceptor(); + protected List createQueryRewriteInterceptors() { + return List.of(new SemanticMatchQueryRewriteInterceptor()); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java index 43c2f00b56f40..98adc9b0538f5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java @@ -67,8 +67,8 @@ protected InterceptedInferenceQueryBuilder createInter } @Override - protected QueryRewriteInterceptor createQueryRewriteInterceptor() { - return new SemanticSparseVectorQueryRewriteInterceptor(); + protected List createQueryRewriteInterceptors() { + return List.of(new SemanticSparseVectorQueryRewriteInterceptor()); } @Override diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml index 5de05dc2678b2..14b2a61b4d987 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml @@ -63,6 +63,9 @@ setup: inference_field: type: semantic_text inference_id: dense-inference-id + another_dense_inference_field: + type: semantic_text + inference_id: dense-inference-id - do: indices.create: @@ -75,6 +78,9 @@ setup: inference_field: type: semantic_text inference_id: dense-inference-id-2 + another_dense_inference_field: + type: semantic_text + inference_id: dense-inference-id - do: indices.create: @@ -109,6 +115,7 @@ setup: body: keyword_field: "foo" inference_field: [ "inference test", "another inference test" ] + another_dense_inference_field: "some text" refresh: true - do: @@ -118,6 +125,7 @@ setup: body: keyword_field: "bar" inference_field: [ "inference test", "another inference test" ] + another_dense_inference_field: "some text" refresh: true - do: @@ -365,7 +373,92 @@ setup: - match: { hits.total.value: 1 } - match: { hits.hits.0._id: "doc_1" } +--- +"knn query with semantic knn query filter": + - requires: + cluster_features: "search.semantic_knn_filter_fix" + reason: filters fixed in 8.18.0 + + - do: + search: + index: + - test-semantic-text-index + - test-semantic-text-index-2 + body: + query: + knn: + field: inference_field + k: 10 + num_candidates: 100 + query_vector_builder: + text_embedding: + model_text: test + filter: + knn: + field: another_dense_inference_field + k: 10 + num_candidates: 100 + query_vector_builder: + text_embedding: + model_text: some + + - match: { hits.total.value: 0 } + +--- +"knn query with semantic match query filter": + - requires: + cluster_features: "search.semantic_knn_filter_fix" + reason: filters fixed in 8.18.0 + + - do: + search: + index: + - test-semantic-text-index + - test-semantic-text-index-2 + body: + query: + knn: + field: inference_field + k: 10 + num_candidates: 100 + query_vector_builder: + text_embedding: + model_text: test + filter: + match: + another_dense_inference_field: + query: some + + - match: { hits.total.value: 0 } + +--- +"knn query with semantic match query filter within bool query": + - requires: + cluster_features: "search.semantic_knn_filter_fix" + reason: filters fixed in 8.18.0 + + - do: + search: + index: + - test-semantic-text-index + - test-semantic-text-index-2 + body: + query: + knn: + field: inference_field + k: 10 + num_candidates: 100 + query_vector_builder: + text_embedding: + model_text: test + filter: + bool: + filter: + match: + another_dense_inference_field: + query: some + - match: { hits.total.value: 0 } --- "knn query against multiple semantic_text fields with multiple inference IDs specified in semantic_text fields with smaller k returns k for each index": From 0d46908fc2f71a32b2c906f016b86d48637228af Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Sun, 23 Nov 2025 13:17:01 +0200 Subject: [PATCH 2/2] Update docs/changelog/138457.yaml --- docs/changelog/138457.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/138457.yaml diff --git a/docs/changelog/138457.yaml b/docs/changelog/138457.yaml new file mode 100644 index 0000000000000..cad2c0311edd7 --- /dev/null +++ b/docs/changelog/138457.yaml @@ -0,0 +1,6 @@ +pr: 138457 +summary: Intercept filters to knn queries +area: Relevance +type: bug +issues: + - 138410