Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/138457.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 138457
summary: Intercept filters to knn queries
area: Relevance
type: bug
issues:
- 138410
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;

import java.io.IOException;
import java.util.Map;

/**
Expand All @@ -27,7 +28,7 @@ public interface QueryRewriteInterceptor {
* @param queryBuilder the original {@link QueryBuilder} to potentially rewrite
* @return the rewritten {@link QueryBuilder}, or the original instance if no rewrite was needed
*/
QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder);
QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) throws IOException;

/**
* Name of the query to be intercepted and rewritten.
Expand All @@ -52,7 +53,7 @@ public String getQueryName() {
}

@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) throws IOException {
QueryRewriteInterceptor interceptor = interceptors.get(queryBuilder.getName());
if (interceptor != null) {
return interceptor.interceptAndRewrite(context, queryBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,13 @@ public KnnVectorQueryBuilder addFilterQueries(List<QueryBuilder> filterQueries)
return this;
}

public KnnVectorQueryBuilder setFilterQueries(List<QueryBuilder> filterQueries) {
Objects.requireNonNull(filterQueries);
this.filterQueries.clear();
this.filterQueries.addAll(filterQueries);
return this;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
if (queryVectorSupplier != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.QueryShardException;
import org.elasticsearch.index.query.RandomQueryBuilder;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.query.TermQueryBuilder;
Expand Down Expand Up @@ -565,4 +566,11 @@ public void testRewriteWithQueryVectorBuilder() throws Exception {
assertThat(rewritten.filterQueries(), hasSize(numFilters));
assertThat(rewritten.filterQueries(), equalTo(filters));
}

public void testSetFilterQueries() {
KnnVectorQueryBuilder knnQueryBuilder = doCreateTestQueryBuilder();
List<QueryBuilder> newFilters = randomList(5, () -> RandomQueryBuilder.createQuery(random()));
knnQueryBuilder.setFilterQueries(newFilters);
assertThat(knnQueryBuilder.filterQueries(), equalTo(newFilters));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;

import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
Expand Down Expand Up @@ -91,6 +93,34 @@ protected FullyQualifiedInferenceId getInferenceIdOverride() {
return modelId != null ? new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, modelId) : null;
}

@Override
protected QueryBuilder customCoordinatorNodeRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
// knn query may contain filters that are also intercepted.
// We need to rewrite those here so that we can get inference results for them too.
QueryBuilder rewritten = rewriteFilterQueries(queryRewriteContext);
if (rewritten != this) {
return rewritten;
}
return super.customCoordinatorNodeRewrite(queryRewriteContext);
}

private QueryBuilder rewriteFilterQueries(QueryRewriteContext queryRewriteContext) throws IOException {
boolean filtersChanged = false;
List<QueryBuilder> rewrittenFilters = new ArrayList<>(originalQuery.filterQueries().size());
for (QueryBuilder filter : originalQuery.filterQueries()) {
QueryBuilder rewrittenFilter = filter.rewrite(queryRewriteContext);
if (rewrittenFilter != filter) {
filtersChanged = true;
}
rewrittenFilters.add(rewrittenFilter);
}
if (filtersChanged) {
originalQuery.setFilterQueries(rewrittenFilters);
return copy(inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest);
}
return this;
}

@Override
protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
if (originalQuery.queryVector() == null && originalQuery.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder == false) {
Expand Down Expand Up @@ -119,7 +149,7 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
}

@Override
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException {
QueryBuilder rewritten = this;
if (queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS) == false) {
rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ protected String getQuery() {
}

@Override
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException {
QueryBuilder rewritten = this;
if (queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS) == false) {
rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ protected InterceptedInferenceQueryBuilder(
* @param queryRewriteContext The query rewrite context
* @return The query builder rewritten to a backwards-compatible form
*/
protected abstract QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext);
protected abstract QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException;

/**
* Generate a copy of {@code this}.
Expand Down Expand Up @@ -209,6 +209,10 @@ protected FullyQualifiedInferenceId getInferenceIdOverride() {
*/
protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {}

