diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 971c2ceff82be4..01800a7a401330 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -26,8 +26,6 @@ namespace System.Net.WebSockets /// internal sealed partial class ManagedWebSocket : WebSocket { - /// Thread-safe random number generator used to generate masks for each send. - private static readonly RandomNumberGenerator s_random = RandomNumberGenerator.Create(); /// Encoding for the payload of text messages: UTF8 encoding that throws if invalid bytes are discovered, per the RFC. private static readonly UTF8Encoding s_textEncoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true); @@ -63,8 +61,6 @@ internal sealed partial class ManagedWebSocket : WebSocket private readonly string? _subprotocol; /// Timer used to send periodic pings to the server, at the interval specified private readonly Timer? _keepAliveTimer; - /// CancellationTokenSource used to abort all current and future operations when anything is canceled or any error occurs. - private readonly CancellationTokenSource _abortSource = new CancellationTokenSource(); /// Buffer used for reading data from the network. private readonly Memory _receiveBuffer; /// @@ -136,7 +132,7 @@ internal sealed partial class ManagedWebSocket : WebSocket private Task _lastReceiveAsync = Task.CompletedTask; /// Lock used to protect update and check-and-update operations on _state. - private object StateUpdateLock => _abortSource; + private object StateUpdateLock => _sendFrameAsyncLock; /// /// We need to coordinate between receives and close operations happening concurrently, as a ReceiveAsync may /// be pending while a Close{Output}Async is issued, which itself needs to loop until a close frame is received. @@ -173,25 +169,6 @@ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Tim const int ReceiveBufferMinLength = MaxControlPayloadLength; _receiveBuffer = new byte[ReceiveBufferMinLength]; - // Set up the abort source so that if it's triggered, we transition the instance appropriately. - // There's no need to store the resulting CancellationTokenRegistration, as this instance owns - // the CancellationTokenSource, and the lifetime of that CTS matches the lifetime of the registration. - _abortSource.Token.UnsafeRegister(static s => - { - var thisRef = (ManagedWebSocket)s!; - - lock (thisRef.StateUpdateLock) - { - WebSocketState state = thisRef._state; - if (state != WebSocketState.Closed && state != WebSocketState.Aborted) - { - thisRef._state = state != WebSocketState.None && state != WebSocketState.Connecting ? - WebSocketState.Aborted : - WebSocketState.Closed; - } - } - }, this); - // Now that we're opened, initiate the keep alive timer to send periodic pings. // We use a weak reference from the timer to the web socket to avoid a cycle // that could keep the web socket rooted in erroneous cases. @@ -387,10 +364,24 @@ private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string public override void Abort() { - _abortSource.Cancel(); + OnAborted(); Dispose(); // forcibly tear down connection } + private void OnAborted() + { + lock (StateUpdateLock) + { + WebSocketState state = _state; + if (state != WebSocketState.Closed && state != WebSocketState.Aborted) + { + _state = state != WebSocketState.None && state != WebSocketState.Connecting ? + WebSocketState.Aborted : + WebSocketState.Closed; + } + } + } + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { return SendPrivateAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken); @@ -721,7 +712,7 @@ private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnly /// The buffer to which to write the mask. /// The offset into the buffer at which to write the mask. private static void WriteRandomMask(byte[] buffer, int offset) => - s_random.GetBytes(buffer, offset, MaskLength); + RandomNumberGenerator.Fill(buffer.AsSpan(offset, MaskLength)); /// /// Receive the next text, binary, continuation, or close message, returning information about it and @@ -917,7 +908,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo { throw new OperationCanceledException(nameof(WebSocketState.Aborted), exc); } - _abortSource.Cancel(); + OnAborted(); if (exc is WebSocketException) {