@@ -3,11 +3,14 @@ package graphql.servlet
33import com.fasterxml.jackson.databind.ObjectMapper
44import graphql.Scalars
55import graphql.execution.ExecutionTypeInfo
6+ import graphql.execution.instrumentation.Instrumentation
7+ import graphql.execution.instrumentation.dataloader.DataLoaderDispatcherInstrumentation
68import graphql.schema.DataFetcher
79import graphql.schema.GraphQLFieldDefinition
810import graphql.schema.GraphQLNonNull
911import graphql.schema.GraphQLObjectType
1012import graphql.schema.GraphQLSchema
13+ import org.dataloader.DataLoaderRegistry
1114import org.springframework.mock.web.MockHttpServletRequest
1215import org.springframework.mock.web.MockHttpServletResponse
1316import spock.lang.Shared
@@ -38,39 +41,45 @@ class GraphQLServletSpec extends Specification {
3841 response = new MockHttpServletResponse ()
3942 }
4043
41- def createServlet (DataFetcher queryDataFetcher = { env -> env. arguments. arg }, DataFetcher mutationDataFetcher = { env -> env. arguments. arg }) {
44+ def createServlet (DataFetcher queryDataFetcher = { env -> env. arguments. arg },
45+ DataFetcher mutationDataFetcher = { env -> env. arguments. arg }) {
46+ return new SimpleGraphQLServlet (createGraphQlSchema(queryDataFetcher, mutationDataFetcher))
47+ }
48+
49+ def createGraphQlSchema (DataFetcher queryDataFetcher = { env -> env. arguments. arg },
50+ DataFetcher mutationDataFetcher = { env -> env. arguments. arg }) {
4251 GraphQLObjectType query = GraphQLObjectType . newObject()
43- .name(" Query" )
44- .field { GraphQLFieldDefinition.Builder field ->
45- field. name(" echo" )
46- field. type(Scalars.GraphQLString )
47- field. argument { argument ->
48- argument. name(" arg" )
49- argument. type(Scalars.GraphQLString )
50- }
51- field. dataFetcher(queryDataFetcher)
52+ .name(" Query" )
53+ .field { GraphQLFieldDefinition.Builder field ->
54+ field. name(" echo" )
55+ field. type(Scalars.GraphQLString )
56+ field. argument { argument ->
57+ argument. name(" arg" )
58+ argument. type(Scalars.GraphQLString )
5259 }
53- .field { GraphQLFieldDefinition.Builder field ->
54- field. name(" returnsNullIncorrectly" )
55- field. type(new GraphQLNonNull (Scalars.GraphQLString ))
56- field. dataFetcher({env -> null })
57- }
58- .build()
60+ field. dataFetcher(queryDataFetcher)
61+ }
62+ .field { GraphQLFieldDefinition.Builder field ->
63+ field. name(" returnsNullIncorrectly" )
64+ field. type(new GraphQLNonNull (Scalars.GraphQLString ))
65+ field. dataFetcher({env -> null })
66+ }
67+ .build()
5968
6069 GraphQLObjectType mutation = GraphQLObjectType . newObject()
61- .name(" Mutation" )
62- .field { field ->
63- field. name(" echo" )
64- field. type(Scalars.GraphQLString )
65- field. argument { argument ->
66- argument. name(" arg" )
67- argument. type(Scalars.GraphQLString )
68- }
69- field. dataFetcher(mutationDataFetcher)
70+ .name(" Mutation" )
71+ .field { field ->
72+ field. name(" echo" )
73+ field. type(Scalars.GraphQLString )
74+ field. argument { argument ->
75+ argument. name(" arg" )
76+ argument. type(Scalars.GraphQLString )
7077 }
71- .build()
78+ field. dataFetcher(mutationDataFetcher)
79+ }
80+ .build()
7281
73- return new SimpleGraphQLServlet ( new GraphQLSchema (query, mutation, [query, mutation]. toSet() ))
82+ return new GraphQLSchema (query, mutation, [query, mutation]. toSet())
7483 }
7584
7685 Map<String , Object > getResponseContent () {
@@ -855,4 +864,40 @@ class GraphQLServletSpec extends Specification {
855864 then :
856865 1 * mockInputStream. reset()
857866 }
867+
868+ def " getInstrumentation returns the set Instrumentation if none is provided in the context" () {
869+
870+ setup :
871+ Instrumentation expectedInstrumentation = Mock ()
872+ GraphQLContext context = new GraphQLContext (Optional . of(request), Optional . of(response))
873+ SimpleGraphQLServlet simpleGraphQLServlet = SimpleGraphQLServlet
874+ .builder(createGraphQlSchema())
875+ .withInstrumentation(expectedInstrumentation)
876+ .build();
877+ when :
878+ Instrumentation actualInstrumentation = simpleGraphQLServlet. getInstrumentation(context)
879+ then :
880+ actualInstrumentation == expectedInstrumentation;
881+ ! (actualInstrumentation instanceof DataLoaderDispatcherInstrumentation )
882+
883+ }
884+
885+ def " getInstrumentation returns the DataLoaderDispatcherInstrumentation if DataLoader provided in context" () {
886+
887+ setup :
888+ Instrumentation servletInstrumentation = Mock ()
889+ GraphQLContext context = new GraphQLContext (Optional . of(request), Optional . of(response))
890+ DataLoaderRegistry dlr = Mock ()
891+ context. setDataLoaderRegistry(Optional . of(dlr))
892+ SimpleGraphQLServlet simpleGraphQLServlet = SimpleGraphQLServlet
893+ .builder(createGraphQlSchema())
894+ .withInstrumentation(servletInstrumentation)
895+ .build();
896+ when :
897+ Instrumentation actualInstrumentation = simpleGraphQLServlet. getInstrumentation(context)
898+ then :
899+ actualInstrumentation instanceof DataLoaderDispatcherInstrumentation
900+ actualInstrumentation != servletInstrumentation;
901+
902+ }
858903}
0 commit comments