Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support handshake timeout in websocket handlers #8856

Merged
merged 9 commits into from May 22, 2019
Expand Up @@ -38,9 +38,11 @@
* {@link ClientHandshakeStateEvent#HANDSHAKE_ISSUED} or {@link ClientHandshakeStateEvent#HANDSHAKE_COMPLETE}.
*/
public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
private static final long DEFAULT_HANDSHAKE_TIMEOUT = 10000L;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DEFAULT_HANDSHAKE_TIMEOUT_MS ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DONE


private final WebSocketClientHandshaker handshaker;
private final boolean handleCloseFrames;
private final long handshakeTimeoutMillis;

/**
* Returns the used handshaker
Expand All @@ -53,6 +55,11 @@ public WebSocketClientHandshaker handshaker() {
* Events that are fired to notify about handshake status
*/
public enum ClientHandshakeStateEvent {
/**
* The Handshake was timed out
*/
HANDSHAKE_TIMEOUT,

/**
* The Handshake was started but the server did not response yet to the request
*/
Expand Down Expand Up @@ -92,9 +99,45 @@ public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version
boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength, boolean handleCloseFrames,
boolean performMasking, boolean allowMaskMismatch) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, handleCloseFrames, performMasking, allowMaskMismatch, DEFAULT_HANDSHAKE_TIMEOUT);
}

/**
* Base constructor
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param handleCloseFrames
* {@code true} if close frames should not be forwarded and just close the channel
* @param performMasking
* Whether to mask all written websocket frames. This must be set to true in order to be fully compatible
* with the websocket specifications. Client applications that communicate with a non-standard server
* which doesn't require masking might set this to false to achieve a higher performance.
* @param allowMaskMismatch
* When set to true, frames which are not masked properly according to the standard will still be
* accepted.
* @param handshakeTimeoutMillis
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength, boolean handleCloseFrames, boolean performMasking,
boolean allowMaskMismatch, long handshakeTimeoutMillis) {
this(WebSocketClientHandshakerFactory.newHandshaker(webSocketURL, version, subprotocol,
allowExtensions, customHeaders, maxFramePayloadLength,
performMasking, allowMaskMismatch), handleCloseFrames);
performMasking, allowMaskMismatch),
handleCloseFrames, handshakeTimeoutMillis);
}

/**
Expand All @@ -118,7 +161,34 @@ public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version
boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength, boolean handleCloseFrames) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
handleCloseFrames, true, false);
handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT);
}

/**
* Base constructor
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param handleCloseFrames
* {@code true} if close frames should not be forwarded and just close the channel
* @param handshakeTimeoutMillis
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean handleCloseFrames, long handshakeTimeoutMillis) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
handleCloseFrames, true, false, handshakeTimeoutMillis);
}

/**
Expand All @@ -140,7 +210,32 @@ public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version
boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength) {
this(webSocketURL, version, subprotocol,
allowExtensions, customHeaders, maxFramePayloadLength, true);
allowExtensions, customHeaders, maxFramePayloadLength, DEFAULT_HANDSHAKE_TIMEOUT);
}

/**
* Base constructor
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param handshakeTimeoutMillis
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength, long handshakeTimeoutMillis) {
this(webSocketURL, version, subprotocol,
allowExtensions, customHeaders, maxFramePayloadLength, true, handshakeTimeoutMillis);
}

/**
Expand All @@ -153,7 +248,24 @@ public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version
* {@code true} if close frames should not be forwarded and just close the channel
*/
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames) {
this(handshaker, handleCloseFrames, true);
this(handshaker, handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT);
}

/**
* Base constructor
*
* @param handshaker
* The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection
* was established to the remote peer.
* @param handleCloseFrames
* {@code true} if close frames should not be forwarded and just close the channel
* @param handshakeTimeoutMillis
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames,
long handshakeTimeoutMillis) {
this(handshaker, handleCloseFrames, true, handshakeTimeoutMillis);
}

/**
Expand All @@ -169,9 +281,29 @@ public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, bool
*/
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames,
boolean dropPongFrames) {
this(handshaker, handleCloseFrames, dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT);
}

