Skip to content

Commit

Permalink
Add support for lazy loaded attribute fetching with JPA entity graph …
Browse files Browse the repository at this point in the history
…hint
  • Loading branch information
igdianov committed Nov 22, 2023
1 parent 17120b6 commit cdebc1b
Showing 1 changed file with 124 additions and 33 deletions.
Expand Up @@ -22,6 +22,7 @@
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.isLogicalArgument;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.isPageArgument;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.isWhereArgument;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.selections;
import static graphql.introspection.Introspection.SchemaMetaFieldDef;
import static graphql.introspection.Introspection.TypeMetaFieldDef;
import static graphql.introspection.Introspection.TypeNameMetaFieldDef;
Expand Down Expand Up @@ -60,7 +61,9 @@
import graphql.schema.GraphQLScalarType;
import graphql.schema.GraphQLSchema;
import graphql.schema.GraphQLType;
import jakarta.persistence.EntityGraph;
import jakarta.persistence.EntityManager;
import jakarta.persistence.Subgraph;
import jakarta.persistence.TypedQuery;
import jakarta.persistence.criteria.AbstractQuery;
import jakarta.persistence.criteria.CriteriaBuilder;
Expand Down Expand Up @@ -121,6 +124,7 @@ public final class GraphQLJpaQueryFactory {
private static final String DESC = "DESC";

private static final Logger logger = LoggerFactory.getLogger(GraphQLJpaQueryFactory.class);
public static final String JAKARTA_PERSISTENCE_FETCHGRAPH = "jakarta.persistence.fetchgraph";
private static Function<Object, Object> unproxy;

static {
Expand Down Expand Up @@ -175,11 +179,7 @@ private GraphQLJpaQueryFactory(Builder builder) {

public DataFetchingEnvironment getQueryEnvironment(DataFetchingEnvironment environment, MergedField queryField) {
// Override query environment with associated entity object type and select field
return DataFetchingEnvironmentBuilder
.newDataFetchingEnvironment(environment)
.fieldType(getEntityObjectType())
.mergedField(queryField)
.build();
return DataFetchingEnvironmentBuilder.newDataFetchingEnvironment(environment).fieldType(getEntityObjectType()).mergedField(queryField).build();
}

public Optional<List<Object>> getRestrictedKeys(DataFetchingEnvironment environment) {
Expand Down Expand Up @@ -260,15 +260,26 @@ protected Stream<Object> queryResultStream(DataFetchingEnvironment environment,
keys.toArray()
);

// Let's create entity graph from selection
var entityGraph = createEntityGraph(queryEnvironment);

// Let's execute query and get wrap result into stream
return getResultStream(query, fetchSize, isDistinct);
return getResultStream(query, fetchSize, isDistinct, entityGraph);
}

protected <T> Stream<T> getResultStream(TypedQuery<T> query, int fetchSize, boolean isDistinct) {
protected <T> Stream<T> getResultStream(
TypedQuery<T> query,
int fetchSize,
boolean isDistinct,
EntityGraph<?> entityGraph
) {
// Let' try reduce overhead and disable all caching
query.setHint(ORG_HIBERNATE_READ_ONLY, true);
query.setHint(ORG_HIBERNATE_FETCH_SIZE, fetchSize);
query.setHint(ORG_HIBERNATE_CACHEABLE, false);
if (entityGraph != null) {
query.setHint("jakarta.persistence.loadgraph", entityGraph);
}

if (logger.isDebugEnabled()) {
logger.info("\nGraphQL JPQL Fetch Query String:\n {}", getJPQLQueryString(query));
Expand Down Expand Up @@ -344,8 +355,7 @@ protected <T> TypedQuery<T> getQuery(
boolean isDistinct,
Object... keys
) {
DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder
.newDataFetchingEnvironment(environment)
DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder.newDataFetchingEnvironment(environment)
.localContext(Boolean.TRUE) // Fetch mode
.build();

Expand All @@ -359,8 +369,7 @@ protected TypedQuery<Long> getCountQuery(DataFetchingEnvironment environment, Fi
CriteriaQuery<Long> query = cb.createQuery(Long.class);
Root<?> root = query.from(entityType);

DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder
.newDataFetchingEnvironment(environment)
DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder.newDataFetchingEnvironment(environment)
.root(query)
.localContext(Boolean.FALSE) // Join mode
.build();
Expand Down Expand Up @@ -392,8 +401,7 @@ protected TypedQuery<Object> getKeysQuery(DataFetchingEnvironment environment, F

from.alias("root");

DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder
.newDataFetchingEnvironment(environment)
DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder.newDataFetchingEnvironment(environment)
.root(query)
.localContext(Boolean.FALSE)
.build();
Expand Down Expand Up @@ -441,7 +449,9 @@ protected Map<Object, List<Object>> loadOneToMany(DataFetchingEnvironment enviro

TypedQuery<Object[]> query = getBatchQuery(environment, field, isDefaultDistinct(), keys);

List<Object[]> resultList = getResultList(query);
var entityGraph = createEntityGraph(environment);

List<Object[]> resultList = getResultList(query, entityGraph);

if (logger.isTraceEnabled()) {
logger.trace(
Expand Down Expand Up @@ -477,7 +487,9 @@ protected Map<Object, Object> loadManyToOne(DataFetchingEnvironment environment,

TypedQuery<Object[]> query = getBatchQuery(environment, field, isDefaultDistinct(), keys);

List<Object[]> resultList = getResultList(query);
var entityGraph = createEntityGraph(environment);

List<Object[]> resultList = getResultList(query, entityGraph);

Map<Object, Object> resultMap = new LinkedHashMap<>(resultList.size());

Expand All @@ -486,7 +498,7 @@ protected Map<Object, Object> loadManyToOne(DataFetchingEnvironment environment,
return resultMap;
}

protected <T> List<T> getResultList(TypedQuery<T> query) {
protected <T> List<T> getResultList(TypedQuery<T> query, EntityGraph<?> entityGraph) {
if (logger.isDebugEnabled()) {
logger.info("\nGraphQL JPQL Batch Query String:\n {}", getJPQLQueryString(query));
}
Expand All @@ -496,6 +508,10 @@ protected <T> List<T> getResultList(TypedQuery<T> query) {
query.setHint(ORG_HIBERNATE_FETCH_SIZE, defaultFetchSize);
query.setHint(ORG_HIBERNATE_CACHEABLE, false);

if (entityGraph != null) {
query.setHint(JAKARTA_PERSISTENCE_FETCHGRAPH, entityGraph);
}

return query.getResultList();
}

Expand All @@ -513,8 +529,7 @@ protected TypedQuery<Object[]> getBatchQuery(
CriteriaQuery<Object[]> query = cb.createQuery(Object[].class);
Root<?> from = query.from(entityType);

DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder
.newDataFetchingEnvironment(environment)
DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder.newDataFetchingEnvironment(environment)
.root(query)
.localContext(Boolean.TRUE)
.build();
Expand Down Expand Up @@ -551,8 +566,7 @@ protected TypedQuery<Object> getBatchCollectionQuery(
CriteriaQuery<Object> query = cb.createQuery();
Root<?> from = query.from(entityType);

DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder
.newDataFetchingEnvironment(environment)
DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder.newDataFetchingEnvironment(environment)
.root(query)
.localContext(Boolean.TRUE)
.build();
Expand Down Expand Up @@ -585,10 +599,7 @@ protected <T> CriteriaQuery<T> getCriteriaQuery(
CriteriaQuery<T> query = cb.createQuery((Class<T>) entityType.getJavaType());
Root<?> from = query.from(entityType);

DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder
.newDataFetchingEnvironment(environment)
.root(query)
.build();
DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder.newDataFetchingEnvironment(environment).root(query).build();
from.alias(from.getModel().getName().toLowerCase());

// Build predicates from query arguments
Expand Down Expand Up @@ -965,8 +976,7 @@ protected Predicate getWherePredicate(
Map<String, Object> predicateArguments = new LinkedHashMap<>();
predicateArguments.put(logical.name(), environment.getArguments());

DataFetchingEnvironment predicateDataFetchingEnvironment = DataFetchingEnvironmentBuilder
.newDataFetchingEnvironment(environment)
DataFetchingEnvironment predicateDataFetchingEnvironment = DataFetchingEnvironmentBuilder.newDataFetchingEnvironment(environment)
.arguments(predicateArguments)
.build();
Argument predicateArgument = new Argument(logical.name(), whereValue);
Expand Down Expand Up @@ -1064,8 +1074,7 @@ protected Predicate getObjectFieldPredicate(

Join<?, ?> correlationJoin = correlation.join(objectField.getName());

DataFetchingEnvironment existsEnvironment = DataFetchingEnvironmentBuilder
.newDataFetchingEnvironment(environment)
DataFetchingEnvironment existsEnvironment = DataFetchingEnvironmentBuilder.newDataFetchingEnvironment(environment)
.root(subquery)
.build();

Expand Down Expand Up @@ -1361,8 +1370,7 @@ private PredicateFilter getPredicateFilter(
Map<String, Object> valueArguments = new LinkedHashMap<String, Object>();
valueArguments.put(objectField.getName(), environment.getArgument(argument.getName()));

DataFetchingEnvironment dataFetchingEnvironment = DataFetchingEnvironmentBuilder
.newDataFetchingEnvironment(environment)
DataFetchingEnvironment dataFetchingEnvironment = DataFetchingEnvironmentBuilder.newDataFetchingEnvironment(environment)
.arguments(valueArguments)
.build();

Expand Down Expand Up @@ -1393,8 +1401,7 @@ protected DataFetchingEnvironment wherePredicateEnvironment(
GraphQLFieldDefinition fieldDefinition,
Map<String, Object> arguments
) {
return DataFetchingEnvironmentBuilder
.newDataFetchingEnvironment(environment)
return DataFetchingEnvironmentBuilder.newDataFetchingEnvironment(environment)
.arguments(arguments)
.fieldDefinition(fieldDefinition)
.fieldType(fieldDefinition.getType())
Expand Down Expand Up @@ -1693,8 +1700,10 @@ private EmbeddableType<?> computeEmbeddableType(GraphQLObjectType objectType) {
* @return resolved GraphQL object type or null if no output type is provided
*/
private GraphQLObjectType getObjectType(DataFetchingEnvironment environment) {
GraphQLType outputType = environment.getFieldType();
return getObjectType(environment.getFieldType());
}

private GraphQLObjectType getObjectType(GraphQLType outputType) {
if (outputType instanceof GraphQLList) outputType = ((GraphQLList) outputType).getWrappedType();

if (outputType instanceof GraphQLObjectType) return (GraphQLObjectType) outputType;
Expand Down Expand Up @@ -1976,6 +1985,88 @@ private <T> T detach(T entity) {
return entity;
}

EntityGraph<?> createEntityGraph(DataFetchingEnvironment environment) {
Field root = environment.getMergedField().getSingleField();
GraphQLObjectType fieldType = getObjectType(environment);
EntityType<?> entityType = getEntityType(fieldType);

EntityGraph<?> entityGraph = entityManager.createEntityGraph(entityType.getJavaType());

var entityDescriptor = EntityIntrospector.introspect(entityType);

selections(root)
.forEach(selectedField -> {
var propertyDescriptor = entityDescriptor.getPropertyDescriptor(selectedField.getName());

propertyDescriptor
.flatMap(AttributePropertyDescriptor::getAttribute)
.ifPresent(attribute -> {
if (
isManagedType(attribute) && hasSelectionSet(selectedField) && hasNoArguments(selectedField)
) {
var attributeFieldDefinition = fieldType.getFieldDefinition(attribute.getName());
entityGraph.addAttributeNodes(attribute.getName());
addSubgraph(
selectedField,
attributeFieldDefinition,
entityGraph.addSubgraph(attribute.getName())
);
} else if (isBasic(attribute)) {
entityGraph.addAttributeNodes(attribute.getName());
}
});
});

return entityGraph;
}

void addSubgraph(Field field, GraphQLFieldDefinition fieldDefinition, Subgraph<?> subgraph) {
var fieldObjectType = getObjectType(fieldDefinition.getType());
var fieldEntityType = getEntityType(fieldObjectType);
var fieldEntityDescriptor = EntityIntrospector.introspect(fieldEntityType);

selections(field)
.forEach(selectedField -> {
var propertyDescriptor = fieldEntityDescriptor.getPropertyDescriptor(selectedField.getName());

propertyDescriptor
.flatMap(AttributePropertyDescriptor::getAttribute)
.ifPresent(attribute -> {
var selectedName = selectedField.getName();

if (
hasSelectionSet(selectedField) && isManagedType(attribute) && hasNoArguments(selectedField)
) {
var selectedFieldDefinition = fieldObjectType.getFieldDefinition(selectedName);
subgraph.addAttributeNodes(selectedName);
addSubgraph(selectedField, selectedFieldDefinition, subgraph.addSubgraph(selectedName));
} else if (isBasic(attribute)) {
subgraph.addAttributeNodes(selectedName);
}
});
});
}

static boolean isManagedType(Attribute<?, ?> attribute) {
return (
attribute.getPersistentAttributeType() != Attribute.PersistentAttributeType.EMBEDDED &&
attribute.getPersistentAttributeType() != Attribute.PersistentAttributeType.BASIC &&
attribute.getPersistentAttributeType() != Attribute.PersistentAttributeType.ELEMENT_COLLECTION
);
}

static boolean isBasic(Attribute<?, ?> attribute) {
return !isManagedType(attribute);
}

static boolean hasNoArguments(Field field) {
return !hasArguments(field);
}

static boolean hasArguments(Field field) {
return field.getArguments() != null && !field.getArguments().isEmpty();
}

/**
* Creates builder to build {@link GraphQLJpaQueryFactory}.
* @return created builder
Expand Down

0 comments on commit cdebc1b

Please sign in to comment.