Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.

Commit 6cd1708

Browse files
authored
Give WebSocket server time to close connection (#27993)
* Avoid unnecessary CancellationToken arguments * Give WebSocket server time to close connection Per RFC 6455, a websocket client should try to let the server close the connection. * Address PR feedback
1 parent d4853e2 commit 6cd1708

File tree

1 file changed

+57
-24
lines changed

1 file changed

+57
-24
lines changed

src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ private async Task SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessag
440440
int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span);
441441
using (cancellationToken.Register(s => ((ManagedWebSocket)s).Abort(), this))
442442
{
443-
await _stream.WriteAsync(new ReadOnlyMemory<byte>(_sendBuffer, 0, sendBytes), cancellationToken).ConfigureAwait(false);
443+
await _stream.WriteAsync(new ReadOnlyMemory<byte>(_sendBuffer, 0, sendBytes), default).ConfigureAwait(false);
444444
}
445445
}
446446
catch (Exception exc) when (!(exc is OperationCanceledException))
@@ -636,7 +636,7 @@ private async ValueTask<TWebSocketReceiveResult> ReceiveAsyncPrivate<TWebSocketR
636636
// Make sure we have the first two bytes, which includes the start of the payload length.
637637
if (_receiveBufferCount < 2)
638638
{
639-
await EnsureBufferContainsAsync(2, cancellationToken, throwOnPrematureClosure: true).ConfigureAwait(false);
639+
await EnsureBufferContainsAsync(2, throwOnPrematureClosure: true).ConfigureAwait(false);
640640
}
641641

642642
// Then make sure we have the full header based on the payload length.
@@ -648,13 +648,13 @@ private async ValueTask<TWebSocketReceiveResult> ReceiveAsyncPrivate<TWebSocketR
648648
2 +
649649
(_isServer ? MaskLength : 0) +
650650
(payloadLength <= 125 ? 0 : payloadLength == 126 ? sizeof(ushort) : sizeof(ulong)); // additional 2 or 8 bytes for 16-bit or 64-bit length
651-
await EnsureBufferContainsAsync(minNeeded, cancellationToken).ConfigureAwait(false);
651+
await EnsureBufferContainsAsync(minNeeded).ConfigureAwait(false);
652652
}
653653
}
654654

655655
if (!TryParseMessageHeaderFromReceiveBuffer(out header))
656656
{
657-
await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, cancellationToken).ConfigureAwait(false);
657+
await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted).ConfigureAwait(false);
658658
}
659659
_receivedMaskOffsetOffset = 0;
660660
}
@@ -664,12 +664,12 @@ private async ValueTask<TWebSocketReceiveResult> ReceiveAsyncPrivate<TWebSocketR
664664
// Alternatively, if it's a close message, handle it and exit.
665665
if (header.Opcode == MessageOpcode.Ping || header.Opcode == MessageOpcode.Pong)
666666
{
667-
await HandleReceivedPingPongAsync(header, cancellationToken).ConfigureAwait(false);
667+
await HandleReceivedPingPongAsync(header).ConfigureAwait(false);
668668
continue;
669669
}
670670
else if (header.Opcode == MessageOpcode.Close)
671671
{
672-
await HandleReceivedCloseAsync(header, cancellationToken).ConfigureAwait(false);
672+
await HandleReceivedCloseAsync(header).ConfigureAwait(false);
673673
return resultGetter.GetResult(0, WebSocketMessageType.Close, true, _closeStatus, _closeStatusDescription);
674674
}
675675

@@ -699,7 +699,7 @@ private async ValueTask<TWebSocketReceiveResult> ReceiveAsyncPrivate<TWebSocketR
699699
Debug.Assert(bytesToCopy > 0, $"Expected {nameof(bytesToCopy)} > 0");
700700
if (_receiveBufferCount < bytesToCopy)
701701
{
702-
await EnsureBufferContainsAsync(bytesToCopy, cancellationToken, throwOnPrematureClosure: true).ConfigureAwait(false);
702+
await EnsureBufferContainsAsync(bytesToCopy, throwOnPrematureClosure: true).ConfigureAwait(false);
703703
}
704704

705705
if (_isServer)
@@ -714,7 +714,7 @@ private async ValueTask<TWebSocketReceiveResult> ReceiveAsyncPrivate<TWebSocketR
714714
if (header.Opcode == MessageOpcode.Text &&
715715
!TryValidateUtf8(payloadBuffer.Span.Slice(0, bytesToCopy), header.Fin, _utf8TextState))
716716
{
717-
await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted, cancellationToken).ConfigureAwait(false);
717+
await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false);
718718
}
719719

