Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fa34177
POC - Automatic prefiltering for semantic_text queries
dimitris-athanasiou Oct 31, 2025
cf68a9f
[CI] Auto commit changes from spotless
Oct 31, 2025
2f93625
Set prefilters rather than add to existing ones
dimitris-athanasiou Nov 3, 2025
d26b648
boosting query needs not propagate its positive query
dimitris-athanasiou Nov 3, 2025
c7c62db
prefiltering for nested query
dimitris-athanasiou Nov 3, 2025
92f2201
Merge branch 'main' into auto-prefiltering
dimitris-athanasiou Nov 3, 2025
4a4c060
[CI] Auto commit changes from spotless
Nov 3, 2025
3a4c3e4
Add interface method for declaring prefiltering targets
dimitris-athanasiou Nov 4, 2025
2dd0b43
Add bool query `must_not` clauses as prefilters
dimitris-athanasiou Nov 4, 2025
7291d4e
Rename Prefiltering -> PrefilteredQuery
dimitris-athanasiou Nov 5, 2025
2d9c7e6
Adds a comment
dimitris-athanasiou Nov 5, 2025
a53c9f2
Fix prefilters assignment on SemanticQueryBuilder copy
dimitris-athanasiou Nov 5, 2025
74af627
Merge branch 'main' into auto-prefiltering
dimitris-athanasiou Nov 5, 2025
f41f0f1
SemanticQueryBuilderTests
dimitris-athanasiou Nov 5, 2025
72a4861
clean up
dimitris-athanasiou Nov 5, 2025
a672a52
Boosting query should also prefilter negative query
dimitris-athanasiou Nov 5, 2025
45c39eb
Serialization and equality tests
dimitris-athanasiou Nov 5, 2025
e9a7218
Move prefiltering test util in abstract class
dimitris-athanasiou Nov 5, 2025
84bebea
Adds tests
dimitris-athanasiou Nov 6, 2025
22006a8
Merge branch 'main' into auto-prefiltering
dimitris-athanasiou Nov 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,20 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.elasticsearch.common.lucene.search.Queries.fixNegativeQueryIfNeeded;

