Skip to content

Commit

Permalink
Handle post and other methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ikhoon committed Jan 5, 2024
1 parent 9aa3d1b commit 74eb55f
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ public WebSocketUpgradeResult upgrade(ServiceRequestContext ctx, HttpRequest req
case CONNECT:
return upgradeHttp2(ctx, req);
default:
return WebSocketUpgradeResult.ofFailure(HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED));
final HttpResponse httpResponse =
failOrFallback(ctx, req, () -> HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED));
return WebSocketUpgradeResult.ofFailure(httpResponse);
}
}

Expand Down Expand Up @@ -147,26 +149,19 @@ public WebSocketUpgradeResult upgrade(ServiceRequestContext ctx, HttpRequest req
*/
private WebSocketUpgradeResult upgradeHttp1(ServiceRequestContext ctx, HttpRequest req) throws Exception {
if (!ctx.sessionProtocol().isExplicitHttp1()) {
final HttpResponse fallbackResponse;
if (fallbackService != null) {
fallbackResponse = fallbackService.serve(ctx, req);
} else {
fallbackResponse = HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED);
}
return WebSocketUpgradeResult.ofFailure(fallbackResponse);
final HttpResponse httpResponse =
failOrFallback(ctx, req, () -> HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED));
return WebSocketUpgradeResult.ofFailure(httpResponse);
}
final RequestHeaders headers = req.headers();
if (!isHttp1WebSocketUpgradeRequest(headers)) {
final HttpResponse fallbackResponse;
if (fallbackService != null) {
fallbackResponse = fallbackService.serve(ctx, req);
} else {
fallbackResponse = HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8,
"The upgrade header must contain:\n" +
" Upgrade: websocket\n" +
" Connection: Upgrade");
}
return WebSocketUpgradeResult.ofFailure(fallbackResponse);
final HttpResponse httpResponse =
failOrFallback(ctx, req, () -> HttpResponse.of(
HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8,
"The upgrade header must contain:\n" +
" Upgrade: websocket\n" +
" Connection: Upgrade"));
return WebSocketUpgradeResult.ofFailure(httpResponse);
}

HttpResponse invalidResponse = checkOrigin(ctx, headers);
Expand All @@ -188,6 +183,15 @@ private WebSocketUpgradeResult upgradeHttp1(ServiceRequestContext ctx, HttpReque
return WebSocketUpgradeResult.ofSuccess();
}

private HttpResponse failOrFallback(ServiceRequestContext ctx, HttpRequest req,
Supplier<HttpResponse> invalidResponse) throws Exception {
if (fallbackService != null) {
return fallbackService.serve(ctx, req);
} else {
return invalidResponse.get();
}
}

private void maybeAddSubprotocol(RequestHeaders headers,
ResponseHeadersBuilder responseHeadersBuilder) {
final String subprotocols = headers.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "");
Expand Down Expand Up @@ -245,8 +249,8 @@ private WebSocketUpgradeResult upgradeHttp2(ServiceRequestContext ctx, HttpReque
fallbackResponse = fallbackService.serve(ctx, req);

Check warning on line 249 in core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java

View check run for this annotation

Codecov / codecov/patch

core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java#L249

Added line #L249 was not covered by tests
} else {
fallbackResponse = HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8,

Check warning on line 251 in core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java

View check run for this annotation

Codecov / codecov/patch

core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java#L251

Added line #L251 was not covered by tests
"The upgrade header must contain:\n" +
" :protocol = websocket");
"The upgrade header must contain:\n" +
" :protocol = websocket");
}
return WebSocketUpgradeResult.ofFailure(fallbackResponse);

Check warning on line 255 in core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java

View check run for this annotation

Codecov / codecov/patch

core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java#L255

Added line #L255 was not covered by tests
}
Expand All @@ -273,120 +277,120 @@ private HttpResponse maybeFallbackResponse(ServiceRequestContext ctx, HttpReques
}
}

@Nullable
private HttpResponse checkOrigin (ServiceRequestContext ctx, RequestHeaders headers){
if (allowAnyOrigin) {
return null;
}
final String origin = headers.get(HttpHeaderNames.ORIGIN, "");
if (origin.isEmpty()) {
return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8,
"missing the origin header");
}
@Nullable
private HttpResponse checkOrigin(ServiceRequestContext ctx, RequestHeaders headers) {
if (allowAnyOrigin) {
return null;

Check warning on line 283 in core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java

View check run for this annotation

Codecov / codecov/patch

core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java#L283

Added line #L283 was not covered by tests
}
final String origin = headers.get(HttpHeaderNames.ORIGIN, "");
if (origin.isEmpty()) {
return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8,
"missing the origin header");
}

