Skip to content

Commit

Permalink
Merge pull request #3536 from graphql-java/20.x-backport-max-result-n…
Browse files Browse the repository at this point in the history
…odes

20.x Backport max result nodes PR 3525
  • Loading branch information
dondonz committed Mar 19, 2024
2 parents b0127e4 + 242f0da commit 67035a2
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/main/java/graphql/execution/Execution.java
Expand Up @@ -96,6 +96,7 @@ public CompletableFuture<ExecutionResult> execute(Document document, GraphQLSche
.executionInput(executionInput)
.build();

executionContext.getGraphQLContext().put(ResultNodesInfo.RESULT_NODES_INFO, executionContext.getResultNodesInfo());

InstrumentationExecutionParameters parameters = new InstrumentationExecutionParameters(
executionInput, graphQLSchema, instrumentationState
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/graphql/execution/ExecutionContext.java
Expand Up @@ -57,6 +57,7 @@ public class ExecutionContext {
private final ValueUnboxer valueUnboxer;
private final ExecutionInput executionInput;
private final Supplier<ExecutableNormalizedOperation> queryTree;
private final ResultNodesInfo resultNodesInfo = new ResultNodesInfo();

ExecutionContext(ExecutionContextBuilder builder) {
this.graphQLSchema = builder.graphQLSchema;
Expand Down Expand Up @@ -291,4 +292,8 @@ public ExecutionContext transform(Consumer<ExecutionContextBuilder> builderConsu
builderConsumer.accept(builder);
return builder.build();
}

public ResultNodesInfo getResultNodesInfo() {
return resultNodesInfo;
}
}
29 changes: 28 additions & 1 deletion src/main/java/graphql/execution/ExecutionStrategy.java
Expand Up @@ -61,6 +61,7 @@
import static graphql.execution.FieldValueInfo.CompleteValueType.NULL;
import static graphql.execution.FieldValueInfo.CompleteValueType.OBJECT;
import static graphql.execution.FieldValueInfo.CompleteValueType.SCALAR;
import static graphql.execution.ResultNodesInfo.MAX_RESULT_NODES;
import static graphql.execution.instrumentation.SimpleInstrumentationContext.nonNullCtx;
import static graphql.schema.DataFetchingEnvironmentImpl.newDataFetchingEnvironment;
import static graphql.schema.GraphQLTypeUtil.isEnum;
Expand Down Expand Up @@ -238,7 +239,23 @@ protected CompletableFuture<FetchedValue> fetchField(ExecutionContext executionC
MergedField field = parameters.getField();
GraphQLObjectType parentType = (GraphQLObjectType) parameters.getExecutionStepInfo().getUnwrappedNonNullType();
GraphQLFieldDefinition fieldDef = getFieldDef(executionContext.getGraphQLSchema(), parentType, field.getSingleField());
GraphQLCodeRegistry codeRegistry = executionContext.getGraphQLSchema().getCodeRegistry();
return fetchField(fieldDef, executionContext, parameters);
}

private CompletableFuture<FetchedValue> fetchField(GraphQLFieldDefinition fieldDef, ExecutionContext executionContext, ExecutionStrategyParameters parameters) {

int resultNodesCount = executionContext.getResultNodesInfo().incrementAndGetResultNodesCount();

Integer maxNodes;
if ((maxNodes = executionContext.getGraphQLContext().get(MAX_RESULT_NODES)) != null) {
if (resultNodesCount > maxNodes) {
executionContext.getResultNodesInfo().maxResultNodesExceeded();
return CompletableFuture.completedFuture(new FetchedValue(null, null, ImmutableKit.emptyList(), null));
}
}

MergedField field = parameters.getField();
GraphQLObjectType parentType = (GraphQLObjectType) parameters.getExecutionStepInfo().getUnwrappedNonNullType();

// if the DF (like PropertyDataFetcher) does not use the arguments or execution step info then dont build any

Expand Down Expand Up @@ -273,6 +290,7 @@ protected CompletableFuture<FetchedValue> fetchField(ExecutionContext executionC
.queryDirectives(queryDirectives)
.build();
});
GraphQLCodeRegistry codeRegistry = executionContext.getGraphQLSchema().getCodeRegistry();
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(parentType, fieldDef);

Instrumentation instrumentation = executionContext.getInstrumentation();
Expand Down Expand Up @@ -555,6 +573,15 @@ protected FieldValueInfo completeValueForList(ExecutionContext executionContext,
List<FieldValueInfo> fieldValueInfos = new ArrayList<>(size.orElse(1));
int index = 0;
for (Object item : iterableValues) {
int resultNodesCount = executionContext.getResultNodesInfo().incrementAndGetResultNodesCount();
Integer maxNodes;
if ((maxNodes = executionContext.getGraphQLContext().get(MAX_RESULT_NODES)) != null) {
if (resultNodesCount > maxNodes) {
executionContext.getResultNodesInfo().maxResultNodesExceeded();
return new FieldValueInfo(NULL, completedFuture(ExecutionResult.newExecutionResult().build()), fieldValueInfos);
}
}

ResultPath indexedPath = parameters.getPath().segment(index);

ExecutionStepInfo stepInfoForListElement = executionStepInfoFactory.newExecutionStepInfoForListElement(executionStepInfo, index);
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/graphql/execution/FetchedValue.java
Expand Up @@ -19,7 +19,7 @@ public class FetchedValue {
private final Object localContext;
private final ImmutableList<GraphQLError> errors;

private FetchedValue(Object fetchedValue, Object rawFetchedValue, ImmutableList<GraphQLError> errors, Object localContext) {
FetchedValue(Object fetchedValue, Object rawFetchedValue, ImmutableList<GraphQLError> errors, Object localContext) {
this.fetchedValue = fetchedValue;
this.rawFetchedValue = rawFetchedValue;
this.errors = errors;
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/graphql/execution/FieldValueInfo.java
Expand Up @@ -25,7 +25,7 @@ public enum CompleteValueType {
private final CompletableFuture<ExecutionResult> fieldValue;
private final List<FieldValueInfo> fieldValueInfos;

private FieldValueInfo(CompleteValueType completeValueType, CompletableFuture<ExecutionResult> fieldValue, List<FieldValueInfo> fieldValueInfos) {
FieldValueInfo(CompleteValueType completeValueType, CompletableFuture<ExecutionResult> fieldValue, List<FieldValueInfo> fieldValueInfos) {
assertNotNull(fieldValueInfos, () -> "fieldValueInfos can't be null");
this.completeValueType = completeValueType;
this.fieldValue = fieldValue;
Expand Down
55 changes: 55 additions & 0 deletions src/main/java/graphql/execution/ResultNodesInfo.java
@@ -0,0 +1,55 @@
package graphql.execution;

import graphql.Internal;
import graphql.PublicApi;

import java.util.concurrent.atomic.AtomicInteger;

/**
* This class is used to track the number of result nodes that have been created during execution.
* After each execution the GraphQLContext contains a ResultNodeInfo object under the key {@link ResultNodesInfo#RESULT_NODES_INFO}
* <p>
* The number of result can be limited (and should be for security reasons) by setting the maximum number of result nodes
* in the GraphQLContext under the key {@link ResultNodesInfo#MAX_RESULT_NODES} to an Integer
* </p>
*/
@PublicApi
public class ResultNodesInfo {

public static final String MAX_RESULT_NODES = "__MAX_RESULT_NODES";
public static final String RESULT_NODES_INFO = "__RESULT_NODES_INFO";

private volatile boolean maxResultNodesExceeded = false;
private final AtomicInteger resultNodesCount = new AtomicInteger(0);

@Internal
public int incrementAndGetResultNodesCount() {
return resultNodesCount.incrementAndGet();
}

@Internal
public void maxResultNodesExceeded() {
this.maxResultNodesExceeded = true;
}

/**
* The number of result nodes created.
* Note: this can be higher than max result nodes because
* a each node that exceeds the number of max nodes is set to null,
* but still is a result node (with value null)
*
* @return number of result nodes created
*/
public int getResultNodesCount() {
return resultNodesCount.get();
}

/**
* If the number of result nodes has exceeded the maximum allowed numbers.
*
* @return true if the number of result nodes has exceeded the maximum allowed numbers
*/
public boolean isMaxResultNodesExceeded() {
return maxResultNodesExceeded;
}
}
141 changes: 141 additions & 0 deletions src/test/groovy/graphql/GraphQLTest.groovy
Expand Up @@ -13,6 +13,7 @@ import graphql.execution.ExecutionId
import graphql.execution.ExecutionIdProvider
import graphql.execution.ExecutionStrategyParameters
import graphql.execution.MissingRootTypeException
import graphql.execution.ResultNodesInfo
import graphql.execution.SubscriptionExecutionStrategy
import graphql.execution.ValueUnboxer
import graphql.execution.instrumentation.ChainedInstrumentation
Expand Down Expand Up @@ -49,6 +50,7 @@ import static graphql.ExecutionInput.Builder
import static graphql.ExecutionInput.newExecutionInput
import static graphql.Scalars.GraphQLInt
import static graphql.Scalars.GraphQLString
import static graphql.execution.ResultNodesInfo.MAX_RESULT_NODES
import static graphql.schema.GraphQLArgument.newArgument
import static graphql.schema.GraphQLFieldDefinition.newFieldDefinition
import static graphql.schema.GraphQLInputObjectField.newInputObjectField
Expand Down Expand Up @@ -1440,4 +1442,143 @@ many lines''']
then:
!er.errors.isEmpty()
}
def "max result nodes not breached"() {
given:
def sdl = '''
type Query {
hello: String
}
'''
def df = { env -> "world" } as DataFetcher
def fetchers = ["Query": ["hello": df]]
def schema = TestUtil.schema(sdl, fetchers)
def graphQL = GraphQL.newGraphQL(schema).build()
def query = "{ hello h1: hello h2: hello h3: hello } "
def ei = newExecutionInput(query).build()
ei.getGraphQLContext().put(MAX_RESULT_NODES, 4);
when:
def er = graphQL.execute(ei)
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
then:
!rni.maxResultNodesExceeded
rni.resultNodesCount == 4
er.data == [hello: "world", h1: "world", h2: "world", h3: "world"]
}
def "max result nodes breached"() {
given:
def sdl = '''
type Query {
hello: String
}
'''
def df = { env -> "world" } as DataFetcher
def fetchers = ["Query": ["hello": df]]
def schema = TestUtil.schema(sdl, fetchers)
def graphQL = GraphQL.newGraphQL(schema).build()
def query = "{ hello h1: hello h2: hello h3: hello } "
def ei = newExecutionInput(query).build()
ei.getGraphQLContext().put(MAX_RESULT_NODES, 3);
when:
def er = graphQL.execute(ei)
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
then:
rni.maxResultNodesExceeded
rni.resultNodesCount == 4
er.data == [hello: "world", h1: "world", h2: "world", h3: null]
}
def "max result nodes breached with list"() {
given:
def sdl = '''
type Query {
hello: [String]
}
'''
def df = { env -> ["w1", "w2", "w3"] } as DataFetcher
def fetchers = ["Query": ["hello": df]]
def schema = TestUtil.schema(sdl, fetchers)
def graphQL = GraphQL.newGraphQL(schema).build()
def query = "{ hello}"
def ei = newExecutionInput(query).build()
ei.getGraphQLContext().put(MAX_RESULT_NODES, 3);
when:
def er = graphQL.execute(ei)
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
then:
rni.maxResultNodesExceeded
rni.resultNodesCount == 4
er.data == [hello: null]
}
def "max result nodes breached with list 2"() {
given:
def sdl = '''
type Query {
hello: [Foo]
}
type Foo {
name: String
}
'''
def df = { env -> [[name: "w1"], [name: "w2"], [name: "w3"]] } as DataFetcher
def fetchers = ["Query": ["hello": df]]
def schema = TestUtil.schema(sdl, fetchers)
def graphQL = GraphQL.newGraphQL(schema).build()
def query = "{ hello {name}}"
def ei = newExecutionInput(query).build()
// we have 7 result nodes overall
ei.getGraphQLContext().put(MAX_RESULT_NODES, 6);
when:
def er = graphQL.execute(ei)
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
then:
rni.resultNodesCount == 7
rni.maxResultNodesExceeded
er.data == [hello: [[name: "w1"], [name: "w2"], [name: null]]]
}
def "max result nodes not breached with list"() {
given:
def sdl = '''
type Query {
hello: [Foo]
}
type Foo {
name: String
}
'''
def df = { env -> [[name: "w1"], [name: "w2"], [name: "w3"]] } as DataFetcher
def fetchers = ["Query": ["hello": df]]
def schema = TestUtil.schema(sdl, fetchers)
def graphQL = GraphQL.newGraphQL(schema).build()
def query = "{ hello {name}}"
def ei = newExecutionInput(query).build()
// we have 7 result nodes overall
ei.getGraphQLContext().put(MAX_RESULT_NODES, 7);
when:
def er = graphQL.execute(ei)
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
then:
!rni.maxResultNodesExceeded
rni.resultNodesCount == 7
er.data == [hello: [[name: "w1"], [name: "w2"], [name: "w3"]]]
}
}

0 comments on commit 67035a2

Please sign in to comment.