Skip to content

Commit

Permalink
Enhance security of the Complete message for GraphQL over WebSocket P…
Browse files Browse the repository at this point in the history
…rotocol

Motivation:
When constructing the `Complete` message in the GraphQL over WebSocket Protocol,
appending a string directly to the ID can lead to malformed messages if the input is manipulated by the user.
This vulnerability could potentially allow users to create arbitrary responses by inputting malformed IDs.

Modification:
- Serialize the `Complete` message using a map instead of concatenating strings directly.

Result:
- The `Complete` message in the GraphQL over WebSocket Protocol is now constructed securely.
  • Loading branch information
minwoox committed Mar 26, 2024
1 parent 9908022 commit 96a9a1e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,14 @@ public ErrorClassification getErrorType() {
}

private static void writeComplete(WebSocketWriter out, String operationId) {
out.tryWrite("{\"type\":\"complete\",\"id\":\"" + operationId + "\"}");
try {
final String json = serializeToJson(ImmutableMap.of("type", "complete", "id", operationId));
out.tryWrite(json);
} catch (JsonProcessingException e) {
logger.warn("Unexpected exception while serializing complete event. operationId: {}",

Check warning on line 408 in graphql/src/main/java/com/linecorp/armeria/server/graphql/GraphqlWSSubProtocol.java

View check run for this annotation

Codecov / codecov/patch

graphql/src/main/java/com/linecorp/armeria/server/graphql/GraphqlWSSubProtocol.java#L407-L408

Added lines #L407 - L408 were not covered by tests
operationId, e);
out.close(e);

Check warning on line 410 in graphql/src/main/java/com/linecorp/armeria/server/graphql/GraphqlWSSubProtocol.java

View check run for this annotation

Codecov / codecov/patch

graphql/src/main/java/com/linecorp/armeria/server/graphql/GraphqlWSSubProtocol.java#L410

Added line #L410 was not covered by tests
}
}

