4949import java .security .AccessController ;
5050import java .security .PrivilegedAction ;
5151import java .util .ArrayList ;
52+ import java .util .Collections ;
5253import java .util .HashMap ;
5354import java .util .List ;
5455import java .util .Map ;
56+ import java .util .Objects ;
5557import java .util .Optional ;
5658import java .util .function .BiConsumer ;
5759import java .util .function .Consumer ;
60+ import java .util .function .Function ;
5861import java .util .stream .Collectors ;
5962
6063/**
@@ -77,20 +80,18 @@ public abstract class GraphQLServlet extends HttpServlet implements Servlet, Gra
7780 protected abstract Instrumentation getInstrumentation ();
7881 protected abstract Map <String , Object > transformVariables (GraphQLSchema schema , String query , Map <String , Object > variables );
7982
80- private final List <GraphQLOperationListener > operationListeners ;
81- private final List <GraphQLServletListener > servletListeners ;
83+ private final List <GraphQLServletListener > listeners ;
8284 private final ServletFileUpload fileUpload ;
8385
8486 private final RequestHandler getHandler ;
8587 private final RequestHandler postHandler ;
8688
8789 public GraphQLServlet () {
88- this (null , null , null );
90+ this (null , null );
8991 }
9092
91- public GraphQLServlet (List <GraphQLOperationListener > operationListeners , List <GraphQLServletListener > servletListeners , FileItemFactory fileItemFactory ) {
92- this .operationListeners = operationListeners != null ? new ArrayList <>(operationListeners ) : new ArrayList <>();
93- this .servletListeners = servletListeners != null ? new ArrayList <>(servletListeners ) : new ArrayList <>();
93+ public GraphQLServlet (List <GraphQLServletListener > listeners , FileItemFactory fileItemFactory ) {
94+ this .listeners = listeners != null ? new ArrayList <>(listeners ) : new ArrayList <>();
9495 this .fileUpload = new ServletFileUpload (fileItemFactory != null ? fileItemFactory : new DiskFileItemFactory ());
9596
9697 this .getHandler = (request , response ) -> {
@@ -188,20 +189,12 @@ public GraphQLServlet(List<GraphQLOperationListener> operationListeners, List<Gr
188189 };
189190 }
190191
191- public void addOperationListener ( GraphQLOperationListener operationListener ) {
192- operationListeners .add (operationListener );
192+ public void addListener ( GraphQLServletListener servletListener ) {
193+ listeners .add (servletListener );
193194 }
194195
195- public void removeOperationListener (GraphQLOperationListener operationListener ) {
196- operationListeners .remove (operationListener );
197- }
198-
199- public void addServletListener (GraphQLServletListener servletListener ) {
200- servletListeners .add (servletListener );
201- }
202-
203- public void removeServletListener (GraphQLServletListener servletListener ) {
204- servletListeners .remove (servletListener );
196+ public void removeListener (GraphQLServletListener servletListener ) {
197+ listeners .remove (servletListener );
205198 }
206199
207200 @ Override
@@ -225,16 +218,18 @@ public String executeQuery(String query) {
225218 }
226219
227220 private void doRequest (HttpServletRequest request , HttpServletResponse response , RequestHandler handler ) {
221+
222+ List <GraphQLServletListener .RequestCallback > requestCallbacks = runListeners (l -> l .onRequest (request , response ));
223+
228224 try {
229- runListeners (servletListeners , l -> l .onStart (request , response ));
230225 handler .handle (request , response );
231-
226+ runCallbacks ( requestCallbacks , c -> c . onSuccess ( request , response ));
232227 } catch (Throwable t ) {
233228 response .setStatus (500 );
234229 log .error ("Error executing GraphQL request!" , t );
235- runListeners ( servletListeners , l -> l .onError (request , response , t ));
230+ runCallbacks ( requestCallbacks , c -> c .onError (request , response , t ));
236231 } finally {
237- runListeners ( servletListeners , l -> l .onFinally (request , response ));
232+ runCallbacks ( requestCallbacks , c -> c .onFinally (request , response ));
238233 }
239234 }
240235
@@ -276,7 +271,7 @@ private void query(String query, String operationName, Map<String, Object> varia
276271 return null ;
277272 });
278273 } else {
279- runListeners ( operationListeners , l -> runListener ( l , it -> it . beforeGraphQLOperation (context , operationName , query , variables ) ));
274+ List < GraphQLServletListener . OperationCallback > operationCallbacks = runListeners ( l -> l . onOperation (context , operationName , query , variables ));
280275
281276 final ExecutionResult executionResult = newGraphQL (schema ).execute (query , operationName , context , transformVariables (schema , query , variables ));
282277 final List <GraphQLError > errors = executionResult .getErrors ();
@@ -289,10 +284,12 @@ private void query(String query, String operationName, Map<String, Object> varia
289284 resp .getWriter ().write (response );
290285
291286 if (errorsPresent (errors )) {
292- runListeners ( operationListeners , l -> l . onFailedGraphQLOperation (context , operationName , query , variables , data , errors ));
287+ runCallbacks ( operationCallbacks , c -> c . onError (context , operationName , query , variables , data , errors ));
293288 } else {
294- runListeners ( operationListeners , l -> l . onSuccessfulGraphQLOperation (context , operationName , query , variables , data ));
289+ runCallbacks ( operationCallbacks , c -> c . onSuccess (context , operationName , query , variables , data ));
295290 }
291+
292+ runCallbacks (operationCallbacks , c -> c .onFinally (context , operationName , query , variables , data ));
296293 }
297294 }
298295
@@ -333,21 +330,38 @@ protected boolean isClientError(GraphQLError error) {
333330 return error instanceof InvalidSyntaxError || error instanceof ValidationError ;
334331 }
335332
336- private <T > void runListeners ( List <T > listeners , Consumer <? super T > action ) {
337- if (listeners ! = null ) {
338- listeners . forEach ( l -> runListener ( l , action ) );
333+ private <R > List <R > runListeners ( Function <? super GraphQLServletListener , R > action ) {
334+ if (listeners = = null ) {
335+ return Collections . emptyList ( );
339336 }
337+
338+ return listeners .stream ()
339+ .map (listener -> {
340+ try {
341+ return action .apply (listener );
342+ } catch (Throwable t ) {
343+ log .error ("Error running listener: {}" , listener , t );
344+ return null ;
345+ }
346+ })
347+ .filter (Objects ::nonNull )
348+ .collect (Collectors .toList ());
349+ }
350+
351+ private <T > void runCallbacks (List <T > callbacks , Consumer <T > action ) {
352+ callbacks .forEach (callback -> {
353+ try {
354+ action .accept (callback );
355+ } catch (Throwable t ) {
356+ log .error ("Error running callback: {}" , callback , t );
357+ }
358+ });
340359 }
341360
342361 /**
343362 * Don't let listener errors escape to the client.
344363 */
345- private <T > void runListener (T listener , Consumer <? super T > action ) {
346- try {
347- action .accept (listener );
348- } catch (Throwable t ) {
349- log .error ("Error running listener: {}" , listener .getClass ().getName (), t );
350- }
364+ private <T , R > void runListener (T listener , Function <? super T , ? super R > action ) {
351365 }
352366
353367 protected static class VariablesDeserializer extends JsonDeserializer <Map <String , Object >> {
0 commit comments