11package graphql .servlet ;
22
3- import graphql .servlet .internal .ApolloSubscriptionProtocolFactory ;
4- import graphql .servlet .internal .FallbackSubscriptionProtocolFactory ;
5- import graphql .servlet .internal .SubscriptionHandlerInput ;
6- import graphql .servlet .internal .SubscriptionProtocolFactory ;
7- import graphql .servlet .internal .SubscriptionProtocolHandler ;
8- import graphql .servlet .internal .WsSessionSubscriptions ;
3+ import graphql .servlet .internal .*;
94import org .slf4j .Logger ;
105import org .slf4j .LoggerFactory ;
116
12- import javax .websocket .CloseReason ;
13- import javax .websocket .Endpoint ;
14- import javax .websocket .EndpointConfig ;
15- import javax .websocket .HandshakeResponse ;
16- import javax .websocket .MessageHandler ;
17- import javax .websocket .Session ;
7+ import javax .websocket .*;
188import javax .websocket .server .HandshakeRequest ;
199import javax .websocket .server .ServerEndpointConfig ;
2010import java .io .IOException ;
@@ -44,8 +34,8 @@ public class GraphQLWebsocketServlet extends Endpoint {
4434
4535 static {
4636 allSubscriptionProtocols = Stream .concat (subscriptionProtocolFactories .stream (), Stream .of (fallbackSubscriptionProtocolFactory ))
47- .map (SubscriptionProtocolFactory ::getProtocol )
48- .collect (Collectors .toList ());
37+ .map (SubscriptionProtocolFactory ::getProtocol )
38+ .collect (Collectors .toList ());
4939 }
5040
5141 private final Map <Session , WsSessionSubscriptions > sessionSubscriptionCache = new HashMap <>();
@@ -57,7 +47,7 @@ public GraphQLWebsocketServlet(GraphQLQueryInvoker queryInvoker, GraphQLInvocati
5747
5848 @ Override
5949 public void onOpen (Session session , EndpointConfig endpointConfig ) {
60-
50+ log . debug ( "Session opened: {}, {}" , session . getId (), endpointConfig );
6151 final WsSessionSubscriptions subscriptions = new WsSessionSubscriptions ();
6252 final HandshakeRequest request = (HandshakeRequest ) session .getUserProperties ().get (HANDSHAKE_REQUEST_KEY );
6353 final SubscriptionProtocolHandler subscriptionProtocolHandler = (SubscriptionProtocolHandler ) session .getUserProperties ().get (PROTOCOL_HANDLER_REQUEST_KEY );
@@ -82,7 +72,7 @@ public void onMessage(String text) {
8272 public void onClose (Session session , CloseReason closeReason ) {
8373 log .debug ("Session closed: {}, {}" , session .getId (), closeReason );
8474 WsSessionSubscriptions subscriptions = sessionSubscriptionCache .remove (session );
85- if (subscriptions != null ) {
75+ if (subscriptions != null ) {
8676 subscriptions .close ();
8777 }
8878 }
@@ -105,23 +95,25 @@ public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request,
10595 sec .getUserProperties ().put (HANDSHAKE_REQUEST_KEY , request );
10696
10797 List <String > protocol = request .getHeaders ().get (HandshakeRequest .SEC_WEBSOCKET_PROTOCOL );
108- if (protocol == null ) {
98+ if (protocol == null ) {
10999 protocol = Collections .emptyList ();
110100 }
111101
112102 SubscriptionProtocolFactory subscriptionProtocolFactory = getSubscriptionProtocolFactory (protocol );
113103 sec .getUserProperties ().put (PROTOCOL_HANDLER_REQUEST_KEY , subscriptionProtocolFactory .createHandler (subscriptionHandlerInput ));
114104
115- if (request .getHeaders ().get (HandshakeResponse .SEC_WEBSOCKET_ACCEPT ) != null ) {
105+ if (request .getHeaders ().get (HandshakeResponse .SEC_WEBSOCKET_ACCEPT ) != null ) {
116106 response .getHeaders ().put (HandshakeResponse .SEC_WEBSOCKET_ACCEPT , allSubscriptionProtocols );
117107 }
118- response .getHeaders ().put (HandshakeRequest .SEC_WEBSOCKET_PROTOCOL , Collections .singletonList (subscriptionProtocolFactory .getProtocol ()));
108+ if (!protocol .isEmpty ()) {
109+ response .getHeaders ().put (HandshakeRequest .SEC_WEBSOCKET_PROTOCOL , Collections .singletonList (subscriptionProtocolFactory .getProtocol ()));
110+ }
119111 }
120112
121113 private static SubscriptionProtocolFactory getSubscriptionProtocolFactory (List <String > accept ) {
122- for (String protocol : accept ) {
123- for (SubscriptionProtocolFactory subscriptionProtocolFactory : subscriptionProtocolFactories ) {
124- if (subscriptionProtocolFactory .getProtocol ().equals (protocol )) {
114+ for (String protocol : accept ) {
115+ for (SubscriptionProtocolFactory subscriptionProtocolFactory : subscriptionProtocolFactories ) {
116+ if (subscriptionProtocolFactory .getProtocol ().equals (protocol )) {
125117 return subscriptionProtocolFactory ;
126118 }
127119 }
0 commit comments