From f2df176ee4ab9584d9b614dcc7f8d520148e4c95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yoann=20Rodi=C3=A8re?= Date: Wed, 15 Apr 2020 11:37:55 +0200 Subject: [PATCH] HSEARCH-3881 Expose the extended search predicate factory in extended search aggregation factories I.e. make sure that this works: .search( ... ).extension( LuceneExtension.get() ) .where( ... ) .aggregation( f -> f.terms()... .filter( pf -> pf.fromLuceneQuery( ... ) ) ) ... and that we don't have to call .extension() again on "pf", because the provided factory is already extended. --- .../elasticsearch/ElasticsearchExtension.java | 5 +- ...ElasticsearchSearchAggregationFactory.java | 6 +- ...ticsearchSearchAggregationFactoryImpl.java | 9 +-- .../backend/lucene/LuceneExtension.java | 5 +- .../dsl/LuceneSearchAggregationFactory.java | 5 +- .../LuceneSearchAggregationFactoryImpl.java | 10 ++-- .../scope/impl/MappedIndexScopeImpl.java | 6 +- .../dsl/AggregationFilterStep.java | 60 +++++++++++++++++++ .../dsl/ExtendedSearchAggregationFactory.java | 30 ++++++++++ .../dsl/RangeAggregationFieldStep.java | 11 +++- .../dsl/RangeAggregationOptionsStep.java | 41 +++---------- .../dsl/RangeAggregationRangeMoreStep.java | 12 ++-- .../dsl/RangeAggregationRangeStep.java | 9 ++- .../dsl/SearchAggregationFactory.java | 4 +- .../SearchAggregationFactoryExtension.java | 2 +- .../dsl/TermsAggregationFieldStep.java | 10 +++- .../dsl/TermsAggregationOptionsStep.java | 41 +++---------- .../impl/DefaultSearchAggregationFactory.java | 13 ++-- .../impl/RangeAggregationFieldStepImpl.java | 9 +-- .../impl/RangeAggregationRangeStepImpl.java | 37 ++++-------- .../impl/SearchAggregationDslContextImpl.java | 39 +++++++++--- .../impl/TermsAggregationFieldStepImpl.java | 9 +-- .../impl/TermsAggregationOptionsStepImpl.java | 42 +++++-------- .../DelegatingSearchAggregationFactory.java | 20 +++++-- .../dsl/spi/SearchAggregationDslContext.java | 21 ++++++- .../spi/AbstractSearchQueryOptionsStep.java | 3 +- .../search/aggregation/AggregationBaseIT.java | 4 +- 27 files changed, 281 insertions(+), 182 deletions(-) create mode 100644 engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/AggregationFilterStep.java create mode 100644 engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/ExtendedSearchAggregationFactory.java diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/ElasticsearchExtension.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/ElasticsearchExtension.java index 26621a6a6c8..9f3eed3049d 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/ElasticsearchExtension.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/ElasticsearchExtension.java @@ -203,11 +203,12 @@ public Optional> extendOptional( @Override @SuppressWarnings("unchecked") // If the factory is an instance of ElasticsearchSearchAggregationBuilderFactory, the cast is safe public Optional extendOptional( - SearchAggregationFactory original, SearchAggregationDslContext dslContext) { + SearchAggregationFactory original, SearchAggregationDslContext dslContext) { if ( dslContext.getBuilderFactory() instanceof ElasticsearchSearchAggregationBuilderFactory ) { return Optional.of( new ElasticsearchSearchAggregationFactoryImpl( original, - (SearchAggregationDslContext) dslContext + ((SearchAggregationDslContext) dslContext) + .withExtendedPredicateFactory( this ) ) ); } else { diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/dsl/ElasticsearchSearchAggregationFactory.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/dsl/ElasticsearchSearchAggregationFactory.java index 0db63fea377..cfeca418b0a 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/dsl/ElasticsearchSearchAggregationFactory.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/dsl/ElasticsearchSearchAggregationFactory.java @@ -6,12 +6,14 @@ */ package org.hibernate.search.backend.elasticsearch.search.aggregation.dsl; +import org.hibernate.search.backend.elasticsearch.search.predicate.dsl.ElasticsearchSearchPredicateFactory; import org.hibernate.search.engine.search.aggregation.dsl.AggregationFinalStep; -import org.hibernate.search.engine.search.aggregation.dsl.SearchAggregationFactory; +import org.hibernate.search.engine.search.aggregation.dsl.ExtendedSearchAggregationFactory; import com.google.gson.JsonObject; -public interface ElasticsearchSearchAggregationFactory extends SearchAggregationFactory { +public interface ElasticsearchSearchAggregationFactory + extends ExtendedSearchAggregationFactory { /** * Create an aggregation from JSON. diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/dsl/impl/ElasticsearchSearchAggregationFactoryImpl.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/dsl/impl/ElasticsearchSearchAggregationFactoryImpl.java index aed6789ceac..250f60cfd42 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/dsl/impl/ElasticsearchSearchAggregationFactoryImpl.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/aggregation/dsl/impl/ElasticsearchSearchAggregationFactoryImpl.java @@ -8,6 +8,7 @@ import org.hibernate.search.backend.elasticsearch.search.aggregation.impl.ElasticsearchSearchAggregationBuilderFactory; import org.hibernate.search.backend.elasticsearch.search.aggregation.dsl.ElasticsearchSearchAggregationFactory; +import org.hibernate.search.backend.elasticsearch.search.predicate.dsl.ElasticsearchSearchPredicateFactory; import org.hibernate.search.engine.search.aggregation.dsl.AggregationFinalStep; import org.hibernate.search.engine.search.aggregation.dsl.SearchAggregationFactory; import org.hibernate.search.engine.search.aggregation.dsl.spi.DelegatingSearchAggregationFactory; @@ -16,14 +17,14 @@ import com.google.gson.JsonObject; public class ElasticsearchSearchAggregationFactoryImpl - extends DelegatingSearchAggregationFactory + extends DelegatingSearchAggregationFactory implements ElasticsearchSearchAggregationFactory { - private final SearchAggregationDslContext dslContext; + private final SearchAggregationDslContext dslContext; public ElasticsearchSearchAggregationFactoryImpl(SearchAggregationFactory delegate, - SearchAggregationDslContext dslContext) { - super( delegate ); + SearchAggregationDslContext dslContext) { + super( delegate, dslContext ); this.dslContext = dslContext; } diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/LuceneExtension.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/LuceneExtension.java index c8767f362af..b877d6f7d0e 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/LuceneExtension.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/LuceneExtension.java @@ -203,11 +203,12 @@ public Optional> extendOptional( @Override @SuppressWarnings("unchecked") // If the factory is an instance of LuceneSearchAggregationBuilderFactory, the cast is safe public Optional extendOptional( - SearchAggregationFactory original, SearchAggregationDslContext dslContext) { + SearchAggregationFactory original, SearchAggregationDslContext dslContext) { if ( dslContext.getBuilderFactory() instanceof LuceneSearchAggregationBuilderFactory ) { return Optional.of( new LuceneSearchAggregationFactoryImpl( original, - (SearchAggregationDslContext) dslContext + ((SearchAggregationDslContext) dslContext) + .withExtendedPredicateFactory( this ) ) ); } else { diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/search/aggregation/dsl/LuceneSearchAggregationFactory.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/search/aggregation/dsl/LuceneSearchAggregationFactory.java index 1972a14ee3c..29c53150f2f 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/search/aggregation/dsl/LuceneSearchAggregationFactory.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/search/aggregation/dsl/LuceneSearchAggregationFactory.java @@ -6,7 +6,8 @@ */ package org.hibernate.search.backend.lucene.search.aggregation.dsl; -import org.hibernate.search.engine.search.aggregation.dsl.SearchAggregationFactory; +import org.hibernate.search.backend.lucene.search.predicate.dsl.LuceneSearchPredicateFactory; +import org.hibernate.search.engine.search.aggregation.dsl.ExtendedSearchAggregationFactory; -public interface LuceneSearchAggregationFactory extends SearchAggregationFactory { +public interface LuceneSearchAggregationFactory extends ExtendedSearchAggregationFactory { } diff --git a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/search/aggregation/dsl/impl/LuceneSearchAggregationFactoryImpl.java b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/search/aggregation/dsl/impl/LuceneSearchAggregationFactoryImpl.java index a2dd68d1bc6..122ea9d76ab 100644 --- a/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/search/aggregation/dsl/impl/LuceneSearchAggregationFactoryImpl.java +++ b/backend/lucene/src/main/java/org/hibernate/search/backend/lucene/search/aggregation/dsl/impl/LuceneSearchAggregationFactoryImpl.java @@ -8,20 +8,18 @@ import org.hibernate.search.backend.lucene.search.aggregation.impl.LuceneSearchAggregationBuilderFactory; import org.hibernate.search.backend.lucene.search.aggregation.dsl.LuceneSearchAggregationFactory; +import org.hibernate.search.backend.lucene.search.predicate.dsl.LuceneSearchPredicateFactory; import org.hibernate.search.engine.search.aggregation.dsl.SearchAggregationFactory; import org.hibernate.search.engine.search.aggregation.dsl.spi.DelegatingSearchAggregationFactory; import org.hibernate.search.engine.search.aggregation.dsl.spi.SearchAggregationDslContext; public class LuceneSearchAggregationFactoryImpl - extends DelegatingSearchAggregationFactory + extends DelegatingSearchAggregationFactory implements LuceneSearchAggregationFactory { - private final SearchAggregationDslContext dslContext; - public LuceneSearchAggregationFactoryImpl(SearchAggregationFactory delegate, - SearchAggregationDslContext dslContext) { - super( delegate ); - this.dslContext = dslContext; + SearchAggregationDslContext dslContext) { + super( delegate, dslContext ); } // Empty: no extension at the moment. diff --git a/engine/src/main/java/org/hibernate/search/engine/mapper/scope/impl/MappedIndexScopeImpl.java b/engine/src/main/java/org/hibernate/search/engine/mapper/scope/impl/MappedIndexScopeImpl.java index 16a580a493d..e95a2f0a5e5 100644 --- a/engine/src/main/java/org/hibernate/search/engine/mapper/scope/impl/MappedIndexScopeImpl.java +++ b/engine/src/main/java/org/hibernate/search/engine/mapper/scope/impl/MappedIndexScopeImpl.java @@ -72,7 +72,11 @@ public SearchProjectionFactory projection() { @Override public SearchAggregationFactory aggregation() { return new DefaultSearchAggregationFactory( - SearchAggregationDslContextImpl.root( delegate.getSearchAggregationFactory(), delegate.getSearchPredicateBuilderFactory() ) + SearchAggregationDslContextImpl.root( + delegate.getSearchAggregationFactory(), + predicate(), + delegate.getSearchPredicateBuilderFactory() + ) ); } } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/AggregationFilterStep.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/AggregationFilterStep.java new file mode 100644 index 00000000000..1c73d78c04c --- /dev/null +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/AggregationFilterStep.java @@ -0,0 +1,60 @@ +/* + * Hibernate Search, full-text search for your domain model + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.search.engine.search.aggregation.dsl; + +import java.util.function.Function; + +import org.hibernate.search.engine.search.predicate.SearchPredicate; +import org.hibernate.search.engine.search.predicate.dsl.PredicateFinalStep; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; + +/** + * The step in an aggregation definition where a filter can be set + * to select nested objects from which values will be extracted for this aggregation. + * + * @param The "self" type (the actual exposed type of this step) + * @param The type of factory used to create predicates in {@link #filter(Function)}. + */ +public interface AggregationFilterStep { + + /** + * Filter nested objects from which values will be extracted for this aggregation. + *

