Skip to content

Commit

Permalink
Use the HttpServerResponse to send a response on an handshake failure…
Browse files Browse the repository at this point in the history
… so the HTTP connection remains usable after - fixes #2878
  • Loading branch information
vietj committed Mar 13, 2019
1 parent 14d554f commit 29d1181
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 42 deletions.
8 changes: 5 additions & 3 deletions src/main/java/io/vertx/core/http/HttpServerRequest.java
Expand Up @@ -300,10 +300,12 @@ default HttpServerRequest bodyHandler(@Nullable Handler<Buffer> bodyHandler) {
/** /**
* Upgrade the connection to a WebSocket connection. * Upgrade the connection to a WebSocket connection.
* <p> * <p>
* This is an alternative way of handling WebSockets and can only be used if no websocket handlers are set on the * This is an alternative way of handling WebSockets and can only be used if no WebSocket handler is set on the
* Http server, and can only be used during the upgrade request during the WebSocket handshake. * {@code HttpServer}, and can only be used during the upgrade request during the WebSocket handshake.
* *
* @return the WebSocket * @return the WebSocket
* @throws IllegalStateException if the current request cannot be upgraded, when it happens an appropriate response
* is sent
*/ */
ServerWebSocket upgrade(); ServerWebSocket upgrade();


Expand Down
33 changes: 21 additions & 12 deletions src/main/java/io/vertx/core/http/impl/Http1xServerConnection.java
Expand Up @@ -43,6 +43,7 @@
import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
import static io.netty.handler.codec.http.HttpResponseStatus.CONTINUE; import static io.netty.handler.codec.http.HttpResponseStatus.CONTINUE;
import static io.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED; import static io.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED;
import static io.netty.handler.codec.http.HttpResponseStatus.UPGRADE_REQUIRED;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import static io.vertx.core.spi.metrics.Metrics.METRICS_ENABLED; import static io.vertx.core.spi.metrics.Metrics.METRICS_ENABLED;


Expand Down Expand Up @@ -319,11 +320,10 @@ ServerWebSocketImpl createWebSocket(HttpServerRequestImpl request) {
if (ws != null) { if (ws != null) {
return ws; return ws;
} }
if (!(request.getRequest() instanceof FullHttpRequest)) { if (!(request.nettyRequest() instanceof FullHttpRequest)) {
throw new IllegalStateException(); throw new IllegalStateException();
} }
FullHttpRequest nettyReq = (FullHttpRequest) request.getRequest(); WebSocketServerHandshaker handshaker = createHandshaker(request);
WebSocketServerHandshaker handshaker = createHandshaker(nettyReq);
if (handshaker == null) { if (handshaker == null) {
return null; return null;
} }
Expand All @@ -335,25 +335,31 @@ ServerWebSocketImpl createWebSocket(HttpServerRequestImpl request) {
return ws; return ws;
} }


private WebSocketServerHandshaker createHandshaker(HttpRequest request) { private WebSocketServerHandshaker createHandshaker(HttpServerRequestImpl request) {
// As a fun part, Firefox 6.0.2 supports Websockets protocol '7'. But, // As a fun part, Firefox 6.0.2 supports Websockets protocol '7'. But,
// it doesn't send a normal 'Connection: Upgrade' header. Instead it // it doesn't send a normal 'Connection: Upgrade' header. Instead it
// sends: 'Connection: keep-alive, Upgrade'. Brilliant. // sends: 'Connection: keep-alive, Upgrade'. Brilliant.
Channel ch = channel(); Channel ch = channel();
String connectionHeader = request.headers().get(io.vertx.core.http.HttpHeaders.CONNECTION); String connectionHeader = request.getHeader(io.vertx.core.http.HttpHeaders.CONNECTION);
if (connectionHeader == null || !connectionHeader.toLowerCase().contains("upgrade")) { if (connectionHeader == null || !connectionHeader.toLowerCase().contains("upgrade")) {
HttpUtils.sendError(ch, BAD_REQUEST, "\"Connection\" header must be \"Upgrade\"."); request.response()
.setStatusCode(BAD_REQUEST.code())
.end("\"Connection\" header must be \"Upgrade\".");
return null; return null;
} }
if (request.method() != HttpMethod.GET) { if (request.method() != io.vertx.core.http.HttpMethod.GET) {
HttpUtils.sendError(ch, METHOD_NOT_ALLOWED, null); request.response()
.setStatusCode(METHOD_NOT_ALLOWED.code())
.end();
return null; return null;
} }
String wsURL; String wsURL;
try { try {
wsURL = HttpUtils.getWebSocketLocation(request, isSsl()); wsURL = HttpUtils.getWebSocketLocation(request, isSsl());
} catch (Exception e) { } catch (Exception e) {
HttpUtils.sendError(ch, BAD_REQUEST, "Invalid request URI"); request.response()
.setStatusCode(BAD_REQUEST.code())
.end("Invalid request URI");
return null; return null;
} }


Expand All @@ -362,10 +368,13 @@ private WebSocketServerHandshaker createHandshaker(HttpRequest request) {
options.getWebsocketSubProtocols(), options.getWebsocketSubProtocols(),
options.getPerMessageWebsocketCompressionSupported() || options.getPerFrameWebsocketCompressionSupported(), options.getPerMessageWebsocketCompressionSupported() || options.getPerFrameWebsocketCompressionSupported(),
options.getMaxWebsocketFrameSize(), options.isAcceptUnmaskedFrames()); options.getMaxWebsocketFrameSize(), options.isAcceptUnmaskedFrames());
WebSocketServerHandshaker shake = factory.newHandshaker(request); WebSocketServerHandshaker shake = factory.newHandshaker(request.nettyRequest());
if (shake == null) { if (shake == null) {
//log.error("Unrecognised websockets handshake"); // See WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ch);
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ch); request.response()
.putHeader(HttpHeaderNames.SEC_WEBSOCKET_VERSION, WebSocketVersion.V13.toHttpHeaderValue())
.setStatusCode(UPGRADE_REQUIRED.code())
.end();
} }
return shake; return shake;
} }
Expand Down
12 changes: 8 additions & 4 deletions src/main/java/io/vertx/core/http/impl/HttpServerRequestImpl.java
Expand Up @@ -62,7 +62,7 @@ public class HttpServerRequestImpl implements HttpServerRequest {
private final Http1xServerConnection conn; private final Http1xServerConnection conn;
final ContextInternal context; final ContextInternal context;


private DefaultHttpRequest request; private HttpRequest request;
private io.vertx.core.http.HttpVersion version; private io.vertx.core.http.HttpVersion version;
private io.vertx.core.http.HttpMethod method; private io.vertx.core.http.HttpMethod method;
private String rawMethod; private String rawMethod;
Expand Down Expand Up @@ -92,19 +92,23 @@ public class HttpServerRequestImpl implements HttpServerRequest {


private InboundBuffer<Buffer> pending; private InboundBuffer<Buffer> pending;


HttpServerRequestImpl(Http1xServerConnection conn, DefaultHttpRequest request) { HttpServerRequestImpl(Http1xServerConnection conn, HttpRequest request) {
this.conn = conn; this.conn = conn;
this.context = conn.getContext().duplicate(); this.context = conn.getContext().duplicate();
this.request = request; this.request = request;
} }


DefaultHttpRequest getRequest() { /**
*
* @return
*/
HttpRequest nettyRequest() {
synchronized (conn) { synchronized (conn) {
return request; return request;
} }
} }


void setRequest(DefaultHttpRequest request) { void setRequest(HttpRequest request) {
synchronized (conn) { synchronized (conn) {
this.request = request; this.request = request;
} }
Expand Down
Expand Up @@ -16,7 +16,6 @@
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.*; import io.netty.handler.codec.http.*;
import io.netty.util.concurrent.GenericFutureListener;
import io.vertx.codegen.annotations.Nullable; import io.vertx.codegen.annotations.Nullable;
import io.vertx.core.AsyncResult; import io.vertx.core.AsyncResult;
import io.vertx.core.Future; import io.vertx.core.Future;
Expand Down Expand Up @@ -82,7 +81,7 @@ public class HttpServerResponseImpl implements HttpServerResponse {
private long bytesWritten; private long bytesWritten;
private NetSocket netSocket; private NetSocket netSocket;


HttpServerResponseImpl(final VertxInternal vertx, ContextInternal context, Http1xServerConnection conn, DefaultHttpRequest request, Object requestMetric) { HttpServerResponseImpl(final VertxInternal vertx, ContextInternal context, Http1xServerConnection conn, HttpRequest request, Object requestMetric) {
this.vertx = vertx; this.vertx = vertx;
this.conn = conn; this.conn = conn;
this.context = context; this.context = context;
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/io/vertx/core/http/impl/HttpUtils.java
Expand Up @@ -626,7 +626,7 @@ static void sendError(Channel ch, HttpResponseStatus status, CharSequence err) {
ch.writeAndFlush(resp); ch.writeAndFlush(resp);
} }


static String getWebSocketLocation(HttpRequest req, boolean ssl) throws Exception { static String getWebSocketLocation(HttpServerRequest req, boolean ssl) throws Exception {
String prefix; String prefix;
if (ssl) { if (ssl) {
prefix = "ws://"; prefix = "ws://";
Expand Down
Expand Up @@ -14,9 +14,7 @@
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.vertx.core.Future; import io.vertx.core.Future;
import io.vertx.core.MultiMap; import io.vertx.core.MultiMap;
Expand All @@ -29,8 +27,6 @@
import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSession;
import javax.security.cert.X509Certificate; import javax.security.cert.X509Certificate;


import java.util.function.Function;

import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;


import static io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS; import static io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS;
Expand Down Expand Up @@ -174,10 +170,9 @@ private void handleHandshake(int sc) {
private void doHandshake() { private void doHandshake() {
Channel channel = conn.channel(); Channel channel = conn.channel();
try { try {
handshaker.handshake(channel, request.getRequest()); handshaker.handshake(channel, request.nettyRequest());
} catch (Exception e) { } catch (Exception e) {
status = BAD_REQUEST.code(); request.response().setStatusCode(BAD_REQUEST.code()).end();
HttpUtils.sendError(conn.channel(), BAD_REQUEST, "\"Connection\" header must be \"Upgrade\".");
throw e; throw e;
} finally { } finally {
request = null; request = null;
Expand Down
Expand Up @@ -13,12 +13,12 @@
import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.EmptyHttpHeaders; import io.netty.handler.codec.http.EmptyHttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.vertx.core.Handler; import io.vertx.core.Handler;
import io.vertx.core.buffer.Buffer; import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpServerRequest; import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.spi.metrics.HttpServerMetrics; import io.vertx.core.spi.metrics.HttpServerMetrics;


import static io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS;
import static io.vertx.core.http.HttpHeaders.UPGRADE; import static io.vertx.core.http.HttpHeaders.UPGRADE;
import static io.vertx.core.http.HttpHeaders.WEBSOCKET; import static io.vertx.core.http.HttpHeaders.WEBSOCKET;
import static io.vertx.core.http.impl.HttpUtils.SC_SWITCHING_PROTOCOLS; import static io.vertx.core.http.impl.HttpUtils.SC_SWITCHING_PROTOCOLS;
Expand Down Expand Up @@ -80,7 +80,7 @@ private void handle(HttpServerRequestImpl req) {
* Handle the request once we have the full body. * Handle the request once we have the full body.
*/ */
private void handle(HttpServerRequestImpl req, Buffer body) { private void handle(HttpServerRequestImpl req, Buffer body) {
DefaultHttpRequest nettyReq = req.getRequest(); HttpRequest nettyReq = req.nettyRequest();
nettyReq = new DefaultFullHttpRequest( nettyReq = new DefaultFullHttpRequest(
nettyReq.protocolVersion(), nettyReq.protocolVersion(),
nettyReq.method(), nettyReq.method(),
Expand Down
40 changes: 29 additions & 11 deletions src/test/java/io/vertx/core/http/WebSocketTest.java
Expand Up @@ -60,7 +60,9 @@
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;


import static io.vertx.core.http.HttpTestBase.*; import static io.vertx.core.http.HttpTestBase.DEFAULT_HTTP_HOST;
import static io.vertx.core.http.HttpTestBase.DEFAULT_HTTP_PORT;
import static io.vertx.core.http.HttpTestBase.DEFAULT_TEST_URI;
import static io.vertx.test.core.TestUtils.*; import static io.vertx.test.core.TestUtils.*;


/** /**
Expand Down Expand Up @@ -1202,26 +1204,42 @@ private void testInvalidHandshake(Function<Handler<AsyncResult<HttpClientRespons
boolean expectEvent, boolean expectEvent,
boolean upgradeRequest, boolean upgradeRequest,
int expectedStatus) { int expectedStatus) {
client.close();
client = vertx.createHttpClient(new HttpClientOptions().setMaxPoolSize(1));
if (upgradeRequest) { if (upgradeRequest) {
server = vertx.createHttpServer().websocketHandler(ws -> { server = vertx.createHttpServer()
// Check we can get notified .websocketHandler(ws -> {
// handshake fails after this method returns and does not reject the ws // Check we can get notified
assertTrue(expectEvent); // handshake fails after this method returns and does not reject the ws
}); assertTrue(expectEvent);
})
.requestHandler(req -> {
req.response().end();
});
} else { } else {
AtomicBoolean first = new AtomicBoolean();
server = vertx.createHttpServer().requestHandler(req -> { server = vertx.createHttpServer().requestHandler(req -> {
try { if (first.compareAndSet(false, true)) {
req.upgrade(); try {
} catch (Exception e) { req.upgrade();
// Expected } catch (Exception e) {
// Expected
}
} else {
req.response().end();
} }
}); });
} }
server.listen(DEFAULT_HTTP_PORT, DEFAULT_HTTP_HOST, onSuccess(ar -> { server.listen(DEFAULT_HTTP_PORT, DEFAULT_HTTP_HOST, onSuccess(ar -> {
HttpClientRequest req = requestProvider.apply(onSuccess(resp -> { HttpClientRequest req = requestProvider.apply(onSuccess(resp -> {
assertEquals(expectedStatus, resp.statusCode()); assertEquals(expectedStatus, resp.statusCode());
resp.endHandler(v1 -> { resp.endHandler(v1 -> {
testComplete(); // Make another request to check the connection remains usable
client.getNow(DEFAULT_HTTP_PORT, DEFAULT_HTTP_HOST, DEFAULT_TEST_URI, onSuccess(resp2 -> {
resp2.endHandler(v2 -> {
testComplete();
});
}));
}); });
})); }));
req.end(); req.end();
Expand Down

0 comments on commit 29d1181

Please sign in to comment.