Skip to content

Commit

Permalink
WebSocket client handshaker to support "force close" after timeout (#…
Browse files Browse the repository at this point in the history
…8896)

Motivation:

RFC 6455 defines that, generally, a WebSocket client should not close a TCP
connection as far as a server is the one who's responsible for doing that.
In practice tho', it's not always possible to control the server. Server's
misbehavior may lead to connections being leaked (if the server does not
comply with the RFC).

RFC 6455 #7.1.1 says

> In abnormal cases (such as not having received a TCP Close from the server
after a reasonable amount of time) a client MAY initiate the TCP Close.

Modifications:

* WebSocket client handshaker additional param `forceCloseAfterMillis`

* Use 10 seconds as default

Result:

WebSocket client handshaker to comply with RFC. Fixes #8883.
  • Loading branch information
kachayev authored and normanmaurer committed Apr 10, 2019
1 parent ac023da commit ee351ef
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.util.Locale;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;

/**
* Base class for web socket client handshake implementations
Expand All @@ -50,13 +53,23 @@ public abstract class WebSocketClientHandshaker {

private static final String HTTP_SCHEME_PREFIX = HttpScheme.HTTP + "://";
private static final String HTTPS_SCHEME_PREFIX = HttpScheme.HTTPS + "://";
protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000;

private final URI uri;

private final WebSocketVersion version;

private volatile boolean handshakeComplete;

private volatile long forceCloseTimeoutMillis = DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS;

private volatile int forceCloseInit;

private static final AtomicIntegerFieldUpdater<WebSocketClientHandshaker> FORCE_CLOSE_INIT_UPDATER =
AtomicIntegerFieldUpdater.newUpdater(WebSocketClientHandshaker.class, "forceCloseInit");

private volatile boolean forceCloseComplete;

private final String expectedSubprotocol;

private volatile String actualSubprotocol;
Expand All @@ -82,11 +95,35 @@ public abstract class WebSocketClientHandshaker {
*/
protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
HttpHeaders customHeaders, int maxFramePayloadLength) {
this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
}

/**
* Base constructor
*
* @param uri
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param forceCloseTimeoutMillis
* Close the connection if it was not closed by the server after timeout specified
*/
protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
HttpHeaders customHeaders, int maxFramePayloadLength,
long forceCloseTimeoutMillis) {
this.uri = uri;
this.version = version;
expectedSubprotocol = subprotocol;
this.customHeaders = customHeaders;
this.maxFramePayloadLength = maxFramePayloadLength;
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
}

/**
Expand Down Expand Up @@ -140,6 +177,29 @@ private void setActualSubprotocol(String actualSubprotocol) {
this.actualSubprotocol = actualSubprotocol;
}

public long forceCloseTimeoutMillis() {
return forceCloseTimeoutMillis;
}

/**
* Flag to indicate if the closing handshake was initiated because of timeout.
* For testing only.
*/
protected boolean isForceCloseComplete() {
return forceCloseComplete;
}

/**
* Sets timeout to close the connection if it was not closed by the server.
*
* @param forceCloseTimeoutMillis
* Close the connection if it was not closed by the server after timeout specified
*/
public WebSocketClientHandshaker setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
return this;
}

/**
* Begins the opening handshake
*
Expand Down Expand Up @@ -431,7 +491,46 @@ public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPr
if (channel == null) {
throw new NullPointerException("channel");
}
return channel.writeAndFlush(frame, promise);
channel.writeAndFlush(frame, promise);
applyForceCloseTimeout(channel, promise);
return promise;
}

private void applyForceCloseTimeout(final Channel channel, ChannelFuture flushFuture) {
final long forceCloseTimeoutMillis = this.forceCloseTimeoutMillis;
final WebSocketClientHandshaker handshaker = this;
if (forceCloseTimeoutMillis <= 0 || !channel.isActive() || forceCloseInit != 0) {
return;
}

flushFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// If flush operation failed, there is no reason to expect
// a server to receive CloseFrame. Thus this should be handled
// by the application separately.
// Also, close might be called twice from different threads.
if (future.isSuccess() && channel.isActive() &&
FORCE_CLOSE_INIT_UPDATER.compareAndSet(handshaker, 0, 1)) {
final Future<?> forceCloseFuture = channel.eventLoop().schedule(new Runnable() {
@Override
public void run() {
if (channel.isActive()) {
channel.close();
forceCloseComplete = true;
}
}
}, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS);

channel.closeFuture().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
forceCloseFuture.cancel(false);
}
});
}
}
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
private ByteBuf expectedChallengeResponseBytes;

/**
* Constructor specifying the destination web socket location and version to initiate
* Creates a new instance with the specified destination WebSocket location and version to initiate.
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
Expand All @@ -64,7 +64,31 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
*/
public WebSocketClientHandshaker00(URI webSocketURL, WebSocketVersion version, String subprotocol,
HttpHeaders customHeaders, int maxFramePayloadLength) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength);
this(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength,
DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
}

