diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/query/impl/ElasticsearchSearchQuery.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/query/impl/ElasticsearchSearchQuery.java index 7877b499dae..97795f1ac7c 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/query/impl/ElasticsearchSearchQuery.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/search/query/impl/ElasticsearchSearchQuery.java @@ -6,8 +6,10 @@ */ package org.hibernate.search.backend.elasticsearch.search.query.impl; +import java.util.Optional; import java.util.Set; +import org.hibernate.search.backend.elasticsearch.gson.impl.JsonAccessor; import org.hibernate.search.backend.elasticsearch.util.impl.URLEncodedString; import org.hibernate.search.backend.elasticsearch.orchestration.impl.ElasticsearchWorkOrchestrator; import org.hibernate.search.backend.elasticsearch.work.impl.ElasticsearchWork; @@ -77,11 +79,13 @@ public SearchResult execute() { @Override public long executeCount() { - ElasticsearchWork> work = workFactory.search( - indexNames, routingKeys, - payload, searchResultExtractor, - firstResultIndex, 0L ); - SearchResult executeNoHits = queryOrchestrator.submit( work ).join(); - return executeNoHits.getHitCount(); + JsonObject filteredPayload = new JsonObject(); + Optional querySubTree = JsonAccessor.root().property( "query" ).asObject().get( payload ); + if ( querySubTree.isPresent() ) { + filteredPayload.add( "query", querySubTree.get() ); + } + + ElasticsearchWork work = workFactory.count( indexNames, routingKeys, filteredPayload ); + return queryOrchestrator.submit( work ).join(); } } diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/work/impl/ElasticsearchStubWork.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/work/impl/ElasticsearchStubWork.java index dbb5839da79..429b8953770 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/work/impl/ElasticsearchStubWork.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/work/impl/ElasticsearchStubWork.java @@ -11,6 +11,7 @@ import org.hibernate.search.backend.elasticsearch.client.impl.ElasticsearchRequest; import org.hibernate.search.backend.elasticsearch.client.impl.ElasticsearchResponse; +import org.hibernate.search.backend.elasticsearch.gson.impl.JsonAccessor; import com.google.gson.JsonObject; @@ -32,6 +33,11 @@ public ElasticsearchStubWork(ElasticsearchRequest request, Function accessor) { + this.request = request; + this.resultFunction = response -> accessor.get( response ).get(); + } + @Override public CompletableFuture execute(ElasticsearchWorkExecutionContext context) { CompletableFuture response = context.getClient().submit( request ); diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/work/impl/ElasticsearchStubWorkFactory.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/work/impl/ElasticsearchStubWorkFactory.java index e79fd08a5ce..ef411df7b73 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/work/impl/ElasticsearchStubWorkFactory.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/work/impl/ElasticsearchStubWorkFactory.java @@ -11,6 +11,7 @@ import org.hibernate.search.backend.elasticsearch.client.impl.ElasticsearchRequest; import org.hibernate.search.backend.elasticsearch.client.impl.Paths; +import org.hibernate.search.backend.elasticsearch.gson.impl.JsonAccessor; import org.hibernate.search.backend.elasticsearch.index.settings.impl.esnative.IndexSettings; import org.hibernate.search.backend.elasticsearch.multitenancy.impl.MultiTenancyStrategy; import org.hibernate.search.backend.elasticsearch.util.impl.URLEncodedString; @@ -172,4 +173,18 @@ public ElasticsearchWork> search(Set index return new ElasticsearchStubWork<>( builder.build(), searchResultExtractor::extract ); } + + @Override + public ElasticsearchWork count(Set indexNames, Set routingKeys, JsonObject payload) { + ElasticsearchRequest.Builder builder = ElasticsearchRequest.post() + .multiValuedPathComponent( indexNames ) + .pathComponent( Paths._COUNT ) + .body( payload ); + + if ( !routingKeys.isEmpty() ) { + builder.param( "_routing", routingKeys.stream().collect( Collectors.joining( "," ) ) ); + } + + return new ElasticsearchStubWork<>( builder.build(), JsonAccessor.root().property( "count" ).asLong() ); + } } diff --git a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/work/impl/ElasticsearchWorkFactory.java b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/work/impl/ElasticsearchWorkFactory.java index a5cf4420056..41731066bfd 100644 --- a/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/work/impl/ElasticsearchWorkFactory.java +++ b/backend/elasticsearch/src/main/java/org/hibernate/search/backend/elasticsearch/work/impl/ElasticsearchWorkFactory.java @@ -42,4 +42,6 @@ ElasticsearchWork> search(Set indexNames, JsonObject payload, ElasticsearchSearchResultExtractor searchResultExtractor, Long offset, Long limit); + ElasticsearchWork count(Set indexNames, Set routingKeys, JsonObject payload); + } diff --git a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/SearchResultLoadingOrTransformingIT.java b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/SearchResultLoadingOrTransformingIT.java index a04bc729366..3cab244435c 100644 --- a/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/SearchResultLoadingOrTransformingIT.java +++ b/integrationtest/backend/tck/src/main/java/org/hibernate/search/integrationtest/backend/tck/search/SearchResultLoadingOrTransformingIT.java @@ -327,7 +327,7 @@ public void projections_hitTransformer_referencesTransformer_objectLoading() { } @Test - public void count() { + public void countQuery() { StubMappingSearchTarget searchTarget = indexManager.createSearchTarget(); SearchQuery query = searchTarget.query() @@ -345,6 +345,25 @@ public void count() { assertEquals( 1L, query.executeCount() ); } + @Test + public void countQueryWithProjection() { + StubMappingSearchTarget searchTarget = indexManager.createSearchTarget(); + + SearchQuery query = searchTarget.query() + .asProjection( f -> f.field( "string", String.class ).toProjection() ) + .predicate( f -> f.matchAll().toPredicate() ) + .build(); + + assertEquals( 2L, query.executeCount() ); + + query = searchTarget.query() + .asProjection( f -> f.field( "string", String.class ).toProjection() ) + .predicate( f -> f.match().onField( "string" ).matching( STRING_VALUE ).toPredicate() ) + .build(); + + assertEquals( 1L, query.executeCount() ); + } + private void initData() { IndexWorkPlan workPlan = indexManager.createWorkPlan(); workPlan.add( referenceProvider( MAIN_ID ), document -> {