+ * The filter is based on a previously-built {@link SearchPredicate}. + * + * @param searchPredicate The predicate that must match. + * @return {@code this}, for method chaining. + */ + S filter(SearchPredicate searchPredicate); + + /** + * Filter nested objects from which values will be extracted for this aggregation. + *

+ * The filter is defined by the given function. + *

+ * Best used with lambda expressions. + * + * @param clauseContributor A function that will use the factory passed in parameter to create a predicate, + * returning the final step in the predicate DSL. + * Should generally be a lambda expression. + * @return {@code this}, for method chaining. + */ + S filter(Function clauseContributor); + + /** + * Filter nested objects from which values will be extracted for this aggregation. + *

+ * The filter is based on an almost-built {@link SearchPredicate}. + * + * @param dslFinalStep A final step in the predicate DSL allowing the retrieval of a {@link SearchPredicate}. + * @return {@code this}, for method chaining. + */ + default S filter(PredicateFinalStep dslFinalStep) { + return filter( dslFinalStep.toPredicate() ); + } + +} diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/ExtendedSearchAggregationFactory.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/ExtendedSearchAggregationFactory.java new file mode 100644 index 00000000000..48d57ae4b4d --- /dev/null +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/ExtendedSearchAggregationFactory.java @@ -0,0 +1,30 @@ +/* + * Hibernate Search, full-text search for your domain model + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.search.engine.search.aggregation.dsl; + +import java.util.function.Function; + +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; + +/** + * A base interface for subtypes of {@link SearchAggregationFactory} allowing to + * easily override the predicate factory type for all relevant methods. + *

+ * Warning: Generic parameters of this type are subject to change, + * so this type should not be referenced directtly in user code. + * + * @param The type of factory used to create predicates in {@link AggregationFilterStep#filter(Function)}. + */ +public interface ExtendedSearchAggregationFactory + extends SearchAggregationFactory { + + @Override + RangeAggregationFieldStep range(); + + @Override + TermsAggregationFieldStep terms(); +} diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationFieldStep.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationFieldStep.java index ac8a95c65f8..b204c803e1d 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationFieldStep.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationFieldStep.java @@ -6,12 +6,17 @@ */ package org.hibernate.search.engine.search.aggregation.dsl; +import java.util.function.Function; + import org.hibernate.search.engine.search.common.ValueConvert; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; /** * The initial step in a "range" aggregation definition, where the target field can be set. + * + * @param The type of factory used to create predicates in {@link AggregationFilterStep#filter(Function)}. */ -public interface RangeAggregationFieldStep { +public interface RangeAggregationFieldStep { /** * Target the given field in the range aggregation. @@ -21,7 +26,7 @@ public interface RangeAggregationFieldStep { * @param The type of field values. * @return The next step. */ - default RangeAggregationRangeStep field(String absoluteFieldPath, Class type) { + default RangeAggregationRangeStep field(String absoluteFieldPath, Class type) { return field( absoluteFieldPath, type, ValueConvert.YES ); } @@ -35,6 +40,6 @@ default RangeAggregationRangeStep field(String absoluteFieldPath, Clas * See {@link ValueConvert}. * @return The next step. */ - RangeAggregationRangeStep field(String absoluteFieldPath, Class type, ValueConvert convert); + RangeAggregationRangeStep field(String absoluteFieldPath, Class type, ValueConvert convert); } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationOptionsStep.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationOptionsStep.java index d189f74f0d4..bcf532af991 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationOptionsStep.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationOptionsStep.java @@ -7,47 +7,22 @@ package org.hibernate.search.engine.search.aggregation.dsl; import java.util.function.Function; -import org.hibernate.search.engine.search.predicate.SearchPredicate; -import org.hibernate.search.engine.search.predicate.dsl.PredicateFinalStep; import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; /** * The final step in a "range" aggregation definition, where optional parameters can be set. * * @param The "self" type (the actual exposed type of this step). + * @param The type of factory used to create predicates in {@link #filter(Function)}. * @param The type of the targeted field. * @param The type of result for this aggregation. */ -public interface RangeAggregationOptionsStep, F, A> - extends AggregationFinalStep { +public interface RangeAggregationOptionsStep< + S extends RangeAggregationOptionsStep, + PDF extends SearchPredicateFactory, + F, + A + > + extends AggregationFinalStep, AggregationFilterStep { - /** - * Add a "filter" clause based on a previously-built {@link SearchPredicate}. - * - * @param searchPredicate The predicate that must match. - * @return {@code this}, for method chaining. - */ - S filter(SearchPredicate searchPredicate); - - /** - * Add a "filter" clause to be defined by the given function. - *

- * Best used with lambda expressions. - * - * @param clauseContributor A function that will use the factory passed in parameter to create a predicate, - * returning the final step in the predicate DSL. - * Should generally be a lambda expression. - * @return {@code this}, for method chaining. - */ - S filter(Function clauseContributor); - - /** - * Add a "filter" clause based on an almost-built {@link SearchPredicate}. - * - * @param dslFinalStep A final step in the predicate DSL allowing the retrieval of a {@link SearchPredicate}. - * @return {@code this}, for method chaining. - */ - default S filter(PredicateFinalStep dslFinalStep) { - return filter( dslFinalStep.toPredicate() ); - } } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationRangeMoreStep.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationRangeMoreStep.java index e34befae9c0..37b95da6e03 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationRangeMoreStep.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationRangeMoreStep.java @@ -7,7 +7,9 @@ package org.hibernate.search.engine.search.aggregation.dsl; import java.util.Map; +import java.util.function.Function; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; import org.hibernate.search.util.common.data.Range; /** @@ -16,15 +18,17 @@ * or more ranges can be added. * * @param The "self" type (the actual exposed type of this step). + * @param The type of factory used to create predicates in {@link #filter(Function)}. * @param The type of the next step. * @param The type of the targeted field. */ public interface RangeAggregationRangeMoreStep< - S extends RangeAggregationRangeMoreStep, - N extends RangeAggregationOptionsStep, Long>>, + S extends RangeAggregationRangeMoreStep, + N extends RangeAggregationOptionsStep, Long>>, + PDF extends SearchPredicateFactory, F > - extends RangeAggregationOptionsStep, Long>>, - RangeAggregationRangeStep { + extends RangeAggregationOptionsStep, Long>>, + RangeAggregationRangeStep { } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationRangeStep.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationRangeStep.java index 5c847b6916a..1aeb64620da 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationRangeStep.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/RangeAggregationRangeStep.java @@ -7,16 +7,23 @@ package org.hibernate.search.engine.search.aggregation.dsl; import java.util.Collection; +import java.util.function.Function; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; import org.hibernate.search.util.common.data.Range; /** * The step in a "range" aggregation definition where the ranges can be set. * * @param The type of the next step. + * @param The type of factory used to create predicates in {@link AggregationFilterStep#filter(Function)}. * @param The type of the targeted field. */ -public interface RangeAggregationRangeStep, F> { +public interface RangeAggregationRangeStep< + N extends RangeAggregationRangeMoreStep, + PDF extends SearchPredicateFactory, + F + > { /** * Add a bucket for the range {@code [lowerBound, upperBound)} (lower bound included, upper bound excluded), diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/SearchAggregationFactory.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/SearchAggregationFactory.java index 8fc3212687a..fdd8a424182 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/SearchAggregationFactory.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/SearchAggregationFactory.java @@ -28,7 +28,7 @@ public interface SearchAggregationFactory { * * @return The next step. */ - RangeAggregationFieldStep range(); + RangeAggregationFieldStep range(); /** * Perform aggregation in term buckets. @@ -43,7 +43,7 @@ public interface SearchAggregationFactory { * * @return The next step. */ - TermsAggregationFieldStep terms(); + TermsAggregationFieldStep terms(); /** * Extend the current factory with the given extension, diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/SearchAggregationFactoryExtension.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/SearchAggregationFactoryExtension.java index 8151649038f..e9d50113696 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/SearchAggregationFactoryExtension.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/SearchAggregationFactoryExtension.java @@ -38,6 +38,6 @@ public interface SearchAggregationFactoryExtension { * @return An optional containing the extended aggregation factory ({@link T}) in case * of success, or an empty optional otherwise. */ - Optional extendOptional(SearchAggregationFactory original, SearchAggregationDslContext dslContext); + Optional extendOptional(SearchAggregationFactory original, SearchAggregationDslContext dslContext); } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/TermsAggregationFieldStep.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/TermsAggregationFieldStep.java index 95dd0acab71..3b82991b51e 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/TermsAggregationFieldStep.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/TermsAggregationFieldStep.java @@ -7,13 +7,17 @@ package org.hibernate.search.engine.search.aggregation.dsl; import java.util.Map; +import java.util.function.Function; import org.hibernate.search.engine.search.common.ValueConvert; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; /** * The initial step in a "terms" aggregation definition, where the target field can be set. + * + * @param The type of factory used to create predicates in {@link AggregationFilterStep#filter(Function)}. */ -public interface TermsAggregationFieldStep { +public interface TermsAggregationFieldStep { /** * Target the given field in the terms aggregation. @@ -23,7 +27,7 @@ public interface TermsAggregationFieldStep { * @param The type of field values. * @return The next step. */ - default TermsAggregationOptionsStep> field(String absoluteFieldPath, Class type) { + default TermsAggregationOptionsStep> field(String absoluteFieldPath, Class type) { return field( absoluteFieldPath, type, ValueConvert.YES ); } @@ -37,7 +41,7 @@ default TermsAggregationOptionsStep> field(String absolut * See {@link ValueConvert}. * @return The next step. */ - TermsAggregationOptionsStep> field(String absoluteFieldPath, Class type, + TermsAggregationOptionsStep> field(String absoluteFieldPath, Class type, ValueConvert convert); } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/TermsAggregationOptionsStep.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/TermsAggregationOptionsStep.java index 857894080fb..67602a02bf1 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/TermsAggregationOptionsStep.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/TermsAggregationOptionsStep.java @@ -7,19 +7,23 @@ package org.hibernate.search.engine.search.aggregation.dsl; import java.util.function.Function; -import org.hibernate.search.engine.search.predicate.SearchPredicate; -import org.hibernate.search.engine.search.predicate.dsl.PredicateFinalStep; import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; /** * The final step in a "terms" aggregation definition, where optional parameters can be set. * * @param The "self" type (the actual exposed type of this step). + * @param The type of factory used to create predicates in {@link #filter(Function)}. * @param The type of the targeted field. * @param The type of result for this aggregation. */ -public interface TermsAggregationOptionsStep, F, A> - extends AggregationFinalStep { +public interface TermsAggregationOptionsStep< + S extends TermsAggregationOptionsStep, + PDF extends SearchPredicateFactory, + F, + A + > + extends AggregationFinalStep, AggregationFilterStep { /** * Order buckets by descending document count in the aggregation result. @@ -76,33 +80,4 @@ public interface TermsAggregationOptionsStep"filter" clause based on a previously-built {@link SearchPredicate}. - * - * @param searchPredicate The predicate that must match. - * @return {@code this}, for method chaining. - */ - S filter(SearchPredicate searchPredicate); - - /** - * Add a "filter" clause to be defined by the given function. - *

- * Best used with lambda expressions. - * - * @param clauseContributor A function that will use the factory passed in parameter to create a predicate, - * returning the final step in the predicate DSL. - * Should generally be a lambda expression. - * @return {@code this}, for method chaining. - */ - S filter(Function clauseContributor); - - /** - * Add a "filter" clause based on an almost-built {@link SearchPredicate}. - * - * @param dslFinalStep A final step in the predicate DSL allowing the retrieval of a {@link SearchPredicate}. - * @return {@code this}, for method chaining. - */ - default S filter(PredicateFinalStep dslFinalStep) { - return filter( dslFinalStep.toPredicate() ); - } } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/DefaultSearchAggregationFactory.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/DefaultSearchAggregationFactory.java index b83812a20db..b256163ebed 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/DefaultSearchAggregationFactory.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/DefaultSearchAggregationFactory.java @@ -12,23 +12,24 @@ import org.hibernate.search.engine.search.aggregation.dsl.SearchAggregationFactoryExtension; import org.hibernate.search.engine.search.aggregation.dsl.TermsAggregationFieldStep; import org.hibernate.search.engine.search.aggregation.dsl.spi.SearchAggregationDslContext; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; public class DefaultSearchAggregationFactory implements SearchAggregationFactory { - private final SearchAggregationDslContext dslContext; + private final SearchAggregationDslContext dslContext; - public DefaultSearchAggregationFactory(SearchAggregationDslContext dslContext) { + public DefaultSearchAggregationFactory(SearchAggregationDslContext dslContext) { this.dslContext = dslContext; } @Override - public RangeAggregationFieldStep range() { - return new RangeAggregationFieldStepImpl( dslContext ); + public RangeAggregationFieldStep range() { + return new RangeAggregationFieldStepImpl<>( dslContext ); } @Override - public TermsAggregationFieldStep terms() { - return new TermsAggregationFieldStepImpl( dslContext ); + public TermsAggregationFieldStep terms() { + return new TermsAggregationFieldStepImpl<>( dslContext ); } @Override diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/RangeAggregationFieldStepImpl.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/RangeAggregationFieldStepImpl.java index 9b43361485f..1486bb890df 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/RangeAggregationFieldStepImpl.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/RangeAggregationFieldStepImpl.java @@ -11,17 +11,18 @@ import org.hibernate.search.engine.search.aggregation.dsl.RangeAggregationFieldStep; import org.hibernate.search.engine.search.aggregation.dsl.RangeAggregationRangeStep; import org.hibernate.search.engine.search.aggregation.dsl.spi.SearchAggregationDslContext; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; import org.hibernate.search.util.common.impl.Contracts; -class RangeAggregationFieldStepImpl implements RangeAggregationFieldStep { - private final SearchAggregationDslContext dslContext; +public class RangeAggregationFieldStepImpl implements RangeAggregationFieldStep { + private final SearchAggregationDslContext dslContext; - RangeAggregationFieldStepImpl(SearchAggregationDslContext dslContext) { + public RangeAggregationFieldStepImpl(SearchAggregationDslContext dslContext) { this.dslContext = dslContext; } @Override - public RangeAggregationRangeStep field(String absoluteFieldPath, Class type, ValueConvert convert) { + public RangeAggregationRangeStep field(String absoluteFieldPath, Class type, ValueConvert convert) { Contracts.assertNotNull( absoluteFieldPath, "absoluteFieldPath" ); Contracts.assertNotNull( type, "type" ); RangeAggregationBuilder builder = diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/RangeAggregationRangeStepImpl.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/RangeAggregationRangeStepImpl.java index 3f586e19444..e86522cf15f 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/RangeAggregationRangeStepImpl.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/RangeAggregationRangeStepImpl.java @@ -18,31 +18,29 @@ import org.hibernate.search.engine.search.predicate.SearchPredicate; import org.hibernate.search.engine.search.predicate.dsl.PredicateFinalStep; import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; -import org.hibernate.search.engine.search.predicate.dsl.impl.DefaultSearchPredicateFactory; -import org.hibernate.search.engine.search.predicate.spi.SearchPredicateBuilderFactory; import org.hibernate.search.util.common.data.Range; import org.hibernate.search.util.common.impl.Contracts; -class RangeAggregationRangeStepImpl - implements RangeAggregationRangeStep, F>, - RangeAggregationRangeMoreStep, RangeAggregationRangeStepImpl, F> { +class RangeAggregationRangeStepImpl + implements RangeAggregationRangeStep, PDF, F>, + RangeAggregationRangeMoreStep, RangeAggregationRangeStepImpl, PDF, F> { private final RangeAggregationBuilder builder; - private final SearchAggregationDslContext dslContext; + private final SearchAggregationDslContext dslContext; - RangeAggregationRangeStepImpl(RangeAggregationBuilder builder, SearchAggregationDslContext dslContext) { + RangeAggregationRangeStepImpl(RangeAggregationBuilder builder, SearchAggregationDslContext dslContext) { this.builder = builder; this.dslContext = dslContext; } @Override - public RangeAggregationRangeStepImpl range(Range range) { + public RangeAggregationRangeStepImpl range(Range range) { Contracts.assertNotNull( range, "range" ); builder.range( range ); return this; } @Override - public RangeAggregationRangeStepImpl ranges(Collection> ranges) { + public RangeAggregationRangeStepImpl ranges(Collection> ranges) { Contracts.assertNotNull( ranges, "ranges" ); for ( Range range : ranges ) { range( range ); @@ -51,30 +49,19 @@ public RangeAggregationRangeStepImpl ranges(Collection filter( - Function clauseContributor) { + public RangeAggregationRangeStepImpl filter( + Function clauseContributor) { + SearchPredicate predicate = clauseContributor.apply( dslContext.getPredicateFactory() ).toPredicate(); - SearchPredicateBuilderFactory predicateBuilderFactory = dslContext.getPredicateBuilderFactory(); - SearchPredicateFactory factory = new DefaultSearchPredicateFactory<>( predicateBuilderFactory ); - SearchPredicate predicate = clauseContributor.apply( extendPredicateFactory( factory ) ).toPredicate(); - - filter( predicate ); - return this; + return filter( predicate ); } @Override - public RangeAggregationRangeStepImpl filter(SearchPredicate searchPredicate) { - SearchPredicateBuilderFactory predicateBuilderFactory = dslContext.getPredicateBuilderFactory(); - searchPredicate = (SearchPredicate) predicateBuilderFactory.toImplementation( searchPredicate ); - + public RangeAggregationRangeStepImpl filter(SearchPredicate searchPredicate) { builder.filter( searchPredicate ); return this; } - protected SearchPredicateFactory extendPredicateFactory(SearchPredicateFactory predicateFactory) { - return predicateFactory; - } - @Override public SearchAggregation, Long>> toAggregation() { return builder.build(); diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/SearchAggregationDslContextImpl.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/SearchAggregationDslContextImpl.java index d535b06d361..18e051a442e 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/SearchAggregationDslContextImpl.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/SearchAggregationDslContextImpl.java @@ -6,22 +6,30 @@ */ package org.hibernate.search.engine.search.aggregation.dsl.impl; +import org.hibernate.search.engine.common.dsl.spi.DslExtensionState; import org.hibernate.search.engine.search.aggregation.spi.SearchAggregationBuilderFactory; import org.hibernate.search.engine.search.aggregation.dsl.spi.SearchAggregationDslContext; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactoryExtension; import org.hibernate.search.engine.search.predicate.spi.SearchPredicateBuilderFactory; -public class SearchAggregationDslContextImpl> - implements SearchAggregationDslContext { - public static > SearchAggregationDslContextImpl root(F builderFactory, SearchPredicateBuilderFactory predicateFactory) { - return new SearchAggregationDslContextImpl<>( builderFactory, predicateFactory ); +public class SearchAggregationDslContextImpl, PDF extends SearchPredicateFactory> + implements SearchAggregationDslContext { + public static , PDF extends SearchPredicateFactory> + SearchAggregationDslContextImpl root(F builderFactory, PDF predicateFactory, + SearchPredicateBuilderFactory predicateBuilderFactory) { + return new SearchAggregationDslContextImpl<>( builderFactory, predicateFactory, predicateBuilderFactory ); } private final F builderFactory; - private final SearchPredicateBuilderFactory predicateFactory; + private final PDF predicateFactory; + private final SearchPredicateBuilderFactory predicateBuilderFactory; - private SearchAggregationDslContextImpl(F builderFactory, SearchPredicateBuilderFactory predicateFactory) { + private SearchAggregationDslContextImpl(F builderFactory, PDF predicateFactory, + SearchPredicateBuilderFactory predicateBuilderFactory) { this.builderFactory = builderFactory; this.predicateFactory = predicateFactory; + this.predicateBuilderFactory = predicateBuilderFactory; } @Override @@ -30,7 +38,24 @@ public F getBuilderFactory() { } @Override - public SearchPredicateBuilderFactory getPredicateBuilderFactory() { + public PDF getPredicateFactory() { return predicateFactory; } + + @Override + public SearchAggregationDslContext withExtendedPredicateFactory( + SearchPredicateFactoryExtension extension) { + return new SearchAggregationDslContextImpl<>( + builderFactory, + DslExtensionState.returnIfSupported( + extension, extension.extendOptional( predicateFactory, predicateBuilderFactory ) + ), + predicateBuilderFactory + ); + } + + @Override + public SearchPredicateBuilderFactory getPredicateBuilderFactory() { + return predicateBuilderFactory; + } } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/TermsAggregationFieldStepImpl.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/TermsAggregationFieldStepImpl.java index 8f6b6d974a5..c70b8217aa6 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/TermsAggregationFieldStepImpl.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/TermsAggregationFieldStepImpl.java @@ -13,17 +13,18 @@ import org.hibernate.search.engine.search.aggregation.dsl.TermsAggregationFieldStep; import org.hibernate.search.engine.search.aggregation.dsl.TermsAggregationOptionsStep; import org.hibernate.search.engine.search.aggregation.dsl.spi.SearchAggregationDslContext; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; import org.hibernate.search.util.common.impl.Contracts; -class TermsAggregationFieldStepImpl implements TermsAggregationFieldStep { - private final SearchAggregationDslContext dslContext; +public class TermsAggregationFieldStepImpl implements TermsAggregationFieldStep { + private final SearchAggregationDslContext dslContext; - TermsAggregationFieldStepImpl(SearchAggregationDslContext dslContext) { + public TermsAggregationFieldStepImpl(SearchAggregationDslContext dslContext) { this.dslContext = dslContext; } @Override - public TermsAggregationOptionsStep> field(String absoluteFieldPath, Class type, + public TermsAggregationOptionsStep> field(String absoluteFieldPath, Class type, ValueConvert convert) { Contracts.assertNotNull( absoluteFieldPath, "absoluteFieldPath" ); Contracts.assertNotNull( type, "type" ); diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/TermsAggregationOptionsStepImpl.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/TermsAggregationOptionsStepImpl.java index 204d93d33fc..cf84a977a2c 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/TermsAggregationOptionsStepImpl.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/impl/TermsAggregationOptionsStepImpl.java @@ -16,82 +16,70 @@ import org.hibernate.search.engine.search.predicate.SearchPredicate; import org.hibernate.search.engine.search.predicate.dsl.PredicateFinalStep; import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; -import org.hibernate.search.engine.search.predicate.dsl.impl.DefaultSearchPredicateFactory; -import org.hibernate.search.engine.search.predicate.spi.SearchPredicateBuilderFactory; import org.hibernate.search.util.common.impl.Contracts; -class TermsAggregationOptionsStepImpl - implements TermsAggregationOptionsStep, F, Map> { +class TermsAggregationOptionsStepImpl + implements TermsAggregationOptionsStep, PDF, F, Map> { private final TermsAggregationBuilder builder; - private final SearchAggregationDslContext dslContext; + private final SearchAggregationDslContext dslContext; - TermsAggregationOptionsStepImpl(TermsAggregationBuilder builder, SearchAggregationDslContext dslContext) { + TermsAggregationOptionsStepImpl(TermsAggregationBuilder builder, SearchAggregationDslContext dslContext) { this.builder = builder; this.dslContext = dslContext; } @Override - public TermsAggregationOptionsStepImpl orderByCountDescending() { + public TermsAggregationOptionsStepImpl orderByCountDescending() { builder.orderByCountDescending(); return this; } @Override - public TermsAggregationOptionsStepImpl orderByCountAscending() { + public TermsAggregationOptionsStepImpl orderByCountAscending() { builder.orderByCountAscending(); return this; } @Override - public TermsAggregationOptionsStepImpl orderByTermAscending() { + public TermsAggregationOptionsStepImpl orderByTermAscending() { builder.orderByTermAscending(); return this; } @Override - public TermsAggregationOptionsStepImpl orderByTermDescending() { + public TermsAggregationOptionsStepImpl orderByTermDescending() { builder.orderByTermDescending(); return this; } @Override - public TermsAggregationOptionsStepImpl minDocumentCount(int minDocumentCount) { + public TermsAggregationOptionsStepImpl minDocumentCount(int minDocumentCount) { Contracts.assertPositiveOrZero( minDocumentCount, "minDocumentCount" ); builder.minDocumentCount( minDocumentCount ); return this; } @Override - public TermsAggregationOptionsStepImpl maxTermCount(int maxTermCount) { + public TermsAggregationOptionsStepImpl maxTermCount(int maxTermCount) { Contracts.assertStrictlyPositive( maxTermCount, "maxTermCount" ); builder.maxTermCount( maxTermCount ); return this; } @Override - public TermsAggregationOptionsStepImpl filter( - Function clauseContributor) { - SearchPredicateBuilderFactory predicateBuilderFactory = dslContext.getPredicateBuilderFactory(); - SearchPredicateFactory factory = new DefaultSearchPredicateFactory<>( predicateBuilderFactory ); - SearchPredicate predicate = clauseContributor.apply( extendPredicateFactory( factory ) ).toPredicate(); + public TermsAggregationOptionsStepImpl filter( + Function clauseContributor) { + SearchPredicate predicate = clauseContributor.apply( dslContext.getPredicateFactory() ).toPredicate(); - filter( predicate ); - return this; + return filter( predicate ); } @Override - public TermsAggregationOptionsStepImpl filter(SearchPredicate searchPredicate) { - SearchPredicateBuilderFactory predicateBuilderFactory = dslContext.getPredicateBuilderFactory(); - searchPredicate = (SearchPredicate) predicateBuilderFactory.toImplementation( searchPredicate ); - + public TermsAggregationOptionsStepImpl filter(SearchPredicate searchPredicate) { builder.filter( searchPredicate ); return this; } - protected SearchPredicateFactory extendPredicateFactory(SearchPredicateFactory predicateFactory) { - return predicateFactory; - } - @Override public SearchAggregation> toAggregation() { return builder.build(); diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/spi/DelegatingSearchAggregationFactory.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/spi/DelegatingSearchAggregationFactory.java index 2fdb27d2237..bb00d396ff0 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/spi/DelegatingSearchAggregationFactory.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/spi/DelegatingSearchAggregationFactory.java @@ -6,32 +6,40 @@ */ package org.hibernate.search.engine.search.aggregation.dsl.spi; +import org.hibernate.search.engine.search.aggregation.dsl.ExtendedSearchAggregationFactory; import org.hibernate.search.engine.search.aggregation.dsl.RangeAggregationFieldStep; import org.hibernate.search.engine.search.aggregation.dsl.SearchAggregationFactory; import org.hibernate.search.engine.search.aggregation.dsl.SearchAggregationFactoryExtension; import org.hibernate.search.engine.search.aggregation.dsl.TermsAggregationFieldStep; +import org.hibernate.search.engine.search.aggregation.dsl.impl.RangeAggregationFieldStepImpl; +import org.hibernate.search.engine.search.aggregation.dsl.impl.TermsAggregationFieldStepImpl; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; /** * A delegating {@link SearchAggregationFactory}. *

* Mainly useful when implementing a {@link SearchAggregationFactoryExtension}. */ -public class DelegatingSearchAggregationFactory implements SearchAggregationFactory { +public class DelegatingSearchAggregationFactory + implements ExtendedSearchAggregationFactory { private final SearchAggregationFactory delegate; + private final SearchAggregationDslContext dslContext; - public DelegatingSearchAggregationFactory(SearchAggregationFactory delegate) { + public DelegatingSearchAggregationFactory(SearchAggregationFactory delegate, + SearchAggregationDslContext dslContext) { this.delegate = delegate; + this.dslContext = dslContext; } @Override - public RangeAggregationFieldStep range() { - return delegate.range(); + public RangeAggregationFieldStep range() { + return new RangeAggregationFieldStepImpl<>( dslContext ); } @Override - public TermsAggregationFieldStep terms() { - return delegate.terms(); + public TermsAggregationFieldStep terms() { + return new TermsAggregationFieldStepImpl<>( dslContext ); } @Override diff --git a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/spi/SearchAggregationDslContext.java b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/spi/SearchAggregationDslContext.java index 456ccdeaf58..ce0b5e31ea9 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/spi/SearchAggregationDslContext.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/aggregation/dsl/spi/SearchAggregationDslContext.java @@ -6,16 +6,22 @@ */ package org.hibernate.search.engine.search.aggregation.dsl.spi; +import java.util.function.Function; + import org.hibernate.search.engine.search.aggregation.spi.SearchAggregationBuilderFactory; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory; +import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactoryExtension; import org.hibernate.search.engine.search.predicate.spi.SearchPredicateBuilderFactory; +import org.hibernate.search.engine.search.sort.dsl.FieldSortOptionsStep; /** * Represents the current context in the search DSL, * including in particular the aggregation builder factory. * * @param The type of aggregation factory. + * @param The type of factory used to create predicates in {@link FieldSortOptionsStep#filter(Function)}. */ -public interface SearchAggregationDslContext> { +public interface SearchAggregationDslContext, PDF extends SearchPredicateFactory> { /** * @return The aggregation builder factory. Will always return the exact same instance. @@ -26,4 +32,17 @@ public interface SearchAggregationDslContext getPredicateBuilderFactory(); + + /** + * @return The predicate factory. Will always return the exact same instance. + */ + PDF getPredicateFactory(); + + /** + * @param extension The extension to apply to the predicate factory. + * @param The type of the new predicate factory. + * @return A new context, identical to {@code this} except for the predicate factory which is extended. + */ + SearchAggregationDslContext withExtendedPredicateFactory( + SearchPredicateFactoryExtension extension); } diff --git a/engine/src/main/java/org/hibernate/search/engine/search/query/dsl/spi/AbstractSearchQueryOptionsStep.java b/engine/src/main/java/org/hibernate/search/engine/search/query/dsl/spi/AbstractSearchQueryOptionsStep.java index 034eb9aaa04..6e3a76c1432 100644 --- a/engine/src/main/java/org/hibernate/search/engine/search/query/dsl/spi/AbstractSearchQueryOptionsStep.java +++ b/engine/src/main/java/org/hibernate/search/engine/search/query/dsl/spi/AbstractSearchQueryOptionsStep.java @@ -139,8 +139,9 @@ public S aggregation(AggregationKey key, SearchAggregation aggregation public S aggregation(AggregationKey key, Function> aggregationContributor) { SearchAggregationBuilderFactory builderFactory = indexScope.getSearchAggregationFactory(); SearchPredicateBuilderFactory predicateBuilderFactory = indexScope.getSearchPredicateBuilderFactory(); + SearchPredicateFactory predicateFactory = new DefaultSearchPredicateFactory<>( predicateBuilderFactory ); AF factory = extendAggregationFactory( new DefaultSearchAggregationFactory( - SearchAggregationDslContextImpl.root( builderFactory, predicateBuilderFactory ) + SearchAggregationDslContextImpl.root( builderFactory, predicateFactory, predicateBuilderFactory ) ) ); SearchAggregation aggregation = aggregationContributor.apply( factory ).toAggregation(); contribute( builderFactory, key, aggregation ); diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/AggregationBaseIT.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/AggregationBaseIT.java index 5e7cd940222..1d4db5e2189 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/AggregationBaseIT.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/aggregation/AggregationBaseIT.java @@ -118,7 +118,7 @@ private static class IndexBinding { private static class SupportedExtension implements SearchAggregationFactoryExtension { @Override public Optional extendOptional(SearchAggregationFactory original, - SearchAggregationDslContext dslContext) { + SearchAggregationDslContext dslContext) { Assertions.assertThat( original ).isNotNull(); Assertions.assertThat( dslContext ).isNotNull(); return Optional.of( new MyExtendedFactory( original ) ); @@ -128,7 +128,7 @@ public Optional extendOptional(SearchAggregationFactory origi private static class UnSupportedExtension implements SearchAggregationFactoryExtension { @Override public Optional extendOptional(SearchAggregationFactory original, - SearchAggregationDslContext dslContext) { + SearchAggregationDslContext dslContext) { Assertions.assertThat( original ).isNotNull(); Assertions.assertThat( dslContext ).isNotNull(); return Optional.empty();