/**
* Base constructor
*
* @param handshaker
* The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection
* was established to the remote peer.
* @param handleCloseFrames
* {@code true} if close frames should not be forwarded and just close the channel
* @param dropPongFrames
* {@code true} if pong frames should not be forwarded
* @param handshakeTimeoutMillis
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames,
boolean dropPongFrames, long handshakeTimeoutMillis) {
super(dropPongFrames);
this.handshaker = handshaker;
this.handleCloseFrames = handleCloseFrames;
this.handshakeTimeoutMillis = handshakeTimeoutMillis;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check positive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

/**
Expand All @@ -182,7 +314,21 @@ public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, bool
* was established to the remote peer.
*/
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker) {
this(handshaker, true);
this(handshaker, DEFAULT_HANDSHAKE_TIMEOUT);
}

/**
* Base constructor
*
* @param handshaker
* The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection
* was established to the remote peer.
* @param handshakeTimeoutMillis
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) {
this(handshaker, true, handshakeTimeoutMillis);
}

@Override
Expand All @@ -200,7 +346,7 @@ public void handlerAdded(ChannelHandlerContext ctx) {
if (cp.get(WebSocketClientProtocolHandshakeHandler.class) == null) {
// Add the WebSocketClientProtocolHandshakeHandler before this one.
ctx.pipeline().addBefore(ctx.name(), WebSocketClientProtocolHandshakeHandler.class.getName(),
new WebSocketClientProtocolHandshakeHandler(handshaker));
new WebSocketClientProtocolHandshakeHandler(handshaker, handshakeTimeoutMillis));
}
if (cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.
Expand Down
Expand Up @@ -19,13 +19,40 @@
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ScheduledFuture;
import io.netty.util.internal.ThrowableUtil;

import java.util.concurrent.TimeUnit;

class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapter {
private static final WebSocketHandshakeException HANDSHAKE_TIMED_OUT_EXCEPTION = ThrowableUtil.unknownStackTrace(
new WebSocketHandshakeException("handshake timed out"),
WebSocketClientProtocolHandshakeHandler.class,
"channelActive(...)");

private final WebSocketClientHandshaker handshaker;
private final long handshakeTimeoutMillis;
private volatile ChannelHandlerContext ctx;
private volatile ChannelPromise handshakePromise;

WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker) {
this(handshaker, 10000);
}

WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) {
this.handshaker = handshaker;
this.handshakeTimeoutMillis = handshakeTimeoutMillis;
}

@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
this.ctx = ctx;
handshakePromise = ctx.newPromise();
}

@Override
Expand All @@ -35,13 +62,15 @@ public void channelActive(final ChannelHandlerContext ctx) throws Exception {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
handshakePromise.tryFailure(future.cause());
ctx.fireExceptionCaught(future.cause());
} else {
ctx.fireUserEventTriggered(
WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_ISSUED);
}
}
});
applyHandshakeTimeout();
}

@Override
Expand All @@ -55,6 +84,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
try {
if (!handshaker.isHandshakeComplete()) {
handshaker.finishHandshake(ctx.channel(), response);
handshakePromise.trySuccess();
ctx.fireUserEventTriggered(
WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE);
ctx.pipeline().remove(this);
Expand All @@ -65,4 +95,43 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
response.release();
}
}

private void applyHandshakeTimeout() {
final ChannelPromise localHandshakePromise = handshakePromise;
if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we also need to check if localHandshakePromise == null ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handshakePromise is initialized when the handle is added. Is there a case where handshakePromise is empty?

return;
}

final ScheduledFuture<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this could be Future<?>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@Override
public void run() {
if (localHandshakePromise.isDone()) {
return;
}

if (localHandshakePromise.tryFailure(HANDSHAKE_TIMED_OUT_EXCEPTION)) {
ctx.flush()
.fireUserEventTriggered(ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT)
.close();
}
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);

// Cancel the handshake timeout when handshake is finished.
localHandshakePromise.addListener(new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> f) throws Exception {
timeoutFuture.cancel(false);
}
});
}

/**
* This method is visible for testing.
*
* @return current handshake future
*/
ChannelFuture getHandshakeFuture() {
return handshakePromise;
}
}