720720
_lastReceiveHeader = header;
@@ -742,10 +742,8 @@ private async ValueTask<TWebSocketReceiveResult> ReceiveAsyncPrivate<TWebSocketR
742742

743743
/// <summary>Processes a received close message.</summary>
744744
/// <param name="header">The message header.</param>
745-
/// <param name="cancellationToken">The cancellation token to use to cancel the websocket.</param>
746745
/// <returns>The received result message.</returns>
747-
private async Task HandleReceivedCloseAsync(
748-
MessageHeader header, CancellationToken cancellationToken)
746+
private async Task HandleReceivedCloseAsync(MessageHeader header)
749747
{
750748
lock (StateUpdateLock)
751749
{
@@ -763,13 +761,13 @@ private async Task HandleReceivedCloseAsync(
763761
if (header.PayloadLength == 1)
764762
{
765763
// The close payload length can be 0 or >= 2, but not 1.
766-
await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, cancellationToken).ConfigureAwait(false);
764+
await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted).ConfigureAwait(false);
767765
}
768766
else if (header.PayloadLength >= 2)
769767
{
770768
if (_receiveBufferCount < header.PayloadLength)
771769
{
772-
await EnsureBufferContainsAsync((int)header.PayloadLength, cancellationToken).ConfigureAwait(false);
770+
await EnsureBufferContainsAsync((int)header.PayloadLength).ConfigureAwait(false);
773771
}
774772

775773
if (_isServer)
@@ -780,7 +778,7 @@ private async Task HandleReceivedCloseAsync(
780778
closeStatus = (WebSocketCloseStatus)(_receiveBuffer.Span[_receiveBufferOffset] << 8 | _receiveBuffer.Span[_receiveBufferOffset + 1]);
781779
if (!IsValidCloseStatus(closeStatus))
782780
{
783-
await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, cancellationToken).ConfigureAwait(false);
781+
await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted).ConfigureAwait(false);
784782
}
785783

786784
if (header.PayloadLength > 2)
@@ -791,7 +789,7 @@ private async Task HandleReceivedCloseAsync(
791789
}
792790
catch (DecoderFallbackException exc)
793791
{
794-
await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, cancellationToken, exc).ConfigureAwait(false);
792+
await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, exc).ConfigureAwait(false);
795793
}
796794
}
797795
ConsumeFromBuffer((int)header.PayloadLength);
@@ -800,17 +798,48 @@ private async Task HandleReceivedCloseAsync(
800798
// Store the close status and description onto the instance.
801799
_closeStatus = closeStatus;
802800
_closeStatusDescription = closeStatusDescription;
801+
802+
if (!_isServer && _sentCloseFrame)
803+
{
804+
await WaitForServerToCloseConnectionAsync().ConfigureAwait(false);
805+
}
806+
}
807+
808+
/// <summary>Issues a read on the stream to wait for EOF.</summary>
809+
private async Task WaitForServerToCloseConnectionAsync()
810+
{
811+
// Per RFC 6455 7.1.1, try to let the server close the connection. We give it up to a second.
812+
// We simply issue a read and don't care what we get back; we could validate that we don't get
813+
// additional data, but at this point we're about to close the connection and we're just stalling
814+
// to try to get the server to close first.
815+
ValueTask<int> finalReadTask = _stream.ReadAsync(_receiveBuffer, default);
816+
if (!finalReadTask.IsCompletedSuccessfully)
817+
{
818+
const int WaitForCloseTimeoutMs = 1_000; // arbitrary amount of time to give the server (same as netfx)
819+
using (var finalCts = new CancellationTokenSource(WaitForCloseTimeoutMs))
820+
using (finalCts.Token.Register(s => ((ManagedWebSocket)s).Abort(), this))
821+
{
822+
try
823+
{
824+
await finalReadTask.ConfigureAwait(false);
825+
}
826+
catch
827+
{
828+
// Eat any resulting exceptions. We were going to close the connection, anyway.
829+
// TODO #24057: Log the exception to NetEventSource.
830+
}
831+
}
832+
}
803833
}
804834

