Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ namespace System.Net.WebSockets
/// </remarks>
internal sealed partial class ManagedWebSocket : WebSocket
{
/// <summary>Thread-safe random number generator used to generate masks for each send.</summary>
private static readonly RandomNumberGenerator s_random = RandomNumberGenerator.Create();
/// <summary>Encoding for the payload of text messages: UTF8 encoding that throws if invalid bytes are discovered, per the RFC.</summary>
private static readonly UTF8Encoding s_textEncoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true);

Expand Down Expand Up @@ -63,8 +61,6 @@ internal sealed partial class ManagedWebSocket : WebSocket
private readonly string? _subprotocol;
/// <summary>Timer used to send periodic pings to the server, at the interval specified</summary>
private readonly Timer? _keepAliveTimer;
/// <summary>CancellationTokenSource used to abort all current and future operations when anything is canceled or any error occurs.</summary>
private readonly CancellationTokenSource _abortSource = new CancellationTokenSource();
/// <summary>Buffer used for reading data from the network.</summary>
private readonly Memory<byte> _receiveBuffer;
/// <summary>
Expand Down Expand Up @@ -136,7 +132,7 @@ internal sealed partial class ManagedWebSocket : WebSocket
private Task _lastReceiveAsync = Task.CompletedTask;

/// <summary>Lock used to protect update and check-and-update operations on _state.</summary>
private object StateUpdateLock => _abortSource;
private object StateUpdateLock => _sendFrameAsyncLock;
/// <summary>
/// We need to coordinate between receives and close operations happening concurrently, as a ReceiveAsync may
/// be pending while a Close{Output}Async is issued, which itself needs to loop until a close frame is received.
Expand Down Expand Up @@ -173,25 +169,6 @@ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Tim
const int ReceiveBufferMinLength = MaxControlPayloadLength;
_receiveBuffer = new byte[ReceiveBufferMinLength];

// Set up the abort source so that if it's triggered, we transition the instance appropriately.
// There's no need to store the resulting CancellationTokenRegistration, as this instance owns
// the CancellationTokenSource, and the lifetime of that CTS matches the lifetime of the registration.
_abortSource.Token.UnsafeRegister(static s =>
{
var thisRef = (ManagedWebSocket)s!;

lock (thisRef.StateUpdateLock)
{
WebSocketState state = thisRef._state;
if (state != WebSocketState.Closed && state != WebSocketState.Aborted)
{
thisRef._state = state != WebSocketState.None && state != WebSocketState.Connecting ?
WebSocketState.Aborted :
WebSocketState.Closed;
}
}
}, this);

// Now that we're opened, initiate the keep alive timer to send periodic pings.
// We use a weak reference from the timer to the web socket to avoid a cycle
// that could keep the web socket rooted in erroneous cases.
Expand Down Expand Up @@ -387,10 +364,24 @@ private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string

public override void Abort()
{
_abortSource.Cancel();
OnAborted();
Dispose(); // forcibly tear down connection
}

private void OnAborted()
{
lock (StateUpdateLock)
{
WebSocketState state = _state;
if (state != WebSocketState.Closed && state != WebSocketState.Aborted)
{
_state = state != WebSocketState.None && state != WebSocketState.Connecting ?
WebSocketState.Aborted :
WebSocketState.Closed;
}
}
}

public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
{
return SendPrivateAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken);
Expand Down Expand Up @@ -721,7 +712,7 @@ private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnly
/// <param name="buffer">The buffer to which to write the mask.</param>
/// <param name="offset">The offset into the buffer at which to write the mask.</param>
private static void WriteRandomMask(byte[] buffer, int offset) =>
s_random.GetBytes(buffer, offset, MaskLength);
RandomNumberGenerator.Fill(buffer.AsSpan(offset, MaskLength));

/// <summary>
/// Receive the next text, binary, continuation, or close message, returning information about it and
Expand Down Expand Up @@ -917,7 +908,7 @@ private async ValueTask<TResult> ReceiveAsyncPrivate<TResult>(Memory<byte> paylo
{
throw new OperationCanceledException(nameof(WebSocketState.Aborted), exc);
}
_abortSource.Cancel();
OnAborted();

if (exc is WebSocketException)
{
Expand Down