Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 103 additions & 8 deletions src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import graphql.schema.GraphQLFieldDefinition;
import graphql.servlet.internal.GraphQLRequest;
import graphql.servlet.internal.VariableMapper;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.AsyncContext;
import javax.servlet.Servlet;
import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.*;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
Expand All @@ -30,6 +30,8 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
Expand All @@ -43,6 +45,7 @@ public abstract class AbstractGraphQLHttpServlet extends HttpServlet implements
public static final Logger log = LoggerFactory.getLogger(AbstractGraphQLHttpServlet.class);

public static final String APPLICATION_JSON_UTF8 = "application/json;charset=UTF-8";
public static final String APPLICATION_EVENT_STREAM_UTF8 = "text/event-stream;charset=UTF-8";
public static final String APPLICATION_GRAPHQL = "application/graphql";
public static final int STATUS_OK = 200;
public static final int STATUS_BAD_REQUEST = 400;
Expand Down Expand Up @@ -289,7 +292,7 @@ public String executeQuery(String query) {

private void doRequestAsync(HttpServletRequest request, HttpServletResponse response, HttpRequestHandler handler) {
if (configuration.isAsyncServletModeEnabled()) {
AsyncContext asyncContext = request.startAsync();
AsyncContext asyncContext = request.startAsync(request, response);
HttpServletRequest asyncRequest = (HttpServletRequest) asyncContext.getRequest();
HttpServletResponse asyncResponse = (HttpServletResponse) asyncContext.getResponse();
new Thread(() -> doRequest(asyncRequest, asyncResponse, handler, asyncContext)).start();
Expand Down Expand Up @@ -334,9 +337,31 @@ private Optional<Part> getFileItem(Map<String, List<Part>> fileItems, String nam
private void query(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, GraphQLSingleInvocationInput invocationInput, HttpServletResponse resp) throws IOException {
ExecutionResult result = queryInvoker.query(invocationInput);

resp.setContentType(APPLICATION_JSON_UTF8);
resp.setStatus(STATUS_OK);
resp.getWriter().write(graphQLObjectMapper.serializeResultAsJson(result));
if (!(result.getData() instanceof Publisher)) {
resp.setContentType(APPLICATION_JSON_UTF8);
resp.setStatus(STATUS_OK);
resp.getWriter().write(graphQLObjectMapper.serializeResultAsJson(result));
} else {
resp.setContentType(APPLICATION_EVENT_STREAM_UTF8);
resp.setStatus(STATUS_OK);

HttpServletRequest req = invocationInput.getContext().getHttpServletRequest().orElseThrow(IllegalStateException::new);
boolean isInAsyncThread = req.isAsyncStarted();
AsyncContext asyncContext = isInAsyncThread ? req.getAsyncContext() : req.startAsync(req, resp);
asyncContext.setTimeout(configuration.getSubscriptionTimeout());
AtomicReference<Subscription> subscriptionRef = new AtomicReference<>();
asyncContext.addListener(new SubscriptionAsyncListener(subscriptionRef));
ExecutionResultSubscriber subscriber = new ExecutionResultSubscriber(subscriptionRef, asyncContext, graphQLObjectMapper);
((Publisher<ExecutionResult>) result.getData()).subscribe(subscriber);
if (isInAsyncThread) {
// We need to delay the completion of async context until after the subscription has terminated, otherwise the AsyncContext is prematurely closed.
try {
subscriber.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
}

private void queryBatched(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, GraphQLBatchedInvocationInput invocationInput, HttpServletResponse resp) throws Exception {
Expand Down Expand Up @@ -437,4 +462,74 @@ default void accept(HttpServletRequest request, HttpServletResponse response) {

void handle(HttpServletRequest request, HttpServletResponse response) throws Exception;
}

private static class SubscriptionAsyncListener implements AsyncListener {
private final AtomicReference<Subscription> subscriptionRef;
public SubscriptionAsyncListener(AtomicReference<Subscription> subscriptionRef) {
this.subscriptionRef = subscriptionRef;
}

@Override public void onComplete(AsyncEvent event) {
subscriptionRef.get().cancel();
}

@Override public void onTimeout(AsyncEvent event) {
subscriptionRef.get().cancel();
}

@Override public void onError(AsyncEvent event) {
subscriptionRef.get().cancel();
}

@Override public void onStartAsync(AsyncEvent event) {
}
}


private static class ExecutionResultSubscriber implements Subscriber<ExecutionResult> {

private final AtomicReference<Subscription> subscriptionRef;
private final AsyncContext asyncContext;
private final GraphQLObjectMapper graphQLObjectMapper;
private final CountDownLatch completedLatch = new CountDownLatch(1);

public ExecutionResultSubscriber(AtomicReference<Subscription> subscriptionRef, AsyncContext asyncContext, GraphQLObjectMapper graphQLObjectMapper) {
this.subscriptionRef = subscriptionRef;
this.asyncContext = asyncContext;
this.graphQLObjectMapper = graphQLObjectMapper;
}

@Override
public void onSubscribe(Subscription subscription) {
subscriptionRef.set(subscription);
subscriptionRef.get().request(1);
}

@Override
public void onNext(ExecutionResult executionResult) {
try {
Writer writer = asyncContext.getResponse().getWriter();
writer.write("data: " + graphQLObjectMapper.serializeResultAsJson(executionResult) + "\n\n");
writer.flush();
subscriptionRef.get().request(1);
} catch (IOException ignored) {
}
}

@Override
public void onError(Throwable t) {
asyncContext.complete();
completedLatch.countDown();
}

@Override
public void onComplete() {
asyncContext.complete();
completedLatch.countDown();
}

public void await() throws InterruptedException {
completedLatch.await();
}
}
}
17 changes: 15 additions & 2 deletions src/main/java/graphql/servlet/GraphQLConfiguration.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public class GraphQLConfiguration {
private GraphQLObjectMapper objectMapper;
private List<GraphQLServletListener> listeners;
private boolean asyncServletModeEnabled;
private long subscriptionTimeout;

public static GraphQLConfiguration.Builder with(GraphQLSchema schema) {
return with(new DefaultGraphQLSchemaProvider(schema));
Expand All @@ -26,12 +27,13 @@ public static GraphQLConfiguration.Builder with(GraphQLInvocationInputFactory in
return new Builder(invocationInputFactory);
}

private GraphQLConfiguration(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper objectMapper, List<GraphQLServletListener> listeners, boolean asyncServletModeEnabled) {
private GraphQLConfiguration(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper objectMapper, List<GraphQLServletListener> listeners, boolean asyncServletModeEnabled, long subscriptionTimeout) {
this.invocationInputFactory = invocationInputFactory;
this.queryInvoker = queryInvoker;
this.objectMapper = objectMapper;
this.listeners = listeners;
this.asyncServletModeEnabled = asyncServletModeEnabled;
this.subscriptionTimeout = subscriptionTimeout;
}

public GraphQLInvocationInputFactory getInvocationInputFactory() {
Expand Down Expand Up @@ -62,6 +64,10 @@ public boolean remove(GraphQLServletListener listener) {
return listeners.remove(listener);
}

public long getSubscriptionTimeout() {
return subscriptionTimeout;
}

public static class Builder {

private GraphQLInvocationInputFactory.Builder invocationInputFactoryBuilder;
Expand All @@ -70,6 +76,7 @@ public static class Builder {
private GraphQLObjectMapper objectMapper = GraphQLObjectMapper.newBuilder().build();
private List<GraphQLServletListener> listeners = new ArrayList<>();
private boolean asyncServletModeEnabled = false;
private long subscriptionTimeout = 0;

private Builder(GraphQLInvocationInputFactory.Builder invocationInputFactoryBuilder) {
this.invocationInputFactoryBuilder = invocationInputFactoryBuilder;
Expand Down Expand Up @@ -115,13 +122,19 @@ public Builder with(GraphQLRootObjectBuilder rootObjectBuilder) {
return this;
}

public Builder with(long subscriptionTimeout) {
this.subscriptionTimeout = subscriptionTimeout;
return this;
}

public GraphQLConfiguration build() {
return new GraphQLConfiguration(
this.invocationInputFactory != null ? this.invocationInputFactory : invocationInputFactoryBuilder.build(),
queryInvoker,
objectMapper,
listeners,
asyncServletModeEnabled
asyncServletModeEnabled,
subscriptionTimeout
);
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/java/graphql/servlet/GraphQLSchemaProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ static GraphQLSchema copyReadOnly(GraphQLSchema schema) {

/**
* @param request the http request
* @return a read-only schema based on the request (auth, etc). Should return the same schema (query-only version) as {@link #getSchema(HttpServletRequest)} for a given request.
* @return a read-only schema based on the request (auth, etc). Should return the same schema (query/subscription-only version) as {@link #getSchema(HttpServletRequest)} for a given request.
*/
GraphQLSchema getReadOnlySchema(HttpServletRequest request);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package graphql.servlet;

import graphql.schema.GraphQLFieldDefinition;

import java.util.Collection;

public interface GraphQLSubscriptionProvider extends GraphQLProvider {
Collection<GraphQLFieldDefinition> getSubscriptions();
}
32 changes: 32 additions & 0 deletions src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class OsgiGraphQLHttpServlet extends AbstractGraphQLHttpServlet {

private final List<GraphQLQueryProvider> queryProviders = new ArrayList<>();
private final List<GraphQLMutationProvider> mutationProviders = new ArrayList<>();
private final List<GraphQLSubscriptionProvider> subscriptionProviders = new ArrayList<>();
private final List<GraphQLTypesProvider> typesProviders = new ArrayList<>();

private final GraphQLQueryInvoker queryInvoker;
Expand Down Expand Up @@ -151,8 +152,23 @@ private void doUpdateSchema() {
}
}

GraphQLObjectType subscriptionType = null;

if (!subscriptionProviders.isEmpty()) {
final GraphQLObjectType.Builder subscriptionTypeBuilder = newObject().name("Subscription").description("Root subscription type");

for (GraphQLSubscriptionProvider provider : subscriptionProviders) {
provider.getSubscriptions().forEach(subscriptionTypeBuilder::field);
}

if (!subscriptionTypeBuilder.build().getFieldDefinitions().isEmpty()) {
subscriptionType = subscriptionTypeBuilder.build();
}
}

this.schemaProvider = new DefaultGraphQLSchemaProvider(newSchema().query(queryTypeBuilder.build())
.mutation(mutationType)
.subscription(subscriptionType)
.additionalTypes(types)
.build());
}
Expand All @@ -165,6 +181,9 @@ public void bindProvider(GraphQLProvider provider) {
if (provider instanceof GraphQLMutationProvider) {
mutationProviders.add((GraphQLMutationProvider) provider);
}
if (provider instanceof GraphQLSubscriptionProvider) {
subscriptionProviders.add((GraphQLSubscriptionProvider) provider);
}
if (provider instanceof GraphQLTypesProvider) {
typesProviders.add((GraphQLTypesProvider) provider);
}
Expand All @@ -177,6 +196,9 @@ public void unbindProvider(GraphQLProvider provider) {
if (provider instanceof GraphQLMutationProvider) {
mutationProviders.remove(provider);
}
if (provider instanceof GraphQLSubscriptionProvider) {
subscriptionProviders.remove(provider);
}
if (provider instanceof GraphQLTypesProvider) {
typesProviders.remove(provider);
}
Expand All @@ -203,6 +225,16 @@ public void unbindMutationProvider(GraphQLMutationProvider mutationProvider) {
updateSchema();
}

@Reference(cardinality = ReferenceCardinality.MULTIPLE, policy = ReferencePolicy.DYNAMIC)
public void bindSubscriptionProvider(GraphQLSubscriptionProvider subscriptionProvider) {
subscriptionProviders.add(subscriptionProvider);
updateSchema();
}
public void unbindSubscriptionProvider(GraphQLSubscriptionProvider subscriptionProvider) {
subscriptionProviders.remove(subscriptionProvider);
updateSchema();
}

@Reference(cardinality = ReferenceCardinality.MULTIPLE, policy = ReferencePolicy.DYNAMIC)
public void bindTypesProvider(GraphQLTypesProvider typesProvider) {
typesProviders.add(typesProvider);
Expand Down
22 changes: 22 additions & 0 deletions src/main/java/graphql/servlet/SimpleGraphQLHttpServlet.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@ public SimpleGraphQLHttpServlet(GraphQLInvocationInputFactory invocationInputFac
.build();
}

/**
* @deprecated use {@link GraphQLHttpServlet} instead
*/
@Deprecated
public SimpleGraphQLHttpServlet(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, List<GraphQLServletListener> listeners, boolean asyncServletMode, long subscriptionTimeout) {
super(listeners);
this.configuration = GraphQLConfiguration.with(invocationInputFactory)
.with(queryInvoker)
.with(graphQLObjectMapper)
.with(listeners != null ? listeners : new ArrayList<>())
.with(asyncServletMode)
.with(subscriptionTimeout)
.build();
}

private SimpleGraphQLHttpServlet(GraphQLConfiguration configuration) {
this.configuration = Objects.requireNonNull(configuration, "configuration is required");
}
Expand Down Expand Up @@ -77,6 +92,7 @@ public static class Builder {
private GraphQLObjectMapper graphQLObjectMapper = GraphQLObjectMapper.newBuilder().build();
private List<GraphQLServletListener> listeners;
private boolean asyncServletMode;
private long subscriptionTimeout;

Builder(GraphQLInvocationInputFactory invocationInputFactory) {
this.invocationInputFactory = invocationInputFactory;
Expand All @@ -102,13 +118,19 @@ public Builder withListeners(List<GraphQLServletListener> listeners) {
return this;
}

public Builder withSubscriptionTimeout(long subscriptionTimeout) {
this.subscriptionTimeout = subscriptionTimeout;
return this;
}

@Deprecated
public SimpleGraphQLHttpServlet build() {
GraphQLConfiguration configuration = GraphQLConfiguration.with(invocationInputFactory)
.with(queryInvoker)
.with(graphQLObjectMapper)
.with(listeners != null ? listeners : new ArrayList<>())
.with(asyncServletMode)
.with(subscriptionTimeout)
.build();
return new SimpleGraphQLHttpServlet(configuration);
}
Expand Down
Loading