805835
/// <summary>Processes a received ping or pong message.</summary>
806836
/// <param name="header">The message header.</param>
807-
/// <param name="cancellationToken">The cancellation token to use to cancel the websocket.</param>
808-
private async Task HandleReceivedPingPongAsync(MessageHeader header, CancellationToken cancellationToken)
837+
private async Task HandleReceivedPingPongAsync(MessageHeader header)
809838
{
810839
// Consume any (optional) payload associated with the ping/pong.
811840
if (header.PayloadLength > 0 && _receiveBufferCount < header.PayloadLength)
812841
{
813-
await EnsureBufferContainsAsync((int)header.PayloadLength, cancellationToken).ConfigureAwait(false);
842+
await EnsureBufferContainsAsync((int)header.PayloadLength).ConfigureAwait(false);
814843
}
815844

816845
// If this was a ping, send back a pong response.
@@ -823,7 +852,7 @@ private async Task HandleReceivedPingPongAsync(MessageHeader header, Cancellatio
823852

824853
await SendFrameAsync(
825854
MessageOpcode.Pong, true,
826-
_receiveBuffer.Slice(_receiveBufferOffset, (int)header.PayloadLength), cancellationToken).ConfigureAwait(false);
855+
_receiveBuffer.Slice(_receiveBufferOffset, (int)header.PayloadLength), default).ConfigureAwait(false);
827856
}
828857

829858
// Regardless of whether it was a ping or pong, we no longer need the payload.
@@ -875,15 +904,14 @@ private static bool IsValidCloseStatus(WebSocketCloseStatus closeStatus)
875904
/// <summary>Send a close message to the server and throw an exception, in response to getting bad data from the server.</summary>
876905
/// <param name="closeStatus">The close status code to use.</param>
877906
/// <param name="error">The error reason.</param>
878-
/// <param name="cancellationToken">The CancellationToken used to cancel the websocket.</param>
879907
/// <param name="innerException">An optional inner exception to include in the thrown exception.</param>
880908
private async Task CloseWithReceiveErrorAndThrowAsync(
881-
WebSocketCloseStatus closeStatus, WebSocketError error, CancellationToken cancellationToken, Exception innerException = null)
909+
WebSocketCloseStatus closeStatus, WebSocketError error, Exception innerException = null)
882910
{
883911
// Close the connection if it hasn't already been closed
884912
if (!_sentCloseFrame)
885913
{
886-
await CloseOutputAsync(closeStatus, string.Empty, cancellationToken).ConfigureAwait(false);
914+
await CloseOutputAsync(closeStatus, string.Empty, default).ConfigureAwait(false);
887915
}
888916

889917
// Dump our receive buffer; we're in a bad state to do any further processing
@@ -1100,6 +1128,11 @@ private async Task SendCloseFrameAsync(WebSocketCloseStatus closeStatus, string
11001128
_state = WebSocketState.CloseSent;
11011129
}
11021130
}
1131+
1132+
if (!_isServer && _receivedCloseFrame)
1133+
{
1134+
await WaitForServerToCloseConnectionAsync().ConfigureAwait(false);
1135+
}
11031136
}
11041137

11051138
private void ConsumeFromBuffer(int count)
@@ -1110,7 +1143,7 @@ private void ConsumeFromBuffer(int count)
11101143
_receiveBufferOffset += count;
11111144
}
11121145

1113-
private async Task EnsureBufferContainsAsync(int minimumRequiredBytes, CancellationToken cancellationToken, bool throwOnPrematureClosure = true)
1146+
private async Task EnsureBufferContainsAsync(int minimumRequiredBytes, bool throwOnPrematureClosure = true)
11141147
{
11151148
Debug.Assert(minimumRequiredBytes <= _receiveBuffer.Length, $"Requested number of bytes {minimumRequiredBytes} must not exceed {_receiveBuffer.Length}");
11161149

@@ -1127,7 +1160,7 @@ private async Task EnsureBufferContainsAsync(int minimumRequiredBytes, Cancellat
11271160
// While we don't have enough data, read more.
11281161
while (_receiveBufferCount < minimumRequiredBytes)
11291162
{
1130-
int numRead = await _stream.ReadAsync(_receiveBuffer.Slice(_receiveBufferCount, _receiveBuffer.Length - _receiveBufferCount), cancellationToken).ConfigureAwait(false);
1163+
int numRead = await _stream.ReadAsync(_receiveBuffer.Slice(_receiveBufferCount, _receiveBuffer.Length - _receiveBufferCount), default).ConfigureAwait(false);
11311164
Debug.Assert(numRead >= 0, $"Expected non-negative bytes read, got {numRead}");
11321165
_receiveBufferCount += numRead;
11331166
if (numRead == 0)

0 commit comments

Comments
 (0)