Skip to content

Commit 75f2e38

Browse files
committed
Rewrite multipart support for servlet <3.0 (fixes #18)
1 parent 877a061 commit 75f2e38

File tree

5 files changed

+179
-125
lines changed

5 files changed

+179
-125
lines changed

src/main/java/graphql/servlet/GraphQLContext.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import lombok.NonNull;
1919
import lombok.RequiredArgsConstructor;
2020
import lombok.Setter;
21+
import org.apache.commons.fileupload.FileItem;
2122
import org.apache.commons.fileupload.FileItemIterator;
2223
import org.apache.commons.fileupload.FileItemStream;
2324
import org.apache.commons.fileupload.servlet.ServletFileUpload;
@@ -42,5 +43,5 @@ public class GraphQLContext {
4243
private Optional<Subject> subject = Optional.empty();
4344

4445
@Getter @Setter
45-
private Optional<Collection<Part>> parts = Optional.empty();
46+
private Optional<Map<String, List<FileItem>>> files = Optional.empty();
4647
}

src/main/java/graphql/servlet/GraphQLServlet.java

Lines changed: 114 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
import lombok.Setter;
3535
import lombok.SneakyThrows;
3636
import lombok.extern.slf4j.Slf4j;
37+
import org.apache.commons.fileupload.FileItem;
38+
import org.apache.commons.fileupload.FileItemFactory;
39+
import org.apache.commons.fileupload.disk.DiskFileItemFactory;
3740
import org.apache.commons.fileupload.servlet.ServletFileUpload;
3841

3942
import javax.security.auth.Subject;
@@ -42,7 +45,6 @@
4245
import javax.servlet.http.HttpServlet;
4346
import javax.servlet.http.HttpServletRequest;
4447
import javax.servlet.http.HttpServletResponse;
45-
import javax.servlet.http.Part;
4648
import java.io.IOException;
4749
import java.io.InputStream;
4850
import java.io.InputStreamReader;
@@ -75,14 +77,113 @@ public abstract class GraphQLServlet extends HttpServlet implements Servlet, Gra
7577

7678
private final List<GraphQLOperationListener> operationListeners;
7779
private final List<GraphQLServletListener> servletListeners;
80+
private final ServletFileUpload fileUpload;
81+
82+
private final RequestHandler getHandler;
83+
private final RequestHandler postHandler;
7884

7985
public GraphQLServlet() {
80-
this(null, null);
86+
this(null, null, null);
8187
}
8288

83-
public GraphQLServlet(List<GraphQLOperationListener> operationListeners, List<GraphQLServletListener> servletListeners) {
89+
public GraphQLServlet(List<GraphQLOperationListener> operationListeners, List<GraphQLServletListener> servletListeners, FileItemFactory fileItemFactory) {
8490
this.operationListeners = operationListeners != null ? new ArrayList<>(operationListeners) : new ArrayList<>();
8591
this.servletListeners = servletListeners != null ? new ArrayList<>(servletListeners) : new ArrayList<>();
92+
this.fileUpload = new ServletFileUpload(fileItemFactory != null ? fileItemFactory : new DiskFileItemFactory());
93+
94+
this.getHandler = (request, response) -> {
95+
GraphQLContext context = createContext(Optional.of(request), Optional.of(response));
96+
String path = request.getPathInfo();
97+
if (path == null) {
98+
path = request.getServletPath();
99+
}
100+
if (path.contentEquals("/schema.json")) {
101+
query(CharStreams.toString(new InputStreamReader(GraphQLServlet.class.getResourceAsStream("introspectionQuery"))), null, new HashMap<>(), getSchema(), request, response, context);
102+
} else {
103+
if (request.getParameter("query") != null) {
104+
Map<String, Object> variables = new HashMap<>();
105+
if (request.getParameter("variables") != null) {
106+
variables.putAll(mapper.readValue(request.getParameter("variables"), new TypeReference<Map<String, Object>>() { }));
107+
}
108+
String operationName = null;
109+
if (request.getParameter("operationName") != null) {
110+
operationName = request.getParameter("operationName");
111+
}
112+
query(request.getParameter("query"), operationName, variables, getReadOnlySchema(), request, response, context);
113+
} else {
114+
response.setStatus(STATUS_BAD_REQUEST);
115+
log.info("Bad GET request: path was not \"/schema.json\" or no query variable named \"query\" given");
116+
}
117+
}
118+
};
119+
120+
this.postHandler = (request, response) -> {
121+
GraphQLContext context = createContext(Optional.of(request), Optional.of(response));
122+
GraphQLRequest graphQLRequest = null;
123+
124+
try {
125+
InputStream inputStream = null;
126+
127+
if (ServletFileUpload.isMultipartContent(request)) {
128+
Map<String, List<FileItem>> fileItems = fileUpload.parseParameterMap(request);
129+
130+
if (fileItems.containsKey("graphql")) {
131+
Optional<FileItem> graphqlItem = getFileItem(fileItems, "graphql");
132+
if(graphqlItem.isPresent()) {
133+
inputStream = graphqlItem.get().getInputStream();
134+
}
135+
136+
} else if(fileItems.containsKey("query")) {
137+
Optional<FileItem> queryItem = getFileItem(fileItems, "query");
138+
if(queryItem.isPresent()) {
139+
graphQLRequest = new GraphQLRequest();
140+
graphQLRequest.setQuery(new String(queryItem.get().get()));
141+
142+
Optional<FileItem> operationNameItem = getFileItem(fileItems, "operationName");
143+
if(operationNameItem.isPresent()) {
144+
graphQLRequest.setOperationName(new String(operationNameItem.get().get()).trim());
145+
}
146+
147+
Optional<FileItem> variablesItem = getFileItem(fileItems, "variables");
148+
if(variablesItem.isPresent()) {
149+
String variables = new String(variablesItem.get().get());
150+
if(!variables.isEmpty()) {
151+
graphQLRequest.setVariables((Map<String, Object>) mapper.readValue(variables, Map.class));
152+
}
153+
}
154+
}
155+
}
156+
157+
if(inputStream == null && graphQLRequest == null) {
158+
response.setStatus(STATUS_BAD_REQUEST);
159+
log.info("Bad POST multipart request: no part named \"graphql\" or \"query\"");
160+
return;
161+
}
162+
163+
context.setFiles(Optional.of(fileItems));
164+
165+
} else {
166+
// this is not a multipart request
167+
inputStream = request.getInputStream();
168+
}
169+
170+
if(graphQLRequest == null) {
171+
graphQLRequest = mapper.readValue(inputStream, GraphQLRequest.class);
172+
}
173+
174+
} catch (Exception e) {
175+
log.info("Bad POST request: parsing failed", e);
176+
response.setStatus(STATUS_BAD_REQUEST);
177+
return;
178+
}
179+
180+
Map<String,Object> variables = graphQLRequest.getVariables();
181+
if (variables == null) {
182+
variables = new HashMap<>();
183+
}
184+
185+
query(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), variables, getSchema(), request, response, context);
186+
};
86187
}
87188

88189
public void addOperationListener(GraphQLOperationListener operationListener) {
@@ -121,71 +222,6 @@ public String executeQuery(String query) {
121222
}
122223
}
123224

124-
private final RequestHandler getHandler = (request, response) -> {
125-
GraphQLContext context = createContext(Optional.of(request), Optional.of(response));
126-
String path = request.getPathInfo();
127-
if (path == null) {
128-
path = request.getServletPath();
129-
}
130-
if (path.contentEquals("/schema.json")) {
131-
query(CharStreams.toString(new InputStreamReader(GraphQLServlet.class.getResourceAsStream("introspectionQuery"))), null, new HashMap<>(), getSchema(), request, response, context);
132-
} else {
133-
if (request.getParameter("query") != null) {
134-
Map<String, Object> variables = new HashMap<>();
135-
if (request.getParameter("variables") != null) {
136-
variables.putAll(mapper.readValue(request.getParameter("variables"), new TypeReference<Map<String, Object>>() { }));
137-
}
138-
String operationName = null;
139-
if (request.getParameter("operationName") != null) {
140-
operationName = request.getParameter("operationName");
141-
}
142-
query(request.getParameter("query"), operationName, variables, getReadOnlySchema(), request, response, context);
143-
} else {
144-
response.setStatus(STATUS_BAD_REQUEST);
145-
log.info("Bad GET request: path was not \"/schema.json\" or no query variable named \"query\" given");
146-
}
147-
}
148-
};
149-
150-
private final RequestHandler postHandler = (request, response) -> {
151-
GraphQLContext context = createContext(Optional.of(request), Optional.of(response));
152-
InputStream inputStream = null;
153-
154-
if (ServletFileUpload.isMultipartContent(request)) {
155-
Part part = request.getPart("graphql");
156-
if(part != null) {
157-
inputStream = part.getInputStream();
158-
}
159-
160-
if (inputStream == null) {
161-
response.setStatus(STATUS_BAD_REQUEST);
162-
log.info("Bad POST multipart request: no part named \"graphql\"");
163-
return;
164-
}
165-
166-
context.setParts(Optional.of(request.getParts()));
167-
168-
} else {
169-
// this is not a multipart request
170-
inputStream = request.getInputStream();
171-
}
172-
173-
GraphQLRequest graphQLRequest;
174-
try {
175-
graphQLRequest = mapper.readValue(inputStream, GraphQLRequest.class);
176-
} catch (Exception e) {
177-
log.info("Bad POST request: deserialization failed", e);
178-
response.setStatus(STATUS_BAD_REQUEST);
179-
return;
180-
}
181-
182-
Map<String,Object> variables = graphQLRequest.variables;
183-
if (variables == null) {
184-
variables = new HashMap<>();
185-
}
186-
query(graphQLRequest.query, graphQLRequest.operationName, variables, getSchema(), request, response, context);
187-
};
188-
189225
private void doRequest(HttpServletRequest request, HttpServletResponse response, RequestHandler handler) {
190226
try {
191227
runListeners(servletListeners, l -> l.onStart(request, response));
@@ -210,6 +246,16 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws S
210246
doRequest(req, resp, postHandler);
211247
}
212248

249+
private Optional<FileItem> getFileItem(Map<String, List<FileItem>> fileItems, String name) {
250+
List<FileItem> items = fileItems.get(name);
251+
if(items == null || items.isEmpty()) {
252+
return Optional.empty();
253+
}
254+
255+
return items.stream().findFirst();
256+
}
257+
258+
213259
private void query(String query, String operationName, Map<String, Object> variables, GraphQLSchema schema, HttpServletRequest req, HttpServletResponse resp, GraphQLContext context) throws IOException {
214260
if (Subject.getSubject(AccessController.getContext()) == null && context.getSubject().isPresent()) {
215261
Subject.doAs(context.getSubject().get(), new PrivilegedAction<Void>() {

src/main/java/graphql/servlet/SimpleGraphQLServlet.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public SimpleGraphQLServlet(GraphQLSchema schema, ExecutionStrategy executionStr
4242
}
4343

4444
public SimpleGraphQLServlet(GraphQLSchema schema, ExecutionStrategy executionStrategy, List<GraphQLOperationListener> operationListeners, List<GraphQLServletListener> servletListeners) {
45-
super(operationListeners, servletListeners);
45+
super(operationListeners, servletListeners, null);
4646

4747
this.schema = schema;
4848
this.readOnlySchema = new GraphQLSchema(schema.getQueryType(), EMPTY_MUTATION_TYPE, schema.getDictionary());

src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import graphql.schema.GraphQLObjectType
2222
import graphql.schema.GraphQLSchema
2323
import org.springframework.mock.web.MockHttpServletRequest
2424
import org.springframework.mock.web.MockHttpServletResponse
25+
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder
26+
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders
2527
import spock.lang.Shared
2628
import spock.lang.Specification
2729

@@ -213,13 +215,14 @@ class GraphQLServletSpec extends Specification {
213215
getResponseContent().data.echoTwo == "test-two"
214216
}
215217

216-
def "query over HTTP POST multipart returns data"() {
218+
def "query over HTTP POST multipart named 'graphql' returns data"() {
217219
setup:
218-
request.setContentType("multipart/graphql, boundary=Test")
220+
request.setContentType("multipart/form-data, boundary=test")
219221
request.setMethod("POST")
220-
request.addPart(new TestMultipartPart(name: 'graphql', content: mapper.writeValueAsString([
221-
query: 'query { echo(arg:"test") }'
222-
])))
222+
223+
request.setContent(new TestMultipartContentBuilder()
224+
.addPart('graphql', mapper.writeValueAsString([query: 'query { echo(arg:"test") }']))
225+
.build())
223226

224227
when:
225228
servlet.doPost(request, response)
@@ -230,14 +233,13 @@ class GraphQLServletSpec extends Specification {
230233
getResponseContent().data.echo == "test"
231234
}
232235

233-
def "query over HTTP POST multipart with variables returns data"() {
236+
def "query over HTTP POST multipart named 'query' returns data"() {
234237
setup:
235-
request.setContentType("multipart/graphql")
238+
request.setContentType("multipart/form-data, boundary=test")
236239
request.setMethod("POST")
237-
request.addPart(new TestMultipartPart(name: 'graphql', content: mapper.writeValueAsString([
238-
query: 'query Echo($arg: String) { echo(arg:$arg) }',
239-
variables: '{"arg": "test"}'
240-
])))
240+
request.setContent(new TestMultipartContentBuilder()
241+
.addPart('query', 'query { echo(arg:"test") }')
242+
.build())
241243

242244
when:
243245
servlet.doPost(request, response)
@@ -248,14 +250,14 @@ class GraphQLServletSpec extends Specification {
248250
getResponseContent().data.echo == "test"
249251
}
250252

251-
def "query over HTTP POST multipart with operationName returns data"() {
253+
def "query over HTTP POST multipart named 'query' with operationName returns data"() {
252254
setup:
253-
request.setContentType("multipart/graphql")
255+
request.setContentType("multipart/form-data, boundary=test")
254256
request.setMethod("POST")
255-
request.addPart(new TestMultipartPart(name: 'graphql', content: mapper.writeValueAsString([
256-
query: 'query one{ echoOne: echo(arg:"test-one") } query two{ echoTwo: echo(arg:"test-two") }',
257-
operationName: 'two'
258-
])))
257+
request.setContent(new TestMultipartContentBuilder()
258+
.addPart('query', 'query one{ echoOne: echo(arg:"test-one") } query two{ echoTwo: echo(arg:"test-two") }')
259+
.addPart('operationName', 'two')
260+
.build())
259261

260262
when:
261263
servlet.doPost(request, response)
@@ -267,6 +269,24 @@ class GraphQLServletSpec extends Specification {
267269
getResponseContent().data.echoTwo == "test-two"
268270
}
269271

272+
def "query over HTTP POST multipart named 'query' with variables returns data"() {
273+
setup:
274+
request.setContentType("multipart/form-data, boundary=test")
275+
request.setMethod("POST")
276+
request.setContent(new TestMultipartContentBuilder()
277+
.addPart('query', 'query Echo($arg: String) { echo(arg:$arg) }')
278+
.addPart('variables', '{"arg": "test"}')
279+
.build())
280+
281+
when:
282+
servlet.doPost(request, response)
283+
284+
then:
285+
response.getStatus() == STATUS_OK
286+
response.getContentType() == CONTENT_TYPE_JSON_UTF8
287+
getResponseContent().data.echo == "test"
288+
}
289+
270290
def "mutation over HTTP POST body returns data"() {
271291
setup:
272292
request.setContent(mapper.writeValueAsBytes([
@@ -332,4 +352,8 @@ class GraphQLServletSpec extends Specification {
332352
resp.data == null
333353
resp.errors != null
334354
}
355+
356+
private byte[] createContent(String data) {
357+
data.split('\\n').collect { it.replaceAll('^\\s+', '') }.join('\n').getBytes()
358+
}
335359
}

0 commit comments

Comments
 (0)