diff --git a/build.gradle b/build.gradle index 31643f0d..03029408 100644 --- a/build.gradle +++ b/build.gradle @@ -50,7 +50,8 @@ dependencies { compileOnly 'biz.aQute.bnd:biz.aQute.bndlib:3.1.0' // Servlet - compile 'javax.servlet:javax.servlet-api:3.0.1' + compile 'javax.servlet:javax.servlet-api:4.0.0' + compile 'javax.websocket:javax.websocket-api:1.1' // GraphQL compile 'com.graphql-java:graphql-java:9.2' diff --git a/gradle.properties b/gradle.properties index 97590f05..09b44475 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,2 +1,2 @@ -version = 5.2.1 +version = 6.0.0 group = com.graphql-java diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index a56cae0c..f6b961fd 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradlew b/gradlew index 4453ccea..cccdd3d5 100755 --- a/gradlew +++ b/gradlew @@ -33,11 +33,11 @@ DEFAULT_JVM_OPTS="" # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD="maximum" -warn ( ) { +warn () { echo "$*" } -die ( ) { +die () { echo echo "$*" echo @@ -155,7 +155,7 @@ if $cygwin ; then fi # Escape application args -save ( ) { +save () { for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done echo " " } diff --git a/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java b/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java new file mode 100644 index 00000000..6f425a30 --- /dev/null +++ b/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java @@ -0,0 +1,374 @@ +package graphql.servlet; + +import com.google.common.io.ByteStreams; +import com.google.common.io.CharStreams; +import graphql.ExecutionResult; +import graphql.introspection.IntrospectionQuery; +import graphql.schema.GraphQLFieldDefinition; +import graphql.servlet.internal.GraphQLRequest; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.servlet.AsyncContext; +import javax.servlet.Servlet; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.Part; +import java.io.*; +import java.util.*; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * @author Andrew Potter + */ +public abstract class AbstractGraphQLHttpServlet extends HttpServlet implements Servlet, GraphQLMBean { + + public static final Logger log = LoggerFactory.getLogger(AbstractGraphQLHttpServlet.class); + + public static final String APPLICATION_JSON_UTF8 = "application/json;charset=UTF-8"; + public static final String APPLICATION_GRAPHQL = "application/graphql"; + public static final int STATUS_OK = 200; + public static final int STATUS_BAD_REQUEST = 400; + + private static final GraphQLRequest INTROSPECTION_REQUEST = new GraphQLRequest(IntrospectionQuery.INTROSPECTION_QUERY, new HashMap<>(), null); + + protected abstract GraphQLQueryInvoker getQueryInvoker(); + + protected abstract GraphQLInvocationInputFactory getInvocationInputFactory(); + + protected abstract GraphQLObjectMapper getGraphQLObjectMapper(); + + private final List listeners; + + private final HttpRequestHandler getHandler; + private final HttpRequestHandler postHandler; + + private final boolean asyncServletMode; + + public AbstractGraphQLHttpServlet() { + this(null, false); + } + + public AbstractGraphQLHttpServlet(List listeners, boolean asyncServletMode) { + this.listeners = listeners != null ? new ArrayList<>(listeners) : new ArrayList<>(); + this.asyncServletMode = asyncServletMode; + + this.getHandler = (request, response) -> { + GraphQLInvocationInputFactory invocationInputFactory = getInvocationInputFactory(); + GraphQLObjectMapper graphQLObjectMapper = getGraphQLObjectMapper(); + GraphQLQueryInvoker queryInvoker = getQueryInvoker(); + + String path = request.getPathInfo(); + if (path == null) { + path = request.getServletPath(); + } + if (path.contentEquals("/schema.json")) { + query(queryInvoker, graphQLObjectMapper, invocationInputFactory.create(INTROSPECTION_REQUEST, request), response); + } else { + String query = request.getParameter("query"); + if (query != null) { + + if (isBatchedQuery(query)) { + queryBatched(queryInvoker, graphQLObjectMapper, invocationInputFactory.createReadOnly(graphQLObjectMapper.readBatchedGraphQLRequest(query), request), response); + } else { + final Map variables = new HashMap<>(); + if (request.getParameter("variables") != null) { + variables.putAll(graphQLObjectMapper.deserializeVariables(request.getParameter("variables"))); + } + + String operationName = request.getParameter("operationName"); + + query(queryInvoker, graphQLObjectMapper, invocationInputFactory.createReadOnly(new GraphQLRequest(query, variables, operationName), request), response); + } + } else { + response.setStatus(STATUS_BAD_REQUEST); + log.info("Bad GET request: path was not \"/schema.json\" or no query variable named \"query\" given"); + } + } + }; + + this.postHandler = (request, response) -> { + GraphQLInvocationInputFactory invocationInputFactory = getInvocationInputFactory(); + GraphQLObjectMapper graphQLObjectMapper = getGraphQLObjectMapper(); + GraphQLQueryInvoker queryInvoker = getQueryInvoker(); + + try { + if (APPLICATION_GRAPHQL.equals(request.getContentType())) { + String query = CharStreams.toString(request.getReader()); + query(queryInvoker, graphQLObjectMapper, invocationInputFactory.create(new GraphQLRequest(query, null, null)), response); + } else if (request.getContentType() != null && request.getContentType().startsWith("multipart/form-data") && !request.getParts().isEmpty()) { + final Map> fileItems = request.getParts().stream() + .collect(Collectors.toMap( + Part::getName, + Collections::singletonList, + (l1, l2) -> Stream.concat(l1.stream(), l2.stream()).collect(Collectors.toList()))); + + if (fileItems.containsKey("graphql")) { + final Optional graphqlItem = getFileItem(fileItems, "graphql"); + if (graphqlItem.isPresent()) { + InputStream inputStream = graphqlItem.get().getInputStream(); + + if (!inputStream.markSupported()) { + inputStream = new BufferedInputStream(inputStream); + } + + if (isBatchedQuery(inputStream)) { + GraphQLBatchedInvocationInput invocationInput = invocationInputFactory.create(graphQLObjectMapper.readBatchedGraphQLRequest(inputStream), request); + invocationInput.getContext().setFiles(fileItems); + queryBatched(queryInvoker, graphQLObjectMapper, invocationInput, response); + return; + } else { + GraphQLSingleInvocationInput invocationInput = invocationInputFactory.create(graphQLObjectMapper.readGraphQLRequest(inputStream), request); + invocationInput.getContext().setFiles(fileItems); + query(queryInvoker, graphQLObjectMapper, invocationInput, response); + return; + } + } + } else if (fileItems.containsKey("query")) { + final Optional queryItem = getFileItem(fileItems, "query"); + if (queryItem.isPresent()) { + InputStream inputStream = queryItem.get().getInputStream(); + + if (!inputStream.markSupported()) { + inputStream = new BufferedInputStream(inputStream); + } + + if (isBatchedQuery(inputStream)) { + GraphQLBatchedInvocationInput invocationInput = invocationInputFactory.create(graphQLObjectMapper.readBatchedGraphQLRequest(inputStream), request); + invocationInput.getContext().setFiles(fileItems); + queryBatched(queryInvoker, graphQLObjectMapper, invocationInput, response); + return; + } else { + String query = new String(ByteStreams.toByteArray(inputStream)); + + Map variables = null; + final Optional variablesItem = getFileItem(fileItems, "variables"); + if (variablesItem.isPresent()) { + variables = graphQLObjectMapper.deserializeVariables(new String(ByteStreams.toByteArray(variablesItem.get().getInputStream()))); + } + + String operationName = null; + final Optional operationNameItem = getFileItem(fileItems, "operationName"); + if (operationNameItem.isPresent()) { + operationName = new String(ByteStreams.toByteArray(operationNameItem.get().getInputStream())).trim(); + } + + GraphQLSingleInvocationInput invocationInput = invocationInputFactory.create(new GraphQLRequest(query, variables, operationName), request); + invocationInput.getContext().setFiles(fileItems); + query(queryInvoker, graphQLObjectMapper, invocationInput, response); + return; + } + } + } + + response.setStatus(STATUS_BAD_REQUEST); + log.info("Bad POST multipart request: no part named \"graphql\" or \"query\""); + } else { + // this is not a multipart request + InputStream inputStream = request.getInputStream(); + + if (!inputStream.markSupported()) { + inputStream = new BufferedInputStream(inputStream); + } + + if (isBatchedQuery(inputStream)) { + queryBatched(queryInvoker, graphQLObjectMapper, invocationInputFactory.create(graphQLObjectMapper.readBatchedGraphQLRequest(inputStream), request), response); + } else { + query(queryInvoker, graphQLObjectMapper, invocationInputFactory.create(graphQLObjectMapper.readGraphQLRequest(inputStream), request), response); + } + } + } catch (Exception e) { + log.info("Bad POST request: parsing failed", e); + response.setStatus(STATUS_BAD_REQUEST); + } + }; + } + + public void addListener(GraphQLServletListener servletListener) { + listeners.add(servletListener); + } + + public void removeListener(GraphQLServletListener servletListener) { + listeners.remove(servletListener); + } + + @Override + public String[] getQueries() { + return getInvocationInputFactory().getSchemaProvider().getSchema().getQueryType().getFieldDefinitions().stream().map(GraphQLFieldDefinition::getName).toArray(String[]::new); + } + + @Override + public String[] getMutations() { + return getInvocationInputFactory().getSchemaProvider().getSchema().getMutationType().getFieldDefinitions().stream().map(GraphQLFieldDefinition::getName).toArray(String[]::new); + } + + @Override + public String executeQuery(String query) { + try { + return getGraphQLObjectMapper().serializeResultAsJson(getQueryInvoker().query(getInvocationInputFactory().create(new GraphQLRequest(query, new HashMap<>(), null)))); + } catch (Exception e) { + return e.getMessage(); + } + } + + private void doRequestAsync(HttpServletRequest request, HttpServletResponse response, HttpRequestHandler handler) { + if (asyncServletMode) { + AsyncContext asyncContext = request.startAsync(); + HttpServletRequest asyncRequest = (HttpServletRequest) asyncContext.getRequest(); + HttpServletResponse asyncResponse = (HttpServletResponse) asyncContext.getResponse(); + new Thread(() -> doRequest(asyncRequest, asyncResponse, handler, asyncContext)).start(); + } else { + doRequest(request, response, handler, null); + } + } + + private void doRequest(HttpServletRequest request, HttpServletResponse response, HttpRequestHandler handler, AsyncContext asyncContext) { + + List requestCallbacks = runListeners(l -> l.onRequest(request, response)); + + try { + handler.handle(request, response); + runCallbacks(requestCallbacks, c -> c.onSuccess(request, response)); + } catch (Throwable t) { + response.setStatus(500); + log.error("Error executing GraphQL request!", t); + runCallbacks(requestCallbacks, c -> c.onError(request, response, t)); + } finally { + runCallbacks(requestCallbacks, c -> c.onFinally(request, response)); + if (asyncContext != null) { + asyncContext.complete(); + } + } + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + doRequestAsync(req, resp, getHandler); + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + doRequestAsync(req, resp, postHandler); + } + + private Optional getFileItem(Map> fileItems, String name) { + return Optional.ofNullable(fileItems.get(name)).filter(list -> !list.isEmpty()).map(list -> list.get(0)); + } + + private void query(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, GraphQLSingleInvocationInput invocationInput, HttpServletResponse resp) throws IOException { + ExecutionResult result = queryInvoker.query(invocationInput); + + resp.setContentType(APPLICATION_JSON_UTF8); + resp.setStatus(STATUS_OK); + resp.getWriter().write(graphQLObjectMapper.serializeResultAsJson(result)); + } + + private void queryBatched(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, GraphQLBatchedInvocationInput invocationInput, HttpServletResponse resp) throws Exception { + resp.setContentType(APPLICATION_JSON_UTF8); + resp.setStatus(STATUS_OK); + + Writer respWriter = resp.getWriter(); + respWriter.write('['); + + queryInvoker.query(invocationInput, (result, hasNext) -> { + respWriter.write(graphQLObjectMapper.serializeResultAsJson(result)); + if (hasNext) { + respWriter.write(','); + } + }); + + respWriter.write(']'); + } + + private List runListeners(Function action) { + if (listeners == null) { + return Collections.emptyList(); + } + + return listeners.stream() + .map(listener -> { + try { + return action.apply(listener); + } catch (Throwable t) { + log.error("Error running listener: {}", listener, t); + return null; + } + }) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + } + + private void runCallbacks(List callbacks, Consumer action) { + callbacks.forEach(callback -> { + try { + action.accept(callback); + } catch (Throwable t) { + log.error("Error running callback: {}", callback, t); + } + }); + } + + private boolean isBatchedQuery(InputStream inputStream) throws IOException { + if (inputStream == null) { + return false; + } + + ByteArrayOutputStream result = new ByteArrayOutputStream(); + byte[] buffer = new byte[128]; + int length; + + inputStream.mark(0); + while ((length = inputStream.read(buffer)) != -1) { + result.write(buffer, 0, length); + String chunk = result.toString(); + Boolean isArrayStart = isArrayStart(chunk); + if (isArrayStart != null) { + inputStream.reset(); + return isArrayStart; + } + } + + inputStream.reset(); + return false; + } + + private boolean isBatchedQuery(String query) { + if (query == null) { + return false; + } + + Boolean isArrayStart = isArrayStart(query); + return isArrayStart != null && isArrayStart; + } + + // return true if the first non whitespace character is the beginning of an array + private Boolean isArrayStart(String s) { + for (int i = 0; i < s.length(); i++) { + char ch = s.charAt(i); + if (!Character.isWhitespace(ch)) { + return ch == '['; + } + } + + return null; + } + + protected interface HttpRequestHandler extends BiConsumer { + @Override + default void accept(HttpServletRequest request, HttpServletResponse response) { + try { + handle(request, response); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + void handle(HttpServletRequest request, HttpServletResponse response) throws Exception; + } +} diff --git a/src/main/java/graphql/servlet/DefaultGraphQLContextBuilder.java b/src/main/java/graphql/servlet/DefaultGraphQLContextBuilder.java index 34588647..d1a3476d 100644 --- a/src/main/java/graphql/servlet/DefaultGraphQLContextBuilder.java +++ b/src/main/java/graphql/servlet/DefaultGraphQLContextBuilder.java @@ -1,14 +1,22 @@ package graphql.servlet; import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.util.Optional; +import javax.websocket.server.HandshakeRequest; public class DefaultGraphQLContextBuilder implements GraphQLContextBuilder { @Override - public GraphQLContext build(Optional req, Optional resp) { - return new GraphQLContext(req, resp); + public GraphQLContext build(HttpServletRequest httpServletRequest) { + return new GraphQLContext(httpServletRequest); } + @Override + public GraphQLContext build(HandshakeRequest handshakeRequest) { + return new GraphQLContext(handshakeRequest); + } + + @Override + public GraphQLContext build() { + return new GraphQLContext(); + } } diff --git a/src/main/java/graphql/servlet/DefaultGraphQLSchemaProvider.java b/src/main/java/graphql/servlet/DefaultGraphQLSchemaProvider.java index 458d6442..f21cf559 100644 --- a/src/main/java/graphql/servlet/DefaultGraphQLSchemaProvider.java +++ b/src/main/java/graphql/servlet/DefaultGraphQLSchemaProvider.java @@ -3,6 +3,7 @@ import graphql.schema.GraphQLSchema; import javax.servlet.http.HttpServletRequest; +import javax.websocket.server.HandshakeRequest; /** * @author Andrew Potter @@ -27,6 +28,11 @@ public GraphQLSchema getSchema(HttpServletRequest request) { return getSchema(); } + @Override + public GraphQLSchema getSchema(HandshakeRequest request) { + return getSchema(); + } + @Override public GraphQLSchema getSchema() { return schema; diff --git a/src/main/java/graphql/servlet/GraphQLBatchedInvocationInput.java b/src/main/java/graphql/servlet/GraphQLBatchedInvocationInput.java new file mode 100644 index 00000000..78aed616 --- /dev/null +++ b/src/main/java/graphql/servlet/GraphQLBatchedInvocationInput.java @@ -0,0 +1,30 @@ +package graphql.servlet; + +import graphql.ExecutionInput; +import graphql.execution.ExecutionContext; +import graphql.schema.GraphQLSchema; +import graphql.servlet.internal.GraphQLRequest; + +import javax.security.auth.Subject; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * @author Andrew Potter + */ +public class GraphQLBatchedInvocationInput extends GraphQLInvocationInput { + private final List requests; + + public GraphQLBatchedInvocationInput(List requests, GraphQLSchema schema, GraphQLContext context, Object root) { + super(schema, context, root); + this.requests = Collections.unmodifiableList(requests); + } + + public List getExecutionInputs() { + return requests.stream() + .map(this::createExecutionInput) + .collect(Collectors.toList()); + } +} diff --git a/src/main/java/graphql/servlet/GraphQLContext.java b/src/main/java/graphql/servlet/GraphQLContext.java index 6c916daa..a3917889 100644 --- a/src/main/java/graphql/servlet/GraphQLContext.java +++ b/src/main/java/graphql/servlet/GraphQLContext.java @@ -2,53 +2,54 @@ import javax.security.auth.Subject; import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import javax.servlet.http.Part; +import javax.websocket.server.HandshakeRequest; import java.util.List; import java.util.Map; import java.util.Optional; public class GraphQLContext { - private Optional request; - private Optional response; + private HttpServletRequest httpServletRequest; + private HandshakeRequest handshakeRequest; - private Optional subject = Optional.empty(); - private Optional>> files = Optional.empty(); + private Subject subject; + private Map> files; - public GraphQLContext(Optional request, Optional response) { - this.request = request; - this.response = response; + public GraphQLContext(HttpServletRequest httpServletRequest, HandshakeRequest handshakeRequest, Subject subject) { + this.httpServletRequest = httpServletRequest; + this.handshakeRequest = handshakeRequest; + this.subject = subject; } - public Optional getRequest() { - return request; + public GraphQLContext(HttpServletRequest httpServletRequest) { + this(httpServletRequest, null, null); } - public void setRequest(Optional request) { - this.request = request; + public GraphQLContext(HandshakeRequest handshakeRequest) { + this(null, handshakeRequest, null); } - public Optional getResponse() { - return response; + public GraphQLContext() { + this(null, null, null); } - public void setResponse(Optional response) { - this.response = response; + public Optional getHttpServletRequest() { + return Optional.ofNullable(httpServletRequest); } public Optional getSubject() { - return subject; + return Optional.ofNullable(subject); } - public void setSubject(Optional subject) { - this.subject = subject; + public Optional getHandshakeRequest() { + return Optional.ofNullable(handshakeRequest); } public Optional>> getFiles() { - return files; + return Optional.ofNullable(files); } - public void setFiles(Optional>> files) { + public void setFiles(Map> files) { this.files = files; } } diff --git a/src/main/java/graphql/servlet/GraphQLContextBuilder.java b/src/main/java/graphql/servlet/GraphQLContextBuilder.java index a5756dca..cdd5bcb0 100644 --- a/src/main/java/graphql/servlet/GraphQLContextBuilder.java +++ b/src/main/java/graphql/servlet/GraphQLContextBuilder.java @@ -2,8 +2,16 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.websocket.server.HandshakeRequest; import java.util.Optional; public interface GraphQLContextBuilder { - GraphQLContext build(Optional req, Optional resp); + GraphQLContext build(HttpServletRequest httpServletRequest); + GraphQLContext build(HandshakeRequest handshakeRequest); + + /** + * Only used for MBean calls. + * @return the graphql context + */ + GraphQLContext build(); } diff --git a/src/main/java/graphql/servlet/GraphQLInvocationInput.java b/src/main/java/graphql/servlet/GraphQLInvocationInput.java new file mode 100644 index 00000000..5806af53 --- /dev/null +++ b/src/main/java/graphql/servlet/GraphQLInvocationInput.java @@ -0,0 +1,50 @@ +package graphql.servlet; + +import graphql.ExecutionInput; +import graphql.schema.GraphQLSchema; +import graphql.servlet.internal.GraphQLRequest; + +import javax.security.auth.Subject; +import java.util.List; +import java.util.Optional; + +/** + * @author Andrew Potter + */ +public abstract class GraphQLInvocationInput { + private final GraphQLSchema schema; + private final GraphQLContext context; + private final Object root; + + public GraphQLInvocationInput(GraphQLSchema schema, GraphQLContext context, Object root) { + this.schema = schema; + this.context = context; + this.root = root; + } + + public GraphQLSchema getSchema() { + return schema; + } + + public GraphQLContext getContext() { + return context; + } + + public Object getRoot() { + return root; + } + + public Optional getSubject() { + return context.getSubject(); + } + + protected ExecutionInput createExecutionInput(GraphQLRequest graphQLRequest) { + return new ExecutionInput( + graphQLRequest.getQuery(), + graphQLRequest.getOperationName(), + context, + root, + graphQLRequest.getVariables() + ); + } +} diff --git a/src/main/java/graphql/servlet/GraphQLInvocationInputFactory.java b/src/main/java/graphql/servlet/GraphQLInvocationInputFactory.java new file mode 100644 index 00000000..e5631f77 --- /dev/null +++ b/src/main/java/graphql/servlet/GraphQLInvocationInputFactory.java @@ -0,0 +1,137 @@ +package graphql.servlet; + +import graphql.schema.GraphQLSchema; +import graphql.servlet.internal.GraphQLRequest; + +import javax.servlet.http.HttpServletRequest; +import javax.websocket.server.HandshakeRequest; +import java.util.List; +import java.util.function.Supplier; + +/** + * @author Andrew Potter + */ +public class GraphQLInvocationInputFactory { + private final Supplier schemaProviderSupplier; + private final Supplier contextBuilderSupplier; + private final Supplier rootObjectBuilderSupplier; + + protected GraphQLInvocationInputFactory(Supplier schemaProviderSupplier, Supplier contextBuilderSupplier, Supplier rootObjectBuilderSupplier) { + this.schemaProviderSupplier = schemaProviderSupplier; + this.contextBuilderSupplier = contextBuilderSupplier; + this.rootObjectBuilderSupplier = rootObjectBuilderSupplier; + } + + public GraphQLSchemaProvider getSchemaProvider() { + return schemaProviderSupplier.get(); + } + + public GraphQLSingleInvocationInput create(GraphQLRequest graphQLRequest, HttpServletRequest request) { + return create(graphQLRequest, request, false); + } + + public GraphQLBatchedInvocationInput create(List graphQLRequests, HttpServletRequest request) { + return create(graphQLRequests, request, false); + } + + public GraphQLSingleInvocationInput createReadOnly(GraphQLRequest graphQLRequest, HttpServletRequest request) { + return create(graphQLRequest, request, true); + } + + public GraphQLBatchedInvocationInput createReadOnly(List graphQLRequests, HttpServletRequest request) { + return create(graphQLRequests, request, true); + } + + public GraphQLSingleInvocationInput create(GraphQLRequest graphQLRequest) { + return new GraphQLSingleInvocationInput( + graphQLRequest, + schemaProviderSupplier.get().getSchema(), + contextBuilderSupplier.get().build(), + rootObjectBuilderSupplier.get().build() + ); + } + + private GraphQLSingleInvocationInput create(GraphQLRequest graphQLRequest, HttpServletRequest request, boolean readOnly) { + return new GraphQLSingleInvocationInput( + graphQLRequest, + readOnly ? schemaProviderSupplier.get().getReadOnlySchema(request) : schemaProviderSupplier.get().getSchema(request), + contextBuilderSupplier.get().build(request), + rootObjectBuilderSupplier.get().build(request) + ); + } + + private GraphQLBatchedInvocationInput create(List graphQLRequests, HttpServletRequest request, boolean readOnly) { + return new GraphQLBatchedInvocationInput( + graphQLRequests, + readOnly ? schemaProviderSupplier.get().getReadOnlySchema(request) : schemaProviderSupplier.get().getSchema(request), + contextBuilderSupplier.get().build(request), + rootObjectBuilderSupplier.get().build(request) + ); + } + + public GraphQLSingleInvocationInput create(GraphQLRequest graphQLRequest, HandshakeRequest request) { + return new GraphQLSingleInvocationInput( + graphQLRequest, + schemaProviderSupplier.get().getSchema(request), + contextBuilderSupplier.get().build(request), + rootObjectBuilderSupplier.get().build(request) + ); + } + + public GraphQLBatchedInvocationInput create(List graphQLRequest, HandshakeRequest request) { + return new GraphQLBatchedInvocationInput( + graphQLRequest, + schemaProviderSupplier.get().getSchema(request), + contextBuilderSupplier.get().build(request), + rootObjectBuilderSupplier.get().build(request) + ); + } + + public static Builder newBuilder(GraphQLSchema schema) { + return new Builder(new DefaultGraphQLSchemaProvider(schema)); + } + + public static Builder newBuilder(GraphQLSchemaProvider schemaProvider) { + return new Builder(schemaProvider); + } + + public static Builder newBuilder(Supplier schemaProviderSupplier) { + return new Builder(schemaProviderSupplier); + } + + public static class Builder { + private final Supplier schemaProviderSupplier; + private Supplier contextBuilderSupplier = DefaultGraphQLContextBuilder::new; + private Supplier rootObjectBuilderSupplier = DefaultGraphQLRootObjectBuilder::new; + + public Builder(GraphQLSchemaProvider schemaProvider) { + this(() -> schemaProvider); + } + + public Builder(Supplier schemaProviderSupplier) { + this.schemaProviderSupplier = schemaProviderSupplier; + } + + public Builder withGraphQLContextBuilder(GraphQLContextBuilder contextBuilder) { + return withGraphQLContextBuilder(() -> contextBuilder); + } + + public Builder withGraphQLContextBuilder(Supplier contextBuilderSupplier) { + this.contextBuilderSupplier = contextBuilderSupplier; + return this; + } + + public Builder withGraphQLRootObjectBuilder(GraphQLRootObjectBuilder rootObjectBuilder) { + return withGraphQLRootObjectBuilder(() -> rootObjectBuilder); + } + + public Builder withGraphQLRootObjectBuilder(Supplier rootObjectBuilderSupplier) { + this.rootObjectBuilderSupplier = rootObjectBuilderSupplier; + return this; + } + + public GraphQLInvocationInputFactory build() { + return new GraphQLInvocationInputFactory(schemaProviderSupplier, contextBuilderSupplier, rootObjectBuilderSupplier); + } + } +} diff --git a/src/main/java/graphql/servlet/GraphQLObjectMapper.java b/src/main/java/graphql/servlet/GraphQLObjectMapper.java new file mode 100644 index 00000000..ef1174e3 --- /dev/null +++ b/src/main/java/graphql/servlet/GraphQLObjectMapper.java @@ -0,0 +1,191 @@ +package graphql.servlet; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.InjectableValues; +import com.fasterxml.jackson.databind.MappingIterator; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectReader; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import graphql.ExecutionResult; +import graphql.ExecutionResultImpl; +import graphql.GraphQLError; +import graphql.servlet.internal.GraphQLRequest; +import graphql.servlet.internal.VariablesDeserializer; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +/** + * @author Andrew Potter + */ +public class GraphQLObjectMapper { + private final Supplier objectMapperConfigurerSupplier; + private final Supplier graphQLErrorHandlerSupplier; + + private volatile ObjectMapper mapper; + + protected GraphQLObjectMapper(Supplier objectMapperConfigurerSupplier, Supplier graphQLErrorHandlerSupplier) { + this.objectMapperConfigurerSupplier = objectMapperConfigurerSupplier; + this.graphQLErrorHandlerSupplier = graphQLErrorHandlerSupplier; + } + + // Double-check idiom for lazy initialization of instance fields. + public ObjectMapper getJacksonMapper() { + ObjectMapper result = mapper; + if (result == null) { // First check (no locking) + synchronized(this) { + result = mapper; + if (result == null) // Second check (with locking) + mapper = result = createObjectMapper(); + } + } + + return result; + } + + private ObjectMapper createObjectMapper() { + ObjectMapper mapper = new ObjectMapper().disable(SerializationFeature.FAIL_ON_EMPTY_BEANS).registerModule(new Jdk8Module()); + objectMapperConfigurerSupplier.get().configure(mapper); + + InjectableValues.Std injectableValues = new InjectableValues.Std(); + injectableValues.addValue(ObjectMapper.class, mapper); + mapper.setInjectableValues(injectableValues); + + return mapper; + } + + /** + * @return an {@link ObjectReader} for deserializing {@link GraphQLRequest} + */ + public ObjectReader getGraphQLRequestMapper() { + return getJacksonMapper().reader().forType(GraphQLRequest.class); + } + + public GraphQLRequest readGraphQLRequest(InputStream inputStream) throws IOException { + return getGraphQLRequestMapper().readValue(inputStream); + } + + public GraphQLRequest readGraphQLRequest(String text) throws IOException { + return getGraphQLRequestMapper().readValue(text); + } + + public List readBatchedGraphQLRequest(InputStream inputStream) throws IOException { + MappingIterator iterator = getGraphQLRequestMapper().readValues(inputStream); + List requests = new ArrayList<>(); + + while (iterator.hasNext()) { + requests.add(iterator.next()); + } + + return requests; + } + + public List readBatchedGraphQLRequest(String query) throws IOException { + MappingIterator iterator = getGraphQLRequestMapper().readValues(query); + List requests = new ArrayList<>(); + + while (iterator.hasNext()) { + requests.add(iterator.next()); + } + + return requests; + } + + public String serializeResultAsJson(ExecutionResult executionResult) { + try { + return getJacksonMapper().writeValueAsString(createResultFromExecutionResult(executionResult)); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + public boolean areErrorsPresent(ExecutionResult executionResult) { + return graphQLErrorHandlerSupplier.get().errorsPresent(executionResult.getErrors()); + } + + public ExecutionResult sanitizeErrors(ExecutionResult executionResult) { + Object data = executionResult.getData(); + Map extensions = executionResult.getExtensions(); + List errors = executionResult.getErrors(); + + GraphQLErrorHandler errorHandler = graphQLErrorHandlerSupplier.get(); + if(errorHandler.errorsPresent(errors)) { + errors = errorHandler.processErrors(errors); + } else { + errors = null; + } + + return new ExecutionResultImpl(data, errors, extensions); + } + + public Map createResultFromExecutionResult(ExecutionResult executionResult) { + return convertSanitizedExecutionResult(sanitizeErrors(executionResult)); + } + + public Map convertSanitizedExecutionResult(ExecutionResult executionResult) { + return convertSanitizedExecutionResult(executionResult, true); + } + + public Map convertSanitizedExecutionResult(ExecutionResult executionResult, boolean includeData) { + final Map result = new LinkedHashMap<>(); + + if(includeData) { + result.put("data", executionResult.getData()); + } + + if (areErrorsPresent(executionResult)) { + result.put("errors", executionResult.getErrors()); + } + + if(executionResult.getExtensions() != null){ + result.put("extensions", executionResult.getExtensions()); + } + + return result; + } + + public Map deserializeVariables(String variables) { + try { + return VariablesDeserializer.deserializeVariablesObject(getJacksonMapper().readValue(variables, Object.class), getJacksonMapper()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static Builder newBuilder() { + return new Builder(); + } + + public static class Builder { + private Supplier objectMapperConfigurer = DefaultObjectMapperConfigurer::new; + private Supplier graphQLErrorHandler = DefaultGraphQLErrorHandler::new; + + public Builder withObjectMapperConfigurer(ObjectMapperConfigurer objectMapperConfigurer) { + return withObjectMapperConfigurer(() -> objectMapperConfigurer); + } + + public Builder withObjectMapperConfigurer(Supplier objectMapperConfigurer) { + this.objectMapperConfigurer = objectMapperConfigurer; + return this; + } + + public Builder withGraphQLErrorHandler(GraphQLErrorHandler graphQLErrorHandler) { + return withGraphQLErrorHandler(() -> graphQLErrorHandler); + } + + public Builder withGraphQLErrorHandler(Supplier graphQLErrorHandler) { + this.graphQLErrorHandler = graphQLErrorHandler; + return this; + } + + public GraphQLObjectMapper build() { + return new GraphQLObjectMapper(objectMapperConfigurer, graphQLErrorHandler); + } + } +} diff --git a/src/main/java/graphql/servlet/GraphQLQueryInvoker.java b/src/main/java/graphql/servlet/GraphQLQueryInvoker.java new file mode 100644 index 00000000..fcd491cd --- /dev/null +++ b/src/main/java/graphql/servlet/GraphQLQueryInvoker.java @@ -0,0 +1,116 @@ +package graphql.servlet; + +import graphql.ExecutionInput; +import graphql.ExecutionResult; +import graphql.GraphQL; +import graphql.execution.instrumentation.Instrumentation; +import graphql.execution.instrumentation.SimpleInstrumentation; +import graphql.execution.preparsed.NoOpPreparsedDocumentProvider; +import graphql.execution.preparsed.PreparsedDocumentProvider; +import graphql.schema.GraphQLSchema; +import graphql.servlet.internal.ExecutionResultHandler; + +import javax.security.auth.Subject; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Iterator; +import java.util.function.Supplier; + +/** + * @author Andrew Potter + */ +public class GraphQLQueryInvoker { + + private final Supplier getExecutionStrategyProvider; + private final Supplier getInstrumentation; + private final Supplier getPreparsedDocumentProvider; + + protected GraphQLQueryInvoker(Supplier getExecutionStrategyProvider, Supplier getInstrumentation, Supplier getPreparsedDocumentProvider) { + this.getExecutionStrategyProvider = getExecutionStrategyProvider; + this.getInstrumentation = getInstrumentation; + this.getPreparsedDocumentProvider = getPreparsedDocumentProvider; + } + + public ExecutionResult query(GraphQLSingleInvocationInput singleInvocationInput) { + return query(singleInvocationInput, singleInvocationInput.getExecutionInput()); + } + + public void query(GraphQLBatchedInvocationInput batchedInvocationInput, ExecutionResultHandler executionResultHandler) { + Iterator executionInputIterator = batchedInvocationInput.getExecutionInputs().iterator(); + + while (executionInputIterator.hasNext()) { + ExecutionResult result = query(batchedInvocationInput, executionInputIterator.next()); + executionResultHandler.accept(result, executionInputIterator.hasNext()); + } + } + + private GraphQL newGraphQL(GraphQLSchema schema) { + ExecutionStrategyProvider executionStrategyProvider = getExecutionStrategyProvider.get(); + return GraphQL.newGraphQL(schema) + .queryExecutionStrategy(executionStrategyProvider.getQueryExecutionStrategy()) + .mutationExecutionStrategy(executionStrategyProvider.getMutationExecutionStrategy()) + .subscriptionExecutionStrategy(executionStrategyProvider.getSubscriptionExecutionStrategy()) + .instrumentation(getInstrumentation.get()) + .preparsedDocumentProvider(getPreparsedDocumentProvider.get()) + .build(); + } + + private ExecutionResult query(GraphQLInvocationInput invocationInput, ExecutionInput executionInput) { + if (Subject.getSubject(AccessController.getContext()) == null && invocationInput.getSubject().isPresent()) { + return Subject.doAs(invocationInput.getSubject().get(), (PrivilegedAction) () -> { + try { + return query(invocationInput.getSchema(), executionInput); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + return query(invocationInput.getSchema(), executionInput); + } + + private ExecutionResult query(GraphQLSchema schema, ExecutionInput executionInput) { + return newGraphQL(schema).execute(executionInput); + } + + public static Builder newBuilder() { + return new Builder(); + } + + public static class Builder { + private Supplier getExecutionStrategyProvider = DefaultExecutionStrategyProvider::new; + private Supplier getInstrumentation = () -> SimpleInstrumentation.INSTANCE; + private Supplier getPreparsedDocumentProvider = () -> NoOpPreparsedDocumentProvider.INSTANCE; + + public Builder withExecutionStrategyProvider(ExecutionStrategyProvider provider) { + return withExecutionStrategyProvider(() -> provider); + } + + public Builder withExecutionStrategyProvider(Supplier supplier) { + this.getExecutionStrategyProvider = supplier; + return this; + } + + public Builder withInstrumentation(Instrumentation instrumentation) { + return withInstrumentation(() -> instrumentation); + } + + public Builder withInstrumentation(Supplier supplier) { + this.getInstrumentation = supplier; + return this; + } + + public Builder withPreparsedDocumentProvider(PreparsedDocumentProvider provider) { + return withPreparsedDocumentProvider(() -> provider); + } + + public Builder withPreparsedDocumentProvider(Supplier supplier) { + this.getPreparsedDocumentProvider = supplier; + return this; + } + + public GraphQLQueryInvoker build() { + return new GraphQLQueryInvoker(getExecutionStrategyProvider, getInstrumentation, getPreparsedDocumentProvider); + } + } +} diff --git a/src/main/java/graphql/servlet/GraphQLRootObjectBuilder.java b/src/main/java/graphql/servlet/GraphQLRootObjectBuilder.java index 03f90e15..197a1537 100644 --- a/src/main/java/graphql/servlet/GraphQLRootObjectBuilder.java +++ b/src/main/java/graphql/servlet/GraphQLRootObjectBuilder.java @@ -2,8 +2,16 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.websocket.server.HandshakeRequest; import java.util.Optional; public interface GraphQLRootObjectBuilder { - Object build(Optional req, Optional resp); + Object build(HttpServletRequest req); + Object build(HandshakeRequest req); + + /** + * Only used for MBean calls. + * @return the graphql root object + */ + Object build(); } diff --git a/src/main/java/graphql/servlet/GraphQLSchemaProvider.java b/src/main/java/graphql/servlet/GraphQLSchemaProvider.java index 1fcbcd9e..6873174d 100644 --- a/src/main/java/graphql/servlet/GraphQLSchemaProvider.java +++ b/src/main/java/graphql/servlet/GraphQLSchemaProvider.java @@ -3,6 +3,7 @@ import graphql.schema.GraphQLSchema; import javax.servlet.http.HttpServletRequest; +import javax.websocket.server.HandshakeRequest; public interface GraphQLSchemaProvider { @@ -12,10 +13,15 @@ static GraphQLSchema copyReadOnly(GraphQLSchema schema) { /** * @param request the http request - * @return a schema based on the request (auth, etc). Optional is empty when called from an mbean. + * @return a schema based on the request (auth, etc). */ GraphQLSchema getSchema(HttpServletRequest request); + /** + * @param request the http request used to create a websocket + * @return a schema based on the request (auth, etc). + */ + GraphQLSchema getSchema(HandshakeRequest request); /** * @return a schema for handling mbean calls. @@ -24,7 +30,7 @@ static GraphQLSchema copyReadOnly(GraphQLSchema schema) { /** * @param request the http request - * @return a read-only schema based on the request (auth, etc). Should return the same schema as {@link #getSchema(HttpServletRequest)} for a given request. + * @return a read-only schema based on the request (auth, etc). Should return the same schema (query-only version) as {@link #getSchema(HttpServletRequest)} for a given request. */ GraphQLSchema getReadOnlySchema(HttpServletRequest request); } diff --git a/src/main/java/graphql/servlet/GraphQLServlet.java b/src/main/java/graphql/servlet/GraphQLServlet.java deleted file mode 100644 index 19faebd8..00000000 --- a/src/main/java/graphql/servlet/GraphQLServlet.java +++ /dev/null @@ -1,580 +0,0 @@ -package graphql.servlet; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.*; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.google.common.io.ByteStreams; -import com.google.common.io.CharStreams; -import graphql.ExecutionInput; -import graphql.ExecutionResult; -import graphql.GraphQL; -import graphql.GraphQLError; -import graphql.execution.instrumentation.Instrumentation; -import graphql.execution.preparsed.PreparsedDocumentProvider; -import graphql.introspection.IntrospectionQuery; -import graphql.schema.GraphQLFieldDefinition; -import graphql.schema.GraphQLSchema; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.security.auth.Subject; -import javax.servlet.AsyncContext; -import javax.servlet.Servlet; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServlet; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.Part; -import java.io.*; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.util.*; -import java.util.function.BiConsumer; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -/** - * @author Andrew Potter - */ -public abstract class GraphQLServlet extends HttpServlet implements Servlet, GraphQLMBean { - - public static final Logger log = LoggerFactory.getLogger(GraphQLServlet.class); - - public static final String APPLICATION_JSON_UTF8 = "application/json;charset=UTF-8"; - public static final String APPLICATION_GRAPHQL = "application/graphql"; - public static final int STATUS_OK = 200; - public static final int STATUS_BAD_REQUEST = 400; - - protected abstract GraphQLSchemaProvider getSchemaProvider(); - - protected abstract GraphQLContext createContext(Optional request, Optional response); - - protected abstract Object createRootObject(Optional request, Optional response); - - protected abstract ExecutionStrategyProvider getExecutionStrategyProvider(); - - protected abstract Instrumentation getInstrumentation(); - - protected abstract GraphQLErrorHandler getGraphQLErrorHandler(); - - protected abstract PreparsedDocumentProvider getPreparsedDocumentProvider(); - - private final LazyObjectMapperBuilder lazyObjectMapperBuilder; - private final List listeners; - - private final HttpRequestHandler getHandler; - private final HttpRequestHandler postHandler; - - private final boolean asyncServletMode; - - public GraphQLServlet() { - this(null, null, false); - } - - public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List listeners, boolean asyncServletMode) { - this.lazyObjectMapperBuilder = new LazyObjectMapperBuilder(objectMapperConfigurer != null ? objectMapperConfigurer : new DefaultObjectMapperConfigurer()); - this.listeners = listeners != null ? new ArrayList<>(listeners) : new ArrayList<>(); - this.asyncServletMode = asyncServletMode; - - this.getHandler = (request, response) -> { - final GraphQLContext context = createContext(Optional.of(request), Optional.of(response)); - final Object rootObject = createRootObject(Optional.of(request), Optional.of(response)); - - String path = request.getPathInfo(); - if (path == null) { - path = request.getServletPath(); - } - if (path.contentEquals("/schema.json")) { - doQuery(IntrospectionQuery.INTROSPECTION_QUERY, null, new HashMap<>(), getSchemaProvider().getSchema(request), context, rootObject, request, response); - } else { - String query = request.getParameter("query"); - if (query != null) { - if (isBatchedQuery(query)) { - doBatchedQuery(getGraphQLRequestMapper().readValues(query), getSchemaProvider().getReadOnlySchema(request), context, rootObject, request, response); - } else { - final Map variables = new HashMap<>(); - if (request.getParameter("variables") != null) { - variables.putAll(deserializeVariables(request.getParameter("variables"))); - } - - String operationName = null; - if (request.getParameter("operationName") != null) { - operationName = request.getParameter("operationName"); - } - - doQuery(query, operationName, variables, getSchemaProvider().getReadOnlySchema(request), context, rootObject, request, response); - } - } else { - response.setStatus(STATUS_BAD_REQUEST); - log.info("Bad GET request: path was not \"/schema.json\" or no query variable named \"query\" given"); - } - } - }; - - this.postHandler = (request, response) -> { - final GraphQLContext context = createContext(Optional.of(request), Optional.of(response)); - final Object rootObject = createRootObject(Optional.of(request), Optional.of(response)); - - try { - if (APPLICATION_GRAPHQL.equals(request.getContentType())) { - String query = CharStreams.toString(request.getReader()); - doQuery(query, null, null, getSchemaProvider().getSchema(request), context, rootObject, request, response); - } else if (request.getContentType() != null && request.getContentType().startsWith("multipart/form-data") && !request.getParts().isEmpty()) { - final Map> fileItems = request.getParts().stream() - .collect(Collectors.toMap( - Part::getName, - Collections::singletonList, - (l1, l2) -> Stream.concat(l1.stream(), l2.stream()).collect(Collectors.toList()))); - - context.setFiles(Optional.of(fileItems)); - - if (fileItems.containsKey("graphql")) { - final Optional graphqlItem = getFileItem(fileItems, "graphql"); - if (graphqlItem.isPresent()) { - InputStream inputStream = graphqlItem.get().getInputStream(); - - if (!inputStream.markSupported()) { - inputStream = new BufferedInputStream(inputStream); - } - - if (isBatchedQuery(inputStream)) { - doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); - return; - } else { - doQuery(getGraphQLRequestMapper().readValue(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); - return; - } - } - } else if (fileItems.containsKey("query")) { - final Optional queryItem = getFileItem(fileItems, "query"); - if (queryItem.isPresent()) { - InputStream inputStream = queryItem.get().getInputStream(); - - if (!inputStream.markSupported()) { - inputStream = new BufferedInputStream(inputStream); - } - - if (isBatchedQuery(inputStream)) { - doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); - return; - } else { - - String query = new String(ByteStreams.toByteArray(inputStream)); - - Map variables = null; - final Optional variablesItem = getFileItem(fileItems, "variables"); - if (variablesItem.isPresent()) { - variables = deserializeVariables(new String(ByteStreams.toByteArray(variablesItem.get().getInputStream()))); - } - - String operationName = null; - final Optional operationNameItem = getFileItem(fileItems, "operationName"); - if (operationNameItem.isPresent()) { - operationName = new String(ByteStreams.toByteArray(operationNameItem.get().getInputStream())).trim(); - } - - doQuery(query, operationName, variables, getSchemaProvider().getSchema(request), context, rootObject, request, response); - return; - } - } - } - - response.setStatus(STATUS_BAD_REQUEST); - log.info("Bad POST multipart request: no part named \"graphql\" or \"query\""); - } else { - handleNonMultipartRequest(request, response, context, rootObject); - } - } catch (Exception e) { - log.info("Bad POST request: parsing failed", e); - response.setStatus(STATUS_BAD_REQUEST); - } - }; - } - - private void handleNonMultipartRequest(HttpServletRequest request, HttpServletResponse response, GraphQLContext context, Object rootObject) throws Exception { - // this is not a multipart request - InputStream inputStream = request.getInputStream(); - - if (!inputStream.markSupported()) { - inputStream = new BufferedInputStream(inputStream); - } - - if (isBatchedQuery(inputStream)) { - doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); - } else { - doQuery(getGraphQLRequestMapper().readValue(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); - } - } - - protected ObjectMapper getMapper() { - return lazyObjectMapperBuilder.getMapper(); - } - - /** - * Creates an {@link ObjectReader} for deserializing {@link GraphQLRequest} - */ - private ObjectReader getGraphQLRequestMapper() { - // Add object mapper to injection so VariablesDeserializer can access it... - InjectableValues.Std injectableValues = new InjectableValues.Std(); - injectableValues.addValue(ObjectMapper.class, getMapper()); - - return getMapper().reader(injectableValues).forType(GraphQLRequest.class); - } - - public void addListener(GraphQLServletListener servletListener) { - listeners.add(servletListener); - } - - public void removeListener(GraphQLServletListener servletListener) { - listeners.remove(servletListener); - } - - @Override - public String[] getQueries() { - return getSchemaProvider().getSchema().getQueryType().getFieldDefinitions().stream().map(GraphQLFieldDefinition::getName).toArray(String[]::new); - } - - @Override - public String[] getMutations() { - return getSchemaProvider().getSchema().getMutationType().getFieldDefinitions().stream().map(GraphQLFieldDefinition::getName).toArray(String[]::new); - } - - @Override - public String executeQuery(String query) { - try { - final ExecutionResult result = newGraphQL(getSchemaProvider().getSchema()).execute(new ExecutionInput(query, null, createContext(Optional.empty(), Optional.empty()), createRootObject(Optional.empty(), Optional.empty()), new HashMap<>())); - return getMapper().writeValueAsString(createResultFromDataErrorsAndExtensions(result.getData(), result.getErrors(), result.getExtensions())); - } catch (Exception e) { - return e.getMessage(); - } - } - - private void doRequest(HttpServletRequest request, HttpServletResponse response, HttpRequestHandler handler,AsyncContext asyncContext) { - - List requestCallbacks = runListeners(l -> l.onRequest(request, response)); - - try { - handler.handle(request, response); - runCallbacks(requestCallbacks, c -> c.onSuccess(request, response)); - } catch (Throwable t) { - response.setStatus(500); - log.error("Error executing GraphQL request!", t); - runCallbacks(requestCallbacks, c -> c.onError(request, response, t)); - } finally { - runCallbacks(requestCallbacks, c -> c.onFinally(request, response)); - if(asyncContext !=null) - asyncContext.complete(); - } - } - - @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - if (asyncServletMode) { - AsyncContext asyncContext = req.startAsync(); - HttpServletRequest request = (HttpServletRequest) asyncContext.getRequest(); - HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); - new Thread(() -> doRequest(request, response, getHandler, asyncContext)).start(); - } else { - doRequest(req, resp, getHandler, null); - } - } - - @Override - protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - if (asyncServletMode) { - AsyncContext asyncContext = req.startAsync(); - HttpServletRequest request = (HttpServletRequest) asyncContext.getRequest(); - HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); - new Thread(() -> doRequest(request, response, postHandler, asyncContext)).start(); - } else { - doRequest(req, resp, postHandler, null); - } - } - - private Optional getFileItem(Map> fileItems, String name) { - return Optional.ofNullable(fileItems.get(name)).filter(list -> !list.isEmpty()).map(list -> list.get(0)); - } - - private GraphQL newGraphQL(GraphQLSchema schema) { - ExecutionStrategyProvider executionStrategyProvider = getExecutionStrategyProvider(); - return GraphQL.newGraphQL(schema) - .queryExecutionStrategy(executionStrategyProvider.getQueryExecutionStrategy()) - .mutationExecutionStrategy(executionStrategyProvider.getMutationExecutionStrategy()) - .subscriptionExecutionStrategy(executionStrategyProvider.getSubscriptionExecutionStrategy()) - .instrumentation(getInstrumentation()) - .preparsedDocumentProvider(getPreparsedDocumentProvider()) - .build(); - } - - private void doQuery(GraphQLRequest graphQLRequest, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest httpReq, HttpServletResponse httpRes) throws Exception { - doQuery(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), graphQLRequest.getVariables(), schema, context, rootObject, httpReq, httpRes); - } - - private void doQuery(String query, String operationName, Map variables, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest req, HttpServletResponse resp) throws Exception { - query(query, operationName, variables, schema, context, rootObject, (r) -> { - resp.setContentType(APPLICATION_JSON_UTF8); - resp.setStatus(r.getStatus()); - resp.getWriter().write(r.getResponse()); - }); - } - - private void doBatchedQuery(Iterator graphQLRequests, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest req, HttpServletResponse resp) throws Exception { - resp.setContentType(APPLICATION_JSON_UTF8); - resp.setStatus(STATUS_OK); - - Writer respWriter = resp.getWriter(); - respWriter.write('['); - while (graphQLRequests.hasNext()) { - GraphQLRequest graphQLRequest = graphQLRequests.next(); - query(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), graphQLRequest.getVariables(), schema, context, rootObject, (r) -> respWriter.write(r.getResponse())); - if (graphQLRequests.hasNext()) { - respWriter.write(','); - } - } - respWriter.write(']'); - } - - private void query(String query, String operationName, Map variables, GraphQLSchema schema, GraphQLContext context, Object rootObject, GraphQLResponseHandler responseHandler) throws Exception { - if (operationName != null && operationName.isEmpty()) { - query(query, null, variables, schema, context, rootObject, responseHandler); - } else if (Subject.getSubject(AccessController.getContext()) == null && context.getSubject().isPresent()) { - Subject.doAs(context.getSubject().get(), (PrivilegedAction) () -> { - try { - query(query, operationName, variables, schema, context, rootObject, responseHandler); - } catch (Exception e) { - throw new RuntimeException(e); - } - return null; - }); - } else { - List operationCallbacks = runListeners(l -> l.onOperation(context, operationName, query, variables)); - - final ExecutionResult executionResult = newGraphQL(schema).execute(new ExecutionInput(query, operationName, context, rootObject, variables)); - final List errors = executionResult.getErrors(); - final Object data = executionResult.getData(); - final Object extensions = executionResult.getExtensions(); - - final String response = getMapper().writeValueAsString(createResultFromDataErrorsAndExtensions(data, errors, extensions)); - - GraphQLResponse graphQLResponse = new GraphQLResponse(); - graphQLResponse.setStatus(STATUS_OK); - graphQLResponse.setResponse(response); - responseHandler.handle(graphQLResponse); - - if (getGraphQLErrorHandler().errorsPresent(errors)) { - runCallbacks(operationCallbacks, c -> c.onError(context, operationName, query, variables, data, errors, extensions)); - } else { - runCallbacks(operationCallbacks, c -> c.onSuccess(context, operationName, query, variables, data, extensions)); - } - - runCallbacks(operationCallbacks, c -> c.onFinally(context, operationName, query, variables, data, extensions)); - } - } - - private Map createResultFromDataErrorsAndExtensions(Object data, List errors, Object extensions) { - - final Map result = new LinkedHashMap<>(); - result.put("data", data); - - if (getGraphQLErrorHandler().errorsPresent(errors)) { - result.put("errors", getGraphQLErrorHandler().processErrors(errors)); - } - - if (extensions != null) { - result.put("extensions", extensions); - } - - return result; - } - - private List runListeners(Function action) { - if (listeners == null) { - return Collections.emptyList(); - } - - return listeners.stream() - .map(listener -> { - try { - return action.apply(listener); - } catch (Throwable t) { - log.error("Error running listener: {}", listener, t); - return null; - } - }) - .filter(Objects::nonNull) - .collect(Collectors.toList()); - } - - private void runCallbacks(List callbacks, Consumer action) { - callbacks.forEach(callback -> { - try { - action.accept(callback); - } catch (Throwable t) { - log.error("Error running callback: {}", callback, t); - } - }); - } - - protected static class VariablesDeserializer extends JsonDeserializer> { - - @Override - public Map deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { - return deserializeVariablesObject(p.readValueAs(Object.class), (ObjectMapper) ctxt.findInjectableValue(ObjectMapper.class.getName(), null, null)); - } - } - - private Map deserializeVariables(String variables) { - try { - return deserializeVariablesObject(getMapper().readValue(variables, Object.class), getMapper()); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private static Map deserializeVariablesObject(Object variables, ObjectMapper mapper) { - if (variables instanceof Map) { - @SuppressWarnings("unchecked") - Map genericVariables = (Map) variables; - return genericVariables; - } else if (variables instanceof String) { - try { - return mapper.readValue((String) variables, new TypeReference>() { - }); - } catch (IOException e) { - throw new RuntimeException(e); - } - } else { - throw new RuntimeException("variables should be either an object or a string"); - } - } - - private boolean isBatchedQuery(InputStream inputStream) throws IOException { - if (inputStream == null) { - return false; - } - - final int BUFFER_LENGTH = 128; - ByteArrayOutputStream result = new ByteArrayOutputStream(); - byte[] buffer = new byte[BUFFER_LENGTH]; - int length; - - inputStream.mark(BUFFER_LENGTH); - while ((length = inputStream.read(buffer)) != -1) { - result.write(buffer, 0, length); - String chunk = result.toString(); - Boolean isArrayStart = isArrayStart(chunk); - if (isArrayStart != null) { - inputStream.reset(); - return isArrayStart; - } - } - - inputStream.reset(); - return false; - } - - private boolean isBatchedQuery(String query) { - if (query == null) { - return false; - } - - Boolean isArrayStart = isArrayStart(query); - return isArrayStart != null && isArrayStart; - } - - // return true if the first non whitespace character is the beginning of an array - private Boolean isArrayStart(String s) { - for (int i = 0; i < s.length(); i++) { - char ch = s.charAt(i); - if (!Character.isWhitespace(ch)) { - return ch == '['; - } - } - - return null; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - protected static class GraphQLRequest { - private String query; - @JsonDeserialize(using = GraphQLServlet.VariablesDeserializer.class) - private Map variables = new HashMap<>(); - private String operationName; - - public String getQuery() { - return query; - } - - public void setQuery(String query) { - this.query = query; - } - - public Map getVariables() { - return variables; - } - - public void setVariables(Map variables) { - this.variables = variables; - } - - public String getOperationName() { - return operationName; - } - - public void setOperationName(String operationName) { - this.operationName = operationName; - } - } - - @JsonIgnoreProperties(ignoreUnknown = true) - protected static class GraphQLResponse { - private int status; - private String response; - - public int getStatus() { - return status; - } - - public void setStatus(int status) { - this.status = status; - } - - public String getResponse() { - return response; - } - - public void setResponse(String response) { - this.response = response; - } - } - - protected interface HttpRequestHandler extends BiConsumer { - @Override - default void accept(HttpServletRequest request, HttpServletResponse response) { - try { - handle(request, response); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - void handle(HttpServletRequest request, HttpServletResponse response) throws Exception; - } - - protected interface GraphQLResponseHandler extends Consumer { - @Override - default void accept(GraphQLResponse response) { - try { - handle(response); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - void handle(GraphQLResponse r) throws Exception; - } -} diff --git a/src/main/java/graphql/servlet/GraphQLServletListener.java b/src/main/java/graphql/servlet/GraphQLServletListener.java index 60f8c448..3b9f2560 100644 --- a/src/main/java/graphql/servlet/GraphQLServletListener.java +++ b/src/main/java/graphql/servlet/GraphQLServletListener.java @@ -1,11 +1,7 @@ package graphql.servlet; -import graphql.GraphQLError; - import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.util.List; -import java.util.Map; /** * @author Andrew Potter @@ -14,19 +10,10 @@ public interface GraphQLServletListener { default RequestCallback onRequest(HttpServletRequest request, HttpServletResponse response) { return null; } - default OperationCallback onOperation(GraphQLContext context, String operationName, String query, Map variables) { - return null; - } interface RequestCallback { default void onSuccess(HttpServletRequest request, HttpServletResponse response) {} default void onError(HttpServletRequest request, HttpServletResponse response, Throwable throwable) {} default void onFinally(HttpServletRequest request, HttpServletResponse response) {} } - - interface OperationCallback { - default void onSuccess(GraphQLContext context, String operationName, String query, Map variables, Object data, Object extensions) {} - default void onError(GraphQLContext context, String operationName, String query, Map variables, Object data, List errors, Object extensions) {} - default void onFinally(GraphQLContext context, String operationName, String query, Map variables, Object data, Object extensions) {} - } } diff --git a/src/main/java/graphql/servlet/GraphQLSingleInvocationInput.java b/src/main/java/graphql/servlet/GraphQLSingleInvocationInput.java new file mode 100644 index 00000000..9de802ac --- /dev/null +++ b/src/main/java/graphql/servlet/GraphQLSingleInvocationInput.java @@ -0,0 +1,23 @@ +package graphql.servlet; + +import graphql.ExecutionInput; +import graphql.schema.GraphQLSchema; +import graphql.servlet.internal.GraphQLRequest; + +/** + * @author Andrew Potter + */ +public class GraphQLSingleInvocationInput extends GraphQLInvocationInput { + + private final GraphQLRequest request; + + public GraphQLSingleInvocationInput(GraphQLRequest request, GraphQLSchema schema, GraphQLContext context, Object root) { + super(schema, context, root); + + this.request = request; + } + + public ExecutionInput getExecutionInput() { + return createExecutionInput(request); + } +} diff --git a/src/main/java/graphql/servlet/GraphQLWebsocketServlet.java b/src/main/java/graphql/servlet/GraphQLWebsocketServlet.java new file mode 100644 index 00000000..fe7cf5b6 --- /dev/null +++ b/src/main/java/graphql/servlet/GraphQLWebsocketServlet.java @@ -0,0 +1,124 @@ +package graphql.servlet; + +import graphql.servlet.internal.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.websocket.*; +import javax.websocket.server.HandshakeRequest; +import javax.websocket.server.ServerEndpointConfig; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Must be used with {@link #modifyHandshake(ServerEndpointConfig, HandshakeRequest, HandshakeResponse)} + * + * @author Andrew Potter + */ +public class GraphQLWebsocketServlet extends Endpoint { + + private static final Logger log = LoggerFactory.getLogger(GraphQLWebsocketServlet.class); + + private static final String HANDSHAKE_REQUEST_KEY = HandshakeRequest.class.getName(); + private static final String PROTOCOL_HANDLER_REQUEST_KEY = SubscriptionProtocolHandler.class.getName(); + private static final CloseReason ERROR_CLOSE_REASON = new CloseReason(CloseReason.CloseCodes.UNEXPECTED_CONDITION, "Internal Server Error"); + + private static final List subscriptionProtocolFactories = Collections.singletonList(new ApolloSubscriptionProtocolFactory()); + private static final SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory = new FallbackSubscriptionProtocolFactory(); + private static final List allSubscriptionProtocols; + + static { + allSubscriptionProtocols = Stream.concat(subscriptionProtocolFactories.stream(), Stream.of(fallbackSubscriptionProtocolFactory)) + .map(SubscriptionProtocolFactory::getProtocol) + .collect(Collectors.toList()); + } + + private final Map sessionSubscriptionCache = new HashMap<>(); + private final SubscriptionHandlerInput subscriptionHandlerInput; + + public GraphQLWebsocketServlet(GraphQLQueryInvoker queryInvoker, GraphQLInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper) { + this.subscriptionHandlerInput = new SubscriptionHandlerInput(invocationInputFactory, queryInvoker, graphQLObjectMapper); + } + + @Override + public void onOpen(Session session, EndpointConfig endpointConfig) { + log.debug("Session opened: {}, {}", session.getId(), endpointConfig); + final WsSessionSubscriptions subscriptions = new WsSessionSubscriptions(); + final HandshakeRequest request = (HandshakeRequest) session.getUserProperties().get(HANDSHAKE_REQUEST_KEY); + final SubscriptionProtocolHandler subscriptionProtocolHandler = (SubscriptionProtocolHandler) session.getUserProperties().get(PROTOCOL_HANDLER_REQUEST_KEY); + + sessionSubscriptionCache.put(session, subscriptions); + + // This *cannot* be a lambda because of the way undertow checks the class... + session.addMessageHandler(new MessageHandler.Whole() { + @Override + public void onMessage(String text) { + try { + subscriptionProtocolHandler.onMessage(request, session, subscriptions, text); + } catch (Throwable t) { + log.error("Error executing websocket query for session: {}", session.getId(), t); + closeUnexpectedly(session, t); + } + } + }); + } + + @Override + public void onClose(Session session, CloseReason closeReason) { + log.debug("Session closed: {}, {}", session.getId(), closeReason); + WsSessionSubscriptions subscriptions = sessionSubscriptionCache.remove(session); + if (subscriptions != null) { + subscriptions.close(); + } + } + + @Override + public void onError(Session session, Throwable thr) { + log.error("Error in websocket session: {}", session.getId(), thr); + closeUnexpectedly(session, thr); + } + + private void closeUnexpectedly(Session session, Throwable t) { + try { + session.close(ERROR_CLOSE_REASON); + } catch (IOException e) { + log.error("Error closing websocket session for session: {}", session.getId(), t); + } + } + + public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) { + sec.getUserProperties().put(HANDSHAKE_REQUEST_KEY, request); + + List protocol = request.getHeaders().get(HandshakeRequest.SEC_WEBSOCKET_PROTOCOL); + if (protocol == null) { + protocol = Collections.emptyList(); + } + + SubscriptionProtocolFactory subscriptionProtocolFactory = getSubscriptionProtocolFactory(protocol); + sec.getUserProperties().put(PROTOCOL_HANDLER_REQUEST_KEY, subscriptionProtocolFactory.createHandler(subscriptionHandlerInput)); + + if (request.getHeaders().get(HandshakeResponse.SEC_WEBSOCKET_ACCEPT) != null) { + response.getHeaders().put(HandshakeResponse.SEC_WEBSOCKET_ACCEPT, allSubscriptionProtocols); + } + if (!protocol.isEmpty()) { + response.getHeaders().put(HandshakeRequest.SEC_WEBSOCKET_PROTOCOL, Collections.singletonList(subscriptionProtocolFactory.getProtocol())); + } + } + + private static SubscriptionProtocolFactory getSubscriptionProtocolFactory(List accept) { + for (String protocol : accept) { + for (SubscriptionProtocolFactory subscriptionProtocolFactory : subscriptionProtocolFactories) { + if (subscriptionProtocolFactory.getProtocol().equals(protocol)) { + return subscriptionProtocolFactory; + } + } + } + + return fallbackSubscriptionProtocolFactory; + } +} diff --git a/src/main/java/graphql/servlet/LazyObjectMapperBuilder.java b/src/main/java/graphql/servlet/LazyObjectMapperBuilder.java deleted file mode 100644 index db6ee838..00000000 --- a/src/main/java/graphql/servlet/LazyObjectMapperBuilder.java +++ /dev/null @@ -1,38 +0,0 @@ -package graphql.servlet; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.SerializationFeature; -import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; - -/** - * @author Andrew Potter - */ -public class LazyObjectMapperBuilder { - private final ObjectMapperConfigurer configurer; - private volatile ObjectMapper mapper; - - public LazyObjectMapperBuilder(ObjectMapperConfigurer configurer) { - this.configurer = configurer; - } - - // Double-check idiom for lazy initialization of instance fields. - public ObjectMapper getMapper() { - ObjectMapper result = mapper; - if (result == null) { // First check (no locking) - synchronized(this) { - result = mapper; - if (result == null) // Second check (with locking) - mapper = result = createObjectMapper(); - } - } - - return result; - } - - private ObjectMapper createObjectMapper() { - ObjectMapper mapper = new ObjectMapper().disable(SerializationFeature.FAIL_ON_EMPTY_BEANS).registerModule(new Jdk8Module()); - configurer.configure(mapper); - - return mapper; - } -} diff --git a/src/main/java/graphql/servlet/OsgiGraphQLServlet.java b/src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java similarity index 78% rename from src/main/java/graphql/servlet/OsgiGraphQLServlet.java rename to src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java index 36e3383c..e8c42976 100644 --- a/src/main/java/graphql/servlet/OsgiGraphQLServlet.java +++ b/src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java @@ -1,18 +1,18 @@ package graphql.servlet; -import graphql.execution.instrumentation.Instrumentation; import graphql.execution.preparsed.NoOpPreparsedDocumentProvider; import graphql.execution.preparsed.PreparsedDocumentProvider; import graphql.schema.GraphQLObjectType; import graphql.schema.GraphQLType; -import org.osgi.service.component.annotations.*; +import org.osgi.service.component.annotations.Component; +import org.osgi.service.component.annotations.Reference; +import org.osgi.service.component.annotations.ReferenceCardinality; +import org.osgi.service.component.annotations.ReferencePolicy; +import org.osgi.service.component.annotations.ReferencePolicyOption; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import java.util.ArrayList; import java.util.HashSet; import java.util.List; -import java.util.Optional; import java.util.Set; import static graphql.schema.GraphQLObjectType.newObject; @@ -22,12 +22,16 @@ service={javax.servlet.http.HttpServlet.class,javax.servlet.Servlet.class}, property = {"alias=/graphql", "jmx.objectname=graphql.servlet:type=graphql"} ) -public class OsgiGraphQLServlet extends GraphQLServlet { +public class OsgiGraphQLHttpServlet extends AbstractGraphQLHttpServlet { private final List queryProviders = new ArrayList<>(); private final List mutationProviders = new ArrayList<>(); private final List typesProviders = new ArrayList<>(); + private final GraphQLQueryInvoker queryInvoker; + private final GraphQLInvocationInputFactory invocationInputFactory; + private final GraphQLObjectMapper graphQLObjectMapper; + private GraphQLContextBuilder contextBuilder = new DefaultGraphQLContextBuilder(); private GraphQLRootObjectBuilder rootObjectBuilder = new DefaultGraphQLRootObjectBuilder(); private ExecutionStrategyProvider executionStrategyProvider = new DefaultExecutionStrategyProvider(); @@ -37,6 +41,39 @@ public class OsgiGraphQLServlet extends GraphQLServlet { private GraphQLSchemaProvider schemaProvider; + @Override + protected GraphQLQueryInvoker getQueryInvoker() { + return queryInvoker; + } + + @Override + protected GraphQLInvocationInputFactory getInvocationInputFactory() { + return invocationInputFactory; + } + + @Override + protected GraphQLObjectMapper getGraphQLObjectMapper() { + return graphQLObjectMapper; + } + + public OsgiGraphQLHttpServlet() { + updateSchema(); + + this.queryInvoker = GraphQLQueryInvoker.newBuilder() + .withPreparsedDocumentProvider(this::getPreparsedDocumentProvider) + .withInstrumentation(() -> this.getInstrumentationProvider().getInstrumentation()) + .withExecutionStrategyProvider(this::getExecutionStrategyProvider).build(); + + this.invocationInputFactory = GraphQLInvocationInputFactory.newBuilder(this::getSchemaProvider) + .withGraphQLContextBuilder(this::getContextBuilder) + .withGraphQLRootObjectBuilder(this::getRootObjectBuilder) + .build(); + + this.graphQLObjectMapper = GraphQLObjectMapper.newBuilder() + .withGraphQLErrorHandler(this::getErrorHandler) + .build(); + } + protected void updateSchema() { final GraphQLObjectType.Builder queryTypeBuilder = newObject().name("Query").description("Root query type"); @@ -68,10 +105,6 @@ protected void updateSchema() { this.schemaProvider = new DefaultGraphQLSchemaProvider(newSchema().query(queryTypeBuilder.build()).mutation(mutationType).build(types)); } - public OsgiGraphQLServlet() { - updateSchema(); - } - @Reference(cardinality = ReferenceCardinality.MULTIPLE, policyOption = ReferencePolicyOption.GREEDY) public void bindProvider(GraphQLProvider provider) { if (provider instanceof GraphQLQueryProvider) { @@ -184,37 +217,31 @@ public void unsetPreparsedDocumentProvider(PreparsedDocumentProvider preparsedDo this.preparsedDocumentProvider = NoOpPreparsedDocumentProvider.INSTANCE; } - @Override - protected GraphQLSchemaProvider getSchemaProvider() { - return schemaProvider; - } - - protected GraphQLContext createContext(Optional req, Optional resp) { - return contextBuilder.build(req, resp); + public GraphQLContextBuilder getContextBuilder() { + return contextBuilder; } - @Override - protected Object createRootObject(Optional request, Optional response) { - return rootObjectBuilder.build(request, response); + public GraphQLRootObjectBuilder getRootObjectBuilder() { + return rootObjectBuilder; } - @Override - protected ExecutionStrategyProvider getExecutionStrategyProvider() { + public ExecutionStrategyProvider getExecutionStrategyProvider() { return executionStrategyProvider; } - @Override - protected Instrumentation getInstrumentation() { - return instrumentationProvider.getInstrumentation(); + public InstrumentationProvider getInstrumentationProvider() { + return instrumentationProvider; } - @Override - protected GraphQLErrorHandler getGraphQLErrorHandler() { + public GraphQLErrorHandler getErrorHandler() { return errorHandler; } - @Override - protected PreparsedDocumentProvider getPreparsedDocumentProvider() { + public PreparsedDocumentProvider getPreparsedDocumentProvider() { return preparsedDocumentProvider; } + + public GraphQLSchemaProvider getSchemaProvider() { + return schemaProvider; + } } diff --git a/src/main/java/graphql/servlet/SimpleGraphQLHttpServlet.java b/src/main/java/graphql/servlet/SimpleGraphQLHttpServlet.java new file mode 100644 index 00000000..925ae92f --- /dev/null +++ b/src/main/java/graphql/servlet/SimpleGraphQLHttpServlet.java @@ -0,0 +1,77 @@ +package graphql.servlet; + +import graphql.schema.GraphQLSchema; + +/** + * @author Andrew Potter + */ +public class SimpleGraphQLHttpServlet extends AbstractGraphQLHttpServlet { + + private final GraphQLInvocationInputFactory invocationInputFactory; + private final GraphQLQueryInvoker queryInvoker; + private final GraphQLObjectMapper graphQLObjectMapper; + + private SimpleGraphQLHttpServlet(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, boolean asyncServletMode) { + super(null, asyncServletMode); + this.invocationInputFactory = invocationInputFactory; + this.queryInvoker = queryInvoker; + this.graphQLObjectMapper = graphQLObjectMapper; + } + + @Override + protected GraphQLQueryInvoker getQueryInvoker() { + return queryInvoker; + } + + @Override + protected GraphQLInvocationInputFactory getInvocationInputFactory() { + return invocationInputFactory; + } + + @Override + protected GraphQLObjectMapper getGraphQLObjectMapper() { + return graphQLObjectMapper; + } + + public static Builder newBuilder(GraphQLSchema schema) { + return new Builder(GraphQLInvocationInputFactory.newBuilder(schema).build()); + } + + public static Builder newBuilder(GraphQLSchemaProvider schemaProvider) { + return new Builder(GraphQLInvocationInputFactory.newBuilder(schemaProvider).build()); + } + + public static Builder newBuilder(GraphQLInvocationInputFactory invocationInputFactory) { + return new Builder(invocationInputFactory); + } + + public static class Builder { + private final GraphQLInvocationInputFactory invocationInputFactory; + private GraphQLQueryInvoker queryInvoker = GraphQLQueryInvoker.newBuilder().build(); + private GraphQLObjectMapper graphQLObjectMapper = GraphQLObjectMapper.newBuilder().build(); + private boolean asyncServletMode; + + Builder(GraphQLInvocationInputFactory invocationInputFactory) { + this.invocationInputFactory = invocationInputFactory; + } + + public Builder withQueryInvoker(GraphQLQueryInvoker queryInvoker) { + this.queryInvoker = queryInvoker; + return this; + } + + public Builder withObjectMapper(GraphQLObjectMapper objectMapper) { + this.graphQLObjectMapper = objectMapper; + return this; + } + + public Builder withAsyncServletMode(boolean asyncServletMode) { + this.asyncServletMode = asyncServletMode; + return this; + } + + public SimpleGraphQLHttpServlet build() { + return new SimpleGraphQLHttpServlet(invocationInputFactory, queryInvoker, graphQLObjectMapper, asyncServletMode); + } + } +} diff --git a/src/main/java/graphql/servlet/SimpleGraphQLServlet.java b/src/main/java/graphql/servlet/SimpleGraphQLServlet.java deleted file mode 100644 index 17b447af..00000000 --- a/src/main/java/graphql/servlet/SimpleGraphQLServlet.java +++ /dev/null @@ -1,236 +0,0 @@ -package graphql.servlet; - -import java.util.List; -import java.util.Optional; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import graphql.execution.ExecutionStrategy; -import graphql.execution.instrumentation.Instrumentation; -import graphql.execution.instrumentation.SimpleInstrumentation; -import graphql.execution.preparsed.NoOpPreparsedDocumentProvider; -import graphql.execution.preparsed.PreparsedDocumentProvider; -import graphql.schema.GraphQLSchema; - -/** - * @author Andrew Potter - */ -public class SimpleGraphQLServlet extends GraphQLServlet { - - - /** - * @deprecated use {@link #builder(GraphQLSchema)} instead. - */ - @Deprecated - public SimpleGraphQLServlet(GraphQLSchema schema) { - this(schema, new DefaultExecutionStrategyProvider()); - } - - /** - * @deprecated use {@link #builder(GraphQLSchema)} instead. - */ - @Deprecated - public SimpleGraphQLServlet(GraphQLSchema schema, ExecutionStrategy executionStrategy) { - this(schema, new DefaultExecutionStrategyProvider(executionStrategy)); - } - - /** - * @deprecated use {@link #builder(GraphQLSchema)} instead. - */ - @Deprecated - public SimpleGraphQLServlet(GraphQLSchema schema, ExecutionStrategyProvider executionStrategyProvider) { - this(schema, executionStrategyProvider, null, null, null, null, null, null, null); - } - - /** - * @deprecated use {@link #builder(GraphQLSchema)} instead. - */ - @Deprecated - public SimpleGraphQLServlet(final GraphQLSchema schema, ExecutionStrategyProvider executionStrategyProvider, ObjectMapperConfigurer objectMapperConfigurer, List listeners, Instrumentation instrumentation, GraphQLErrorHandler errorHandler, GraphQLContextBuilder contextBuilder, GraphQLRootObjectBuilder rootObjectBuilder, PreparsedDocumentProvider preparsedDocumentProvider) { - this(new DefaultGraphQLSchemaProvider(schema), executionStrategyProvider, objectMapperConfigurer, listeners, instrumentation, errorHandler, contextBuilder, rootObjectBuilder, preparsedDocumentProvider,false); - } - - - /** - * @deprecated use {@link #builder(GraphQLSchemaProvider)} instead. - */ - @Deprecated - public SimpleGraphQLServlet(GraphQLSchemaProvider schemaProvider, ExecutionStrategyProvider executionStrategyProvider, ObjectMapperConfigurer objectMapperConfigurer, List listeners, Instrumentation instrumentation, GraphQLErrorHandler errorHandler, GraphQLContextBuilder contextBuilder, GraphQLRootObjectBuilder rootObjectBuilder, PreparsedDocumentProvider preparsedDocumentProvider, boolean asyncServletMode) { - super(objectMapperConfigurer, listeners, asyncServletMode); - - this.schemaProvider = schemaProvider; - this.executionStrategyProvider = executionStrategyProvider; - - if (instrumentation == null) { - this.instrumentation = SimpleInstrumentation.INSTANCE; - } else { - this.instrumentation = instrumentation; - } - - if (errorHandler == null) { - this.errorHandler = new DefaultGraphQLErrorHandler(); - } else { - this.errorHandler = errorHandler; - } - - if (contextBuilder == null) { - this.contextBuilder = new DefaultGraphQLContextBuilder(); - } else { - this.contextBuilder = contextBuilder; - } - - if (rootObjectBuilder == null) { - this.rootObjectBuilder = new DefaultGraphQLRootObjectBuilder(); - } else { - this.rootObjectBuilder = rootObjectBuilder; - } - - if (preparsedDocumentProvider == null) { - this.preparsedDocumentProvider = NoOpPreparsedDocumentProvider.INSTANCE; - } else { - this.preparsedDocumentProvider = preparsedDocumentProvider; - } - } - - protected SimpleGraphQLServlet(Builder builder) { - super(builder.objectMapperConfigurer, builder.listeners, builder.asyncServletMode); - - this.schemaProvider = builder.schemaProvider; - this.executionStrategyProvider = builder.executionStrategyProvider; - this.instrumentation = builder.instrumentation; - this.errorHandler = builder.errorHandler; - this.contextBuilder = builder.contextBuilder; - this.rootObjectBuilder = builder.rootObjectBuilder; - this.preparsedDocumentProvider = builder.preparsedDocumentProvider; - } - - private final GraphQLSchemaProvider schemaProvider; - private final ExecutionStrategyProvider executionStrategyProvider; - private final Instrumentation instrumentation; - private final GraphQLErrorHandler errorHandler; - private final GraphQLContextBuilder contextBuilder; - private final GraphQLRootObjectBuilder rootObjectBuilder; - private final PreparsedDocumentProvider preparsedDocumentProvider; - - public static SimpleGraphQLServlet create(GraphQLSchema schema) { - return new Builder(schema).build(); - } - - public static SimpleGraphQLServlet create(GraphQLSchemaProvider schemaProvider) { - return new Builder(schemaProvider).build(); - } - - public static Builder builder(GraphQLSchema schema) { - return new Builder(schema); - } - - public static Builder builder(GraphQLSchemaProvider schemaProvider) { - return new Builder(schemaProvider); - } - - public static class Builder { - private final GraphQLSchemaProvider schemaProvider; - private ExecutionStrategyProvider executionStrategyProvider = new DefaultExecutionStrategyProvider(); - private ObjectMapperConfigurer objectMapperConfigurer; - private List listeners; - private Instrumentation instrumentation = SimpleInstrumentation.INSTANCE; - private GraphQLErrorHandler errorHandler = new DefaultGraphQLErrorHandler(); - private GraphQLContextBuilder contextBuilder = new DefaultGraphQLContextBuilder(); - private GraphQLRootObjectBuilder rootObjectBuilder = new DefaultGraphQLRootObjectBuilder(); - private PreparsedDocumentProvider preparsedDocumentProvider = NoOpPreparsedDocumentProvider.INSTANCE; - private boolean asyncServletMode; - - public Builder(GraphQLSchema schema) { - this(new DefaultGraphQLSchemaProvider(schema)); - } - - public Builder(GraphQLSchemaProvider schemaProvider) { - this.schemaProvider = schemaProvider; - } - - public Builder withExecutionStrategyProvider(ExecutionStrategyProvider provider) { - this.executionStrategyProvider = provider; - return this; - } - - public Builder withObjectMapperConfigurer(ObjectMapperConfigurer configurer) { - this.objectMapperConfigurer = configurer; - return this; - } - - public Builder withInstrumentation(Instrumentation instrumentation) { - this.instrumentation = instrumentation; - return this; - } - - public Builder withGraphQLErrorHandler(GraphQLErrorHandler handler) { - this.errorHandler = handler; - return this; - } - - public Builder withGraphQLContextBuilder(GraphQLContextBuilder context) { - this.contextBuilder = context; - return this; - } - - public Builder withGraphQLRootObjectBuilder(GraphQLRootObjectBuilder rootObject) { - this.rootObjectBuilder = rootObject; - return this; - } - - public Builder withPreparsedDocumentProvider(PreparsedDocumentProvider provider) { - this.preparsedDocumentProvider = provider; - return this; - } - - public Builder withListeners(List listeners) { - this.listeners = listeners; - return this; - } - - public Builder withAsyncServletMode(boolean value) { - this.asyncServletMode=value; - return this; - } - - public SimpleGraphQLServlet build() { - return new SimpleGraphQLServlet(this); - } - } - - @Override - protected GraphQLSchemaProvider getSchemaProvider() { - return schemaProvider; - } - - @Override - protected GraphQLContext createContext(Optional request, Optional response) { - return this.contextBuilder.build(request, response); - } - - @Override - protected Object createRootObject(Optional request, Optional response) { - return this.rootObjectBuilder.build(request, response); - } - - @Override - protected ExecutionStrategyProvider getExecutionStrategyProvider() { - return executionStrategyProvider; - } - - @Override - protected Instrumentation getInstrumentation() { - return instrumentation; - } - - @Override - protected GraphQLErrorHandler getGraphQLErrorHandler() { - return errorHandler; - } - - @Override - protected PreparsedDocumentProvider getPreparsedDocumentProvider() { - return preparsedDocumentProvider; - } -} diff --git a/src/main/java/graphql/servlet/StaticGraphQLRootObjectBuilder.java b/src/main/java/graphql/servlet/StaticGraphQLRootObjectBuilder.java index 8aa22481..4426ace8 100644 --- a/src/main/java/graphql/servlet/StaticGraphQLRootObjectBuilder.java +++ b/src/main/java/graphql/servlet/StaticGraphQLRootObjectBuilder.java @@ -1,8 +1,7 @@ package graphql.servlet; import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.util.Optional; +import javax.websocket.server.HandshakeRequest; public class StaticGraphQLRootObjectBuilder implements GraphQLRootObjectBuilder { @@ -13,7 +12,17 @@ public StaticGraphQLRootObjectBuilder(Object rootObject) { } @Override - public Object build(Optional req, Optional resp) { + public Object build(HttpServletRequest req) { + return rootObject; + } + + @Override + public Object build(HandshakeRequest req) { + return rootObject; + } + + @Override + public Object build() { return rootObject; } } diff --git a/src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolFactory.java b/src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolFactory.java new file mode 100644 index 00000000..c8a041a2 --- /dev/null +++ b/src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolFactory.java @@ -0,0 +1,15 @@ +package graphql.servlet.internal; + +/** + * @author Andrew Potter + */ +public class ApolloSubscriptionProtocolFactory extends SubscriptionProtocolFactory { + public ApolloSubscriptionProtocolFactory() { + super("graphql-ws"); + } + + @Override + public SubscriptionProtocolHandler createHandler(SubscriptionHandlerInput subscriptionHandlerInput) { + return new ApolloSubscriptionProtocolHandler(subscriptionHandlerInput); + } +} diff --git a/src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolHandler.java b/src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolHandler.java new file mode 100644 index 00000000..211dfcdc --- /dev/null +++ b/src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolHandler.java @@ -0,0 +1,192 @@ +package graphql.servlet.internal; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonValue; +import graphql.ExecutionResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.websocket.CloseReason; +import javax.websocket.Session; +import javax.websocket.server.HandshakeRequest; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_COMPLETE; +import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_CONNECTION_TERMINATE; +import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_DATA; +import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_ERROR; + +/** + * https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md + * + * @author Andrew Potter + */ +public class ApolloSubscriptionProtocolHandler extends SubscriptionProtocolHandler { + + private static final Logger log = LoggerFactory.getLogger(ApolloSubscriptionProtocolHandler.class); + + private final SubscriptionHandlerInput input; + + public ApolloSubscriptionProtocolHandler(SubscriptionHandlerInput subscriptionHandlerInput) { + this.input = subscriptionHandlerInput; + } + + @Override + public void onMessage(HandshakeRequest request, Session session, WsSessionSubscriptions subscriptions, String text) { + OperationMessage message; + try { + message = input.getGraphQLObjectMapper().getJacksonMapper().readValue(text, OperationMessage.class); + } catch(Throwable t) { + log.warn("Error parsing message", t); + sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ERROR, null); + return; + } + + switch(message.getType()) { + case GQL_CONNECTION_INIT: + sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ACK, message.getId()); + sendMessage(session, OperationMessage.Type.GQL_CONNECTION_KEEP_ALIVE, message.getId()); + break; + + case GQL_START: + handleSubscriptionStart( + session, + subscriptions, + message.id, + input.getQueryInvoker().query(input.getInvocationInputFactory().create( + input.getGraphQLObjectMapper().getJacksonMapper().convertValue(message.payload, GraphQLRequest.class) + )) + ); + break; + + case GQL_STOP: + unsubscribe(subscriptions, message.id); + break; + + case GQL_CONNECTION_TERMINATE: + try { + session.close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "client requested " + GQL_CONNECTION_TERMINATE.getType())); + } catch (IOException e) { + log.error("Unable to close websocket session!", e); + } + break; + + default: + throw new IllegalArgumentException("Unknown message type: " + message.getType()); + } + } + + @SuppressWarnings("unchecked") + private void handleSubscriptionStart(Session session, WsSessionSubscriptions subscriptions, String id, ExecutionResult executionResult) { + executionResult = input.getGraphQLObjectMapper().sanitizeErrors(executionResult); + + if(input.getGraphQLObjectMapper().areErrorsPresent(executionResult)) { + sendMessage(session, OperationMessage.Type.GQL_ERROR, id, input.getGraphQLObjectMapper().convertSanitizedExecutionResult(executionResult, false)); + return; + } + + subscribe(session, executionResult, subscriptions, id); + } + + @Override + protected void sendDataMessage(Session session, String id, Object payload) { + sendMessage(session, GQL_DATA, id, payload); + } + + @Override + protected void sendErrorMessage(Session session, String id) { + sendMessage(session, GQL_ERROR, id); + } + + @Override + protected void sendCompleteMessage(Session session, String id) { + sendMessage(session, GQL_COMPLETE, id); + } + + private void sendMessage(Session session, OperationMessage.Type type, String id) { + sendMessage(session, type, id, null); + } + + private void sendMessage(Session session, OperationMessage.Type type, String id, Object payload) { + try { + session.getBasicRemote().sendText(input.getGraphQLObjectMapper().getJacksonMapper().writeValueAsString( + new OperationMessage(type, id, payload) + )); + } catch (IOException e) { + throw new RuntimeException("Error sending subscription response", e); + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class OperationMessage { + private Type type; + private String id; + private Object payload; + + public OperationMessage() { + } + + public OperationMessage(Type type, String id, Object payload) { + this.type = type; + this.id = id; + this.payload = payload; + } + + public Type getType() { + return type; + } + + public String getId() { + return id; + } + + public Object getPayload() { + return payload; + } + + public enum Type { + + // Server Messages + GQL_CONNECTION_ACK("connection_ack"), + GQL_CONNECTION_ERROR("connection_error"), + GQL_CONNECTION_KEEP_ALIVE("ka"), + GQL_DATA("data"), + GQL_ERROR("error"), + GQL_COMPLETE("complete"), + + // Client Messages + GQL_CONNECTION_INIT("connection_init"), + GQL_CONNECTION_TERMINATE("connection_terminate"), + GQL_START("start"), + GQL_STOP("stop"); + + private static final Map reverseLookup = new HashMap<>(); + + static { + for(Type type: Type.values()) { + reverseLookup.put(type.getType(), type); + } + } + + private final String type; + + Type(String type) { + this.type = type; + } + + @JsonCreator + public static Type findType(String type) { + return reverseLookup.get(type); + } + + @JsonValue + public String getType() { + return type; + } + } + } + +} diff --git a/src/main/java/graphql/servlet/internal/ExecutionResultHandler.java b/src/main/java/graphql/servlet/internal/ExecutionResultHandler.java new file mode 100644 index 00000000..f721e3d9 --- /dev/null +++ b/src/main/java/graphql/servlet/internal/ExecutionResultHandler.java @@ -0,0 +1,23 @@ +package graphql.servlet.internal; + +import graphql.ExecutionResult; + +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +/** + * @author Andrew Potter + */ +public interface ExecutionResultHandler extends BiConsumer { + @Override + default void accept(ExecutionResult executionResult, Boolean hasNext) { + try { + handle(executionResult, hasNext); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + void handle(ExecutionResult result, Boolean hasNext) throws Exception; +} + diff --git a/src/main/java/graphql/servlet/internal/FallbackSubscriptionProtocolFactory.java b/src/main/java/graphql/servlet/internal/FallbackSubscriptionProtocolFactory.java new file mode 100644 index 00000000..e15fa59b --- /dev/null +++ b/src/main/java/graphql/servlet/internal/FallbackSubscriptionProtocolFactory.java @@ -0,0 +1,15 @@ +package graphql.servlet.internal; + +/** + * @author Andrew Potter + */ +public class FallbackSubscriptionProtocolFactory extends SubscriptionProtocolFactory { + public FallbackSubscriptionProtocolFactory() { + super(""); + } + + @Override + public SubscriptionProtocolHandler createHandler(SubscriptionHandlerInput subscriptionHandlerInput) { + return new FallbackSubscriptionProtocolHandler(subscriptionHandlerInput); + } +} diff --git a/src/main/java/graphql/servlet/internal/FallbackSubscriptionProtocolHandler.java b/src/main/java/graphql/servlet/internal/FallbackSubscriptionProtocolHandler.java new file mode 100644 index 00000000..39d0ebb1 --- /dev/null +++ b/src/main/java/graphql/servlet/internal/FallbackSubscriptionProtocolHandler.java @@ -0,0 +1,51 @@ +package graphql.servlet.internal; + +import javax.websocket.Session; +import javax.websocket.server.HandshakeRequest; +import java.io.IOException; +import java.util.UUID; + +/** + * @author Andrew Potter + */ +public class FallbackSubscriptionProtocolHandler extends SubscriptionProtocolHandler { + + private final SubscriptionHandlerInput input; + + public FallbackSubscriptionProtocolHandler(SubscriptionHandlerInput subscriptionHandlerInput) { + this.input = subscriptionHandlerInput; + } + + @Override + public void onMessage(HandshakeRequest request, Session session, WsSessionSubscriptions subscriptions, String text) throws Exception { + subscribe( + session, + input.getQueryInvoker().query( + input.getInvocationInputFactory().create( + input.getGraphQLObjectMapper().readGraphQLRequest(text) + ) + ), + subscriptions, + UUID.randomUUID().toString() + ); + } + + @Override + protected void sendDataMessage(Session session, String id, Object payload) { + try { + session.getBasicRemote().sendText(input.getGraphQLObjectMapper().getJacksonMapper().writeValueAsString(payload)); + } catch (IOException e) { + throw new RuntimeException("Error sending subscription response", e); + } + } + + @Override + protected void sendErrorMessage(Session session, String id) { + + } + + @Override + protected void sendCompleteMessage(Session session, String id) { + + } +} diff --git a/src/main/java/graphql/servlet/internal/GraphQLRequest.java b/src/main/java/graphql/servlet/internal/GraphQLRequest.java new file mode 100644 index 00000000..36eda27f --- /dev/null +++ b/src/main/java/graphql/servlet/internal/GraphQLRequest.java @@ -0,0 +1,57 @@ +package graphql.servlet.internal; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; + +import java.util.HashMap; +import java.util.Map; + +/** + * @author Andrew Potter + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public class GraphQLRequest { + private String query; + @JsonDeserialize(using = VariablesDeserializer.class) + private Map variables = new HashMap<>(); + private String operationName; + + public GraphQLRequest() { + } + + public GraphQLRequest(String query, Map variables, String operationName) { + this.query = query; + this.variables = variables; + this.operationName = operationName; + } + + public String getQuery() { + return query; + } + + public void setQuery(String query) { + this.query = query; + } + + public Map getVariables() { + return variables; + } + + public void setVariables(Map variables) { + this.variables = variables; + } + + public String getOperationName() { + if (operationName != null && !operationName.isEmpty()) { + return operationName; + } + + return null; + } + + public void setOperationName(String operationName) { + this.operationName = operationName; + } +} + + diff --git a/src/main/java/graphql/servlet/internal/SubscriptionHandlerInput.java b/src/main/java/graphql/servlet/internal/SubscriptionHandlerInput.java new file mode 100644 index 00000000..5bc1a3f8 --- /dev/null +++ b/src/main/java/graphql/servlet/internal/SubscriptionHandlerInput.java @@ -0,0 +1,30 @@ +package graphql.servlet.internal; + +import graphql.servlet.GraphQLInvocationInputFactory; +import graphql.servlet.GraphQLObjectMapper; +import graphql.servlet.GraphQLQueryInvoker; + +public class SubscriptionHandlerInput { + + private final GraphQLInvocationInputFactory invocationInputFactory; + private final GraphQLQueryInvoker queryInvoker; + private final GraphQLObjectMapper graphQLObjectMapper; + + public SubscriptionHandlerInput(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper) { + this.invocationInputFactory = invocationInputFactory; + this.queryInvoker = queryInvoker; + this.graphQLObjectMapper = graphQLObjectMapper; + } + + public GraphQLInvocationInputFactory getInvocationInputFactory() { + return invocationInputFactory; + } + + public GraphQLQueryInvoker getQueryInvoker() { + return queryInvoker; + } + + public GraphQLObjectMapper getGraphQLObjectMapper() { + return graphQLObjectMapper; + } +} diff --git a/src/main/java/graphql/servlet/internal/SubscriptionProtocolFactory.java b/src/main/java/graphql/servlet/internal/SubscriptionProtocolFactory.java new file mode 100644 index 00000000..04de3f4b --- /dev/null +++ b/src/main/java/graphql/servlet/internal/SubscriptionProtocolFactory.java @@ -0,0 +1,18 @@ +package graphql.servlet.internal; + +/** + * @author Andrew Potter + */ +public abstract class SubscriptionProtocolFactory { + private final String protocol; + + public SubscriptionProtocolFactory(String protocol) { + this.protocol = protocol; + } + + public String getProtocol() { + return protocol; + } + + public abstract SubscriptionProtocolHandler createHandler(SubscriptionHandlerInput subscriptionHandlerInput); +} diff --git a/src/main/java/graphql/servlet/internal/SubscriptionProtocolHandler.java b/src/main/java/graphql/servlet/internal/SubscriptionProtocolHandler.java new file mode 100644 index 00000000..988bb360 --- /dev/null +++ b/src/main/java/graphql/servlet/internal/SubscriptionProtocolHandler.java @@ -0,0 +1,95 @@ +package graphql.servlet.internal; + +import graphql.ExecutionResult; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.websocket.Session; +import javax.websocket.server.HandshakeRequest; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +/** + * @author Andrew Potter + */ +public abstract class SubscriptionProtocolHandler { + + private static final Logger log = LoggerFactory.getLogger(SubscriptionProtocolHandler.class); + + public abstract void onMessage(HandshakeRequest request, Session session, WsSessionSubscriptions subscriptions, String text) throws Exception; + + protected abstract void sendDataMessage(Session session, String id, Object payload); + + protected abstract void sendErrorMessage(Session session, String id); + + protected abstract void sendCompleteMessage(Session session, String id); + + protected void subscribe(Session session, ExecutionResult executionResult, WsSessionSubscriptions subscriptions, String id) { + final Object data = executionResult.getData(); + + if (data instanceof Publisher) { + @SuppressWarnings("unchecked") final Publisher publisher = (Publisher) data; + final AtomicSubscriptionReference subscriptionReference = new AtomicSubscriptionReference(); + + publisher.subscribe(new Subscriber() { + @Override + public void onSubscribe(Subscription subscription) { + subscriptionReference.set(subscription); + subscriptionReference.get().request(1); + + subscriptions.add(id, subscriptionReference.get()); + } + + @Override + public void onNext(ExecutionResult executionResult) { + subscriptionReference.get().request(1); + Map result = new HashMap<>(); + result.put("data", executionResult.getData()); + sendDataMessage(session, id, result); + } + + @Override + public void onError(Throwable throwable) { + log.error("Subscription error", throwable); + subscriptions.cancel(id); + sendErrorMessage(session, id); + } + + @Override + public void onComplete() { + subscriptions.cancel(id); + sendCompleteMessage(session, id); + } + }); + } + } + + protected void unsubscribe(WsSessionSubscriptions subscriptions, String id) { + subscriptions.cancel(id); + } + + static class AtomicSubscriptionReference { + private final AtomicReference reference = new AtomicReference<>(null); + + public void set(Subscription subscription) { + if(reference.get() != null) { + throw new IllegalStateException("Cannot overwrite subscription!"); + } + + reference.set(subscription); + } + + public Subscription get() { + Subscription subscription = reference.get(); + if(subscription == null) { + throw new IllegalStateException("Subscription has not been initialized yet!"); + } + + return subscription; + } + } +} diff --git a/src/main/java/graphql/servlet/internal/VariablesDeserializer.java b/src/main/java/graphql/servlet/internal/VariablesDeserializer.java new file mode 100644 index 00000000..da3f532f --- /dev/null +++ b/src/main/java/graphql/servlet/internal/VariablesDeserializer.java @@ -0,0 +1,38 @@ +package graphql.servlet.internal; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.io.IOException; +import java.util.Map; + +/** + * @author Andrew Potter + */ +public class VariablesDeserializer extends JsonDeserializer> { + @Override + public Map deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + return deserializeVariablesObject(p.readValueAs(Object.class), (ObjectMapper) ctxt.findInjectableValue(ObjectMapper.class.getName(), null, null)); + } + + public static Map deserializeVariablesObject(Object variables, ObjectMapper mapper) { + if (variables instanceof Map) { + @SuppressWarnings("unchecked") + Map genericVariables = (Map) variables; + return genericVariables; + } else if (variables instanceof String) { + try { + return mapper.readValue((String) variables, new TypeReference>() {}); + } catch (IOException e) { + throw new RuntimeException(e); + } + } else { + throw new RuntimeException("variables should be either an object or a string"); + } + } + +} + diff --git a/src/main/java/graphql/servlet/internal/WsSessionSubscriptions.java b/src/main/java/graphql/servlet/internal/WsSessionSubscriptions.java new file mode 100644 index 00000000..48a0ac79 --- /dev/null +++ b/src/main/java/graphql/servlet/internal/WsSessionSubscriptions.java @@ -0,0 +1,54 @@ +package graphql.servlet.internal; + +import org.reactivestreams.Subscription; + +import java.util.HashMap; +import java.util.Map; + +/** + * @author Andrew Potter + */ +public class WsSessionSubscriptions { + private final Object lock = new Object(); + + private boolean closed = false; + private Map subscriptions = new HashMap<>(); + + public void add(Subscription subscription) { + add(getImplicitId(subscription), subscription); + } + + public void add(String id, Subscription subscription) { + synchronized (lock) { + if(closed) { + throw new IllegalStateException("Websocket was already closed!"); + } + subscriptions.put(id, subscription); + } + } + + public void cancel(Subscription subscription) { + cancel(getImplicitId(subscription)); + } + + public void cancel(String id) { + synchronized (lock) { + Subscription subscription = subscriptions.remove(id); + if(subscription != null) { + subscription.cancel(); + } + } + } + + public void close() { + synchronized (lock) { + closed = true; + subscriptions.forEach((k, v) -> v.cancel()); + subscriptions = new HashMap<>(); + } + } + + private String getImplicitId(Subscription subscription) { + return String.valueOf(subscription.hashCode()); + } +} diff --git a/src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy similarity index 98% rename from src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy rename to src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy index b2b7fd0a..5a11a780 100644 --- a/src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy +++ b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy @@ -10,6 +10,7 @@ import graphql.schema.GraphQLObjectType import graphql.schema.GraphQLSchema import org.springframework.mock.web.MockHttpServletRequest import org.springframework.mock.web.MockHttpServletResponse +import spock.lang.Ignore import spock.lang.Shared import spock.lang.Specification @@ -18,7 +19,7 @@ import javax.servlet.http.HttpServletRequest /** * @author Andrew Potter */ -class GraphQLServletSpec extends Specification { +class AbstractGraphQLHttpServletSpec extends Specification { public static final int STATUS_OK = 200 public static final int STATUS_BAD_REQUEST = 400 @@ -28,7 +29,7 @@ class GraphQLServletSpec extends Specification { @Shared ObjectMapper mapper = new ObjectMapper() - GraphQLServlet servlet + AbstractGraphQLHttpServlet servlet MockHttpServletRequest request MockHttpServletResponse response @@ -70,7 +71,7 @@ class GraphQLServletSpec extends Specification { } .build() - return new SimpleGraphQLServlet(new GraphQLSchema(query, mutation, [query, mutation].toSet())) + return SimpleGraphQLHttpServlet.newBuilder(new GraphQLSchema(query, mutation, [query, mutation].toSet())).build() } Map getResponseContent() { @@ -725,12 +726,7 @@ class GraphQLServletSpec extends Specification { def "errors before graphql schema execution return internal server error"() { setup: - servlet = new SimpleGraphQLServlet(servlet.getSchemaProvider().getSchema()) { - @Override - GraphQLSchemaProvider getSchemaProvider() { - throw new TestException() - } - } + servlet = SimpleGraphQLHttpServlet.newBuilder(GraphQLInvocationInputFactory.newBuilder { throw new TestException() }.build()).build() request.setPathInfo('/schema.json') @@ -830,9 +826,10 @@ class GraphQLServletSpec extends Specification { def "typeInfo is serialized correctly"() { expect: - servlet.getMapper().writeValueAsString(ExecutionTypeInfo.newTypeInfo().type(new GraphQLNonNull(Scalars.GraphQLString)).build()) != "{}" + servlet.getGraphQLObjectMapper().getJacksonMapper().writeValueAsString(ExecutionTypeInfo.newTypeInfo().type(new GraphQLNonNull(Scalars.GraphQLString)).build()) != "{}" } + @Ignore def "isBatchedQuery check uses buffer length as read limit"() { setup: HttpServletRequest mockRequest = Mock() diff --git a/src/test/groovy/graphql/servlet/OsgiGraphQLServletSpec.groovy b/src/test/groovy/graphql/servlet/OsgiGraphQLHttpServletSpec.groovy similarity index 93% rename from src/test/groovy/graphql/servlet/OsgiGraphQLServletSpec.groovy rename to src/test/groovy/graphql/servlet/OsgiGraphQLHttpServletSpec.groovy index 559daf23..577cdff9 100644 --- a/src/test/groovy/graphql/servlet/OsgiGraphQLServletSpec.groovy +++ b/src/test/groovy/graphql/servlet/OsgiGraphQLHttpServletSpec.groovy @@ -9,7 +9,7 @@ import spock.lang.Specification import static graphql.Scalars.GraphQLInt import static graphql.schema.GraphQLFieldDefinition.newFieldDefinition -class OsgiGraphQLServletSpec extends Specification { +class OsgiGraphQLHttpServletSpec extends Specification { static class TestQueryProvider implements GraphQLQueryProvider { @@ -34,7 +34,7 @@ class OsgiGraphQLServletSpec extends Specification { def "query provider adds query objects"() { setup: - OsgiGraphQLServlet servlet = new OsgiGraphQLServlet() + OsgiGraphQLHttpServlet servlet = new OsgiGraphQLHttpServlet() TestQueryProvider queryProvider = new TestQueryProvider() servlet.bindQueryProvider(queryProvider) GraphQLFieldDefinition query @@ -65,7 +65,7 @@ class OsgiGraphQLServletSpec extends Specification { def "mutation provider adds mutation objects"() { setup: - OsgiGraphQLServlet servlet = new OsgiGraphQLServlet() + OsgiGraphQLHttpServlet servlet = new OsgiGraphQLHttpServlet() TestMutationProvider mutationProvider = new TestMutationProvider() when: diff --git a/src/test/groovy/graphql/servlet/TestMultipartPart.groovy b/src/test/groovy/graphql/servlet/TestMultipartPart.groovy index 5eacc66f..cc9cbb6d 100644 --- a/src/test/groovy/graphql/servlet/TestMultipartPart.groovy +++ b/src/test/groovy/graphql/servlet/TestMultipartPart.groovy @@ -34,6 +34,11 @@ class TestMultipartContentBuilder { return name } + @Override + String getSubmittedFileName() { + return name + } + @Override long getSize() { return content.getBytes().length