Skip to content

Commit

Permalink
Add optional feature for aggregate count query support on plural enti…
Browse files Browse the repository at this point in the history
…ty object types #488

* Add aggregate count query support

* Apply prettier formatting

* Fix regression caused by filtered embeddable attributes

* update javadoc lint configuration to none

* Add enable aggregate feature flag to builder class

* Add more aggregate test coverage

* Update group by field type to GraphQLObject scalar

* Update VariableValue scalar to wrap null values with empty Optional

* Implemented aggregate count for nested entity associations.

* Apply prettier formatting

* Refactor group aggregate count arguments
  • Loading branch information
igdianov committed May 27, 2024
1 parent 4e0a854 commit 5ef0f9f
Show file tree
Hide file tree
Showing 10 changed files with 1,369 additions and 61 deletions.
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@
<version>3.6.3</version>
<configuration>
<source>${java.version}</source>
<doclint>none</doclint>
</configuration>
<executions>
<execution>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ private static NoSuchElementException noSuchElementException(Class<?> containerC
/**
* Returns a String which capitalizes the first letter of the string.
*/
private static String capitalize(String name) {
public static String capitalize(String name) {
if (name == null || name.length() == 0) {
return name;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,28 @@
import static com.introproventures.graphql.jpa.query.schema.impl.GraphQLJpaSchemaBuilder.PAGE_TOTAL_PARAM_NAME;
import static com.introproventures.graphql.jpa.query.schema.impl.GraphQLJpaSchemaBuilder.QUERY_SELECT_PARAM_NAME;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.extractPageArgument;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.findArgument;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.getAliasOrName;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.getFields;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.getPageArgument;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.getSelectionField;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.searchByFieldName;

import com.introproventures.graphql.jpa.query.schema.JavaScalars;
import graphql.GraphQLException;
import graphql.language.Argument;
import graphql.language.EnumValue;
import graphql.language.Field;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.GraphQLScalarType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -65,6 +76,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
Optional<Field> pagesSelection = getSelectionField(rootNode, PAGE_PAGES_PARAM_NAME);
Optional<Field> totalSelection = getSelectionField(rootNode, PAGE_TOTAL_PARAM_NAME);
Optional<Field> recordsSelection = searchByFieldName(rootNode, QUERY_SELECT_PARAM_NAME);
Optional<Field> aggregateSelection = getSelectionField(rootNode, "aggregate");

final int firstResult = page.getOffset();
final int maxResults = Integer.min(page.getLimit(), defaultMaxResults); // Limit max results to avoid OoM
Expand Down Expand Up @@ -98,9 +110,155 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
pagedResult.withTotal(total);
}

aggregateSelection.ifPresent(aggregateField -> {
Map<String, Object> aggregate = new LinkedHashMap<>();

getFields(aggregateField.getSelectionSet(), "count")
.forEach(countField -> {
getCountOfArgument(countField)
.ifPresentOrElse(
argument ->
aggregate.put(
getAliasOrName(countField),
queryFactory.queryAggregateCount(argument, environment, restrictedKeys)
),
() ->
aggregate.put(
getAliasOrName(countField),
queryFactory.queryTotalCount(environment, restrictedKeys)
)
);
});

getFields(aggregateField.getSelectionSet(), "group")
.forEach(groupField -> {
var countField = getFields(groupField.getSelectionSet(), "count")
.stream()
.findFirst()
.orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField));

var countOfArgumentValue = getCountOfArgument(countField);

Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
.stream()
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
.toArray(Map.Entry[]::new);

if (groupings.length == 0) {
throw new GraphQLException("At least one field is required for aggregate group: " + groupField);
}

var resultList = queryFactory
.queryAggregateGroupByCount(
getAliasOrName(countField),
countOfArgumentValue,
environment,
restrictedKeys,
groupings
)
.stream()
.peek(map ->
Stream
.of(groupings)
.forEach(group -> {
var value = map.get(group.getKey());

Optional
.ofNullable(value)
.map(Object::getClass)
.map(JavaScalars::of)
.map(GraphQLScalarType::getCoercing)
.ifPresent(coercing -> map.put(group.getKey(), coercing.serialize(value)));
})
)
.toList();

aggregate.put(getAliasOrName(groupField), resultList);
});

aggregateField
.getSelectionSet()
.getSelections()
.stream()
.filter(Field.class::isInstance)
.map(Field.class::cast)
.filter(it -> !Arrays.asList("count", "group").contains(it.getName()))
.forEach(groupField -> {
var countField = getFields(groupField.getSelectionSet(), "count")
.stream()
.findFirst()
.orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField));

Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
.stream()
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
.toArray(Map.Entry[]::new);

if (groupings.length == 0) {
throw new GraphQLException("At least one field is required for aggregate group: " + groupField);
}

var resultList = queryFactory
.queryAggregateGroupByAssociationCount(
getAliasOrName(countField),
groupField.getName(),
environment,
restrictedKeys,
groupings
)
.stream()
.peek(map ->
Stream
.of(groupings)
.forEach(group -> {
var value = map.get(group.getKey());

Optional
.ofNullable(value)
.map(Object::getClass)
.map(JavaScalars::of)
.map(GraphQLScalarType::getCoercing)
.ifPresent(coercing -> map.put(group.getKey(), coercing.serialize(value)));
})
)
.toList();

aggregate.put(getAliasOrName(groupField), resultList);
});

pagedResult.withAggregate(aggregate);
});

return pagedResult.build();
}

static Map.Entry<String, String> groupByFieldEntry(Field selectedField) {
String key = Optional.ofNullable(selectedField.getAlias()).orElse(selectedField.getName());

String value = findArgument(selectedField, "field")
.map(Argument::getValue)
.map(EnumValue.class::cast)
.map(EnumValue::getName)
.orElseThrow(() -> new GraphQLException("group by argument is required."));

return Map.entry(key, value);
}

static Map.Entry<String, String> countFieldEntry(Field selectedField) {
String key = Optional.ofNullable(selectedField.getAlias()).orElse(selectedField.getName());

String value = getCountOfArgument(selectedField).orElse(selectedField.getName());

return Map.entry(key, value);
}

static Optional<String> getCountOfArgument(Field selectedField) {
return findArgument(selectedField, "of")
.map(Argument::getValue)
.map(EnumValue.class::cast)
.map(EnumValue::getName);
}

public int getDefaultMaxResults() {
return defaultMaxResults;
}
Expand Down
Loading

0 comments on commit 5ef0f9f

Please sign in to comment.