/**
* A Query that matches documents matching boolean combinations of other queries.
*/
public class BoolQueryBuilder extends AbstractQueryBuilder<BoolQueryBuilder> {
public class BoolQueryBuilder extends AbstractQueryBuilder<BoolQueryBuilder> implements PrefilteredQuery<BoolQueryBuilder> {
public static final String NAME = "bool";

public static final boolean ADJUST_PURE_NEGATIVE_DEFAULT = true;
Expand All @@ -60,6 +63,8 @@ public class BoolQueryBuilder extends AbstractQueryBuilder<BoolQueryBuilder> {

private String minimumShouldMatch;

private List<QueryBuilder> prefilters = List.of();

/**
* Build an empty bool query.
*/
Expand All @@ -76,6 +81,9 @@ public BoolQueryBuilder(StreamInput in) throws IOException {
filterClauses.addAll(readQueries(in));
adjustPureNegative = in.readBoolean();
minimumShouldMatch = in.readOptionalString();
if (in.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) {
prefilters = in.readNamedWriteableCollectionAsList(QueryBuilder.class);
}
}

@Override
Expand All @@ -86,6 +94,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
writeQueries(out, filterClauses);
out.writeBoolean(adjustPureNegative);
out.writeOptionalString(minimumShouldMatch);
if (out.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) {
out.writeNamedWriteableCollection(prefilters);
}
}

/**
Expand Down Expand Up @@ -332,7 +343,7 @@ private static void addBooleanClauses(

@Override
protected int doHashCode() {
return Objects.hash(adjustPureNegative, minimumShouldMatch, mustClauses, shouldClauses, mustNotClauses, filterClauses);
return Objects.hash(adjustPureNegative, minimumShouldMatch, mustClauses, shouldClauses, mustNotClauses, filterClauses, prefilters);
}

@Override
Expand All @@ -342,7 +353,8 @@ protected boolean doEquals(BoolQueryBuilder other) {
&& Objects.equals(mustClauses, other.mustClauses)
&& Objects.equals(shouldClauses, other.shouldClauses)
&& Objects.equals(mustNotClauses, other.mustNotClauses)
&& Objects.equals(filterClauses, other.filterClauses);
&& Objects.equals(filterClauses, other.filterClauses)
&& Objects.equals(prefilters, other.prefilters);
}

@Override
Expand All @@ -353,6 +365,9 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
if (clauses == 0) {
return new MatchAllQueryBuilder().boost(boost()).queryName(queryName());
}

propagatePrefilters();

changed |= rewriteClauses(queryRewriteContext, mustClauses, newBuilder::must);

try {
Expand Down Expand Up @@ -415,6 +430,7 @@ private static boolean rewriteClauses(
) throws IOException {
boolean changed = false;
for (QueryBuilder builder : builders) {

QueryBuilder result = builder.rewrite(queryRewriteContext);
if (result != builder) {
changed = true;
Expand Down Expand Up @@ -452,4 +468,23 @@ public BoolQueryBuilder shallowCopy() {
}
return copy;
}

@Override
public BoolQueryBuilder setPrefilters(List<QueryBuilder> prefilters) {
this.prefilters = prefilters;
return this;
}

@Override
public List<QueryBuilder> getPrefilters() {
// We declare as prefilters clauses run in the filter context, namely filter and must_not
return Stream.of(prefilters, filterClauses, mustNotClauses.stream().map(c -> QueryBuilders.boolQuery().mustNot(c)).toList())
.flatMap(Collection::stream)
.collect(Collectors.toList());
}

@Override
public List<QueryBuilder> getPrefilteringTargetQueries() {
return Stream.concat(mustClauses.stream(), shouldClauses.stream()).toList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;

Expand All @@ -35,7 +36,7 @@
* multiplied by the supplied "boost" parameter, so this should be less than 1 to achieve a
* demoting effect
*/
public class BoostingQueryBuilder extends AbstractQueryBuilder<BoostingQueryBuilder> {
public class BoostingQueryBuilder extends AbstractQueryBuilder<BoostingQueryBuilder> implements PrefilteredQuery<BoostingQueryBuilder> {
public static final String NAME = "boosting";

private static final ParseField POSITIVE_FIELD = new ParseField("positive");
Expand All @@ -48,6 +49,8 @@ public class BoostingQueryBuilder extends AbstractQueryBuilder<BoostingQueryBuil

private float negativeBoost = -1;

private List<QueryBuilder> prefilters = List.of();

/**
* Create a new {@link BoostingQueryBuilder}
*
Expand All @@ -73,13 +76,19 @@ public BoostingQueryBuilder(StreamInput in) throws IOException {
positiveQuery = in.readNamedWriteable(QueryBuilder.class);
negativeQuery = in.readNamedWriteable(QueryBuilder.class);
negativeBoost = in.readFloat();
if (in.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) {
prefilters = in.readNamedWriteableCollectionAsList(QueryBuilder.class);
}
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(positiveQuery);
out.writeNamedWriteable(negativeQuery);
out.writeFloat(negativeBoost);
if (out.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) {
out.writeNamedWriteableCollection(prefilters);
}
}

/**
Expand Down Expand Up @@ -197,18 +206,21 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {

@Override
protected int doHashCode() {
return Objects.hash(negativeBoost, positiveQuery, negativeQuery);
return Objects.hash(negativeBoost, positiveQuery, negativeQuery, prefilters);
}

@Override
protected boolean doEquals(BoostingQueryBuilder other) {
return Objects.equals(negativeBoost, other.negativeBoost)
&& Objects.equals(positiveQuery, other.positiveQuery)
&& Objects.equals(negativeQuery, other.negativeQuery);
&& Objects.equals(negativeQuery, other.negativeQuery)
&& Objects.equals(prefilters, other.prefilters);
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
propagatePrefilters();

QueryBuilder positiveQuery = this.positiveQuery.rewrite(queryRewriteContext);
if (positiveQuery instanceof MatchNoneQueryBuilder) {
return positiveQuery;
Expand All @@ -233,4 +245,20 @@ protected void extractInnerHitBuilders(Map<String, InnerHitContextBuilder> inner
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.zero();
}

@Override
public BoostingQueryBuilder setPrefilters(List<QueryBuilder> prefilters) {
this.prefilters = prefilters;
return this;
}

@Override
public List<QueryBuilder> getPrefilters() {
return prefilters;
}

@Override
public List<QueryBuilder> getPrefilteringTargetQueries() {
return List.of(positiveQuery, negativeQuery);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,26 @@
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
* A query that wraps a filter and simply returns a constant score equal to the
* query boost for every document in the filter.
*/
public class ConstantScoreQueryBuilder extends AbstractQueryBuilder<ConstantScoreQueryBuilder> {
public class ConstantScoreQueryBuilder extends AbstractQueryBuilder<ConstantScoreQueryBuilder>
implements
PrefilteredQuery<ConstantScoreQueryBuilder> {

public static final String NAME = "constant_score";

private static final ParseField INNER_QUERY_FIELD = new ParseField("filter");

private final QueryBuilder filterBuilder;

private List<QueryBuilder> prefilters = List.of();

/**
* A query that wraps another query and simply returns a constant score equal to the
* query boost for every document in the query.
Expand All @@ -53,11 +59,17 @@ public ConstantScoreQueryBuilder(QueryBuilder filterBuilder) {
public ConstantScoreQueryBuilder(StreamInput in) throws IOException {
super(in);
filterBuilder = in.readNamedWriteable(QueryBuilder.class);
if (in.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) {
prefilters = in.readNamedWriteableCollectionAsList(QueryBuilder.class);
}
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(filterBuilder);
if (out.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) {
out.writeNamedWriteableCollection(prefilters);
}
}

/**
Expand Down Expand Up @@ -135,20 +147,23 @@ public String getWriteableName() {

@Override
protected int doHashCode() {
return Objects.hash(filterBuilder);
return Objects.hash(filterBuilder, prefilters);
}

@Override
protected boolean doEquals(ConstantScoreQueryBuilder other) {
return Objects.equals(filterBuilder, other.filterBuilder);
return Objects.equals(filterBuilder, other.filterBuilder) && Objects.equals(prefilters, other.prefilters);
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
propagatePrefilters();

QueryBuilder rewrite = filterBuilder.rewrite(queryRewriteContext);
if (rewrite instanceof MatchNoneQueryBuilder) {
return rewrite; // we won't match anyway
}

if (rewrite != filterBuilder) {
return new ConstantScoreQueryBuilder(rewrite);
}
Expand All @@ -164,4 +179,20 @@ protected void extractInnerHitBuilders(Map<String, InnerHitContextBuilder> inner
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.zero();
}

@Override
public ConstantScoreQueryBuilder setPrefilters(List<QueryBuilder> prefilters) {
this.prefilters = prefilters;
return this;
}

@Override
public List<QueryBuilder> getPrefilters() {
return prefilters;
}

@Override
public List<QueryBuilder> getPrefilteringTargetQueries() {
return List.of(filterBuilder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
* with the maximum score for that document as produced by any sub-query, plus a tie breaking increment for any
* additional matching sub-queries.
*/
public class DisMaxQueryBuilder extends AbstractQueryBuilder<DisMaxQueryBuilder> {
public class DisMaxQueryBuilder extends AbstractQueryBuilder<DisMaxQueryBuilder> implements PrefilteredQuery<DisMaxQueryBuilder> {
public static final String NAME = "dis_max";

/** Default multiplication factor for breaking ties in document scores.*/
Expand All @@ -42,6 +42,7 @@ public class DisMaxQueryBuilder extends AbstractQueryBuilder<DisMaxQueryBuilder>
private static final ParseField QUERIES_FIELD = new ParseField("queries");

private final List<QueryBuilder> queries = new ArrayList<>();
private List<QueryBuilder> prefilters = List.of();

private float tieBreaker = DEFAULT_TIE_BREAKER;

Expand All @@ -54,12 +55,18 @@ public DisMaxQueryBuilder(StreamInput in) throws IOException {
super(in);
queries.addAll(readQueries(in));
tieBreaker = in.readFloat();
if (in.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) {
prefilters = readQueries(in);
}
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
writeQueries(out, queries);
out.writeFloat(tieBreaker);
if (out.getTransportVersion().supports(PrefilteredQuery.QUERY_PREFILTERING)) {
out.writeNamedWriteableCollection(prefilters);
}
}

/**
Expand Down Expand Up @@ -182,6 +189,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
propagatePrefilters();

DisMaxQueryBuilder newBuilder = new DisMaxQueryBuilder();
boolean changed = false;
for (QueryBuilder query : queries) {
Expand All @@ -203,12 +212,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws

@Override
protected int doHashCode() {
return Objects.hash(queries, tieBreaker);
return Objects.hash(queries, tieBreaker, prefilters);
}

@Override
protected boolean doEquals(DisMaxQueryBuilder other) {
return Objects.equals(queries, other.queries) && Objects.equals(tieBreaker, other.tieBreaker);
return Objects.equals(queries, other.queries)
&& Objects.equals(tieBreaker, other.tieBreaker)
&& Objects.equals(prefilters, other.prefilters);
}

@Override
Expand All @@ -227,4 +238,20 @@ protected void extractInnerHitBuilders(Map<String, InnerHitContextBuilder> inner
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.zero();
}

@Override
public DisMaxQueryBuilder setPrefilters(List<QueryBuilder> prefilters) {
this.prefilters = prefilters;
return this;
}

@Override
public List<QueryBuilder> getPrefilters() {
return prefilters;
}

@Override
public List<QueryBuilder> getPrefilteringTargetQueries() {
return queries;
}
}
Loading