/**
* Creates a new instance with the specified destination WebSocket location and version to initiate.
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param forceCloseTimeoutMillis
* Close the connection if it was not closed by the server after timeout specified
*/
public WebSocketClientHandshaker00(URI webSocketURL, WebSocketVersion version, String subprotocol,
HttpHeaders customHeaders, int maxFramePayloadLength,
long forceCloseTimeoutMillis) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis);
}

/**
Expand Down Expand Up @@ -243,4 +267,11 @@ protected WebSocketFrameDecoder newWebsocketDecoder() {
protected WebSocketFrameEncoder newWebSocketEncoder() {
return new WebSocket00FrameEncoder();
}

@Override
public WebSocketClientHandshaker00 setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
super.setForceCloseTimeoutMillis(forceCloseTimeoutMillis);
return this;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,43 @@ public WebSocketClientHandshaker07(URI webSocketURL, WebSocketVersion version, S
* When set to true, frames which are not masked properly according to the standard will still be
* accepted.
*/
public WebSocketClientHandshaker07(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking,
allowMaskMismatch, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
}

/**
* Creates a new instance.
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
* Allow extensions to be used in the reserved bits of the web socket frame
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param performMasking
* Whether to mask all written websocket frames. This must be set to true in order to be fully compatible
* with the websocket specifications. Client applications that communicate with a non-standard server
* which doesn't require masking might set this to false to achieve a higher performance.
* @param allowMaskMismatch
* When set to true, frames which are not masked properly according to the standard will still be
* accepted
* @param forceCloseTimeoutMillis
* Close the connection if it was not closed by the server after timeout specified.
*/
public WebSocketClientHandshaker07(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength);
boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis);
this.allowExtensions = allowExtensions;
this.performMasking = performMasking;
this.allowMaskMismatch = allowMaskMismatch;
Expand Down Expand Up @@ -216,4 +249,11 @@ protected WebSocketFrameDecoder newWebsocketDecoder() {
protected WebSocketFrameEncoder newWebSocketEncoder() {
return new WebSocket07FrameEncoder(performMasking);
}

@Override
public WebSocketClientHandshaker07 setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
super.setForceCloseTimeoutMillis(forceCloseTimeoutMillis);
return this;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
*/
public WebSocketClientHandshaker08(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, true, false);
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, true,
false, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
}

/**
Expand All @@ -93,12 +94,45 @@ public WebSocketClientHandshaker08(URI webSocketURL, WebSocketVersion version, S
* which doesn't require masking might set this to false to achieve a higher performance.
* @param allowMaskMismatch
* When set to true, frames which are not masked properly according to the standard will still be
* accepted.
* accepted
*/
public WebSocketClientHandshaker08(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking,
allowMaskMismatch, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
}

/**
* Creates a new instance.
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
* Allow extensions to be used in the reserved bits of the web socket frame
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param performMasking
* Whether to mask all written websocket frames. This must be set to true in order to be fully compatible
* with the websocket specifications. Client applications that communicate with a non-standard server
* which doesn't require masking might set this to false to achieve a higher performance.
* @param allowMaskMismatch
* When set to true, frames which are not masked properly according to the standard will still be
* accepted
* @param forceCloseTimeoutMillis
* Close the connection if it was not closed by the server after timeout specified.
*/
public WebSocketClientHandshaker08(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength);
boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis);
this.allowExtensions = allowExtensions;
this.performMasking = performMasking;
this.allowMaskMismatch = allowMaskMismatch;
Expand Down Expand Up @@ -217,4 +251,11 @@ protected WebSocketFrameDecoder newWebsocketDecoder() {
protected WebSocketFrameEncoder newWebSocketEncoder() {
return new WebSocket08FrameEncoder(performMasking);
}

@Override
public WebSocketClientHandshaker08 setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
super.setForceCloseTimeoutMillis(forceCloseTimeoutMillis);
return this;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
*/
public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, true, false);
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
true, false);
}

/**
Expand Down Expand Up @@ -98,7 +99,41 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength);
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
performMasking, allowMaskMismatch, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
}

/**
* Creates a new instance.
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
* Allow extensions to be used in the reserved bits of the web socket frame
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param performMasking
* Whether to mask all written websocket frames. This must be set to true in order to be fully compatible
* with the websocket specifications. Client applications that communicate with a non-standard server
* which doesn't require masking might set this to false to achieve a higher performance.
* @param allowMaskMismatch
* When set to true, frames which are not masked properly according to the standard will still be
* accepted
* @param forceCloseTimeoutMillis
* Close the connection if it was not closed by the server after timeout specified.
*/
public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch,
long forceCloseTimeoutMillis) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis);
this.allowExtensions = allowExtensions;
this.performMasking = performMasking;
this.allowMaskMismatch = allowMaskMismatch;
Expand Down Expand Up @@ -217,4 +252,11 @@ protected WebSocketFrameDecoder newWebsocketDecoder() {
protected WebSocketFrameEncoder newWebSocketEncoder() {
return new WebSocket13FrameEncoder(performMasking);
}

@Override
public WebSocketClientHandshaker13 setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
super.setForceCloseTimeoutMillis(forceCloseTimeoutMillis);
return this;
}

}
Loading

0 comments on commit ee351ef

Please sign in to comment.