11package graphql .kickstart .servlet ;
22
33import static java .util .Arrays .asList ;
4+ import static java .util .Collections .emptyList ;
45import static java .util .Collections .singletonList ;
56import static java .util .stream .Collectors .toList ;
67
@@ -65,6 +66,7 @@ public class GraphQLWebsocketServlet extends Endpoint {
6566 private final AtomicBoolean isShuttingDown = new AtomicBoolean (false );
6667 private final AtomicBoolean isShutDown = new AtomicBoolean (false );
6768 private final Object cacheLock = new Object ();
69+ private final List <String > allowedOrigins ;
6870
6971 public GraphQLWebsocketServlet (GraphQLConfiguration configuration ) {
7072 this (configuration , null );
@@ -77,21 +79,23 @@ public GraphQLWebsocketServlet(
7779 configuration .getGraphQLInvoker (),
7880 configuration .getInvocationInputFactory (),
7981 configuration .getObjectMapper (),
80- connectionListeners );
82+ connectionListeners ,
83+ configuration .getAllowedOrigins ());
8184 }
8285
8386 public GraphQLWebsocketServlet (
8487 GraphQLInvoker graphQLInvoker ,
8588 GraphQLSubscriptionInvocationInputFactory invocationInputFactory ,
8689 GraphQLObjectMapper graphQLObjectMapper ) {
87- this (graphQLInvoker , invocationInputFactory , graphQLObjectMapper , null );
90+ this (graphQLInvoker , invocationInputFactory , graphQLObjectMapper , null , emptyList () );
8891 }
8992
9093 public GraphQLWebsocketServlet (
9194 GraphQLInvoker graphQLInvoker ,
9295 GraphQLSubscriptionInvocationInputFactory invocationInputFactory ,
9396 GraphQLObjectMapper graphQLObjectMapper ,
94- Collection <SubscriptionConnectionListener > connectionListeners ) {
97+ Collection <SubscriptionConnectionListener > connectionListeners ,
98+ List <String > allowedOrigins ) {
9599 List <ApolloSubscriptionConnectionListener > listeners = new ArrayList <>();
96100 if (connectionListeners != null ) {
97101 connectionListeners .stream ()
@@ -114,12 +118,10 @@ public GraphQLWebsocketServlet(
114118 Stream .of (fallbackSubscriptionProtocolFactory ))
115119 .map (SubscriptionProtocolFactory ::getProtocol )
116120 .collect (toList ());
121+ this .allowedOrigins = allowedOrigins ;
117122 }
118123
119124 public GraphQLWebsocketServlet (
120- GraphQLInvoker graphQLInvoker ,
121- GraphQLSubscriptionInvocationInputFactory invocationInputFactory ,
122- GraphQLObjectMapper graphQLObjectMapper ,
123125 List <SubscriptionProtocolFactory > subscriptionProtocolFactory ,
124126 SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory ) {
125127
@@ -132,6 +134,8 @@ public GraphQLWebsocketServlet(
132134 Stream .of (fallbackSubscriptionProtocolFactory ))
133135 .map (SubscriptionProtocolFactory ::getProtocol )
134136 .collect (toList ());
137+
138+ this .allowedOrigins = emptyList ();
135139 }
136140
137141 @ Override
@@ -202,6 +206,26 @@ private void closeUnexpectedly(Session session, Throwable t) {
202206 }
203207 }
204208
209+ public boolean checkOrigin (String originHeaderValue ) {
210+ if (originHeaderValue == null || originHeaderValue .isBlank ()) {
211+ return allowedOrigins .isEmpty ();
212+ }
213+ String originToCheck = trimTrailingSlash (originHeaderValue );
214+ if (!allowedOrigins .isEmpty ()) {
215+ if (allowedOrigins .contains ("*" )) {
216+ return true ;
217+ }
218+ return allowedOrigins .stream ()
219+ .map (this ::trimTrailingSlash )
220+ .anyMatch (originToCheck ::equalsIgnoreCase );
221+ }
222+ return true ;
223+ }
224+
225+ private String trimTrailingSlash (String origin ) {
226+ return (origin .endsWith ("/" ) ? origin .substring (0 , origin .length () - 1 ) : origin );
227+ }
228+
205229 public void modifyHandshake (
206230 ServerEndpointConfig sec , HandshakeRequest request , HandshakeResponse response ) {
207231 sec .getUserProperties ().put (HANDSHAKE_REQUEST_KEY , request );
0 commit comments