diff --git a/server/src/main/java/org/elasticsearch/index/query/BoolQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/BoolQueryBuilder.java index 5944fc3d8df7a..92fa0bb3cceed 100644 --- a/server/src/main/java/org/elasticsearch/index/query/BoolQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/BoolQueryBuilder.java @@ -26,17 +26,20 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.elasticsearch.common.lucene.search.Queries.fixNegativeQueryIfNeeded; /** * A Query that matches documents matching boolean combinations of other queries. */ -public class BoolQueryBuilder extends AbstractQueryBuilder { +public class BoolQueryBuilder extends AbstractQueryBuilder implements PrefilteredQuery { public static final String NAME = "bool"; public static final boolean ADJUST_PURE_NEGATIVE_DEFAULT = true; @@ -60,6 +63,8 @@ public class BoolQueryBuilder extends AbstractQueryBuilder { private String minimumShouldMatch; + private List prefilters = List.of(); + /** * Build an empty bool query. */ @@ -76,6 +81,9 @@ public BoolQueryBuilder(StreamInput in) throws IOException { filterClauses.addAll(readQueries(in)); adjustPureNegative = in.readBoolean(); minimumShouldMatch = in.readOptionalString(); + if (in.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + prefilters = in.readNamedWriteableCollectionAsList(QueryBuilder.class); + } } @Override @@ -86,6 +94,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { writeQueries(out, filterClauses); out.writeBoolean(adjustPureNegative); out.writeOptionalString(minimumShouldMatch); + if (out.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + out.writeNamedWriteableCollection(prefilters); + } } /** @@ -332,7 +343,7 @@ private static void addBooleanClauses( @Override protected int doHashCode() { - return Objects.hash(adjustPureNegative, minimumShouldMatch, mustClauses, shouldClauses, mustNotClauses, filterClauses); + return Objects.hash(adjustPureNegative, minimumShouldMatch, mustClauses, shouldClauses, mustNotClauses, filterClauses, prefilters); } @Override @@ -342,7 +353,8 @@ protected boolean doEquals(BoolQueryBuilder other) { && Objects.equals(mustClauses, other.mustClauses) && Objects.equals(shouldClauses, other.shouldClauses) && Objects.equals(mustNotClauses, other.mustNotClauses) - && Objects.equals(filterClauses, other.filterClauses); + && Objects.equals(filterClauses, other.filterClauses) + && Objects.equals(prefilters, other.prefilters); } @Override @@ -353,6 +365,9 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws if (clauses == 0) { return new MatchAllQueryBuilder().boost(boost()).queryName(queryName()); } + + propagatePrefilters(); + changed |= rewriteClauses(queryRewriteContext, mustClauses, newBuilder::must); try { @@ -415,6 +430,7 @@ private static boolean rewriteClauses( ) throws IOException { boolean changed = false; for (QueryBuilder builder : builders) { + QueryBuilder result = builder.rewrite(queryRewriteContext); if (result != builder) { changed = true; @@ -452,4 +468,23 @@ public BoolQueryBuilder shallowCopy() { } return copy; } + + @Override + public BoolQueryBuilder setPrefilters(List prefilters) { + this.prefilters = prefilters; + return this; + } + + @Override + public List getPrefilters() { + // We declare as prefilters clauses run in the filter context, namely filter and must_not + return Stream.of(prefilters, filterClauses, mustNotClauses.stream().map(c -> QueryBuilders.boolQuery().mustNot(c)).toList()) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + } + + @Override + public List getPrefilteringTargetQueries() { + return Stream.concat(mustClauses.stream(), shouldClauses.stream()).toList(); + } } diff --git a/server/src/main/java/org/elasticsearch/index/query/BoostingQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/BoostingQueryBuilder.java index 9e439efd71dc9..0ad2e55e8502b 100644 --- a/server/src/main/java/org/elasticsearch/index/query/BoostingQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/BoostingQueryBuilder.java @@ -20,6 +20,7 @@ import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -35,7 +36,7 @@ * multiplied by the supplied "boost" parameter, so this should be less than 1 to achieve a * demoting effect */ -public class BoostingQueryBuilder extends AbstractQueryBuilder { +public class BoostingQueryBuilder extends AbstractQueryBuilder implements PrefilteredQuery { public static final String NAME = "boosting"; private static final ParseField POSITIVE_FIELD = new ParseField("positive"); @@ -48,6 +49,8 @@ public class BoostingQueryBuilder extends AbstractQueryBuilder prefilters = List.of(); + /** * Create a new {@link BoostingQueryBuilder} * @@ -73,6 +76,9 @@ public BoostingQueryBuilder(StreamInput in) throws IOException { positiveQuery = in.readNamedWriteable(QueryBuilder.class); negativeQuery = in.readNamedWriteable(QueryBuilder.class); negativeBoost = in.readFloat(); + if (in.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + prefilters = in.readNamedWriteableCollectionAsList(QueryBuilder.class); + } } @Override @@ -80,6 +86,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeNamedWriteable(positiveQuery); out.writeNamedWriteable(negativeQuery); out.writeFloat(negativeBoost); + if (out.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + out.writeNamedWriteableCollection(prefilters); + } } /** @@ -197,18 +206,21 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { @Override protected int doHashCode() { - return Objects.hash(negativeBoost, positiveQuery, negativeQuery); + return Objects.hash(negativeBoost, positiveQuery, negativeQuery, prefilters); } @Override protected boolean doEquals(BoostingQueryBuilder other) { return Objects.equals(negativeBoost, other.negativeBoost) && Objects.equals(positiveQuery, other.positiveQuery) - && Objects.equals(negativeQuery, other.negativeQuery); + && Objects.equals(negativeQuery, other.negativeQuery) + && Objects.equals(prefilters, other.prefilters); } @Override protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + propagatePrefilters(); + QueryBuilder positiveQuery = this.positiveQuery.rewrite(queryRewriteContext); if (positiveQuery instanceof MatchNoneQueryBuilder) { return positiveQuery; @@ -233,4 +245,20 @@ protected void extractInnerHitBuilders(Map inner public TransportVersion getMinimalSupportedVersion() { return TransportVersion.zero(); } + + @Override + public BoostingQueryBuilder setPrefilters(List prefilters) { + this.prefilters = prefilters; + return this; + } + + @Override + public List getPrefilters() { + return prefilters; + } + + @Override + public List getPrefilteringTargetQueries() { + return List.of(positiveQuery, negativeQuery); + } } diff --git a/server/src/main/java/org/elasticsearch/index/query/ConstantScoreQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/ConstantScoreQueryBuilder.java index f70f095ecd5ea..ecd223b131fe1 100644 --- a/server/src/main/java/org/elasticsearch/index/query/ConstantScoreQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/ConstantScoreQueryBuilder.java @@ -20,6 +20,7 @@ import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -27,13 +28,18 @@ * A query that wraps a filter and simply returns a constant score equal to the * query boost for every document in the filter. */ -public class ConstantScoreQueryBuilder extends AbstractQueryBuilder { +public class ConstantScoreQueryBuilder extends AbstractQueryBuilder + implements + PrefilteredQuery { + public static final String NAME = "constant_score"; private static final ParseField INNER_QUERY_FIELD = new ParseField("filter"); private final QueryBuilder filterBuilder; + private List prefilters = List.of(); + /** * A query that wraps another query and simply returns a constant score equal to the * query boost for every document in the query. @@ -53,11 +59,17 @@ public ConstantScoreQueryBuilder(QueryBuilder filterBuilder) { public ConstantScoreQueryBuilder(StreamInput in) throws IOException { super(in); filterBuilder = in.readNamedWriteable(QueryBuilder.class); + if (in.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + prefilters = in.readNamedWriteableCollectionAsList(QueryBuilder.class); + } } @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeNamedWriteable(filterBuilder); + if (out.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + out.writeNamedWriteableCollection(prefilters); + } } /** @@ -135,20 +147,23 @@ public String getWriteableName() { @Override protected int doHashCode() { - return Objects.hash(filterBuilder); + return Objects.hash(filterBuilder, prefilters); } @Override protected boolean doEquals(ConstantScoreQueryBuilder other) { - return Objects.equals(filterBuilder, other.filterBuilder); + return Objects.equals(filterBuilder, other.filterBuilder) && Objects.equals(prefilters, other.prefilters); } @Override protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + propagatePrefilters(); + QueryBuilder rewrite = filterBuilder.rewrite(queryRewriteContext); if (rewrite instanceof MatchNoneQueryBuilder) { return rewrite; // we won't match anyway } + if (rewrite != filterBuilder) { return new ConstantScoreQueryBuilder(rewrite); } @@ -164,4 +179,20 @@ protected void extractInnerHitBuilders(Map inner public TransportVersion getMinimalSupportedVersion() { return TransportVersion.zero(); } + + @Override + public ConstantScoreQueryBuilder setPrefilters(List prefilters) { + this.prefilters = prefilters; + return this; + } + + @Override + public List getPrefilters() { + return prefilters; + } + + @Override + public List getPrefilteringTargetQueries() { + return List.of(filterBuilder); + } } diff --git a/server/src/main/java/org/elasticsearch/index/query/DisMaxQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/DisMaxQueryBuilder.java index 4ddb28b76ee6c..e2ad52a142eaf 100644 --- a/server/src/main/java/org/elasticsearch/index/query/DisMaxQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/DisMaxQueryBuilder.java @@ -32,7 +32,7 @@ * with the maximum score for that document as produced by any sub-query, plus a tie breaking increment for any * additional matching sub-queries. */ -public class DisMaxQueryBuilder extends AbstractQueryBuilder { +public class DisMaxQueryBuilder extends AbstractQueryBuilder implements PrefilteredQuery { public static final String NAME = "dis_max"; /** Default multiplication factor for breaking ties in document scores.*/ @@ -42,6 +42,7 @@ public class DisMaxQueryBuilder extends AbstractQueryBuilder private static final ParseField QUERIES_FIELD = new ParseField("queries"); private final List queries = new ArrayList<>(); + private List prefilters = List.of(); private float tieBreaker = DEFAULT_TIE_BREAKER; @@ -54,12 +55,18 @@ public DisMaxQueryBuilder(StreamInput in) throws IOException { super(in); queries.addAll(readQueries(in)); tieBreaker = in.readFloat(); + if (in.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + prefilters = readQueries(in); + } } @Override protected void doWriteTo(StreamOutput out) throws IOException { writeQueries(out, queries); out.writeFloat(tieBreaker); + if (out.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + out.writeNamedWriteableCollection(prefilters); + } } /** @@ -182,6 +189,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { @Override protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + propagatePrefilters(); + DisMaxQueryBuilder newBuilder = new DisMaxQueryBuilder(); boolean changed = false; for (QueryBuilder query : queries) { @@ -203,12 +212,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws @Override protected int doHashCode() { - return Objects.hash(queries, tieBreaker); + return Objects.hash(queries, tieBreaker, prefilters); } @Override protected boolean doEquals(DisMaxQueryBuilder other) { - return Objects.equals(queries, other.queries) && Objects.equals(tieBreaker, other.tieBreaker); + return Objects.equals(queries, other.queries) + && Objects.equals(tieBreaker, other.tieBreaker) + && Objects.equals(prefilters, other.prefilters); } @Override @@ -227,4 +238,20 @@ protected void extractInnerHitBuilders(Map inner public TransportVersion getMinimalSupportedVersion() { return TransportVersion.zero(); } + + @Override + public DisMaxQueryBuilder setPrefilters(List prefilters) { + this.prefilters = prefilters; + return this; + } + + @Override + public List getPrefilters() { + return prefilters; + } + + @Override + public List getPrefilteringTargetQueries() { + return queries; + } } diff --git a/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java index 56e002287e1e3..623ab559193bb 100644 --- a/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java @@ -29,13 +29,14 @@ import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.List; import java.util.Objects; /** * Match query is a query that analyzes the text and constructs a query as the * result of the analysis. */ -public class MatchQueryBuilder extends AbstractQueryBuilder { +public class MatchQueryBuilder extends AbstractQueryBuilder implements PrefilteredQuery { public static final ParseField ZERO_TERMS_QUERY_FIELD = new ParseField("zero_terms_query"); public static final ParseField LENIENT_FIELD = new ParseField("lenient"); @@ -81,6 +82,8 @@ public class MatchQueryBuilder extends AbstractQueryBuilder { private boolean autoGenerateSynonymsPhraseQuery = true; + private List prefilters = List.of(); + /** * Constructs a new match query. */ @@ -118,6 +121,9 @@ public MatchQueryBuilder(StreamInput in) throws IOException { in.readOptionalFloat(); } autoGenerateSynonymsPhraseQuery = in.readBoolean(); + if (in.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + prefilters = in.readNamedWriteableCollectionAsList(QueryBuilder.class); + } } @Override @@ -140,6 +146,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeOptionalFloat(null); } out.writeBoolean(autoGenerateSynonymsPhraseQuery); + if (out.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + out.writeNamedWriteableCollection(prefilters); + } } /** Returns the field name used in this query. */ @@ -430,7 +439,8 @@ protected boolean doEquals(MatchQueryBuilder other) { && Objects.equals(lenient, other.lenient) && Objects.equals(fuzzyTranspositions, other.fuzzyTranspositions) && Objects.equals(zeroTermsQuery, other.zeroTermsQuery) - && Objects.equals(autoGenerateSynonymsPhraseQuery, other.autoGenerateSynonymsPhraseQuery); + && Objects.equals(autoGenerateSynonymsPhraseQuery, other.autoGenerateSynonymsPhraseQuery) + && Objects.equals(prefilters, other.prefilters); } @Override @@ -448,7 +458,8 @@ protected int doHashCode() { lenient, fuzzyTranspositions, zeroTermsQuery, - autoGenerateSynonymsPhraseQuery + autoGenerateSynonymsPhraseQuery, + prefilters ); } @@ -570,4 +581,20 @@ public static MatchQueryBuilder fromXContent(XContentParser parser) throws IOExc public TransportVersion getMinimalSupportedVersion() { return TransportVersion.zero(); } + + @Override + public MatchQueryBuilder setPrefilters(List prefilters) { + this.prefilters = prefilters; + return this; + } + + @Override + public List getPrefilters() { + return prefilters; + } + + @Override + public List getPrefilteringTargetQueries() { + return List.of(); + } } diff --git a/server/src/main/java/org/elasticsearch/index/query/NestedQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/NestedQueryBuilder.java index 2007f378ed1bd..260af2af09c6e 100644 --- a/server/src/main/java/org/elasticsearch/index/query/NestedQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/NestedQueryBuilder.java @@ -46,6 +46,7 @@ import java.io.IOException; import java.util.HashMap; +import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -53,7 +54,7 @@ import static org.elasticsearch.search.SearchService.ALLOW_EXPENSIVE_QUERIES; import static org.elasticsearch.search.fetch.subphase.InnerHitsContext.intersect; -public class NestedQueryBuilder extends AbstractQueryBuilder { +public class NestedQueryBuilder extends AbstractQueryBuilder implements PrefilteredQuery { public static final String NAME = "nested"; /** * The default value for ignore_unmapped. @@ -71,6 +72,7 @@ public class NestedQueryBuilder extends AbstractQueryBuilder private final QueryBuilder query; private InnerHitBuilder innerHitBuilder; private boolean ignoreUnmapped = DEFAULT_IGNORE_UNMAPPED; + private List prefilters = List.of(); public NestedQueryBuilder(String path, QueryBuilder query, ScoreMode scoreMode) { this(path, query, scoreMode, null); @@ -93,6 +95,9 @@ public NestedQueryBuilder(StreamInput in) throws IOException { query = in.readNamedWriteable(QueryBuilder.class); innerHitBuilder = in.readOptionalWriteable(InnerHitBuilder::new); ignoreUnmapped = in.readBoolean(); + if (in.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + prefilters = in.readNamedWriteableCollectionAsList(QueryBuilder.class); + } } @Override @@ -102,6 +107,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeNamedWriteable(query); out.writeOptionalWriteable(innerHitBuilder); out.writeBoolean(ignoreUnmapped); + if (out.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + out.writeNamedWriteableCollection(prefilters); + } } /** @@ -258,12 +266,13 @@ protected boolean doEquals(NestedQueryBuilder that) { && Objects.equals(path, that.path) && Objects.equals(scoreMode, that.scoreMode) && Objects.equals(innerHitBuilder, that.innerHitBuilder) - && Objects.equals(ignoreUnmapped, that.ignoreUnmapped); + && Objects.equals(ignoreUnmapped, that.ignoreUnmapped) + && Objects.equals(prefilters, that.prefilters); } @Override protected int doHashCode() { - return Objects.hash(query, path, scoreMode, innerHitBuilder, ignoreUnmapped); + return Objects.hash(query, path, scoreMode, innerHitBuilder, ignoreUnmapped, prefilters); } @Override @@ -329,6 +338,7 @@ public static Query toQuery( @Override protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + propagatePrefilters(); QueryBuilder rewrittenQuery = query.rewrite(queryRewriteContext); if (rewrittenQuery != query) { NestedQueryBuilder nestedQuery = new NestedQueryBuilder(path, rewrittenQuery, scoreMode, innerHitBuilder); @@ -353,6 +363,22 @@ public void extractInnerHitBuilders(Map innerHit } } + @Override + public NestedQueryBuilder setPrefilters(List prefilters) { + this.prefilters = prefilters; + return this; + } + + @Override + public List getPrefilters() { + return prefilters; + } + + @Override + public List getPrefilteringTargetQueries() { + return List.of(query); + } + static class NestedInnerHitContextBuilder extends InnerHitContextBuilder { private final String path; diff --git a/server/src/main/java/org/elasticsearch/index/query/PrefilteredQuery.java b/server/src/main/java/org/elasticsearch/index/query/PrefilteredQuery.java new file mode 100644 index 0000000000000..e3255a6273776 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/query/PrefilteredQuery.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.query; + +import org.elasticsearch.TransportVersion; + +import java.util.List; + +public interface PrefilteredQuery { + + TransportVersion QUERY_PREFILTERING = TransportVersion.fromName("query_prefiltering"); + + T setPrefilters(List prefilters); + + List getPrefilters(); + + List getPrefilteringTargetQueries(); + + default void propagatePrefilters() { + List prefilters = getPrefilters(); + if (prefilters.isEmpty() == false) { + for (QueryBuilder targetQuery : getPrefilteringTargetQueries()) { + if (targetQuery instanceof PrefilteredQuery prefilteredQuery) { + prefilteredQuery.setPrefilters(prefilters.stream().filter(q -> q != targetQuery).toList()); + } + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryBuilder.java index 00553ab535fd3..657e65cac855b 100644 --- a/server/src/main/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryBuilder.java @@ -23,6 +23,7 @@ import org.elasticsearch.index.query.InnerHitContextBuilder; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.PrefilteredQuery; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; @@ -44,7 +45,10 @@ * A query that uses a filters with a script associated with them to compute the * score. */ -public class FunctionScoreQueryBuilder extends AbstractQueryBuilder { +public class FunctionScoreQueryBuilder extends AbstractQueryBuilder + implements + PrefilteredQuery { + public static final String NAME = "function_score"; // For better readability of error message @@ -74,6 +78,8 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder prefilters = List.of(); + /** * Creates a function_score query without functions * @@ -144,6 +150,9 @@ public FunctionScoreQueryBuilder(StreamInput in) throws IOException { minScore = in.readOptionalFloat(); boostMode = in.readOptionalWriteable(CombineFunction::readFromStream); scoreMode = FunctionScoreQuery.ScoreMode.readFromStream(in); + if (in.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + prefilters = in.readNamedWriteableCollectionAsList(QueryBuilder.class); + } } @Override @@ -154,6 +163,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeOptionalFloat(minScore); out.writeOptionalWriteable(boostMode); scoreMode.writeTo(out); + if (out.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + out.writeNamedWriteableCollection(prefilters); + } } /** @@ -277,7 +289,8 @@ protected boolean doEquals(FunctionScoreQueryBuilder other) { && Objects.equals(this.boostMode, other.boostMode) && Objects.equals(this.scoreMode, other.scoreMode) && Objects.equals(this.minScore, other.minScore) - && Objects.equals(this.maxBoost, other.maxBoost); + && Objects.equals(this.maxBoost, other.maxBoost) + && Objects.equals(this.prefilters, other.prefilters); } @Override @@ -288,7 +301,8 @@ protected int doHashCode() { this.boostMode, this.scoreMode, this.minScore, - this.maxBoost + this.maxBoost, + this.prefilters ); } @@ -322,6 +336,22 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { return new FunctionScoreQuery(query, scoreMode, filterFunctions, boostMode, minScore, maxBoost); } + @Override + public FunctionScoreQueryBuilder setPrefilters(List prefilters) { + this.prefilters = prefilters; + return this; + } + + @Override + public List getPrefilters() { + return prefilters; + } + + @Override + public List getPrefilteringTargetQueries() { + return List.of(query); + } + /** * Function to be associated with an optional filter, meaning it will be executed only for the documents * that match the given filter. @@ -405,6 +435,8 @@ public FilterFunctionBuilder rewrite(QueryRewriteContext context) throws IOExcep @Override protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + propagatePrefilters(); + QueryBuilder queryBuilder = this.query.rewrite(queryRewriteContext); if (queryBuilder instanceof MatchNoneQueryBuilder) { return queryBuilder; diff --git a/server/src/main/resources/transport/definitions/referable/query_prefiltering.csv b/server/src/main/resources/transport/definitions/referable/query_prefiltering.csv new file mode 100644 index 0000000000000..f81f2fd8c49c7 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/query_prefiltering.csv @@ -0,0 +1 @@ +9214000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index 61602dea24d29..54381ba55a0eb 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -batched_response_might_include_reduction_failure,9213000 +query_prefiltering,9214000 diff --git a/server/src/test/java/org/elasticsearch/index/query/AbstractPrefilteredQueryTestCase.java b/server/src/test/java/org/elasticsearch/index/query/AbstractPrefilteredQueryTestCase.java new file mode 100644 index 0000000000000..368309e48bedd --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/query/AbstractPrefilteredQueryTestCase.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.test.AbstractQueryTestCase; +import org.elasticsearch.test.TransportVersionUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public abstract class AbstractPrefilteredQueryTestCase & PrefilteredQuery> extends + AbstractQueryTestCase { + + protected abstract QB createQueryBuilderForPrefilteredRewriteTest(Supplier prefilteredQuerySupplier); + + protected abstract void assertRewrittenHasPropagatedPrefilters(QueryBuilder rewritten, List prefilters); + + public void testSerializationPrefiltersBwc() throws Exception { + QB originalQuery = createTestQueryBuilder(); + originalQuery.setPrefilters(randomList(1, 5, () -> RandomQueryBuilder.createQuery(random()))); + + for (int i = 0; i < 100; i++) { + TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween( + random(), + originalQuery.getMinimalSupportedVersion().id() == 0 + ? TransportVersions.V_8_0_0 // The first major before introducing prefiltering + : originalQuery.getMinimalSupportedVersion(), + TransportVersionUtils.getPreviousVersion(TransportVersion.current()) + ); + + @SuppressWarnings("unchecked") + QB deserializedQuery = (QB) copyNamedWriteable(originalQuery, namedWriteableRegistry(), QueryBuilder.class, transportVersion); + + if (transportVersion.supports(PrefilteredQuery.QUERY_PREFILTERING)) { + assertThat(deserializedQuery, equalTo(originalQuery)); + } else { + QB originalQueryWithoutPrefilters = copyQuery(originalQuery).setPrefilters(List.of()); + assertThat(deserializedQuery, equalTo(originalQueryWithoutPrefilters)); + } + } + } + + public void testEqualsAndHashcodeForPrefilters() throws IOException { + QB originalQuery = createTestQueryBuilder(); + originalQuery.setPrefilters(randomList(1, 5, () -> RandomQueryBuilder.createQuery(random()))); + + @SuppressWarnings("unchecked") + QB deserializedQuery = (QB) copyNamedWriteable(originalQuery, namedWriteableRegistry(), QueryBuilder.class); + + assertThat(deserializedQuery, equalTo(originalQuery)); + assertThat(deserializedQuery.hashCode(), equalTo(originalQuery.hashCode())); + + deserializedQuery.setPrefilters(List.of()); + assertThat(deserializedQuery, not(equalTo(originalQuery))); + assertThat(deserializedQuery.hashCode(), not(equalTo(originalQuery.hashCode()))); + } + + public void testRewriteWithPrefilters() throws IOException { + QueryRewriteContext queryRewriteContext = createQueryRewriteContext(); + SearchExecutionContext searchExecutionContext = createSearchExecutionContext(); + + for (int i = 0; i < 100; i++) { + QB queryBuilder = createQueryBuilderForPrefilteredRewriteTest(() -> createRandomPrefilteredQuery()); + if (queryBuilder == null) { + return; + } + setRandomPrefilters(queryBuilder); + + QueryBuilder rewritten = rewriteQuery(queryBuilder, queryRewriteContext, searchExecutionContext); + + assertRewrittenHasPropagatedPrefilters(rewritten, queryBuilder.getPrefilters()); + } + } + + private static void setRandomPrefilters(PrefilteredQuery queryBuilder) { + List filters = new ArrayList<>(); + int numFilters = randomIntBetween(1, 5); + for (int i = 0; i < numFilters; i++) { + filters.add(randomFrom(randomTermQuery(), createRandomPrefilteredQuery())); + } + queryBuilder.setPrefilters(filters); + } + + private static QueryBuilder randomTermQuery() { + String filterFieldName = randomFrom(KEYWORD_FIELD_NAME, TEXT_FIELD_NAME); + return QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10)); + } + + private static QueryBuilder createRandomPrefilteredQuery() { + return switch (randomFrom(PrefilteredQueryType.values())) { + case BOOL -> QueryBuilders.boolQuery().must(randomTermQuery()); + case BOOSTING -> QueryBuilders.boostingQuery(randomTermQuery(), randomTermQuery()); + case CONSTANT_SCORE -> QueryBuilders.constantScoreQuery(randomTermQuery()); + case DIS_MAX -> QueryBuilders.disMaxQuery().add(randomTermQuery()).add(randomTermQuery()); + case FUNCTION_SCORE -> QueryBuilders.functionScoreQuery(randomTermQuery()); + case NESTED -> QueryBuilders.nestedQuery(OBJECT_FIELD_NAME, randomTermQuery(), randomFrom(ScoreMode.values())); + }; + } + + private enum PrefilteredQueryType { + // We only include query types that have child queries. + BOOL, + BOOSTING, + CONSTANT_SCORE, + DIS_MAX, + FUNCTION_SCORE, + NESTED + } +} diff --git a/server/src/test/java/org/elasticsearch/index/query/BoolQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/BoolQueryBuilderTests.java index e9ef3ac8ad748..2257e3a137c57 100644 --- a/server/src/test/java/org/elasticsearch/index/query/BoolQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/BoolQueryBuilderTests.java @@ -14,7 +14,6 @@ import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; -import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParseException; @@ -25,9 +24,13 @@ import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Stream; import static org.elasticsearch.index.query.QueryBuilders.boolQuery; import static org.elasticsearch.index.query.QueryBuilders.termQuery; @@ -35,9 +38,10 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; -public class BoolQueryBuilderTests extends AbstractQueryTestCase { +public class BoolQueryBuilderTests extends AbstractPrefilteredQueryTestCase { @Override protected BoolQueryBuilder doCreateTestQueryBuilder() { BoolQueryBuilder query = new BoolQueryBuilder(); @@ -507,4 +511,46 @@ public void testShallowCopy() { } } } + + public void testGetPrefilters() { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + randomList(5, () -> RandomQueryBuilder.createQuery(random())).forEach(boolQueryBuilder::must); + randomList(5, () -> RandomQueryBuilder.createQuery(random())).forEach(boolQueryBuilder::should); + randomList(5, () -> RandomQueryBuilder.createQuery(random())).forEach(boolQueryBuilder::filter); + randomList(5, () -> RandomQueryBuilder.createQuery(random())).forEach(boolQueryBuilder::mustNot); + List topLevelPrefilters = randomList(5, () -> RandomQueryBuilder.createQuery(random())); + boolQueryBuilder.setPrefilters(topLevelPrefilters); + + Set expectedPrefilters = new HashSet<>(); + expectedPrefilters.addAll(boolQueryBuilder.filter()); + expectedPrefilters.addAll(boolQueryBuilder.mustNot().stream().map(q -> QueryBuilders.boolQuery().mustNot(q)).toList()); + expectedPrefilters.addAll(topLevelPrefilters); + + Set actualPrefilters = new HashSet<>(boolQueryBuilder.getPrefilters()); + assertThat(actualPrefilters, equalTo(expectedPrefilters)); + } + + @Override + protected BoolQueryBuilder createQueryBuilderForPrefilteredRewriteTest(Supplier prefilteredQuerySupplier) { + BoolQueryBuilder boolQueryBuilder = boolQuery(); + randomList(5, () -> prefilteredQuerySupplier.get()).forEach(boolQueryBuilder::must); + randomList(5, () -> prefilteredQuerySupplier.get()).forEach(boolQueryBuilder::should); + randomList(5, () -> prefilteredQuerySupplier.get()).forEach(boolQueryBuilder::filter); + randomList(5, () -> prefilteredQuerySupplier.get()).forEach(boolQueryBuilder::mustNot); + return boolQueryBuilder; + } + + @Override + protected void assertRewrittenHasPropagatedPrefilters(QueryBuilder rewritten, List prefilters) { + assertThat(rewritten, instanceOf(BoolQueryBuilder.class)); + BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) rewritten; + for (QueryBuilder query : Stream.concat(boolQueryBuilder.must().stream(), boolQueryBuilder.should().stream()).toList()) { + assertThat(query, instanceOf(PrefilteredQuery.class)); + assertThat(((PrefilteredQuery) query).getPrefilters(), equalTo(prefilters)); + } + for (QueryBuilder query : Stream.concat(boolQueryBuilder.filter().stream(), boolQueryBuilder.mustNot().stream()).toList()) { + assertThat(query, instanceOf(PrefilteredQuery.class)); + assertThat(((PrefilteredQuery) query).getPrefilters().isEmpty(), is(true)); + } + } } diff --git a/server/src/test/java/org/elasticsearch/index/query/BoostingQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/BoostingQueryBuilderTests.java index 763c9b585256e..2ce49e393ac0d 100644 --- a/server/src/test/java/org/elasticsearch/index/query/BoostingQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/BoostingQueryBuilderTests.java @@ -12,14 +12,16 @@ import org.apache.lucene.queries.function.FunctionScoreQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; -import org.elasticsearch.test.AbstractQueryTestCase; import java.io.IOException; +import java.util.List; +import java.util.function.Supplier; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.Matchers.equalTo; -public class BoostingQueryBuilderTests extends AbstractQueryTestCase { +public class BoostingQueryBuilderTests extends AbstractPrefilteredQueryTestCase { @Override protected BoostingQueryBuilder doCreateTestQueryBuilder() { @@ -148,4 +150,21 @@ public void testMustRewrite() throws IOException { e = expectThrows(IllegalStateException.class, () -> queryBuilder2.toQuery(context)); assertEquals("Rewrite first", e.getMessage()); } + + @Override + protected BoostingQueryBuilder createQueryBuilderForPrefilteredRewriteTest(Supplier prefilteredQuerySupplier) { + return QueryBuilders.boostingQuery(prefilteredQuerySupplier.get(), prefilteredQuerySupplier.get()); + } + + @Override + protected void assertRewrittenHasPropagatedPrefilters(QueryBuilder rewritten, List prefilters) { + assertThat(rewritten, instanceOf(BoostingQueryBuilder.class)); + BoostingQueryBuilder boostingQueryBuilder = (BoostingQueryBuilder) rewritten; + QueryBuilder positiveQuery = boostingQueryBuilder.positiveQuery(); + assertThat(positiveQuery, instanceOf(PrefilteredQuery.class)); + assertThat(((PrefilteredQuery) positiveQuery).getPrefilters(), equalTo(prefilters)); + QueryBuilder negativeQuery = boostingQueryBuilder.negativeQuery(); + assertThat(negativeQuery, instanceOf(PrefilteredQuery.class)); + assertThat(((PrefilteredQuery) negativeQuery).getPrefilters(), equalTo(prefilters)); + } } diff --git a/server/src/test/java/org/elasticsearch/index/query/ConstantScoreQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/ConstantScoreQueryBuilderTests.java index ce7480f643c08..e99b464642bf8 100644 --- a/server/src/test/java/org/elasticsearch/index/query/ConstantScoreQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/ConstantScoreQueryBuilderTests.java @@ -14,15 +14,16 @@ import org.apache.lucene.search.Query; import org.elasticsearch.common.ParsingException; import org.elasticsearch.core.Strings; -import org.elasticsearch.test.AbstractQueryTestCase; import java.io.IOException; +import java.util.List; +import java.util.function.Supplier; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.Matchers.containsString; -public class ConstantScoreQueryBuilderTests extends AbstractQueryTestCase { +public class ConstantScoreQueryBuilderTests extends AbstractPrefilteredQueryTestCase { /** * @return a {@link ConstantScoreQueryBuilder} with random boost between 0.1f and 2.0f */ @@ -118,4 +119,18 @@ public void testMustRewrite() throws IOException { IllegalStateException e = expectThrows(IllegalStateException.class, () -> queryBuilder.toQuery(context)); assertEquals("Rewrite first", e.getMessage()); } + + @Override + protected ConstantScoreQueryBuilder createQueryBuilderForPrefilteredRewriteTest(Supplier prefilteredQuerySupplier) { + return QueryBuilders.constantScoreQuery(prefilteredQuerySupplier.get()); + } + + @Override + protected void assertRewrittenHasPropagatedPrefilters(QueryBuilder rewritten, List prefilters) { + assertThat(rewritten, instanceOf(ConstantScoreQueryBuilder.class)); + QueryBuilder innerQuery = ((ConstantScoreQueryBuilder) rewritten).innerQuery(); + assertThat(innerQuery, instanceOf(PrefilteredQuery.class)); + assertEquals(prefilters, ((PrefilteredQuery) innerQuery).getPrefilters()); + } + } diff --git a/server/src/test/java/org/elasticsearch/index/query/DisMaxQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/DisMaxQueryBuilderTests.java index 0b7893312b997..24c7507240e3b 100644 --- a/server/src/test/java/org/elasticsearch/index/query/DisMaxQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/DisMaxQueryBuilderTests.java @@ -15,15 +15,18 @@ import org.apache.lucene.search.PrefixQuery; import org.apache.lucene.search.Query; import org.elasticsearch.core.Strings; -import org.elasticsearch.test.AbstractQueryTestCase; import java.io.IOException; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Supplier; -public class DisMaxQueryBuilderTests extends AbstractQueryTestCase { +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class DisMaxQueryBuilderTests extends AbstractPrefilteredQueryTestCase { /** * @return a {@link DisMaxQueryBuilder} with random inner queries */ @@ -144,4 +147,19 @@ public void testRewriteMultipleTimes() throws IOException { assertEquals(rewrittenAgain, expected); assertEquals(Rewriteable.rewrite(dismax, createSearchExecutionContext()), expected); } + + @Override + protected DisMaxQueryBuilder createQueryBuilderForPrefilteredRewriteTest(Supplier prefilteredQuerySupplier) { + return QueryBuilders.disMaxQuery().add(prefilteredQuerySupplier.get()).add(prefilteredQuerySupplier.get()); + } + + @Override + protected void assertRewrittenHasPropagatedPrefilters(QueryBuilder rewritten, List prefilters) { + assertThat(rewritten, instanceOf(DisMaxQueryBuilder.class)); + DisMaxQueryBuilder innerQueries = (DisMaxQueryBuilder) rewritten; + for (QueryBuilder prefilter : innerQueries.innerQueries()) { + assertThat(prefilter, instanceOf(PrefilteredQuery.class)); + assertThat(((PrefilteredQuery) prefilter).getPrefilters(), equalTo(prefilters)); + } + } } diff --git a/server/src/test/java/org/elasticsearch/index/query/MatchQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/MatchQueryBuilderTests.java index ba46bf76efbfe..e95692709c34d 100644 --- a/server/src/test/java/org/elasticsearch/index/query/MatchQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/MatchQueryBuilderTests.java @@ -40,7 +40,6 @@ import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.search.MatchQueryParser; import org.elasticsearch.index.search.MatchQueryParser.Type; -import org.elasticsearch.test.AbstractQueryTestCase; import org.hamcrest.Matcher; import org.hamcrest.Matchers; @@ -50,6 +49,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.function.Supplier; import static org.hamcrest.CoreMatchers.either; import static org.hamcrest.CoreMatchers.instanceOf; @@ -57,7 +57,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.notNullValue; -public class MatchQueryBuilderTests extends AbstractQueryTestCase { +public class MatchQueryBuilderTests extends AbstractPrefilteredQueryTestCase { @Override protected MatchQueryBuilder doCreateTestQueryBuilder() { @@ -533,6 +533,16 @@ public void testMaxBooleanClause() { expectThrows(IndexSearcher.TooManyClauses.class, () -> query.parse(Type.PHRASE, TEXT_FIELD_NAME, "")); } + @Override + protected MatchQueryBuilder createQueryBuilderForPrefilteredRewriteTest(Supplier prefilteredQuerySupplier) { + return null; + } + + @Override + protected void assertRewrittenHasPropagatedPrefilters(QueryBuilder rewritten, List prefilters) { + // Do nothing, match prefiltering is tested via the inference interceptor + } + private static class MockGraphAnalyzer extends Analyzer { CannedBinaryTokenStream tokenStream; diff --git a/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java index 39520db299f65..f6b75bd48a9a1 100644 --- a/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java @@ -31,7 +31,6 @@ import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; -import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.test.TransportVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; @@ -40,7 +39,9 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static org.elasticsearch.index.IndexSettingsTests.newIndexMeta; import static org.elasticsearch.index.query.InnerHitBuilderTests.randomNestedInnerHits; @@ -51,7 +52,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class NestedQueryBuilderTests extends AbstractQueryTestCase { +public class NestedQueryBuilderTests extends AbstractPrefilteredQueryTestCase { private static final String VECTOR_FIELD = "vector"; private static final int VECTOR_DIMENSION = 3; @@ -469,4 +470,17 @@ public void testDisallowExpensiveQueries() { ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> queryBuilder.toQuery(searchExecutionContext)); assertEquals("[joining] queries cannot be executed when 'search.allow_expensive_queries' is set to false.", e.getMessage()); } + + @Override + protected NestedQueryBuilder createQueryBuilderForPrefilteredRewriteTest(Supplier prefilteredQuerySupplier) { + return QueryBuilders.nestedQuery(OBJECT_FIELD_NAME, prefilteredQuerySupplier.get(), ScoreMode.None); + } + + @Override + protected void assertRewrittenHasPropagatedPrefilters(QueryBuilder rewritten, List prefilters) { + assertThat(rewritten, instanceOf(NestedQueryBuilder.class)); + NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) rewritten; + assertThat(nestedQueryBuilder.query(), instanceOf(PrefilteredQuery.class)); + assertThat(((PrefilteredQuery) nestedQueryBuilder.query()).getPrefilters(), equalTo(prefilters)); + } } diff --git a/server/src/test/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryBuilderTests.java index 108ac8101122b..81ae05c58c543 100644 --- a/server/src/test/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryBuilderTests.java @@ -33,9 +33,12 @@ import org.elasticsearch.common.unit.DistanceUnit; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.mapper.SeqNoFieldMapper; +import org.elasticsearch.index.query.AbstractPrefilteredQueryTestCase; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.PrefilteredQuery; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.RandomQueryBuilder; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.query.TermQueryBuilder; @@ -47,7 +50,6 @@ import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptType; import org.elasticsearch.search.MultiValueMode; -import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentType; @@ -64,6 +66,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static java.util.Collections.singletonList; import static org.elasticsearch.index.query.QueryBuilders.functionScoreQuery; @@ -80,7 +83,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.nullValue; -public class FunctionScoreQueryBuilderTests extends AbstractQueryTestCase { +public class FunctionScoreQueryBuilderTests extends AbstractPrefilteredQueryTestCase { private static final String[] SHUFFLE_PROTECTED_FIELDS = new String[] { Script.PARAMS_PARSE_FIELD.getPreferredName(), @@ -844,6 +847,19 @@ private void expectParsingException(String json, String message) { expectParsingException(json, equalTo("failed to parse [function_score] query. " + message)); } + @Override + protected FunctionScoreQueryBuilder createQueryBuilderForPrefilteredRewriteTest(Supplier prefilteredQuerySupplier) { + return QueryBuilders.functionScoreQuery(prefilteredQuerySupplier.get()); + } + + @Override + protected void assertRewrittenHasPropagatedPrefilters(QueryBuilder rewritten, List prefilters) { + assertThat(rewritten, instanceOf(FunctionScoreQueryBuilder.class)); + FunctionScoreQueryBuilder functionScoreQueryBuilder = (FunctionScoreQueryBuilder) rewritten; + assertThat(functionScoreQueryBuilder.query(), instanceOf(PrefilteredQuery.class)); + assertThat(((PrefilteredQuery) functionScoreQueryBuilder.query()).getPrefilters(), equalTo(prefilters)); + } + /** * A hack on top of the normal random score function that fixed toQuery to work properly in this unit testing environment. */ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 390c32bb773f8..5fdc84a0527a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -1004,7 +1004,13 @@ public boolean fieldHasValue(FieldInfos fieldInfos) { return fieldInfos.fieldInfo(getEmbeddingsFieldName(name())) != null; } - public QueryBuilder semanticQuery(InferenceResults inferenceResults, Integer requestSize, float boost, String queryName) { + public QueryBuilder semanticQuery( + InferenceResults inferenceResults, + Integer requestSize, + float boost, + String queryName, + List filters + ) { String nestedFieldPath = getChunksFieldName(name()); String inferenceResultsFieldName = getEmbeddingsFieldName(name()); QueryBuilder childQueryBuilder; @@ -1055,7 +1061,9 @@ yield new SparseVectorQueryBuilder( k = Math.max(k, DEFAULT_SIZE); } - yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null, null, null); + yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null, null, null).addFilterQueries( + filters + ); } default -> throw new IllegalStateException( "Field [" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java index 018fdca7fabdb..80004eb78e63d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java @@ -95,7 +95,8 @@ protected QueryBuilder queryFields( rewritten = new MatchNoneQueryBuilder(); } else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType) { rewritten = new SemanticQueryBuilder(getField(), getQuery(), null, inferenceResultsMap).boost(originalQuery.boost()) - .queryName(originalQuery.queryName()); + .queryName(originalQuery.queryName()) + .setPrefilters(originalQuery.getPrefilters()); } else { rewritten = originalQuery; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 4060d1c6bc4a9..68b455a7e70ef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -24,6 +24,7 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.PrefilteredQuery; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; @@ -62,7 +63,7 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -public class SemanticQueryBuilder extends AbstractQueryBuilder { +public class SemanticQueryBuilder extends AbstractQueryBuilder implements PrefilteredQuery { public static final String NAME = "semantic"; public static final NodeFeature SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = new NodeFeature("semantic_query.multiple_inference_ids"); @@ -101,6 +102,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder inferenceResultsMap; private final SetOnce> inferenceResultsMapSupplier; private final Boolean lenient; + private List prefilters = List.of(); // ccsRequest is only used on the local cluster coordinator node to detect when: // - The request references a remote index @@ -178,6 +180,10 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { this.ccsRequest = false; } + if (in.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + this.prefilters = in.readNamedWriteableCollectionAsList(QueryBuilder.class); + } + this.inferenceResultsMapSupplier = null; } @@ -230,6 +236,10 @@ protected void doWriteTo(StreamOutput out) throws IOException { + "." ); } + + if (out.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) { + out.writeNamedWriteableCollection(prefilters); + } } private SemanticQueryBuilder( @@ -247,6 +257,7 @@ private SemanticQueryBuilder( this.inferenceResultsMapSupplier = inferenceResultsMapSupplier; this.lenient = other.lenient; this.ccsRequest = ccsRequest; + this.prefilters = other.prefilters; } @Override @@ -493,7 +504,13 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx ); } - return semanticTextFieldType.semanticQuery(inferenceResults, searchExecutionContext.requestSize(), boost(), queryName()); + return semanticTextFieldType.semanticQuery( + inferenceResults, + searchExecutionContext.requestSize(), + boost(), + queryName(), + prefilters + ); } else if (lenient != null && lenient) { return new MatchNoneQueryBuilder(); } else { @@ -646,11 +663,28 @@ protected boolean doEquals(SemanticQueryBuilder other) { && Objects.equals(query, other.query) && Objects.equals(inferenceResultsMap, other.inferenceResultsMap) && Objects.equals(inferenceResultsMapSupplier, other.inferenceResultsMapSupplier) - && Objects.equals(ccsRequest, other.ccsRequest); + && Objects.equals(ccsRequest, other.ccsRequest) + && Objects.equals(prefilters, other.prefilters); } @Override protected int doHashCode() { - return Objects.hash(fieldName, query, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); + return Objects.hash(fieldName, query, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest, prefilters); + } + + @Override + public SemanticQueryBuilder setPrefilters(List prefilters) { + this.prefilters = prefilters; + return this; + } + + @Override + public List getPrefilters() { + return prefilters; + } + + @Override + public List getPrefilteringTargetQueries() { + return List.of(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java index ed87d5adda0b6..9eca85c7d9b8e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java @@ -11,9 +11,11 @@ import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.RandomQueryBuilder; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import java.util.List; import java.util.Map; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; @@ -97,6 +99,8 @@ public void testInterceptAndRewrite() throws Exception { final TestIndex testIndex2 = new TestIndex("test-index-2", Map.of(field, SPARSE_INFERENCE_ID), Map.of()); final TestIndex testIndex3 = new TestIndex("test-index-3", Map.of(), Map.of(field, Map.of("type", "text"))); final MatchQueryBuilder matchQuery = new MatchQueryBuilder(field, queryText).boost(3.0f).queryName("bar"); + List prefilters = randomList(5, () -> RandomQueryBuilder.createQuery(random())); + matchQuery.setPrefilters(prefilters); // Perform coordinator node rewrite final QueryRewriteContext queryRewriteContext = createQueryRewriteContext( @@ -137,7 +141,7 @@ public void testInterceptAndRewrite() throws Exception { queryText, null, coordinatorIntercepted.inferenceResultsMap - ).boost(matchQuery.boost()).queryName(matchQuery.queryName()); + ).boost(matchQuery.boost()).queryName(matchQuery.queryName()).setPrefilters(prefilters); // Perform data node rewrite on test index 1 final QueryRewriteContext indexMetadataContextTestIndex1 = createIndexMetadataContext( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index cb5d1d40e2c2a..44602f3e6520b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -43,7 +43,9 @@ import org.elasticsearch.index.mapper.SourceToParse; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; +import org.elasticsearch.index.query.AbstractPrefilteredQueryTestCase; import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; @@ -55,8 +57,8 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.WeightedToken; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.search.vectors.SparseVectorQueryWrapper; -import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.TransportVersionUtils; import org.elasticsearch.test.client.NoOpClient; @@ -70,6 +72,7 @@ import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import org.elasticsearch.xpack.inference.FakeMlPlugin; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; @@ -100,7 +103,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; -public class SemanticQueryBuilderTests extends AbstractQueryTestCase { +public class SemanticQueryBuilderTests extends AbstractPrefilteredQueryTestCase { private static final String SEMANTIC_TEXT_FIELD = "semantic"; private static final float TOKEN_WEIGHT = 0.5f; private static final int QUERY_TOKEN_LENGTH = 4; @@ -556,6 +559,29 @@ public void testSerializingQueryWhenNoInferenceId() throws IOException { assertThat(rewritten, instanceOf(MatchNoneQueryBuilder.class)); } + @Override + protected SemanticQueryBuilder createQueryBuilderForPrefilteredRewriteTest(Supplier prefilteredQuerySupplier) { + return doCreateTestQueryBuilder(); + } + + @Override + protected void assertRewrittenHasPropagatedPrefilters(QueryBuilder rewritten, List prefilters) { + assertThat(rewritten, instanceOf(NestedQueryBuilder.class)); + NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) rewritten; + switch (inferenceResultType) { + case NONE -> assertThat(nestedQueryBuilder.query(), instanceOf(MatchNoneQueryBuilder.class)); + case SPARSE_EMBEDDING -> assertThat(nestedQueryBuilder.query(), instanceOf(SparseVectorQueryBuilder.class)); + case TEXT_EMBEDDING -> assertVectorQueryBuilderWithPrefilters(nestedQueryBuilder.query(), prefilters); + default -> fail("Unexpected inference result type [" + inferenceResultType + "]"); + } + } + + private static void assertVectorQueryBuilderWithPrefilters(QueryBuilder queryBuilder, List prefilters) { + assertThat(queryBuilder, instanceOf(KnnVectorQueryBuilder.class)); + KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder; + assertThat(knnVectorQueryBuilder.filterQueries(), equalTo(prefilters)); + } + private static SourceToParse buildSemanticTextFieldWithInferenceResults( InferenceResultType inferenceResultType, DenseVectorFieldMapper.ElementType denseVectorElementType,