Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breaking out query complexity into its own class #3254

Merged
merged 1 commit into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static graphql.Assert.assertNotNull;
import static graphql.execution.instrumentation.InstrumentationState.ofState;
import static graphql.execution.instrumentation.SimpleInstrumentationContext.noOp;
import static java.util.Optional.ofNullable;

/**
* Prevents execution if the query complexity is greater than the specified maxComplexity.
Expand Down Expand Up @@ -101,21 +98,8 @@ public InstrumentationState createState(InstrumentationCreateStateParameters par
@Override
public @Nullable InstrumentationContext<ExecutionResult> beginExecuteOperation(InstrumentationExecuteOperationParameters instrumentationExecuteOperationParameters, InstrumentationState rawState) {
State state = ofState(rawState);
QueryTraverser queryTraverser = newQueryTraverser(instrumentationExecuteOperationParameters.getExecutionContext());

Map<QueryVisitorFieldEnvironment, Integer> valuesByParent = new LinkedHashMap<>();
queryTraverser.visitPostOrder(new QueryVisitorStub() {
@Override
public void visitField(QueryVisitorFieldEnvironment env) {
int childComplexity = valuesByParent.getOrDefault(env, 0);
int value = calculateComplexity(env, childComplexity);

valuesByParent.compute(env.getParentEnvironment(), (key, oldValue) ->
ofNullable(oldValue).orElse(0) + value
);
}
});
int totalComplexity = valuesByParent.getOrDefault(null, 0);
QueryComplexityCalculator queryComplexityCalculator = newQueryComplexityCalculator(instrumentationExecuteOperationParameters.getExecutionContext());
int totalComplexity = queryComplexityCalculator.calculate();
if (log.isDebugEnabled()) {
log.debug("Query complexity: {}", totalComplexity);
}
Expand All @@ -133,6 +117,16 @@ public void visitField(QueryVisitorFieldEnvironment env) {
return noOp();
}

private QueryComplexityCalculator newQueryComplexityCalculator(ExecutionContext executionContext) {
return QueryComplexityCalculator.newCalculator()
.fieldComplexityCalculator(fieldComplexityCalculator)
.schema(executionContext.getGraphQLSchema())
.document(executionContext.getDocument())
.operationName(executionContext.getExecutionInput().getOperationName())
.variables(executionContext.getCoercedVariables())
.build();
}

/**
* Called to generate your own error message or custom exception class
*
Expand All @@ -145,37 +139,6 @@ protected AbortExecutionException mkAbortException(int totalComplexity, int maxC
return new AbortExecutionException("maximum query complexity exceeded " + totalComplexity + " > " + maxComplexity);
}

QueryTraverser newQueryTraverser(ExecutionContext executionContext) {
return QueryTraverser.newQueryTraverser()
.schema(executionContext.getGraphQLSchema())
.document(executionContext.getDocument())
.operationName(executionContext.getExecutionInput().getOperationName())
.coercedVariables(executionContext.getCoercedVariables())
.build();
}

private int calculateComplexity(QueryVisitorFieldEnvironment queryVisitorFieldEnvironment, int childComplexity) {
if (queryVisitorFieldEnvironment.isTypeNameIntrospectionField()) {
return 0;
}
FieldComplexityEnvironment fieldComplexityEnvironment = convertEnv(queryVisitorFieldEnvironment);
return fieldComplexityCalculator.calculate(fieldComplexityEnvironment, childComplexity);
}

private FieldComplexityEnvironment convertEnv(QueryVisitorFieldEnvironment queryVisitorFieldEnvironment) {
FieldComplexityEnvironment parentEnv = null;
if (queryVisitorFieldEnvironment.getParentEnvironment() != null) {
parentEnv = convertEnv(queryVisitorFieldEnvironment.getParentEnvironment());
}
return new FieldComplexityEnvironment(
queryVisitorFieldEnvironment.getField(),
queryVisitorFieldEnvironment.getFieldDefinition(),
queryVisitorFieldEnvironment.getFieldsContainer(),
queryVisitorFieldEnvironment.getArguments(),
parentEnv
);
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to the new class


private static class State implements InstrumentationState {
AtomicReference<InstrumentationValidationParameters> instrumentationValidationParameters = new AtomicReference<>();
}
Expand Down
134 changes: 134 additions & 0 deletions src/main/java/graphql/analysis/QueryComplexityCalculator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package graphql.analysis;

import graphql.PublicApi;
import graphql.execution.CoercedVariables;
import graphql.language.Document;
import graphql.schema.GraphQLSchema;

import java.util.LinkedHashMap;
import java.util.Map;

import static graphql.Assert.assertNotNull;
import static java.util.Optional.ofNullable;

/**
* This can calculate the complexity of an operation using the specified {@link FieldComplexityCalculator} you pass
* into it.
*/
@PublicApi
public class QueryComplexityCalculator {

private final FieldComplexityCalculator fieldComplexityCalculator;
private final GraphQLSchema schema;
private final Document document;
private final String operationName;
private final CoercedVariables variables;

public QueryComplexityCalculator(Builder builder) {
this.fieldComplexityCalculator = assertNotNull(builder.fieldComplexityCalculator, () -> "fieldComplexityCalculator can't be null");
this.schema = assertNotNull(builder.schema, () -> "schema can't be null");
this.document = assertNotNull(builder.document, () -> "document can't be null");
this.variables = assertNotNull(builder.variables, () -> "variables can't be null");
this.operationName = builder.operationName;
}


public int calculate() {
Map<QueryVisitorFieldEnvironment, Integer> valuesByParent = calculateByParents();
return valuesByParent.getOrDefault(null, 0);
}

/**
* @return a map that shows the field complexity for each field level in the operation
*/
public Map<QueryVisitorFieldEnvironment, Integer> calculateByParents() {
QueryTraverser queryTraverser = QueryTraverser.newQueryTraverser()
.schema(this.schema)
.document(this.document)
.operationName(this.operationName)
.coercedVariables(this.variables)
.build();


Map<QueryVisitorFieldEnvironment, Integer> valuesByParent = new LinkedHashMap<>();
queryTraverser.visitPostOrder(new QueryVisitorStub() {
@Override
public void visitField(QueryVisitorFieldEnvironment env) {
int childComplexity = valuesByParent.getOrDefault(env, 0);
int value = calculateComplexity(env, childComplexity);

QueryVisitorFieldEnvironment parentEnvironment = env.getParentEnvironment();
valuesByParent.compute(parentEnvironment, (key, oldValue) -> {
Integer currentValue = ofNullable(oldValue).orElse(0);
return currentValue + value;
}
);
}
});

return valuesByParent;
}

private int calculateComplexity(QueryVisitorFieldEnvironment queryVisitorFieldEnvironment, int childComplexity) {
if (queryVisitorFieldEnvironment.isTypeNameIntrospectionField()) {
return 0;
}
FieldComplexityEnvironment fieldComplexityEnvironment = convertEnv(queryVisitorFieldEnvironment);
return fieldComplexityCalculator.calculate(fieldComplexityEnvironment, childComplexity);
}

private FieldComplexityEnvironment convertEnv(QueryVisitorFieldEnvironment queryVisitorFieldEnvironment) {
FieldComplexityEnvironment parentEnv = null;
if (queryVisitorFieldEnvironment.getParentEnvironment() != null) {
parentEnv = convertEnv(queryVisitorFieldEnvironment.getParentEnvironment());
}
return new FieldComplexityEnvironment(
queryVisitorFieldEnvironment.getField(),
queryVisitorFieldEnvironment.getFieldDefinition(),
queryVisitorFieldEnvironment.getFieldsContainer(),
queryVisitorFieldEnvironment.getArguments(),
parentEnv
);
}

public static Builder newCalculator() {
return new Builder();
}

public static class Builder {
private FieldComplexityCalculator fieldComplexityCalculator;
private GraphQLSchema schema;
private Document document;
private String operationName;
private CoercedVariables variables = CoercedVariables.emptyVariables();

public Builder schema(GraphQLSchema graphQLSchema) {
this.schema = graphQLSchema;
return this;
}

public Builder fieldComplexityCalculator(FieldComplexityCalculator complexityCalculator) {
this.fieldComplexityCalculator = complexityCalculator;
return this;
}

public Builder document(Document document) {
this.document = document;
return this;
}

public Builder operationName(String operationName) {
this.operationName = operationName;
return this;
}

public Builder variables(CoercedVariables variables) {
this.variables = variables;
return this;
}

public QueryComplexityCalculator build() {
return new QueryComplexityCalculator(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package graphql.analysis


import graphql.TestUtil
import graphql.execution.CoercedVariables
import graphql.language.Document
import graphql.parser.Parser
import spock.lang.Specification

class QueryComplexityCalculatorTest extends Specification {

Document createQuery(String query) {
Parser parser = new Parser()
parser.parseDocument(query)
}

def "can calculator complexity"() {
given:
def schema = TestUtil.schema("""
type Query{
foo: Foo
bar: String
}
type Foo {
scalar: String
foo: Foo
}
""")
def query = createQuery("""
query q {
f2: foo {scalar foo{scalar}}
f1: foo { foo {foo {foo {foo{foo{scalar}}}}}} }
""")


when:
FieldComplexityCalculator fieldComplexityCalculator = new FieldComplexityCalculator() {
@Override
int calculate(FieldComplexityEnvironment environment, int childComplexity) {
return environment.getField().name.startsWith("foo") ? 10 : 1
}
}
QueryComplexityCalculator calculator = QueryComplexityCalculator.newCalculator()
.fieldComplexityCalculator(fieldComplexityCalculator).schema(schema).document(query).variables(CoercedVariables.emptyVariables())
.build()
def complexityScore = calculator.calculate()
then:
complexityScore == 20


}
}