protected QueryBuilder customCoordinatorNodeRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
return this;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
if (inferenceResultsMapSupplier != null) {
Expand Down Expand Up @@ -304,7 +308,7 @@ private QueryBuilder doRewriteBuildQuery(QueryRewriteContext indexMetadataContex
return queryFields(inferenceFieldsToQuery, nonInferenceFieldsToQuery, indexMetadataContext);
}

private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) throws IOException {
QueryBuilder rewrittenBwC = doRewriteBwC(queryRewriteContext);
if (rewrittenBwC != this) {
return rewrittenBwC;
Expand Down Expand Up @@ -344,6 +348,18 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri
);
}

QueryBuilder rewritten = customCoordinatorNodeRewrite(queryRewriteContext);
if (this != rewritten) {
return rewritten;
}
return coordinatorNodeRewrite(queryRewriteContext, inferenceIds, ccsRequest);
}

private QueryBuilder coordinatorNodeRewrite(
QueryRewriteContext queryRewriteContext,
Set<FullyQualifiedInferenceId> inferenceIds,
boolean ccsRequest
) {
if (inferenceResultsMapSupplier != null) {
// Additional inference results have already been requested, and we are waiting for them to continue the rewrite process
return getNewInferenceResultsFromSupplier(inferenceResultsMapSupplier, this, m -> copy(m, null, ccsRequest));
Expand Down Expand Up @@ -376,7 +392,6 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri
} else {
rewritten = copy(inferenceResultsMap, newInferenceResultsMapSupplier, ccsRequest);
}

return rewritten;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
}

@Override
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException {
QueryBuilder rewritten = this;
if (queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS) == false) {
rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,39 @@

import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;

import java.io.IOException;

public class SemanticKnnVectorQueryRewriteInterceptor implements QueryRewriteInterceptor {
@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) throws IOException {
if (queryBuilder instanceof KnnVectorQueryBuilder knnVectorQueryBuilder) {
return new InterceptedInferenceKnnVectorQueryBuilder(knnVectorQueryBuilder);
return interceptKnnQuery(context, knnVectorQueryBuilder);
} else {
throw new IllegalStateException("Unexpected query builder type: " + queryBuilder.getClass());
}
}

private static InterceptedInferenceKnnVectorQueryBuilder interceptKnnQuery(
QueryRewriteContext context,
KnnVectorQueryBuilder knnVectorQueryBuilder
) throws IOException {
boolean changed = false;
for (QueryBuilder filter : knnVectorQueryBuilder.filterQueries()) {
QueryBuilder rewritten = filter.rewrite(context);
if (rewritten != filter) {
changed = true;
}
}
if (changed) {
knnVectorQueryBuilder.setFilterQueries(Rewriteable.rewrite(knnVectorQueryBuilder.filterQueries(), context));
}
return new InterceptedInferenceKnnVectorQueryBuilder(knnVectorQueryBuilder);
}

@Override
public String getQueryName() {
return KnnVectorQueryBuilder.NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static org.elasticsearch.TransportVersions.V_8_15_0;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
Expand Down Expand Up @@ -335,7 +337,7 @@ protected abstract InterceptedInferenceQueryBuilder<T> createInterceptedQueryBui
Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
);

protected abstract QueryRewriteInterceptor createQueryRewriteInterceptor();
protected abstract List<QueryRewriteInterceptor> createQueryRewriteInterceptors();

protected abstract TransportVersion getMinimalSupportedVersion();

Expand Down Expand Up @@ -427,8 +429,9 @@ protected QueryRewriteContext createQueryRewriteContext(
indexMetadata
);

QueryRewriteInterceptor interceptor = createQueryRewriteInterceptor();
Map<String, QueryRewriteInterceptor> interceptorMap = Map.of(interceptor.getQueryName(), interceptor);
List<QueryRewriteInterceptor> interceptors = createQueryRewriteInterceptors();
Map<String, QueryRewriteInterceptor> interceptorMap = interceptors.stream()
.collect(Collectors.toMap(QueryRewriteInterceptor::getQueryName, Function.identity()));

return new QueryRewriteContext(
null,
Expand Down
Loading