diff --git a/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java b/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java index 4f080881..5eba03f5 100644 --- a/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java +++ b/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java @@ -7,13 +7,13 @@ import graphql.schema.GraphQLFieldDefinition; import graphql.servlet.internal.GraphQLRequest; import graphql.servlet.internal.VariableMapper; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import javax.servlet.AsyncContext; -import javax.servlet.Servlet; -import javax.servlet.ServletConfig; -import javax.servlet.ServletException; +import javax.servlet.*; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -30,6 +30,8 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; @@ -43,6 +45,7 @@ public abstract class AbstractGraphQLHttpServlet extends HttpServlet implements public static final Logger log = LoggerFactory.getLogger(AbstractGraphQLHttpServlet.class); public static final String APPLICATION_JSON_UTF8 = "application/json;charset=UTF-8"; + public static final String APPLICATION_EVENT_STREAM_UTF8 = "text/event-stream;charset=UTF-8"; public static final String APPLICATION_GRAPHQL = "application/graphql"; public static final int STATUS_OK = 200; public static final int STATUS_BAD_REQUEST = 400; @@ -289,7 +292,7 @@ public String executeQuery(String query) { private void doRequestAsync(HttpServletRequest request, HttpServletResponse response, HttpRequestHandler handler) { if (configuration.isAsyncServletModeEnabled()) { - AsyncContext asyncContext = request.startAsync(); + AsyncContext asyncContext = request.startAsync(request, response); HttpServletRequest asyncRequest = (HttpServletRequest) asyncContext.getRequest(); HttpServletResponse asyncResponse = (HttpServletResponse) asyncContext.getResponse(); new Thread(() -> doRequest(asyncRequest, asyncResponse, handler, asyncContext)).start(); @@ -334,9 +337,31 @@ private Optional getFileItem(Map> fileItems, String nam private void query(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, GraphQLSingleInvocationInput invocationInput, HttpServletResponse resp) throws IOException { ExecutionResult result = queryInvoker.query(invocationInput); - resp.setContentType(APPLICATION_JSON_UTF8); - resp.setStatus(STATUS_OK); - resp.getWriter().write(graphQLObjectMapper.serializeResultAsJson(result)); + if (!(result.getData() instanceof Publisher)) { + resp.setContentType(APPLICATION_JSON_UTF8); + resp.setStatus(STATUS_OK); + resp.getWriter().write(graphQLObjectMapper.serializeResultAsJson(result)); + } else { + resp.setContentType(APPLICATION_EVENT_STREAM_UTF8); + resp.setStatus(STATUS_OK); + + HttpServletRequest req = invocationInput.getContext().getHttpServletRequest().orElseThrow(IllegalStateException::new); + boolean isInAsyncThread = req.isAsyncStarted(); + AsyncContext asyncContext = isInAsyncThread ? req.getAsyncContext() : req.startAsync(req, resp); + asyncContext.setTimeout(configuration.getSubscriptionTimeout()); + AtomicReference subscriptionRef = new AtomicReference<>(); + asyncContext.addListener(new SubscriptionAsyncListener(subscriptionRef)); + ExecutionResultSubscriber subscriber = new ExecutionResultSubscriber(subscriptionRef, asyncContext, graphQLObjectMapper); + ((Publisher) result.getData()).subscribe(subscriber); + if (isInAsyncThread) { + // We need to delay the completion of async context until after the subscription has terminated, otherwise the AsyncContext is prematurely closed. + try { + subscriber.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + } } private void queryBatched(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, GraphQLBatchedInvocationInput invocationInput, HttpServletResponse resp) throws Exception { @@ -437,4 +462,74 @@ default void accept(HttpServletRequest request, HttpServletResponse response) { void handle(HttpServletRequest request, HttpServletResponse response) throws Exception; } + + private static class SubscriptionAsyncListener implements AsyncListener { + private final AtomicReference subscriptionRef; + public SubscriptionAsyncListener(AtomicReference subscriptionRef) { + this.subscriptionRef = subscriptionRef; + } + + @Override public void onComplete(AsyncEvent event) { + subscriptionRef.get().cancel(); + } + + @Override public void onTimeout(AsyncEvent event) { + subscriptionRef.get().cancel(); + } + + @Override public void onError(AsyncEvent event) { + subscriptionRef.get().cancel(); + } + + @Override public void onStartAsync(AsyncEvent event) { + } + } + + + private static class ExecutionResultSubscriber implements Subscriber { + + private final AtomicReference subscriptionRef; + private final AsyncContext asyncContext; + private final GraphQLObjectMapper graphQLObjectMapper; + private final CountDownLatch completedLatch = new CountDownLatch(1); + + public ExecutionResultSubscriber(AtomicReference subscriptionRef, AsyncContext asyncContext, GraphQLObjectMapper graphQLObjectMapper) { + this.subscriptionRef = subscriptionRef; + this.asyncContext = asyncContext; + this.graphQLObjectMapper = graphQLObjectMapper; + } + + @Override + public void onSubscribe(Subscription subscription) { + subscriptionRef.set(subscription); + subscriptionRef.get().request(1); + } + + @Override + public void onNext(ExecutionResult executionResult) { + try { + Writer writer = asyncContext.getResponse().getWriter(); + writer.write("data: " + graphQLObjectMapper.serializeResultAsJson(executionResult) + "\n\n"); + writer.flush(); + subscriptionRef.get().request(1); + } catch (IOException ignored) { + } + } + + @Override + public void onError(Throwable t) { + asyncContext.complete(); + completedLatch.countDown(); + } + + @Override + public void onComplete() { + asyncContext.complete(); + completedLatch.countDown(); + } + + public void await() throws InterruptedException { + completedLatch.await(); + } + } } diff --git a/src/main/java/graphql/servlet/GraphQLConfiguration.java b/src/main/java/graphql/servlet/GraphQLConfiguration.java index facdd33e..47e28d1d 100644 --- a/src/main/java/graphql/servlet/GraphQLConfiguration.java +++ b/src/main/java/graphql/servlet/GraphQLConfiguration.java @@ -13,6 +13,7 @@ public class GraphQLConfiguration { private GraphQLObjectMapper objectMapper; private List listeners; private boolean asyncServletModeEnabled; + private long subscriptionTimeout; public static GraphQLConfiguration.Builder with(GraphQLSchema schema) { return with(new DefaultGraphQLSchemaProvider(schema)); @@ -26,12 +27,13 @@ public static GraphQLConfiguration.Builder with(GraphQLInvocationInputFactory in return new Builder(invocationInputFactory); } - private GraphQLConfiguration(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper objectMapper, List listeners, boolean asyncServletModeEnabled) { + private GraphQLConfiguration(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper objectMapper, List listeners, boolean asyncServletModeEnabled, long subscriptionTimeout) { this.invocationInputFactory = invocationInputFactory; this.queryInvoker = queryInvoker; this.objectMapper = objectMapper; this.listeners = listeners; this.asyncServletModeEnabled = asyncServletModeEnabled; + this.subscriptionTimeout = subscriptionTimeout; } public GraphQLInvocationInputFactory getInvocationInputFactory() { @@ -62,6 +64,10 @@ public boolean remove(GraphQLServletListener listener) { return listeners.remove(listener); } + public long getSubscriptionTimeout() { + return subscriptionTimeout; + } + public static class Builder { private GraphQLInvocationInputFactory.Builder invocationInputFactoryBuilder; @@ -70,6 +76,7 @@ public static class Builder { private GraphQLObjectMapper objectMapper = GraphQLObjectMapper.newBuilder().build(); private List listeners = new ArrayList<>(); private boolean asyncServletModeEnabled = false; + private long subscriptionTimeout = 0; private Builder(GraphQLInvocationInputFactory.Builder invocationInputFactoryBuilder) { this.invocationInputFactoryBuilder = invocationInputFactoryBuilder; @@ -115,13 +122,19 @@ public Builder with(GraphQLRootObjectBuilder rootObjectBuilder) { return this; } + public Builder with(long subscriptionTimeout) { + this.subscriptionTimeout = subscriptionTimeout; + return this; + } + public GraphQLConfiguration build() { return new GraphQLConfiguration( this.invocationInputFactory != null ? this.invocationInputFactory : invocationInputFactoryBuilder.build(), queryInvoker, objectMapper, listeners, - asyncServletModeEnabled + asyncServletModeEnabled, + subscriptionTimeout ); } diff --git a/src/main/java/graphql/servlet/GraphQLSchemaProvider.java b/src/main/java/graphql/servlet/GraphQLSchemaProvider.java index cd69430e..51fd4723 100644 --- a/src/main/java/graphql/servlet/GraphQLSchemaProvider.java +++ b/src/main/java/graphql/servlet/GraphQLSchemaProvider.java @@ -34,7 +34,7 @@ static GraphQLSchema copyReadOnly(GraphQLSchema schema) { /** * @param request the http request - * @return a read-only schema based on the request (auth, etc). Should return the same schema (query-only version) as {@link #getSchema(HttpServletRequest)} for a given request. + * @return a read-only schema based on the request (auth, etc). Should return the same schema (query/subscription-only version) as {@link #getSchema(HttpServletRequest)} for a given request. */ GraphQLSchema getReadOnlySchema(HttpServletRequest request); diff --git a/src/main/java/graphql/servlet/GraphQLSubscriptionProvider.java b/src/main/java/graphql/servlet/GraphQLSubscriptionProvider.java new file mode 100644 index 00000000..7c43b8f7 --- /dev/null +++ b/src/main/java/graphql/servlet/GraphQLSubscriptionProvider.java @@ -0,0 +1,9 @@ +package graphql.servlet; + +import graphql.schema.GraphQLFieldDefinition; + +import java.util.Collection; + +public interface GraphQLSubscriptionProvider extends GraphQLProvider { + Collection getSubscriptions(); +} diff --git a/src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java b/src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java index 2cec68bf..bec5ea99 100644 --- a/src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java +++ b/src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java @@ -33,6 +33,7 @@ public class OsgiGraphQLHttpServlet extends AbstractGraphQLHttpServlet { private final List queryProviders = new ArrayList<>(); private final List mutationProviders = new ArrayList<>(); + private final List subscriptionProviders = new ArrayList<>(); private final List typesProviders = new ArrayList<>(); private final GraphQLQueryInvoker queryInvoker; @@ -151,8 +152,23 @@ private void doUpdateSchema() { } } + GraphQLObjectType subscriptionType = null; + + if (!subscriptionProviders.isEmpty()) { + final GraphQLObjectType.Builder subscriptionTypeBuilder = newObject().name("Subscription").description("Root subscription type"); + + for (GraphQLSubscriptionProvider provider : subscriptionProviders) { + provider.getSubscriptions().forEach(subscriptionTypeBuilder::field); + } + + if (!subscriptionTypeBuilder.build().getFieldDefinitions().isEmpty()) { + subscriptionType = subscriptionTypeBuilder.build(); + } + } + this.schemaProvider = new DefaultGraphQLSchemaProvider(newSchema().query(queryTypeBuilder.build()) .mutation(mutationType) + .subscription(subscriptionType) .additionalTypes(types) .build()); } @@ -165,6 +181,9 @@ public void bindProvider(GraphQLProvider provider) { if (provider instanceof GraphQLMutationProvider) { mutationProviders.add((GraphQLMutationProvider) provider); } + if (provider instanceof GraphQLSubscriptionProvider) { + subscriptionProviders.add((GraphQLSubscriptionProvider) provider); + } if (provider instanceof GraphQLTypesProvider) { typesProviders.add((GraphQLTypesProvider) provider); } @@ -177,6 +196,9 @@ public void unbindProvider(GraphQLProvider provider) { if (provider instanceof GraphQLMutationProvider) { mutationProviders.remove(provider); } + if (provider instanceof GraphQLSubscriptionProvider) { + subscriptionProviders.remove(provider); + } if (provider instanceof GraphQLTypesProvider) { typesProviders.remove(provider); } @@ -203,6 +225,16 @@ public void unbindMutationProvider(GraphQLMutationProvider mutationProvider) { updateSchema(); } + @Reference(cardinality = ReferenceCardinality.MULTIPLE, policy = ReferencePolicy.DYNAMIC) + public void bindSubscriptionProvider(GraphQLSubscriptionProvider subscriptionProvider) { + subscriptionProviders.add(subscriptionProvider); + updateSchema(); + } + public void unbindSubscriptionProvider(GraphQLSubscriptionProvider subscriptionProvider) { + subscriptionProviders.remove(subscriptionProvider); + updateSchema(); + } + @Reference(cardinality = ReferenceCardinality.MULTIPLE, policy = ReferencePolicy.DYNAMIC) public void bindTypesProvider(GraphQLTypesProvider typesProvider) { typesProviders.add(typesProvider); diff --git a/src/main/java/graphql/servlet/SimpleGraphQLHttpServlet.java b/src/main/java/graphql/servlet/SimpleGraphQLHttpServlet.java index e421923c..c07d0156 100644 --- a/src/main/java/graphql/servlet/SimpleGraphQLHttpServlet.java +++ b/src/main/java/graphql/servlet/SimpleGraphQLHttpServlet.java @@ -30,6 +30,21 @@ public SimpleGraphQLHttpServlet(GraphQLInvocationInputFactory invocationInputFac .build(); } + /** + * @deprecated use {@link GraphQLHttpServlet} instead + */ + @Deprecated + public SimpleGraphQLHttpServlet(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, List listeners, boolean asyncServletMode, long subscriptionTimeout) { + super(listeners); + this.configuration = GraphQLConfiguration.with(invocationInputFactory) + .with(queryInvoker) + .with(graphQLObjectMapper) + .with(listeners != null ? listeners : new ArrayList<>()) + .with(asyncServletMode) + .with(subscriptionTimeout) + .build(); + } + private SimpleGraphQLHttpServlet(GraphQLConfiguration configuration) { this.configuration = Objects.requireNonNull(configuration, "configuration is required"); } @@ -77,6 +92,7 @@ public static class Builder { private GraphQLObjectMapper graphQLObjectMapper = GraphQLObjectMapper.newBuilder().build(); private List listeners; private boolean asyncServletMode; + private long subscriptionTimeout; Builder(GraphQLInvocationInputFactory invocationInputFactory) { this.invocationInputFactory = invocationInputFactory; @@ -102,6 +118,11 @@ public Builder withListeners(List listeners) { return this; } + public Builder withSubscriptionTimeout(long subscriptionTimeout) { + this.subscriptionTimeout = subscriptionTimeout; + return this; + } + @Deprecated public SimpleGraphQLHttpServlet build() { GraphQLConfiguration configuration = GraphQLConfiguration.with(invocationInputFactory) @@ -109,6 +130,7 @@ public SimpleGraphQLHttpServlet build() { .with(graphQLObjectMapper) .with(listeners != null ? listeners : new ArrayList<>()) .with(asyncServletMode) + .with(subscriptionTimeout) .build(); return new SimpleGraphQLHttpServlet(configuration); } diff --git a/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy index 86534f72..b82140c1 100644 --- a/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy +++ b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy @@ -2,11 +2,10 @@ package graphql.servlet import com.fasterxml.jackson.databind.ObjectMapper import graphql.Scalars -import graphql.annotations.annotationTypes.GraphQLType import graphql.execution.ExecutionStepInfo import graphql.execution.instrumentation.ChainedInstrumentation - import graphql.execution.instrumentation.Instrumentation +import graphql.execution.reactive.SingleSubscriberPublisher import graphql.schema.GraphQLNonNull import org.dataloader.DataLoaderRegistry import org.springframework.mock.web.MockHttpServletRequest @@ -17,6 +16,9 @@ import spock.lang.Specification import javax.servlet.ServletInputStream import javax.servlet.http.HttpServletRequest +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference /** * @author Andrew Potter @@ -27,6 +29,7 @@ class AbstractGraphQLHttpServletSpec extends Specification { public static final int STATUS_BAD_REQUEST = 400 public static final int STATUS_ERROR = 500 public static final String CONTENT_TYPE_JSON_UTF8 = 'application/json;charset=UTF-8' + public static final String CONTENT_TYPE_SERVER_SENT_EVENTS = 'text/event-stream;charset=UTF-8' @Shared ObjectMapper mapper = new ObjectMapper() @@ -34,10 +37,24 @@ class AbstractGraphQLHttpServletSpec extends Specification { AbstractGraphQLHttpServlet servlet MockHttpServletRequest request MockHttpServletResponse response + CountDownLatch subscriptionLatch def setup() { - servlet = TestUtils.createServlet() + subscriptionLatch = new CountDownLatch(1) + servlet = TestUtils.createServlet({ env -> env.arguments.arg }, { env -> env.arguments.arg }, { env -> + AtomicReference> publisherRef = new AtomicReference<>() + publisherRef.set(new SingleSubscriberPublisher({ + SingleSubscriberPublisher publisher = publisherRef.get() + publisher.offer("First\n\n" + env.arguments.arg) + publisher.offer("Second\n\n" + env.arguments.arg) + publisher.noMoreData() + subscriptionLatch.countDown() + })) + return publisherRef.get() + }) + request = new MockHttpServletRequest() + request.asyncSupported = true response = new MockHttpServletResponse() } @@ -46,6 +63,17 @@ class AbstractGraphQLHttpServletSpec extends Specification { mapper.readValue(response.getContentAsByteArray(), Map) } + List> getSubscriptionResponseContent() { + String[] data = response.getContentAsString().split("\n\n") + return data.collect { dataLine -> + if (dataLine.startsWith("data: ")) { + return mapper.readValue(dataLine.substring(5), Map) + } else { + throw new IllegalStateException("Could not read event stream") + } + } + } + List> getBatchedResponseContent() { mapper.readValue(response.getContentAsByteArray(), List) } @@ -263,6 +291,26 @@ class AbstractGraphQLHttpServletSpec extends Specification { getBatchedResponseContent()[1].errors.size() == 1 } + def "subscription query over HTTP GET with variables as string returns data"() { + setup: + request.addParameter('query', 'subscription Subscription($arg: String!) { echo(arg: $arg) }') + request.addParameter('operationName', 'Subscription') + request.addParameter( 'variables', '{"arg": "test"}') + request.setAsyncSupported(true) + + when: + servlet.doGet(request, response) + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_SERVER_SENT_EVENTS + + when: + subscriptionLatch.await(1, TimeUnit.SECONDS) + then: + getSubscriptionResponseContent()[0].data.echo == "First\n\ntest" + getSubscriptionResponseContent()[1].data.echo == "Second\n\ntest" + } + def "query over HTTP POST without part or body returns bad request"() { when: servlet.doPost(request, response) @@ -903,6 +951,24 @@ class AbstractGraphQLHttpServletSpec extends Specification { getBatchedResponseContent()[1].data.echo == "test" } + def "subscription query over HTTP POST with variables as string returns data"() { + setup: + request.setContent('{"query": "subscription Subscription($arg: String!) { echo(arg: $arg) }", "operationName": "Subscription", "variables": {"arg": "test"}}'.bytes) + request.setAsyncSupported(true) + + when: + servlet.doPost(request, response) + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_SERVER_SENT_EVENTS + + when: + subscriptionLatch.await(1, TimeUnit.SECONDS) + then: + getSubscriptionResponseContent()[0].data.echo == "First\n\ntest" + getSubscriptionResponseContent()[1].data.echo == "Second\n\ntest" + } + def "errors before graphql schema execution return internal server error"() { setup: servlet = SimpleGraphQLHttpServlet.newBuilder(GraphQLInvocationInputFactory.newBuilder { diff --git a/src/test/groovy/graphql/servlet/OsgiGraphQLHttpServletSpec.groovy b/src/test/groovy/graphql/servlet/OsgiGraphQLHttpServletSpec.groovy index 577cdff9..b99064ca 100644 --- a/src/test/groovy/graphql/servlet/OsgiGraphQLHttpServletSpec.groovy +++ b/src/test/groovy/graphql/servlet/OsgiGraphQLHttpServletSpec.groovy @@ -79,4 +79,41 @@ class OsgiGraphQLHttpServletSpec extends Specification { then: servlet.getSchemaProvider().getSchema().getMutationType() == null } + + static class TestSubscriptionProvider implements GraphQLSubscriptionProvider { + @Override + Collection getSubscriptions() { + return Collections.singletonList(newFieldDefinition().name("subscription").type(GraphQLAnnotations.object(Subscription.class)).build()) + } + + + @GraphQLName("subscription") + static class Subscription { + @GraphQLField + public String field; + } + } + + def "subscription provider adds subscription objects"() { + setup: + OsgiGraphQLHttpServlet servlet = new OsgiGraphQLHttpServlet() + TestSubscriptionProvider subscriptionProvider = new TestSubscriptionProvider() + servlet.bindSubscriptionProvider(subscriptionProvider) + GraphQLFieldDefinition subscription + + when: + subscription = servlet.getSchemaProvider().getSchema().getSubscriptionType().getFieldDefinition("subscription") + then: + subscription.getType().getName() == "subscription" + + when: + subscription = servlet.getSchemaProvider().getReadOnlySchema(null).getSubscriptionType().getFieldDefinition("subscription") + then: + subscription.getType().getName() == "subscription" + + when: + servlet.unbindSubscriptionProvider(subscriptionProvider) + then: + servlet.getSchemaProvider().getSchema().getSubscriptionType() == null + } } diff --git a/src/test/groovy/graphql/servlet/TestUtils.groovy b/src/test/groovy/graphql/servlet/TestUtils.groovy index c91e8d06..defca67a 100644 --- a/src/test/groovy/graphql/servlet/TestUtils.groovy +++ b/src/test/groovy/graphql/servlet/TestUtils.groovy @@ -3,14 +3,26 @@ package graphql.servlet import com.google.common.io.ByteStreams import graphql.Scalars import graphql.execution.instrumentation.Instrumentation +import graphql.execution.reactive.SingleSubscriberPublisher import graphql.schema.* +import org.reactivestreams.Publisher + +import java.util.concurrent.atomic.AtomicReference class TestUtils { static def createServlet(DataFetcher queryDataFetcher = { env -> env.arguments.arg }, - DataFetcher mutationDataFetcher = { env -> env.arguments.arg }) { + DataFetcher mutationDataFetcher = { env -> env.arguments.arg }, + DataFetcher subscriptionDataFetcher = { env -> + AtomicReference> publisherRef = new AtomicReference<>(); + publisherRef.set(new SingleSubscriberPublisher<>({ subscription -> + publisherRef.get().offer(env.arguments.arg) + publisherRef.get().noMoreData() + })) + return publisherRef.get() + }) { GraphQLHttpServlet servlet = GraphQLHttpServlet.with(GraphQLConfiguration - .with(createGraphQlSchema(queryDataFetcher, mutationDataFetcher)) + .with(createGraphQlSchema(queryDataFetcher, mutationDataFetcher, subscriptionDataFetcher)) .with(createInstrumentedQueryInvoker()).build()) servlet.init(null) return servlet @@ -22,7 +34,15 @@ class TestUtils { } static def createGraphQlSchema(DataFetcher queryDataFetcher = { env -> env.arguments.arg }, - DataFetcher mutationDataFetcher = { env -> env.arguments.arg }) { + DataFetcher mutationDataFetcher = { env -> env.arguments.arg }, + DataFetcher subscriptionDataFetcher = { env -> + AtomicReference> publisherRef = new AtomicReference<>(); + publisherRef.set(new SingleSubscriberPublisher<>({ subscription -> + publisherRef.get().offer(env.arguments.arg) + publisherRef.get().noMoreData() + })) + return publisherRef.get() + }) { GraphQLObjectType query = GraphQLObjectType.newObject() .name("Query") .field { GraphQLFieldDefinition.Builder field -> @@ -72,9 +92,24 @@ class TestUtils { } .build() + GraphQLObjectType subscription = GraphQLObjectType.newObject() + .name("Subscription") + .field { field -> + field.name("echo") + field.type(Scalars.GraphQLString) + field.argument { argument -> + argument.name("arg") + argument.type(Scalars.GraphQLString) + } + field.dataFetcher(subscriptionDataFetcher) + } + .build() + + return GraphQLSchema.newSchema() .query(query) .mutation(mutation) + .subscription(subscription) .additionalType(ApolloScalars.Upload) .build() }