private static final class GraphqlWebSocketCloseException extends Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
Expand Down Expand Up @@ -92,7 +93,7 @@ private static DataFetcher<Publisher<String>> notCompleting() {

@BeforeEach
void beforeEach() {
streamRef = new AtomicReference<>(StreamMessage.streaming());
streamRef = new AtomicReference<>(StreamMessage.streaming());
}

@Test
Expand All @@ -110,100 +111,47 @@ void testSubscriptionOverHttp() {
.isEqualTo("Use GraphQL over WebSocket for subscription");
}

@Test
void testSubscriptionOverWebSocketHttp1() {
testWebSocket(SessionProtocol.H1C);
}

@Test
void testSubscriptionOverWebSocketHttp2() {
testWebSocket(SessionProtocol.H2C);
}

private void testWebSocket(SessionProtocol sessionProtocol) {
@CsvSource({ "H1C", "H2C" })
@ParameterizedTest
void testSubscriptionOverWebSocketHttp1(SessionProtocol sessionProtocol) {
final WebSocketClient webSocketClient =
WebSocketClient.builder(server.uri(sessionProtocol, SerializationFormat.WS))
.subprotocols("graphql-transport-ws")
.build();
final CompletableFuture<WebSocketSession> future = webSocketClient.connect("/graphql");

final WebSocketSession webSocketSession = future.join();

final WebSocketSession webSocketSession = webSocketClient.connect("/graphql").join();
final WebSocketWriter outbound = webSocketSession.outbound();

final List<String> receivedEvents = new ArrayList<>();
//noinspection ReactiveStreamsSubscriberImplementation
webSocketSession.inbound().subscribe(new Subscriber<WebSocketFrame>() {
@Override
public void onSubscribe(Subscription s) {
s.request(Long.MAX_VALUE);
}

@Override
public void onNext(WebSocketFrame webSocketFrame) {
if (webSocketFrame.type() == WebSocketFrameType.TEXT) {
receivedEvents.add(webSocketFrame.text());
}
}

@Override
public void onError(Throwable t) {
}

@Override
public void onComplete() {
}
});
webSocketSession.inbound().subscribe(new TestSubscriber(receivedEvents));

outbound.write("{\"type\":\"ping\"}");
outbound.write("{\"type\":\"connection_init\"}");
outbound.write(
"{\"id\":\"1\",\"type\":\"subscribe\",\"payload\":{\"query\":\"subscription {hello}\"}}");

await().until(() -> receivedEvents.size() >= 3);
await().until(() -> receivedEvents.size() >= 4);
assertThatJson(receivedEvents.get(0)).node("type").isEqualTo("pong");
assertThatJson(receivedEvents.get(1)).node("type").isEqualTo("connection_ack");
assertThatJson(receivedEvents.get(2))
.node("type").isEqualTo("next")
.node("id").isEqualTo("\"1\"")
.node("payload.data.hello").isEqualTo("Armeria");
assertThatJson(receivedEvents.get(3))
.node("type").isEqualTo("complete")
.node("id").isEqualTo("\"1\"");
}

@Test
void testSubscriptionCleanedUpWhenClosed() throws Exception {
void testSubscriptionCleanedUpWhenClosed() {
final WebSocketClient webSocketClient =
WebSocketClient.builder(server.uri(SessionProtocol.H1C, SerializationFormat.WS))
.subprotocols("graphql-transport-ws")
.build();
final CompletableFuture<WebSocketSession> future = webSocketClient.connect("/graphql");

final WebSocketSession webSocketSession = future.join();

final WebSocketSession webSocketSession = webSocketClient.connect("/graphql").join();
final WebSocketWriter outbound = webSocketSession.outbound();

final List<String> receivedEvents = new ArrayList<>();
//noinspection ReactiveStreamsSubscriberImplementation
webSocketSession.inbound().subscribe(new Subscriber<WebSocketFrame>() {
@Override
public void onSubscribe(Subscription s) {
s.request(Long.MAX_VALUE);
}

@Override
public void onNext(WebSocketFrame webSocketFrame) {
if (webSocketFrame.type() == WebSocketFrameType.TEXT) {
receivedEvents.add(webSocketFrame.text());
}
}

@Override
public void onError(Throwable t) {
}

@Override
public void onComplete() {
}
});
webSocketSession.inbound().subscribe(new TestSubscriber(receivedEvents));

outbound.write("{\"type\":\"ping\"}");
outbound.write("{\"type\":\"connection_init\"}");
Expand All @@ -218,4 +166,61 @@ public void onComplete() {
.isInstanceOf(CompletionException.class)
.hasCauseInstanceOf(CancelledSubscriptionException.class);
}

@Test
void completeIdIsnt() {
final WebSocketClient webSocketClient =
WebSocketClient.builder(server.uri(SessionProtocol.H1C, SerializationFormat.WS))
.subprotocols("graphql-transport-ws")
.build();
final WebSocketSession webSocketSession = webSocketClient.connect("/graphql").join();
final WebSocketWriter outbound = webSocketSession.outbound();

final List<String> receivedEvents = new ArrayList<>();
webSocketSession.inbound().subscribe(new TestSubscriber(receivedEvents));

outbound.write("{\"type\":\"connection_init\"}");
outbound.write(
"{\"id\":\"1\\\",\\\"hehe\\\":\\\"hehe\", " +
"\"type\":\"subscribe\",\"payload\":{\"query\":\"subscription {hello}\"}}");

await().until(() -> receivedEvents.size() >= 3);
assertThatJson(receivedEvents.get(0)).node("type").isEqualTo("connection_ack");
assertThatJson(receivedEvents.get(1))
.node("type").isEqualTo("next")
.node("id").isEqualTo("\"1\\\",\\\"hehe\\\":\\\"hehe\"")
.node("payload.data.hello").isEqualTo("Armeria");
assertThatJson(receivedEvents.get(2))
.node("type").isEqualTo("complete")
// Before #5531, "hehe" was set as another property.
.node("id").isEqualTo("\"1\\\",\\\"hehe\\\":\\\"hehe\"");
}

@SuppressWarnings("ReactiveStreamsSubscriberImplementation")
private static class TestSubscriber implements Subscriber<WebSocketFrame> {

private final List<String> receivedEvents;

TestSubscriber(List<String> receivedEvents) {
this.receivedEvents = receivedEvents;
}

@Override
public void onSubscribe(Subscription s) {
s.request(Long.MAX_VALUE);
}

@Override
public void onNext(WebSocketFrame webSocketFrame) {
if (webSocketFrame.type() == WebSocketFrameType.TEXT) {
receivedEvents.add(webSocketFrame.text());
}
}

@Override
public void onError(Throwable t) {}

@Override
public void onComplete() {}
}
}

0 comments on commit 96a9a1e

Please sign in to comment.