Skip to content

Commit 843fa42

Browse files
committed
Added callbacks for remaining message types
1 parent bb9d1d1 commit 843fa42

File tree

6 files changed

+71
-36
lines changed

6 files changed

+71
-36
lines changed

src/main/java/graphql/servlet/GraphQLWebsocketServlet.java

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,10 @@
11
package graphql.servlet;
22

3-
import graphql.servlet.internal.ApolloSubscriptionProtocolFactory;
4-
import graphql.servlet.internal.FallbackSubscriptionProtocolFactory;
5-
import graphql.servlet.internal.SubscriptionHandlerInput;
6-
import graphql.servlet.internal.SubscriptionProtocolFactory;
7-
import graphql.servlet.internal.SubscriptionProtocolHandler;
8-
import graphql.servlet.internal.WsSessionSubscriptions;
3+
import graphql.servlet.internal.*;
94
import org.slf4j.Logger;
105
import org.slf4j.LoggerFactory;
116

12-
import javax.websocket.CloseReason;
13-
import javax.websocket.Endpoint;
14-
import javax.websocket.EndpointConfig;
15-
import javax.websocket.HandshakeResponse;
16-
import javax.websocket.MessageHandler;
17-
import javax.websocket.Session;
7+
import javax.websocket.*;
188
import javax.websocket.server.HandshakeRequest;
199
import javax.websocket.server.ServerEndpointConfig;
2010
import java.io.IOException;
@@ -44,8 +34,8 @@ public class GraphQLWebsocketServlet extends Endpoint {
4434

4535
static {
4636
allSubscriptionProtocols = Stream.concat(subscriptionProtocolFactories.stream(), Stream.of(fallbackSubscriptionProtocolFactory))
47-
.map(SubscriptionProtocolFactory::getProtocol)
48-
.collect(Collectors.toList());
37+
.map(SubscriptionProtocolFactory::getProtocol)
38+
.collect(Collectors.toList());
4939
}
5040

5141
private final Map<Session, WsSessionSubscriptions> sessionSubscriptionCache = new HashMap<>();
@@ -57,7 +47,7 @@ public GraphQLWebsocketServlet(GraphQLQueryInvoker queryInvoker, GraphQLInvocati
5747

5848
@Override
5949
public void onOpen(Session session, EndpointConfig endpointConfig) {
60-
50+
log.debug("Session opened: {}, {}", session.getId(), endpointConfig);
6151
final WsSessionSubscriptions subscriptions = new WsSessionSubscriptions();
6252
final HandshakeRequest request = (HandshakeRequest) session.getUserProperties().get(HANDSHAKE_REQUEST_KEY);
6353
final SubscriptionProtocolHandler subscriptionProtocolHandler = (SubscriptionProtocolHandler) session.getUserProperties().get(PROTOCOL_HANDLER_REQUEST_KEY);
@@ -82,7 +72,7 @@ public void onMessage(String text) {
8272
public void onClose(Session session, CloseReason closeReason) {
8373
log.debug("Session closed: {}, {}", session.getId(), closeReason);
8474
WsSessionSubscriptions subscriptions = sessionSubscriptionCache.remove(session);
85-
if(subscriptions != null) {
75+
if (subscriptions != null) {
8676
subscriptions.close();
8777
}
8878
}
@@ -105,23 +95,25 @@ public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request,
10595
sec.getUserProperties().put(HANDSHAKE_REQUEST_KEY, request);
10696

10797
List<String> protocol = request.getHeaders().get(HandshakeRequest.SEC_WEBSOCKET_PROTOCOL);
108-
if(protocol == null) {
98+
if (protocol == null) {
10999
protocol = Collections.emptyList();
110100
}
111101

112102
SubscriptionProtocolFactory subscriptionProtocolFactory = getSubscriptionProtocolFactory(protocol);
113103
sec.getUserProperties().put(PROTOCOL_HANDLER_REQUEST_KEY, subscriptionProtocolFactory.createHandler(subscriptionHandlerInput));
114104

115-
if(request.getHeaders().get(HandshakeResponse.SEC_WEBSOCKET_ACCEPT) != null) {
105+
if (request.getHeaders().get(HandshakeResponse.SEC_WEBSOCKET_ACCEPT) != null) {
116106
response.getHeaders().put(HandshakeResponse.SEC_WEBSOCKET_ACCEPT, allSubscriptionProtocols);
117107
}
118-
response.getHeaders().put(HandshakeRequest.SEC_WEBSOCKET_PROTOCOL, Collections.singletonList(subscriptionProtocolFactory.getProtocol()));
108+
if (!protocol.isEmpty()) {
109+
response.getHeaders().put(HandshakeRequest.SEC_WEBSOCKET_PROTOCOL, Collections.singletonList(subscriptionProtocolFactory.getProtocol()));
110+
}
119111
}
120112

121113
private static SubscriptionProtocolFactory getSubscriptionProtocolFactory(List<String> accept) {
122-
for(String protocol: accept) {
123-
for(SubscriptionProtocolFactory subscriptionProtocolFactory: subscriptionProtocolFactories) {
124-
if(subscriptionProtocolFactory.getProtocol().equals(protocol)) {
114+
for (String protocol : accept) {
115+
for (SubscriptionProtocolFactory subscriptionProtocolFactory : subscriptionProtocolFactories) {
116+
if (subscriptionProtocolFactory.getProtocol().equals(protocol)) {
125117
return subscriptionProtocolFactory;
126118
}
127119
}

src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolHandler.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
import com.fasterxml.jackson.annotation.JsonInclude;
55
import com.fasterxml.jackson.annotation.JsonValue;
66
import graphql.ExecutionResult;
7-
import org.reactivestreams.Publisher;
8-
import org.reactivestreams.Subscriber;
9-
import org.reactivestreams.Subscription;
107
import org.slf4j.Logger;
118
import org.slf4j.LoggerFactory;
129

@@ -15,7 +12,10 @@
1512
import java.io.IOException;
1613
import java.util.HashMap;
1714
import java.util.Map;
18-
import java.util.concurrent.atomic.AtomicReference;
15+
16+
import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_COMPLETE;
17+
import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_DATA;
18+
import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_ERROR;
1919

2020
/**
2121
* @author Andrew Potter
@@ -69,7 +69,22 @@ private void handleSubscriptionStart(Session session, WsSessionSubscriptions sub
6969
return;
7070
}
7171

72-
subscribe(executionResult, subscriptions, id);
72+
subscribe(session, executionResult, subscriptions, id);
73+
}
74+
75+
@Override
76+
protected void sendDataMessage(Session session, String id, Object payload) {
77+
sendMessage(session, GQL_DATA, id, payload);
78+
}
79+
80+
@Override
81+
protected void sendErrorMessage(Session session, String id) {
82+
sendMessage(session, GQL_ERROR, id);
83+
}
84+
85+
@Override
86+
protected void sendCompleteMessage(Session session, String id) {
87+
sendMessage(session, GQL_COMPLETE, id);
7388
}
7489

7590
private void sendMessage(Session session, OperationMessage.Type type, String id) {

src/main/java/graphql/servlet/internal/FallbackSubscriptionProtocolFactory.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
package graphql.servlet.internal;
22

3-
import graphql.servlet.GraphQLInvocationInputFactory;
4-
import graphql.servlet.GraphQLObjectMapper;
5-
import graphql.servlet.GraphQLQueryInvoker;
6-
73
/**
84
* @author Andrew Potter
95
*/

src/main/java/graphql/servlet/internal/FallbackSubscriptionProtocolHandler.java

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import javax.websocket.Session;
44
import javax.websocket.server.HandshakeRequest;
5+
import java.io.IOException;
56

67
/**
78
* @author Andrew Potter
@@ -16,8 +17,26 @@ public FallbackSubscriptionProtocolHandler(SubscriptionHandlerInput subscription
1617

1718
@Override
1819
public void onMessage(HandshakeRequest request, Session session, WsSessionSubscriptions subscriptions, String text) throws Exception {
19-
session.getBasicRemote().sendText(input.getGraphQLObjectMapper().serializeResultAsJson(
20-
input.getQueryInvoker().query(input.getInvocationInputFactory().create(input.getGraphQLObjectMapper().readGraphQLRequest(text), request))
21-
));
20+
subscribe(session, input.getQueryInvoker().query(input.getInvocationInputFactory().create(
21+
input.getGraphQLObjectMapper().readGraphQLRequest(text))), subscriptions, session.getId());
22+
}
23+
24+
@Override
25+
protected void sendDataMessage(Session session, String id, Object payload) {
26+
try {
27+
session.getBasicRemote().sendText(input.getGraphQLObjectMapper().getJacksonMapper().writeValueAsString(payload));
28+
} catch (IOException e) {
29+
throw new RuntimeException("Error sending subscription response", e);
30+
}
31+
}
32+
33+
@Override
34+
protected void sendErrorMessage(Session session, String id) {
35+
36+
}
37+
38+
@Override
39+
protected void sendCompleteMessage(Session session, String id) {
40+
2241
}
2342
}

src/main/java/graphql/servlet/internal/SubscriptionProtocolHandler.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@ public abstract class SubscriptionProtocolHandler {
2222

2323
public abstract void onMessage(HandshakeRequest request, Session session, WsSessionSubscriptions subscriptions, String text) throws Exception;
2424

25-
protected void subscribe(ExecutionResult executionResult, WsSessionSubscriptions subscriptions, String id) {
25+
protected abstract void sendDataMessage(Session session, String id, Object payload);
26+
27+
protected abstract void sendErrorMessage(Session session, String id);
28+
29+
protected abstract void sendCompleteMessage(Session session, String id);
30+
31+
protected void subscribe(Session session, ExecutionResult executionResult, WsSessionSubscriptions subscriptions, String id) {
2632
final Object data = executionResult.getData();
2733

2834
if (data instanceof Publisher) {
@@ -43,19 +49,21 @@ public void onNext(ExecutionResult executionResult) {
4349
subscriptionReference.get().request(1);
4450
Map<String, Object> result = new HashMap<>();
4551
result.put("data", executionResult.getData());
46-
// sendMessage(session, ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_DATA, id, result);
52+
sendDataMessage(session, id, result);
4753
}
4854

4955
@Override
5056
public void onError(Throwable throwable) {
5157
log.error("Subscription error", throwable);
5258
subscriptions.cancel(id);
59+
sendErrorMessage(session, id);
5360
// sendMessage(session, ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_ERROR, id);
5461
}
5562

5663
@Override
5764
public void onComplete() {
5865
subscriptions.cancel(id);
66+
sendCompleteMessage(session, id);
5967
// sendMessage(session, ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_COMPLETE, id);
6068
}
6169
});

src/test/groovy/graphql/servlet/TestMultipartPart.groovy

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class TestMultipartContentBuilder {
3434
return name
3535
}
3636

37+
@Override
38+
String getSubmittedFileName() {
39+
return name
40+
}
41+
3742
@Override
3843
long getSize() {
3944
return content.getBytes().length

0 commit comments

Comments
 (0)