From 05ded31ef6d3e1291499d5271c41815abf7fba1b Mon Sep 17 00:00:00 2001 From: kmoore15 Date: Mon, 6 Nov 2023 19:33:41 -0600 Subject: [PATCH] add support for variable references for built in query directives --- .../batch/VariableDefinitionFilter.java | 41 ++++++++++- .../batch/VariableReferenceExtractor.java | 6 +- .../batch/VariableDefinitionFilterSpec.groovy | 73 +++++++++++++++++++ 3 files changed, 116 insertions(+), 4 deletions(-) diff --git a/src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java b/src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java index 30bb2836..14ffcf2d 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java +++ b/src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java @@ -6,30 +6,38 @@ import graphql.analysis.QueryVisitorInlineFragmentEnvironment; import graphql.analysis.QueryVisitorStub; import graphql.language.Argument; +import graphql.language.AstTransformer; import graphql.language.Document; import graphql.language.Field; import graphql.language.FragmentDefinition; import graphql.language.FragmentSpread; import graphql.language.InlineFragment; import graphql.language.Node; +import graphql.language.NodeVisitorStub; import graphql.language.OperationDefinition; import graphql.language.Value; import graphql.language.VariableReference; import graphql.schema.GraphQLObjectType; import graphql.schema.GraphQLSchema; +import graphql.util.TraversalControl; +import graphql.util.TraverserContext; +import lombok.Getter; + import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; -import lombok.Getter; /** * This class provides assistance in extracting all VariableReference names used in GraphQL nodes. */ public class VariableDefinitionFilter { + private static AstTransformer astTransformer = new AstTransformer(); + /** * Traverses a GraphQL Node and returns all VariableReference names used in all nodes in the graph. * @@ -67,8 +75,20 @@ public Set getVariableReferencesFromNode(GraphQLSchema graphQLSchema, Gr Set additionalReferences = operationDirectiveVariableReferences(operationDefinitions); - return Stream.concat(variableReferenceVisitor.getVariableReferences().stream(), additionalReferences.stream()) - .map(VariableReference::getName).collect(Collectors.toSet()); + Stream variableReferenceStream; + if((variableReferenceVisitor.getVariableReferences().size() + additionalReferences.size()) != variables.size()) { + NodeTraverser nodeTraverser = new NodeTraverser(); + astTransformer.transform(rootNode, nodeTraverser); + + variableReferenceStream = Stream.of(variableReferenceVisitor.getVariableReferences(), + additionalReferences, + nodeTraverser.getVariableReferenceExtractor().getVariableReferences()) + .flatMap(Collection::stream); + } else { + variableReferenceStream = Stream.concat(variableReferenceVisitor.getVariableReferences().stream(), additionalReferences.stream()); + } + return variableReferenceStream.map(VariableReference::getName).collect(Collectors.toSet()); + } private Set operationDirectiveVariableReferences(List operationDefinitions) { @@ -163,4 +183,19 @@ private void captureVariableReferences(Stream arguments) { variableReferenceExtractor.captureVariableReferences(values); } } + + static class NodeTraverser extends NodeVisitorStub { + + @Getter + private final VariableReferenceExtractor variableReferenceExtractor = new VariableReferenceExtractor(); + + public TraversalControl visitArgument(Argument node, TraverserContext context) { + return this.visitNode(node, context); + } + + public TraversalControl visitVariableReference(VariableReference node, TraverserContext context) { + variableReferenceExtractor.captureVariableReference(node); + return this.visitValue(node, context); + } + } } diff --git a/src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java b/src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java index 218530ae..ca20637c 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java +++ b/src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java @@ -19,10 +19,14 @@ public Set getVariableReferences() { public void captureVariableReferences(List values) { for (final Value value : values) { - doSwitch(value); + captureVariableReference(value); } } + public void captureVariableReference(Value value) { + doSwitch(value); + } + private void doSwitch(Value value) { if (value instanceof ArrayValue) { handleArrayValue((ArrayValue) value); diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy index d4258f94..67e11fc5 100644 --- a/src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy +++ b/src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy @@ -47,6 +47,17 @@ class VariableDefinitionFilterSpec extends Specification { directive @field_directive_argument(arg: InputObject) on FIELD_DEFINITION ''' + private String schema2 = ''' + type Query { person: Person } + + type Person { + address : Address + id: String + } + + type Address { city: String state: String zip: String } + ''' + private GraphQLSchema graphQLSchema private VariableDefinitionFilter variableDefinitionFilter @@ -63,6 +74,12 @@ class VariableDefinitionFilterSpec extends Specification { RuntimeWiring.newRuntimeWiring().build()) } + private GraphQLSchema getSchema2() { + return new SchemaGenerator() + .makeExecutableSchema(new SchemaParser().parse(schema2), + RuntimeWiring.newRuntimeWiring().build()) + } + private Map getFragmentsByName(Document document) { return document.getDefinitionsOfType(FragmentDefinition.class).stream() .inject([:]) {map, it -> map << [(it.getName()): it]} @@ -179,6 +196,62 @@ class VariableDefinitionFilterSpec extends Specification { results.containsAll("int_arg", "string_arg") } + def "variable References In Built in Query Directive includes"() { + given: + String query = ''' + query($includeContext: Boolean!) { + consumer { + liabilities(arg: 1) @include(if: $includeContext) { + totalDebt(arg: 1) + } + income + } + } + ''' + + Document document = parser.parseDocument(query) + HashMap variables = new HashMap<>() + variables.put("includeContext", false) + + when: + final Set results = variableDefinitionFilter + .getVariableReferencesFromNode(graphQLSchema, graphQLSchema.getQueryType(), Collections.emptyMap(), + variables, document) + + then: + results.size() == 1 + + results.containsAll("includeContext") + } + + def "variable References In Built in Query Directive skip"() { + given: + String query = ''' + query($includeContext: Boolean!) { + consumer { + liabilities(arg: 1) @skip(if: $includeContext) { + totalDebt(arg: 1) + } + income + } + } + ''' + + Document document = parser.parseDocument(query) + HashMap variables = new HashMap<>() + variables.put("includeContext", true) + + when: + final Set results = variableDefinitionFilter + .getVariableReferencesFromNode(graphQLSchema, graphQLSchema.getQueryType(), Collections.emptyMap(), + variables, document) + + then: + results.size() == 1 + + results.containsAll("includeContext") + } + def "test Negative Cases"() { given: final String negativeTestCaseQuery = "query { consumer { liabilities { totalDebt(arg: 1234) } } }"