diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java b/core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java index e639c73a6bd..ea2afab09fd 100644 --- a/core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java @@ -115,7 +115,9 @@ public WebSocketUpgradeResult upgrade(ServiceRequestContext ctx, HttpRequest req case CONNECT: return upgradeHttp2(ctx, req); default: - return WebSocketUpgradeResult.ofFailure(HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED)); + final HttpResponse httpResponse = + failOrFallback(ctx, req, () -> HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED)); + return WebSocketUpgradeResult.ofFailure(httpResponse); } } @@ -147,26 +149,19 @@ public WebSocketUpgradeResult upgrade(ServiceRequestContext ctx, HttpRequest req */ private WebSocketUpgradeResult upgradeHttp1(ServiceRequestContext ctx, HttpRequest req) throws Exception { if (!ctx.sessionProtocol().isExplicitHttp1()) { - final HttpResponse fallbackResponse; - if (fallbackService != null) { - fallbackResponse = fallbackService.serve(ctx, req); - } else { - fallbackResponse = HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED); - } - return WebSocketUpgradeResult.ofFailure(fallbackResponse); + final HttpResponse httpResponse = + failOrFallback(ctx, req, () -> HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED)); + return WebSocketUpgradeResult.ofFailure(httpResponse); } final RequestHeaders headers = req.headers(); if (!isHttp1WebSocketUpgradeRequest(headers)) { - final HttpResponse fallbackResponse; - if (fallbackService != null) { - fallbackResponse = fallbackService.serve(ctx, req); - } else { - fallbackResponse = HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, - "The upgrade header must contain:\n" + - " Upgrade: websocket\n" + - " Connection: Upgrade"); - } - return WebSocketUpgradeResult.ofFailure(fallbackResponse); + final HttpResponse httpResponse = + failOrFallback(ctx, req, () -> HttpResponse.of( + HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, + "The upgrade header must contain:\n" + + " Upgrade: websocket\n" + + " Connection: Upgrade")); + return WebSocketUpgradeResult.ofFailure(httpResponse); } HttpResponse invalidResponse = checkOrigin(ctx, headers); @@ -188,6 +183,15 @@ private WebSocketUpgradeResult upgradeHttp1(ServiceRequestContext ctx, HttpReque return WebSocketUpgradeResult.ofSuccess(); } + private HttpResponse failOrFallback(ServiceRequestContext ctx, HttpRequest req, + Supplier invalidResponse) throws Exception { + if (fallbackService != null) { + return fallbackService.serve(ctx, req); + } else { + return invalidResponse.get(); + } + } + private void maybeAddSubprotocol(RequestHeaders headers, ResponseHeadersBuilder responseHeadersBuilder) { final String subprotocols = headers.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, ""); @@ -245,8 +249,8 @@ private WebSocketUpgradeResult upgradeHttp2(ServiceRequestContext ctx, HttpReque fallbackResponse = fallbackService.serve(ctx, req); } else { fallbackResponse = HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, - "The upgrade header must contain:\n" + - " :protocol = websocket"); + "The upgrade header must contain:\n" + + " :protocol = websocket"); } return WebSocketUpgradeResult.ofFailure(fallbackResponse); } @@ -273,120 +277,120 @@ private HttpResponse maybeFallbackResponse(ServiceRequestContext ctx, HttpReques } } - @Nullable - private HttpResponse checkOrigin (ServiceRequestContext ctx, RequestHeaders headers){ - if (allowAnyOrigin) { - return null; - } - final String origin = headers.get(HttpHeaderNames.ORIGIN, ""); - if (origin.isEmpty()) { - return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, - "missing the origin header"); - } + @Nullable + private HttpResponse checkOrigin(ServiceRequestContext ctx, RequestHeaders headers) { + if (allowAnyOrigin) { + return null; + } + final String origin = headers.get(HttpHeaderNames.ORIGIN, ""); + if (origin.isEmpty()) { + return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, + "missing the origin header"); + } - if (allowedOrigins.isEmpty()) { - // Only the same-origin is allowed. - if (!isSameOrigin(ctx, headers, origin)) { - return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, - "not allowed origin: " + origin); - } - return null; - } - if (!allowedOrigins.contains(origin)) { + if (allowedOrigins.isEmpty()) { + // Only the same-origin is allowed. + if (!isSameOrigin(ctx, headers, origin)) { return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, - "not allowed origin: " + origin + ", allowed: " + allowedOrigins); + "not allowed origin: " + origin); } return null; } + if (!allowedOrigins.contains(origin)) { + return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, + "not allowed origin: " + origin + ", allowed: " + allowedOrigins); + } + return null; + } - private static boolean isSameOrigin (ServiceRequestContext ctx, RequestHeaders headers, String origin){ - final int schemeDelimiter = origin.indexOf("://"); - if (schemeDelimiter < 0) { - return false; - } + private static boolean isSameOrigin(ServiceRequestContext ctx, RequestHeaders headers, String origin) { + final int schemeDelimiter = origin.indexOf("://"); + if (schemeDelimiter < 0) { + return false; + } - final String scheme = origin.substring(0, schemeDelimiter); - final SessionProtocol originSessionProtocol = SessionProtocol.find(scheme); - if (originSessionProtocol == null) { - return false; - } + final String scheme = origin.substring(0, schemeDelimiter); + final SessionProtocol originSessionProtocol = SessionProtocol.find(scheme); + if (originSessionProtocol == null) { + return false; + } - if ((ctx.sessionProtocol().isHttp() && originSessionProtocol.isHttp()) || - (ctx.sessionProtocol().isHttps() && originSessionProtocol.isHttps())) { - // The same scheme. - } else { - return false; - } + if ((ctx.sessionProtocol().isHttp() && originSessionProtocol.isHttp()) || + (ctx.sessionProtocol().isHttps() && originSessionProtocol.isHttps())) { + // The same scheme. + } else { + return false; + } - final String authority = headers.authority(); - assert authority != null; - final HostAndPort authorityHostAndPort = HostAndPort.fromString(authority); - final String authorityHost = authorityHostAndPort.getHost(); - final int authorityPort = authorityHostAndPort.getPortOrDefault( - ctx.sessionProtocol().defaultPort()); + final String authority = headers.authority(); + assert authority != null; + final HostAndPort authorityHostAndPort = HostAndPort.fromString(authority); + final String authorityHost = authorityHostAndPort.getHost(); + final int authorityPort = authorityHostAndPort.getPortOrDefault( + ctx.sessionProtocol().defaultPort()); - final HostAndPort originHostAndPort = HostAndPort.fromString(origin.substring(schemeDelimiter + 3)); - final String originHost = originHostAndPort.getHost(); - final int originPort = originHostAndPort.getPortOrDefault(originSessionProtocol.defaultPort()); + final HostAndPort originHostAndPort = HostAndPort.fromString(origin.substring(schemeDelimiter + 3)); + final String originHost = originHostAndPort.getHost(); + final int originPort = originHostAndPort.getPortOrDefault(originSessionProtocol.defaultPort()); - return authorityPort == originPort && authorityHost.equals(originHost); - } + return authorityPort == originPort && authorityHost.equals(originHost); + } - @Nullable - private static HttpResponse checkVersion (RequestHeaders headers){ - // Currently we only support v13. - final String version = headers.get(HttpHeaderNames.SEC_WEBSOCKET_VERSION); - if (!WebSocketVersion.V13.toHttpHeaderValue().equalsIgnoreCase(version)) { - return HttpResponse.of(UNSUPPORTED_WEB_SOCKET_VERSION, - HttpData.ofUtf8("Only 13 version is supported.")); - } - return null; + @Nullable + private static HttpResponse checkVersion(RequestHeaders headers) { + // Currently we only support v13. + final String version = headers.get(HttpHeaderNames.SEC_WEBSOCKET_VERSION); + if (!WebSocketVersion.V13.toHttpHeaderValue().equalsIgnoreCase(version)) { + return HttpResponse.of(UNSUPPORTED_WEB_SOCKET_VERSION, + HttpData.ofUtf8("Only 13 version is supported.")); } + return null; + } - @Override - public WebSocket decode (ServiceRequestContext ctx, HttpRequest req){ - final WebSocketServiceFrameDecoder decoder = - new WebSocketServiceFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch); - ctx.setAttr(DECODER, decoder); - return new WebSocketWrapper(req.decode(decoder, ctx.alloc())); - } + @Override + public WebSocket decode(ServiceRequestContext ctx, HttpRequest req) { + final WebSocketServiceFrameDecoder decoder = + new WebSocketServiceFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch); + ctx.setAttr(DECODER, decoder); + return new WebSocketWrapper(req.decode(decoder, ctx.alloc())); + } - @Override - public HttpResponse encode (ServiceRequestContext ctx, WebSocket out){ - final RequestHeaders requestHeaders = ctx.request().headers(); - final ResponseHeadersBuilder responseHeadersBuilder; - if (ctx.sessionProtocol().isExplicitHttp1()) { - final String webSocketKey = requestHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_KEY, ""); - final String accept = generateSecWebSocketAccept(webSocketKey); - responseHeadersBuilder = - ResponseHeaders.builder(HttpStatus.SWITCHING_PROTOCOLS) - .add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString()) - .add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE.toString()) - .add(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept); - } else { - // As described in https://datatracker.ietf.org/doc/html/rfc8441#section-5, - // HTTP/2 does not use Sec-WebSocket-Key and Sec-WebSocket-Accept headers. - responseHeadersBuilder = ResponseHeaders.builder(HttpStatus.OK); - } - maybeAddSubprotocol(requestHeaders, responseHeadersBuilder); - - final WebSocketServiceFrameDecoder decoder = ctx.attr(DECODER); - assert decoder != null; - decoder.setOutboundWebSocket(out); - final StreamMessage data = - out.recoverAndResume(cause -> { - if (cause instanceof ClosedStreamException) { - return StreamMessage.aborted(cause); - } - ctx.logBuilder().responseCause(cause); - return StreamMessage.of(newCloseWebSocketFrame(cause)); - }) - .map(frame -> HttpData.wrap(encoder.encode(ctx, frame))); - return HttpResponse.of(responseHeadersBuilder.build(), data); + @Override + public HttpResponse encode(ServiceRequestContext ctx, WebSocket out) { + final RequestHeaders requestHeaders = ctx.request().headers(); + final ResponseHeadersBuilder responseHeadersBuilder; + if (ctx.sessionProtocol().isExplicitHttp1()) { + final String webSocketKey = requestHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_KEY, ""); + final String accept = generateSecWebSocketAccept(webSocketKey); + responseHeadersBuilder = + ResponseHeaders.builder(HttpStatus.SWITCHING_PROTOCOLS) + .add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString()) + .add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE.toString()) + .add(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept); + } else { + // As described in https://datatracker.ietf.org/doc/html/rfc8441#section-5, + // HTTP/2 does not use Sec-WebSocket-Key and Sec-WebSocket-Accept headers. + responseHeadersBuilder = ResponseHeaders.builder(HttpStatus.OK); } + maybeAddSubprotocol(requestHeaders, responseHeadersBuilder); + + final WebSocketServiceFrameDecoder decoder = ctx.attr(DECODER); + assert decoder != null; + decoder.setOutboundWebSocket(out); + final StreamMessage data = + out.recoverAndResume(cause -> { + if (cause instanceof ClosedStreamException) { + return StreamMessage.aborted(cause); + } + ctx.logBuilder().responseCause(cause); + return StreamMessage.of(newCloseWebSocketFrame(cause)); + }) + .map(frame -> HttpData.wrap(encoder.encode(ctx, frame))); + return HttpResponse.of(responseHeadersBuilder.build(), data); + } - @Override - public WebSocketProtocolHandler protocolHandler() { - return this; - } + @Override + public WebSocketProtocolHandler protocolHandler() { + return this; } +} diff --git a/core/src/test/java/com/linecorp/armeria/server/websocket/DelegatingWebSocketServiceTest.java b/core/src/test/java/com/linecorp/armeria/server/websocket/DelegatingWebSocketServiceTest.java index 2d07004ef63..a7dce0ca3a2 100644 --- a/core/src/test/java/com/linecorp/armeria/server/websocket/DelegatingWebSocketServiceTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/websocket/DelegatingWebSocketServiceTest.java @@ -73,7 +73,10 @@ void shouldReturnMessageInUpperCase() { @Test void shouldReturnFallbackResponse() { final BlockingWebClient client = server.blockingWebClient(); - final AggregatedHttpResponse response = client.get("/ws-or-http"); + AggregatedHttpResponse response = client.get("/ws-or-http"); + assertThat(response.status()).isEqualTo(HttpStatus.OK); + assertThat(response.contentUtf8()).isEqualTo("fallback"); + response = client.post("/ws-or-http", ""); assertThat(response.status()).isEqualTo(HttpStatus.OK); assertThat(response.contentUtf8()).isEqualTo("fallback"); }