if (allowedOrigins.isEmpty()) {
// Only the same-origin is allowed.
if (!isSameOrigin(ctx, headers, origin)) {
return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8,
"not allowed origin: " + origin);
}
return null;
}
if (!allowedOrigins.contains(origin)) {
if (allowedOrigins.isEmpty()) {
// Only the same-origin is allowed.
if (!isSameOrigin(ctx, headers, origin)) {
return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8,

Check warning on line 294 in core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java

View check run for this annotation

Codecov / codecov/patch

core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java#L294

Added line #L294 was not covered by tests
"not allowed origin: " + origin + ", allowed: " + allowedOrigins);
"not allowed origin: " + origin);
}
return null;
}
if (!allowedOrigins.contains(origin)) {
return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8,

Check warning on line 300 in core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java

View check run for this annotation

Codecov / codecov/patch

core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java#L300

Added line #L300 was not covered by tests
"not allowed origin: " + origin + ", allowed: " + allowedOrigins);
}
return null;
}

private static boolean isSameOrigin (ServiceRequestContext ctx, RequestHeaders headers, String origin){
final int schemeDelimiter = origin.indexOf("://");
if (schemeDelimiter < 0) {
return false;
}
private static boolean isSameOrigin(ServiceRequestContext ctx, RequestHeaders headers, String origin) {
final int schemeDelimiter = origin.indexOf("://");
if (schemeDelimiter < 0) {
return false;

Check warning on line 309 in core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java

View check run for this annotation

Codecov / codecov/patch

core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java#L309

Added line #L309 was not covered by tests
}

final String scheme = origin.substring(0, schemeDelimiter);
final SessionProtocol originSessionProtocol = SessionProtocol.find(scheme);
if (originSessionProtocol == null) {
return false;
}
final String scheme = origin.substring(0, schemeDelimiter);
final SessionProtocol originSessionProtocol = SessionProtocol.find(scheme);
if (originSessionProtocol == null) {
return false;

Check warning on line 315 in core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java

View check run for this annotation

Codecov / codecov/patch

core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java#L315

Added line #L315 was not covered by tests
}

if ((ctx.sessionProtocol().isHttp() && originSessionProtocol.isHttp()) ||
(ctx.sessionProtocol().isHttps() && originSessionProtocol.isHttps())) {
// The same scheme.
} else {
return false;
}
if ((ctx.sessionProtocol().isHttp() && originSessionProtocol.isHttp()) ||
(ctx.sessionProtocol().isHttps() && originSessionProtocol.isHttps())) {
// The same scheme.
} else {
return false;

Check warning on line 322 in core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java

View check run for this annotation

Codecov / codecov/patch

core/src/main/java/com/linecorp/armeria/server/websocket/DefaultWebSocketService.java#L322

Added line #L322 was not covered by tests
}

final String authority = headers.authority();
assert authority != null;
final HostAndPort authorityHostAndPort = HostAndPort.fromString(authority);
final String authorityHost = authorityHostAndPort.getHost();
final int authorityPort = authorityHostAndPort.getPortOrDefault(
ctx.sessionProtocol().defaultPort());
final String authority = headers.authority();
assert authority != null;
final HostAndPort authorityHostAndPort = HostAndPort.fromString(authority);
final String authorityHost = authorityHostAndPort.getHost();
final int authorityPort = authorityHostAndPort.getPortOrDefault(
ctx.sessionProtocol().defaultPort());

final HostAndPort originHostAndPort = HostAndPort.fromString(origin.substring(schemeDelimiter + 3));
final String originHost = originHostAndPort.getHost();
final int originPort = originHostAndPort.getPortOrDefault(originSessionProtocol.defaultPort());
final HostAndPort originHostAndPort = HostAndPort.fromString(origin.substring(schemeDelimiter + 3));
final String originHost = originHostAndPort.getHost();
final int originPort = originHostAndPort.getPortOrDefault(originSessionProtocol.defaultPort());

return authorityPort == originPort && authorityHost.equals(originHost);
}
return authorityPort == originPort && authorityHost.equals(originHost);
}

