diff --git a/webserver/tests/websocket/src/test/java/io/helidon/webserver/tests/websocket/WebSocketTest.java b/webserver/tests/websocket/src/test/java/io/helidon/webserver/tests/websocket/WebSocketTest.java index bd6d964ce6a..dc3aa21b0ee 100644 --- a/webserver/tests/websocket/src/test/java/io/helidon/webserver/tests/websocket/WebSocketTest.java +++ b/webserver/tests/websocket/src/test/java/io/helidon/webserver/tests/websocket/WebSocketTest.java @@ -21,6 +21,7 @@ import java.time.Duration; import java.util.LinkedList; import java.util.List; +import java.util.Random; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; @@ -50,6 +51,8 @@ class WebSocketTest { private final int port; private final HttpClient client; + private volatile boolean isNormalClose = true; + WebSocketTest(WebServer server) { port = server.port(); client = HttpClient.newBuilder() @@ -69,11 +72,13 @@ void resetClosed() { } @AfterEach - void checkClosed() { - EchoService.CloseInfo closeInfo = service.closeInfo(); - assertThat(closeInfo, notNullValue()); - assertThat(closeInfo.status(), is(WsCloseCodes.NORMAL_CLOSE)); - assertThat(closeInfo.reason(), is("normal")); + void checkNormalClose() { + if (isNormalClose) { + EchoService.CloseInfo closeInfo = service.closeInfo(); + assertThat(closeInfo, notNullValue()); + assertThat(closeInfo.status(), is(WsCloseCodes.NORMAL_CLOSE)); + assertThat(closeInfo.reason(), is("normal")); + } } @Test @@ -91,7 +96,7 @@ void testOnce() throws Exception { ws.sendText("Hello", true).get(5, TimeUnit.SECONDS); ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS); - List results = listener.getResults(); + List results = listener.results().received; assertThat(results, contains("Hello")); } @@ -107,7 +112,7 @@ void testMulti() throws Exception { ws.sendText("First", true).get(5, TimeUnit.SECONDS); ws.sendText("Second", true).get(5, TimeUnit.SECONDS); ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS); - assertThat(listener.getResults(), contains("First", "Second")); + assertThat(listener.results().received, contains("First", "Second")); } @Test @@ -124,13 +129,63 @@ void testFragmentedAndMulti() throws Exception { ws.sendText("Third", true).get(5, TimeUnit.SECONDS); ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS); - assertThat(listener.getResults(), contains("FirstSecond", "Third")); + assertThat(listener.results().received, contains("FirstSecond", "Third")); + } + + /** + * Tests sending long text messages. Note that any message longer than 16K + * will be chunked into 16K pieces by the JDK client. + * + * @throws Exception if an error occurs + */ + @Test + void testLongTextMessages() throws Exception { + TestListener listener = new TestListener(); + + java.net.http.WebSocket ws = client.newWebSocketBuilder() + .buildAsync(URI.create("ws://localhost:" + port + "/echo"), listener) + .get(5, TimeUnit.SECONDS); + ws.request(10); + + String s100 = randomString(100); // less than one byte + ws.sendText(s100, true).get(5, TimeUnit.SECONDS); + String s10000 = randomString(10000); // less than two bytes + ws.sendText(s10000, true).get(5, TimeUnit.SECONDS); + ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS); + + assertThat(listener.results().received, contains(s100, s10000)); + } + + /** + * Test sending a single text message that will fit into a single JDK client frame + * of 16K but exceeds max-frame-length set in application.yaml for the server. + * + * @throws Exception if an error occurs + */ + @Test + void testTooLongTextMessage() throws Exception { + TestListener listener = new TestListener(); + + java.net.http.WebSocket ws = client.newWebSocketBuilder() + .buildAsync(URI.create("ws://localhost:" + port + "/echo"), listener) + .get(5, TimeUnit.SECONDS); + ws.request(10); + + String s10001 = randomString(10001); // over the limit of 10000 + ws.sendText(s10001, true).get(5, TimeUnit.SECONDS); + assertThat(listener.results().statusCode, is(1009)); + assertThat(listener.results().reason, is("Payload too large")); + isNormalClose = false; } private static class TestListener implements java.net.http.WebSocket.Listener { + + record Results(int statusCode, String reason, List received) { + } + final List received = new LinkedList<>(); final List buffered = new LinkedList<>(); - private final CompletableFuture> response = new CompletableFuture<>(); + private final CompletableFuture response = new CompletableFuture<>(); @Override public void onOpen(java.net.http.WebSocket webSocket) { @@ -151,12 +206,21 @@ public CompletionStage onText(java.net.http.WebSocket webSocket, CharSequence @Override public CompletionStage onClose(java.net.http.WebSocket webSocket, int statusCode, String reason) { - response.complete(received); + response.complete(new Results(statusCode, reason, received)); return null; } - List getResults() throws ExecutionException, InterruptedException, TimeoutException { + Results results() throws ExecutionException, InterruptedException, TimeoutException { return response.get(10, TimeUnit.SECONDS); } } + + private static String randomString(int length) { + int leftLimit = 97; // letter 'a' + int rightLimit = 122; // letter 'z' + return new Random().ints(leftLimit, rightLimit + 1) + .limit(length) + .collect(StringBuilder::new, StringBuilder::appendCodePoint, StringBuilder::append) + .toString(); + } } diff --git a/webserver/tests/websocket/src/test/resources/application.yaml b/webserver/tests/websocket/src/test/resources/application.yaml new file mode 100644 index 00000000000..72eb09e48f8 --- /dev/null +++ b/webserver/tests/websocket/src/test/resources/application.yaml @@ -0,0 +1,20 @@ +# +# Copyright (c) 2023 Oracle and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +server: + protocols: + websocket: + max-frame-length: 10000 \ No newline at end of file diff --git a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConfigBlueprint.java b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConfigBlueprint.java index 02632a4da12..7ba935890a2 100644 --- a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConfigBlueprint.java +++ b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConfigBlueprint.java @@ -56,4 +56,13 @@ default String type() { @ConfiguredOption(WsUpgradeProvider.CONFIG_NAME) @Override String name(); + + /** + * Max WebSocket frame size supported by the server on a read operation. + * Default is 1 MB. + * + * @return max frame size to read + */ + @ConfiguredOption(WsConnection.MAX_FRAME_LENGTH) + int maxFrameLength(); } diff --git a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConnection.java b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConnection.java index e2b69fa8c77..32d3f442d88 100644 --- a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConnection.java +++ b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConnection.java @@ -45,11 +45,14 @@ public class WsConnection implements ServerConnection, WsSession { private static final System.Logger LOGGER = System.getLogger(WsConnection.class.getName()); + static final String MAX_FRAME_LENGTH = "1048576"; + private final ConnectionContext ctx; private final HttpPrologue prologue; private final Headers upgradeHeaders; private final String wsKey; private final WsListener listener; + private final WsConfig wsConfig; private final BufferData sendBuffer = BufferData.growing(1024); private final DataReader dataReader; @@ -75,6 +78,13 @@ private WsConnection(ConnectionContext ctx, this.listener = wsRoute.listener(); this.dataReader = ctx.dataReader(); this.lastRequestTimestamp = DateTime.timestamp(); + this.wsConfig = (WsConfig) ctx.listenerContext() + .config() + .protocols() + .stream() + .filter(p -> p instanceof WsConfig) + .findFirst() + .orElseThrow(() -> new InternalError("Unable to find WebSocket config")); } /** @@ -243,8 +253,7 @@ private boolean processFrame(ClientWsFrame frame) { private ClientWsFrame readFrame() { try { - // TODO check may payload size, danger of oom - return ClientWsFrame.read(ctx, dataReader, Integer.MAX_VALUE); + return ClientWsFrame.read(ctx, dataReader, wsConfig.maxFrameLength()); } catch (DataReader.InsufficientDataAvailableException e) { throw new CloseConnectionException("Socket closed by the other side", e); } catch (WsCloseException e) { @@ -276,9 +285,18 @@ private WsSession send(ServerWsFrame frame) { opCodeFull |= usedCode.code(); sendBuffer.write(opCodeFull); - if (frame.payloadLength() < 126) { - sendBuffer.write((int) frame.payloadLength()); - // TODO finish other options (payload longer than 126 bytes) + long length = frame.payloadLength(); + if (length < 126) { + sendBuffer.write((int) length); + } else if (length < 1 << 16) { + sendBuffer.write(126); + sendBuffer.write((int) (length >>> 8)); + sendBuffer.write((int) (length & 0xFF)); + } else { + sendBuffer.write(127); + for (int i = 56; i >= 0; i -= 8){ + sendBuffer.write((int) (length >>> i) & 0xFF); + } } sendBuffer.write(frame.payloadData()); ctx.dataWriter().writeNow(sendBuffer); diff --git a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgradeProvider.java b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgradeProvider.java index 4f9573ca07c..ffc7aea7698 100644 --- a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgradeProvider.java +++ b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgradeProvider.java @@ -25,7 +25,7 @@ */ public class WsUpgradeProvider implements Http1UpgradeProvider { /** - * HTTP/2 server connection provider configuration node name. + * WebSocket server connection provider configuration node name. */ protected static final String CONFIG_NAME = "websocket"; diff --git a/websocket/src/main/java/io/helidon/websocket/AbstractWsFrame.java b/websocket/src/main/java/io/helidon/websocket/AbstractWsFrame.java index 6bd28fb53df..35671337f65 100644 --- a/websocket/src/main/java/io/helidon/websocket/AbstractWsFrame.java +++ b/websocket/src/main/java/io/helidon/websocket/AbstractWsFrame.java @@ -109,7 +109,7 @@ protected static FrameHeader readFrameHeader(DataReader reader, int maxFrameLeng throw new WsCloseException("Payload too large", WsCloseCodes.TOO_BIG); } - return new FrameHeader(opCode, fin, masked, length); + return new FrameHeader(opCode, fin, masked, (int) frameLength); } protected static BufferData readPayload(DataReader reader, FrameHeader header) {