Skip to content

Commit

Permalink
Replaced Commons FileUpload with Servlet API 3.0 multipart support (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
nrouge authored and apottere committed Jun 19, 2018
1 parent cacdfb2 commit 387ac86
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 98 deletions.
2 changes: 0 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ dependencies {

// Servlet
compile 'javax.servlet:javax.servlet-api:3.0.1'
// Multipart support
compile 'commons-fileupload:commons-fileupload:1.3.1'

// GraphQL
compile 'com.graphql-java:graphql-java:8.0'
Expand Down
9 changes: 4 additions & 5 deletions src/main/java/graphql/servlet/GraphQLContext.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package graphql.servlet;

import org.apache.commons.fileupload.FileItem;

import javax.security.auth.Subject;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.Part;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -14,7 +13,7 @@ public class GraphQLContext {
private Optional<HttpServletResponse> response;

private Optional<Subject> subject = Optional.empty();
private Optional<Map<String, List<FileItem>>> files = Optional.empty();
private Optional<Map<String, List<Part>>> files = Optional.empty();

public GraphQLContext(Optional<HttpServletRequest> request, Optional<HttpServletResponse> response) {
this.request = request;
Expand Down Expand Up @@ -45,11 +44,11 @@ public void setSubject(Optional<Subject> subject) {
this.subject = subject;
}

public Optional<Map<String, List<FileItem>>> getFiles() {
public Optional<Map<String, List<Part>>> getFiles() {
return files;
}

public void setFiles(Optional<Map<String, List<FileItem>>> files) {
public void setFiles(Optional<Map<String, List<Part>>> files) {
this.files = files;
}
}
47 changes: 23 additions & 24 deletions src/main/java/graphql/servlet/GraphQLServlet.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectReader;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.google.common.io.ByteStreams;
import graphql.ExecutionInput;
import graphql.ExecutionResult;
import graphql.GraphQL;
Expand All @@ -18,10 +19,6 @@
import graphql.introspection.IntrospectionQuery;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLSchema;
import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileItemFactory;
import org.apache.commons.fileupload.disk.DiskFileItemFactory;
import org.apache.commons.fileupload.servlet.ServletFileUpload;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -31,6 +28,7 @@
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.Part;
import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
Expand All @@ -39,6 +37,7 @@
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
Expand All @@ -51,6 +50,7 @@
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* @author Andrew Potter
Expand All @@ -74,19 +74,17 @@ public abstract class GraphQLServlet extends HttpServlet implements Servlet, Gra

private final LazyObjectMapperBuilder lazyObjectMapperBuilder;
private final List<GraphQLServletListener> listeners;
private final ServletFileUpload fileUpload;

private final HttpRequestHandler getHandler;
private final HttpRequestHandler postHandler;

public GraphQLServlet() {
this(null, null, null);
this(null, null);
}

public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List<GraphQLServletListener> listeners, FileItemFactory fileItemFactory) {
public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List<GraphQLServletListener> listeners) {
this.lazyObjectMapperBuilder = new LazyObjectMapperBuilder(objectMapperConfigurer != null ? objectMapperConfigurer : new DefaultObjectMapperConfigurer());
this.listeners = listeners != null ? new ArrayList<>(listeners) : new ArrayList<>();
this.fileUpload = new ServletFileUpload(fileItemFactory != null ? fileItemFactory : new DiskFileItemFactory());

this.getHandler = (request, response) -> {
final GraphQLContext context = createContext(Optional.of(request), Optional.of(response));
Expand Down Expand Up @@ -128,12 +126,17 @@ public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List<GraphQ
final Object rootObject = createRootObject(Optional.of(request), Optional.of(response));

try {
if (ServletFileUpload.isMultipartContent(request)) {
final Map<String, List<FileItem>> fileItems = fileUpload.parseParameterMap(request);
Collection<Part> parts = request.getParts();
if (!parts.isEmpty()) {
final Map<String, List<Part>> fileItems = parts.stream()
.collect(Collectors.toMap(
Part::getName,
Collections::singletonList,
(l1, l2) -> Stream.concat(l1.stream(), l2.stream()).collect(Collectors.toList())));
context.setFiles(Optional.of(fileItems));

if (fileItems.containsKey("graphql")) {
final Optional<FileItem> graphqlItem = getFileItem(fileItems, "graphql");
final Optional<Part> graphqlItem = getFileItem(fileItems, "graphql");
if (graphqlItem.isPresent()) {
InputStream inputStream = graphqlItem.get().getInputStream();

Expand All @@ -150,7 +153,7 @@ public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List<GraphQ
}
}
} else if (fileItems.containsKey("query")) {
final Optional<FileItem> queryItem = getFileItem(fileItems, "query");
final Optional<Part> queryItem = getFileItem(fileItems, "query");
if (queryItem.isPresent()) {
InputStream inputStream = queryItem.get().getInputStream();

Expand All @@ -162,18 +165,19 @@ public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List<GraphQ
doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response);
return;
} else {
String query = new String(queryItem.get().get());

String query = new String(ByteStreams.toByteArray(inputStream));

Map<String, Object> variables = null;
final Optional<FileItem> variablesItem = getFileItem(fileItems, "variables");
final Optional<Part> variablesItem = getFileItem(fileItems, "variables");
if (variablesItem.isPresent()) {
variables = deserializeVariables(new String(variablesItem.get().get()));
variables = deserializeVariables(new String(ByteStreams.toByteArray(variablesItem.get().getInputStream())));
}

String operationName = null;
final Optional<FileItem> operationNameItem = getFileItem(fileItems, "operationName");
final Optional<Part> operationNameItem = getFileItem(fileItems, "operationName");
if (operationNameItem.isPresent()) {
operationName = new String(operationNameItem.get().get()).trim();
operationName = new String(ByteStreams.toByteArray(operationNameItem.get().getInputStream())).trim();
}

doQuery(query, operationName, variables, getSchemaProvider().getSchema(request), context, rootObject, request, response);
Expand Down Expand Up @@ -274,13 +278,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws S
doRequest(req, resp, postHandler);
}

private Optional<FileItem> getFileItem(Map<String, List<FileItem>> fileItems, String name) {
List<FileItem> items = fileItems.get(name);
if(items == null || items.isEmpty()) {
return Optional.empty();
}

return items.stream().findFirst();
private Optional<Part> getFileItem(Map<String, List<Part>> fileItems, String name) {
return Optional.ofNullable(fileItems.get(name)).filter(list -> !list.isEmpty()).map(list -> list.get(0));
}

private GraphQL newGraphQL(GraphQLSchema schema) {
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/graphql/servlet/SimpleGraphQLServlet.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public SimpleGraphQLServlet(final GraphQLSchema schema, ExecutionStrategyProvide
*/
@Deprecated
public SimpleGraphQLServlet(GraphQLSchemaProvider schemaProvider, ExecutionStrategyProvider executionStrategyProvider, ObjectMapperConfigurer objectMapperConfigurer, List<GraphQLServletListener> listeners, Instrumentation instrumentation, GraphQLErrorHandler errorHandler, GraphQLContextBuilder contextBuilder, GraphQLRootObjectBuilder rootObjectBuilder, PreparsedDocumentProvider preparsedDocumentProvider) {
super(objectMapperConfigurer, listeners, null);
super(objectMapperConfigurer, listeners);

this.schemaProvider = schemaProvider;
this.executionStrategyProvider = executionStrategyProvider;
Expand Down Expand Up @@ -91,7 +91,7 @@ public SimpleGraphQLServlet(GraphQLSchemaProvider schemaProvider, ExecutionStrat
}

protected SimpleGraphQLServlet(Builder builder) {
super(builder.objectMapperConfigurer, builder.listeners, null);
super(builder.objectMapperConfigurer, builder.listeners);

this.schemaProvider = builder.schemaProvider;
this.executionStrategyProvider = builder.executionStrategyProvider;
Expand Down
61 changes: 18 additions & 43 deletions src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,7 @@ class GraphQLServletSpec extends Specification {
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")

request.setContent(new TestMultipartContentBuilder()
.addPart('graphql', mapper.writeValueAsString([query: 'query { echo(arg:"test") }']))
.build())
request.addPart(TestMultipartContentBuilder.createPart('graphql', mapper.writeValueAsString([query: 'query { echo(arg:"test") }'])))

when:
servlet.doPost(request, response)
Expand All @@ -404,9 +402,7 @@ class GraphQLServletSpec extends Specification {
setup:
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")
request.setContent(new TestMultipartContentBuilder()
.addPart('query', 'query { echo(arg:"test") }')
.build())
request.addPart(TestMultipartContentBuilder.createPart('query', 'query { echo(arg:"test") }'))

when:
servlet.doPost(request, response)
Expand All @@ -421,10 +417,8 @@ class GraphQLServletSpec extends Specification {
setup:
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")
request.setContent(new TestMultipartContentBuilder()
.addPart('query', 'query one{ echoOne: echo(arg:"test-one") } query two{ echoTwo: echo(arg:"test-two") }')
.addPart('operationName', 'two')
.build())
request.addPart(TestMultipartContentBuilder.createPart('query', 'query one{ echoOne: echo(arg:"test-one") } query two{ echoTwo: echo(arg:"test-two") }'))
request.addPart(TestMultipartContentBuilder.createPart('operationName', 'two'))

when:
servlet.doPost(request, response)
Expand All @@ -440,10 +434,8 @@ class GraphQLServletSpec extends Specification {
setup:
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")
request.setContent(new TestMultipartContentBuilder()
.addPart('query', 'query echo{ echo: echo(arg:"test") }')
.addPart('operationName', '')
.build())
request.addPart(TestMultipartContentBuilder.createPart('query', 'query echo{ echo: echo(arg:"test") }'))
request.addPart(TestMultipartContentBuilder.createPart('operationName', ''))

when:
servlet.doPost(request, response)
Expand All @@ -458,10 +450,8 @@ class GraphQLServletSpec extends Specification {
setup:
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")
request.setContent(new TestMultipartContentBuilder()
.addPart('query', 'query Echo($arg: String) { echo(arg:$arg) }')
.addPart('variables', '{"arg": "test"}')
.build())
request.addPart(TestMultipartContentBuilder.createPart('query', 'query Echo($arg: String) { echo(arg:$arg) }'))
request.addPart(TestMultipartContentBuilder.createPart('variables', '{"arg": "test"}'))

when:
servlet.doPost(request, response)
Expand All @@ -476,10 +466,8 @@ class GraphQLServletSpec extends Specification {
setup:
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")
request.setContent(new TestMultipartContentBuilder()
.addPart('query', 'query { echo(arg:"test") }')
.addPart('test', 'test')
.build())
request.addPart(TestMultipartContentBuilder.createPart('query', 'query { echo(arg:"test") }'))
request.addPart(TestMultipartContentBuilder.createPart('test', 'test'))

when:
servlet.doPost(request, response)
Expand Down Expand Up @@ -567,9 +555,7 @@ class GraphQLServletSpec extends Specification {
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")

request.setContent(new TestMultipartContentBuilder()
.addPart('graphql', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]')
.build())
request.addPart(TestMultipartContentBuilder.createPart('graphql', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]'))

when:
servlet.doPost(request, response)
Expand All @@ -586,9 +572,7 @@ class GraphQLServletSpec extends Specification {
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")

request.setContent(new TestMultipartContentBuilder()
.addPart('graphql', '[{ "query": "query { echo(arg:\\"test\\") }", "test": "test" }, { "query": "query { echo(arg:\\"test\\") }", "test": "test" }]')
.build())
request.addPart(TestMultipartContentBuilder.createPart('graphql', '[{ "query": "query { echo(arg:\\"test\\") }", "test": "test" }, { "query": "query { echo(arg:\\"test\\") }", "test": "test" }]'))

when:
servlet.doPost(request, response)
Expand All @@ -604,9 +588,7 @@ class GraphQLServletSpec extends Specification {
setup:
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")
request.setContent(new TestMultipartContentBuilder()
.addPart('query', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]')
.build())
request.addPart(TestMultipartContentBuilder.createPart('query', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]'))

when:
servlet.doPost(request, response)
Expand All @@ -622,9 +604,7 @@ class GraphQLServletSpec extends Specification {
setup:
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")
request.setContent(new TestMultipartContentBuilder()
.addPart('query', '[{ "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "one" }, { "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "two" }]')
.build())
request.addPart(TestMultipartContentBuilder.createPart('query', '[{ "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "one" }, { "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "two" }]'))

when:
servlet.doPost(request, response)
Expand All @@ -642,9 +622,7 @@ class GraphQLServletSpec extends Specification {
setup:
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")
request.setContent(new TestMultipartContentBuilder()
.addPart('query', '[{ "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }, { "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }]')
.build())
request.addPart(TestMultipartContentBuilder.createPart('query', '[{ "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }, { "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }]'))

when:
servlet.doPost(request, response)
Expand All @@ -660,9 +638,7 @@ class GraphQLServletSpec extends Specification {
setup:
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")
request.setContent(new TestMultipartContentBuilder()
.addPart('query', '[{ "query": "query echo($arg: String) { echo(arg:$arg) }", "variables": { "arg": "test" } }, { "query": "query echo($arg: String) { echo(arg:$arg) }", "variables": { "arg": "test" } }]')
.build())
request.addPart(TestMultipartContentBuilder.createPart('query', '[{ "query": "query echo($arg: String) { echo(arg:$arg) }", "variables": { "arg": "test" } }, { "query": "query echo($arg: String) { echo(arg:$arg) }", "variables": { "arg": "test" } }]'))

when:
servlet.doPost(request, response)
Expand All @@ -678,9 +654,7 @@ class GraphQLServletSpec extends Specification {
setup:
request.setContentType("multipart/form-data, boundary=test")
request.setMethod("POST")
request.setContent(new TestMultipartContentBuilder()
.addPart('query', '[{ "query": "query { echo(arg:\\"test\\") }", "test": "test" }, { "query": "query { echo(arg:\\"test\\") }", "test": "test" }]')
.build())
request.addPart(TestMultipartContentBuilder.createPart('query', '[{ "query": "query { echo(arg:\\"test\\") }", "test": "test" }, { "query": "query { echo(arg:\\"test\\") }", "test": "test" }]'))

when:
servlet.doPost(request, response)
Expand Down Expand Up @@ -853,6 +827,7 @@ class GraphQLServletSpec extends Specification {
mockInputStream.markSupported() >> true
mockRequest.getInputStream() >> mockInputStream
mockRequest.getMethod() >> "POST"
mockRequest.getParts() >> Collections.emptyList()

when:
servlet.doPost(mockRequest, response)
Expand Down
Loading

0 comments on commit 387ac86

Please sign in to comment.