@Nullable
private static HttpResponse checkVersion (RequestHeaders headers){
// Currently we only support v13.
final String version = headers.get(HttpHeaderNames.SEC_WEBSOCKET_VERSION);
if (!WebSocketVersion.V13.toHttpHeaderValue().equalsIgnoreCase(version)) {
return HttpResponse.of(UNSUPPORTED_WEB_SOCKET_VERSION,
HttpData.ofUtf8("Only 13 version is supported."));
}
return null;
@Nullable
private static HttpResponse checkVersion(RequestHeaders headers) {
// Currently we only support v13.
final String version = headers.get(HttpHeaderNames.SEC_WEBSOCKET_VERSION);
if (!WebSocketVersion.V13.toHttpHeaderValue().equalsIgnoreCase(version)) {
return HttpResponse.of(UNSUPPORTED_WEB_SOCKET_VERSION,
HttpData.ofUtf8("Only 13 version is supported."));
}
return null;
}

@Override
public WebSocket decode (ServiceRequestContext ctx, HttpRequest req){
final WebSocketServiceFrameDecoder decoder =
new WebSocketServiceFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch);
ctx.setAttr(DECODER, decoder);
return new WebSocketWrapper(req.decode(decoder, ctx.alloc()));
}
@Override
public WebSocket decode(ServiceRequestContext ctx, HttpRequest req) {
final WebSocketServiceFrameDecoder decoder =
new WebSocketServiceFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch);
ctx.setAttr(DECODER, decoder);
return new WebSocketWrapper(req.decode(decoder, ctx.alloc()));
}

@Override
public HttpResponse encode (ServiceRequestContext ctx, WebSocket out){
final RequestHeaders requestHeaders = ctx.request().headers();
final ResponseHeadersBuilder responseHeadersBuilder;
if (ctx.sessionProtocol().isExplicitHttp1()) {
final String webSocketKey = requestHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_KEY, "");
final String accept = generateSecWebSocketAccept(webSocketKey);
responseHeadersBuilder =
ResponseHeaders.builder(HttpStatus.SWITCHING_PROTOCOLS)
.add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString())
.add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE.toString())
.add(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept);
} else {
// As described in https://datatracker.ietf.org/doc/html/rfc8441#section-5,
// HTTP/2 does not use Sec-WebSocket-Key and Sec-WebSocket-Accept headers.
responseHeadersBuilder = ResponseHeaders.builder(HttpStatus.OK);
}
maybeAddSubprotocol(requestHeaders, responseHeadersBuilder);

final WebSocketServiceFrameDecoder decoder = ctx.attr(DECODER);
assert decoder != null;
decoder.setOutboundWebSocket(out);
final StreamMessage<HttpData> data =
out.recoverAndResume(cause -> {
if (cause instanceof ClosedStreamException) {
return StreamMessage.aborted(cause);
}
ctx.logBuilder().responseCause(cause);
return StreamMessage.of(newCloseWebSocketFrame(cause));
})
.map(frame -> HttpData.wrap(encoder.encode(ctx, frame)));
return HttpResponse.of(responseHeadersBuilder.build(), data);
@Override
public HttpResponse encode(ServiceRequestContext ctx, WebSocket out) {
final RequestHeaders requestHeaders = ctx.request().headers();
final ResponseHeadersBuilder responseHeadersBuilder;
if (ctx.sessionProtocol().isExplicitHttp1()) {
final String webSocketKey = requestHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_KEY, "");
final String accept = generateSecWebSocketAccept(webSocketKey);
responseHeadersBuilder =
ResponseHeaders.builder(HttpStatus.SWITCHING_PROTOCOLS)
.add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString())
.add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE.toString())
.add(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept);
} else {
// As described in https://datatracker.ietf.org/doc/html/rfc8441#section-5,
// HTTP/2 does not use Sec-WebSocket-Key and Sec-WebSocket-Accept headers.
responseHeadersBuilder = ResponseHeaders.builder(HttpStatus.OK);
}
maybeAddSubprotocol(requestHeaders, responseHeadersBuilder);

final WebSocketServiceFrameDecoder decoder = ctx.attr(DECODER);
assert decoder != null;
decoder.setOutboundWebSocket(out);
final StreamMessage<HttpData> data =
out.recoverAndResume(cause -> {
if (cause instanceof ClosedStreamException) {
return StreamMessage.aborted(cause);
}
ctx.logBuilder().responseCause(cause);
return StreamMessage.of(newCloseWebSocketFrame(cause));
})
.map(frame -> HttpData.wrap(encoder.encode(ctx, frame)));
return HttpResponse.of(responseHeadersBuilder.build(), data);
}

@Override
public WebSocketProtocolHandler protocolHandler() {
return this;
}
@Override
public WebSocketProtocolHandler protocolHandler() {
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ void shouldReturnMessageInUpperCase() {
@Test
void shouldReturnFallbackResponse() {
final BlockingWebClient client = server.blockingWebClient();
final AggregatedHttpResponse response = client.get("/ws-or-http");
AggregatedHttpResponse response = client.get("/ws-or-http");
assertThat(response.status()).isEqualTo(HttpStatus.OK);
assertThat(response.contentUtf8()).isEqualTo("fallback");
response = client.post("/ws-or-http", "");
assertThat(response.status()).isEqualTo(HttpStatus.OK);
assertThat(response.contentUtf8()).isEqualTo("fallback");
}
Expand Down

0 comments on commit 74eb55f

Please sign in to comment.