Skip to content

Commit

Permalink
Merge pull request #189 from graph-quilt/built-in-query-ref
Browse files Browse the repository at this point in the history
add support for variable references for built in query directives
  • Loading branch information
CNAChino committed Nov 15, 2023
2 parents ec8d92f + 05ded31 commit 60f73bd
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,38 @@
import graphql.analysis.QueryVisitorInlineFragmentEnvironment;
import graphql.analysis.QueryVisitorStub;
import graphql.language.Argument;
import graphql.language.AstTransformer;
import graphql.language.Document;
import graphql.language.Field;
import graphql.language.FragmentDefinition;
import graphql.language.FragmentSpread;
import graphql.language.InlineFragment;
import graphql.language.Node;
import graphql.language.NodeVisitorStub;
import graphql.language.OperationDefinition;
import graphql.language.Value;
import graphql.language.VariableReference;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLSchema;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;
import lombok.Getter;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.Getter;

/**
* This class provides assistance in extracting all VariableReference names used in GraphQL nodes.
*/
public class VariableDefinitionFilter {

private static AstTransformer astTransformer = new AstTransformer();

/**
* Traverses a GraphQL Node and returns all VariableReference names used in all nodes in the graph.
*
Expand Down Expand Up @@ -67,8 +75,20 @@ public Set<String> getVariableReferencesFromNode(GraphQLSchema graphQLSchema, Gr

Set<VariableReference> additionalReferences = operationDirectiveVariableReferences(operationDefinitions);

return Stream.concat(variableReferenceVisitor.getVariableReferences().stream(), additionalReferences.stream())
.map(VariableReference::getName).collect(Collectors.toSet());
Stream<VariableReference> variableReferenceStream;
if((variableReferenceVisitor.getVariableReferences().size() + additionalReferences.size()) != variables.size()) {
NodeTraverser nodeTraverser = new NodeTraverser();
astTransformer.transform(rootNode, nodeTraverser);

variableReferenceStream = Stream.of(variableReferenceVisitor.getVariableReferences(),
additionalReferences,
nodeTraverser.getVariableReferenceExtractor().getVariableReferences())
.flatMap(Collection::stream);
} else {
variableReferenceStream = Stream.concat(variableReferenceVisitor.getVariableReferences().stream(), additionalReferences.stream());
}
return variableReferenceStream.map(VariableReference::getName).collect(Collectors.toSet());

}

private Set<VariableReference> operationDirectiveVariableReferences(List<OperationDefinition> operationDefinitions) {
Expand Down Expand Up @@ -163,4 +183,19 @@ private void captureVariableReferences(Stream<Argument> arguments) {
variableReferenceExtractor.captureVariableReferences(values);
}
}

static class NodeTraverser extends NodeVisitorStub {

@Getter
private final VariableReferenceExtractor variableReferenceExtractor = new VariableReferenceExtractor();

public TraversalControl visitArgument(Argument node, TraverserContext<Node> context) {
return this.visitNode(node, context);
}

public TraversalControl visitVariableReference(VariableReference node, TraverserContext<Node> context) {
variableReferenceExtractor.captureVariableReference(node);
return this.visitValue(node, context);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ public Set<VariableReference> getVariableReferences() {

public void captureVariableReferences(List<Value> values) {
for (final Value value : values) {
doSwitch(value);
captureVariableReference(value);
}
}

public void captureVariableReference(Value value) {
doSwitch(value);
}

private void doSwitch(Value value) {
if (value instanceof ArrayValue) {
handleArrayValue((ArrayValue) value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ class VariableDefinitionFilterSpec extends Specification {
directive @field_directive_argument(arg: InputObject) on FIELD_DEFINITION
'''

private String schema2 = '''
type Query { person: Person }
type Person {
address : Address
id: String
}
type Address { city: String state: String zip: String }
'''

private GraphQLSchema graphQLSchema

private VariableDefinitionFilter variableDefinitionFilter
Expand All @@ -63,6 +74,12 @@ class VariableDefinitionFilterSpec extends Specification {
RuntimeWiring.newRuntimeWiring().build())
}

private GraphQLSchema getSchema2() {
return new SchemaGenerator()
.makeExecutableSchema(new SchemaParser().parse(schema2),
RuntimeWiring.newRuntimeWiring().build())
}

private Map<String, FragmentDefinition> getFragmentsByName(Document document) {
return document.getDefinitionsOfType(FragmentDefinition.class).stream()
.inject([:]) {map, it -> map << [(it.getName()): it]}
Expand Down Expand Up @@ -179,6 +196,62 @@ class VariableDefinitionFilterSpec extends Specification {
results.containsAll("int_arg", "string_arg")
}

def "variable References In Built in Query Directive includes"() {
given:
String query = '''
query($includeContext: Boolean!) {
consumer {
liabilities(arg: 1) @include(if: $includeContext) {
totalDebt(arg: 1)
}
income
}
}
'''

Document document = parser.parseDocument(query)
HashMap<String, Object> variables = new HashMap<>()
variables.put("includeContext", false)

when:
final Set<String> results = variableDefinitionFilter
.getVariableReferencesFromNode(graphQLSchema, graphQLSchema.getQueryType(), Collections.emptyMap(),
variables, document)

then:
results.size() == 1

results.containsAll("includeContext")
}

def "variable References In Built in Query Directive skip"() {
given:
String query = '''
query($includeContext: Boolean!) {
consumer {
liabilities(arg: 1) @skip(if: $includeContext) {
totalDebt(arg: 1)
}
income
}
}
'''

Document document = parser.parseDocument(query)
HashMap<String, Object> variables = new HashMap<>()
variables.put("includeContext", true)

when:
final Set<String> results = variableDefinitionFilter
.getVariableReferencesFromNode(graphQLSchema, graphQLSchema.getQueryType(), Collections.emptyMap(),
variables, document)

then:
results.size() == 1

results.containsAll("includeContext")
}

def "test Negative Cases"() {
given:
final String negativeTestCaseQuery = "query { consumer { liabilities { totalDebt(arg: 1234) } } }"
Expand Down

0 comments on commit 60f73bd

Please sign in to comment.