diff --git a/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java b/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java index 5eba03f5..45324124 100644 --- a/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java +++ b/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java @@ -54,7 +54,7 @@ public abstract class AbstractGraphQLHttpServlet extends HttpServlet implements private static final String[] MULTIPART_KEYS = new String[]{"operations", "graphql", "query"}; private GraphQLConfiguration configuration; - + /** * @deprecated override {@link #getConfiguration()} instead */ @@ -295,7 +295,7 @@ private void doRequestAsync(HttpServletRequest request, HttpServletResponse resp AsyncContext asyncContext = request.startAsync(request, response); HttpServletRequest asyncRequest = (HttpServletRequest) asyncContext.getRequest(); HttpServletResponse asyncResponse = (HttpServletResponse) asyncContext.getResponse(); - new Thread(() -> doRequest(asyncRequest, asyncResponse, handler, asyncContext)).start(); + configuration.getAsyncExecutor().execute(() -> doRequest(asyncRequest, asyncResponse, handler, asyncContext)); } else { doRequest(request, response, handler, null); } diff --git a/src/main/java/graphql/servlet/GraphQLConfiguration.java b/src/main/java/graphql/servlet/GraphQLConfiguration.java index 47e28d1d..1683a328 100644 --- a/src/main/java/graphql/servlet/GraphQLConfiguration.java +++ b/src/main/java/graphql/servlet/GraphQLConfiguration.java @@ -1,10 +1,12 @@ package graphql.servlet; import graphql.schema.GraphQLSchema; +import graphql.servlet.internal.GraphQLThreadFactory; import java.util.ArrayList; import java.util.List; -import java.util.Objects; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; public class GraphQLConfiguration { @@ -13,6 +15,7 @@ public class GraphQLConfiguration { private GraphQLObjectMapper objectMapper; private List listeners; private boolean asyncServletModeEnabled; + private Executor asyncExecutor; private long subscriptionTimeout; public static GraphQLConfiguration.Builder with(GraphQLSchema schema) { @@ -27,12 +30,13 @@ public static GraphQLConfiguration.Builder with(GraphQLInvocationInputFactory in return new Builder(invocationInputFactory); } - private GraphQLConfiguration(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper objectMapper, List listeners, boolean asyncServletModeEnabled, long subscriptionTimeout) { + private GraphQLConfiguration(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper objectMapper, List listeners, boolean asyncServletModeEnabled, Executor asyncExecutor, long subscriptionTimeout) { this.invocationInputFactory = invocationInputFactory; this.queryInvoker = queryInvoker; this.objectMapper = objectMapper; this.listeners = listeners; this.asyncServletModeEnabled = asyncServletModeEnabled; + this.asyncExecutor = asyncExecutor; this.subscriptionTimeout = subscriptionTimeout; } @@ -56,6 +60,10 @@ public boolean isAsyncServletModeEnabled() { return asyncServletModeEnabled; } + public Executor getAsyncExecutor() { + return asyncExecutor; + } + public void add(GraphQLServletListener listener) { listeners.add(listener); } @@ -76,6 +84,7 @@ public static class Builder { private GraphQLObjectMapper objectMapper = GraphQLObjectMapper.newBuilder().build(); private List listeners = new ArrayList<>(); private boolean asyncServletModeEnabled = false; + private Executor asyncExecutor = Executors.newCachedThreadPool(new GraphQLThreadFactory()); private long subscriptionTimeout = 0; private Builder(GraphQLInvocationInputFactory.Builder invocationInputFactoryBuilder) { @@ -112,6 +121,13 @@ public Builder with(boolean asyncServletModeEnabled) { return this; } + public Builder with(Executor asyncExecutor) { + if (asyncExecutor != null) { + this.asyncExecutor = asyncExecutor; + } + return this; + } + public Builder with(GraphQLContextBuilder contextBuilder) { this.invocationInputFactoryBuilder.withGraphQLContextBuilder(contextBuilder); return this; @@ -134,6 +150,7 @@ public GraphQLConfiguration build() { objectMapper, listeners, asyncServletModeEnabled, + asyncExecutor, subscriptionTimeout ); } diff --git a/src/main/java/graphql/servlet/internal/GraphQLThreadFactory.java b/src/main/java/graphql/servlet/internal/GraphQLThreadFactory.java new file mode 100644 index 00000000..34fdb709 --- /dev/null +++ b/src/main/java/graphql/servlet/internal/GraphQLThreadFactory.java @@ -0,0 +1,26 @@ +package graphql.servlet.internal; + +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicInteger; + +import graphql.servlet.AbstractGraphQLHttpServlet; + +/** + * {@link ThreadFactory} implementation for {@link AbstractGraphQLHttpServlet} async operations + * + * @author John Nutting + */ +public class GraphQLThreadFactory implements ThreadFactory { + + final static String NAME_PREFIX = "GraphQLServlet-"; + final AtomicInteger threadNumber = new AtomicInteger(1); + + @Override + public Thread newThread(final Runnable r) { + Thread t = new Thread(r, NAME_PREFIX + threadNumber.getAndIncrement()); + t.setDaemon(false); + t.setPriority(Thread.NORM_PRIORITY); + return t; + } + +} \ No newline at end of file diff --git a/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy index b82140c1..374a3e4d 100644 --- a/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy +++ b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy @@ -5,6 +5,7 @@ import graphql.Scalars import graphql.execution.ExecutionStepInfo import graphql.execution.instrumentation.ChainedInstrumentation import graphql.execution.instrumentation.Instrumentation +import graphql.schema.DataFetcher import graphql.execution.reactive.SingleSubscriberPublisher import graphql.schema.GraphQLNonNull import org.dataloader.DataLoaderRegistry @@ -54,6 +55,7 @@ class AbstractGraphQLHttpServletSpec extends Specification { }) request = new MockHttpServletRequest() + request.setAsyncSupported(true) request.asyncSupported = true response = new MockHttpServletResponse() } @@ -112,6 +114,18 @@ class AbstractGraphQLHttpServletSpec extends Specification { getResponseContent().data.echo == "test" } + def "async query over HTTP GET starts async request"() { + setup: + servlet = TestUtils.createServlet({ env -> env.arguments.arg },{ env -> env.arguments.arg }, true) + request.addParameter('query', 'query { echo(arg:"test") }') + + when: + servlet.doGet(request, response) + + then: + request.asyncStarted == true + } + def "query over HTTP GET with variables returns data"() { setup: request.addParameter('query', 'query Echo($arg: String) { echo(arg:$arg) }') @@ -334,6 +348,20 @@ class AbstractGraphQLHttpServletSpec extends Specification { getResponseContent().data.echo == "test" } + def "async query over HTTP POST starts async request"() { + setup: + servlet = TestUtils.createServlet({ env -> env.arguments.arg },{ env -> env.arguments.arg }, true) + request.setContent(mapper.writeValueAsBytes([ + query: 'query { echo(arg:"test") }' + ])) + + when: + servlet.doPost(request, response) + + then: + request.asyncStarted == true + } + def "query over HTTP POST body with graphql contentType returns data"() { setup: request.addHeader("Content-Type", "application/graphql") diff --git a/src/test/groovy/graphql/servlet/TestUtils.groovy b/src/test/groovy/graphql/servlet/TestUtils.groovy index defca67a..01c64716 100644 --- a/src/test/groovy/graphql/servlet/TestUtils.groovy +++ b/src/test/groovy/graphql/servlet/TestUtils.groovy @@ -5,7 +5,6 @@ import graphql.Scalars import graphql.execution.instrumentation.Instrumentation import graphql.execution.reactive.SingleSubscriberPublisher import graphql.schema.* -import org.reactivestreams.Publisher import java.util.concurrent.atomic.AtomicReference @@ -13,6 +12,7 @@ class TestUtils { static def createServlet(DataFetcher queryDataFetcher = { env -> env.arguments.arg }, DataFetcher mutationDataFetcher = { env -> env.arguments.arg }, + boolean asyncServletModeEnabled = false, DataFetcher subscriptionDataFetcher = { env -> AtomicReference> publisherRef = new AtomicReference<>(); publisherRef.set(new SingleSubscriberPublisher<>({ subscription -> @@ -23,7 +23,9 @@ class TestUtils { }) { GraphQLHttpServlet servlet = GraphQLHttpServlet.with(GraphQLConfiguration .with(createGraphQlSchema(queryDataFetcher, mutationDataFetcher, subscriptionDataFetcher)) - .with(createInstrumentedQueryInvoker()).build()) + .with(createInstrumentedQueryInvoker()) + .with(asyncServletModeEnabled) + .build()) servlet.init(null) return servlet } @@ -72,23 +74,27 @@ class TestUtils { } field.dataFetcher(mutationDataFetcher) } - .field { field -> + .field { field -> field.name("echoFile") field.type(Scalars.GraphQLString) field.argument { argument -> argument.name("file") argument.type(ApolloScalars.Upload) } - field.dataFetcher( { env -> new String(ByteStreams.toByteArray(env.arguments.file.getInputStream())) } ) + field.dataFetcher({ env -> new String(ByteStreams.toByteArray(env.arguments.file.getInputStream())) }) } - .field { field -> + .field { field -> field.name("echoFiles") field.type(GraphQLList.list(Scalars.GraphQLString)) field.argument { argument -> argument.name("files") argument.type(GraphQLList.list(GraphQLNonNull.nonNull(ApolloScalars.Upload))) } - field.dataFetcher( { env -> env.arguments.files.collect { new String(ByteStreams.toByteArray(it.getInputStream())) } } ) + field.dataFetcher({ env -> + env.arguments.files.collect { + new String(ByteStreams.toByteArray(it.getInputStream())) + } + }) } .build() @@ -107,11 +113,11 @@ class TestUtils { return GraphQLSchema.newSchema() - .query(query) - .mutation(mutation) - .subscription(subscription) - .additionalType(ApolloScalars.Upload) - .build() + .query(query) + .mutation(mutation) + .subscription(subscription) + .additionalType(ApolloScalars.Upload) + .build() } }