@@ -63,8 +63,6 @@ public static ManagedWebSocket CreateFromConnectedStream(
6363 private const int MaxControlPayloadLength = 125 ;
6464 /// <summary>Length of the mask XOR'd with the payload data.</summary>
6565 private const int MaskLength = 4 ;
66- /// <summary>Default length of a receive buffer to create when an invalid scratch buffer is provided.</summary>
67- private const int DefaultReceiveBufferSize = 0x1000 ;
6866
6967 /// <summary>The stream used to communicate with the remote server.</summary>
7068 private readonly Stream _stream ;
@@ -184,8 +182,11 @@ private ManagedWebSocket(Stream stream, bool isServer, string subprotocol, TimeS
184182 // socket rents a similarly sized buffer from the pool for its duration, we'll end up draining
185183 // the pool, such that other web sockets will allocate anyway, as will anyone else in the process using the
186184 // pool. If someone wants to pool, they can do so by passing in the buffer they want to use, and they can
187- // get it from whatever pool they like.
188- _receiveBuffer = buffer . Length >= MaxMessageHeaderLength ? buffer : new byte [ DefaultReceiveBufferSize ] ;
185+ // get it from whatever pool they like. If we create our own buffer, it's small, large enough for message
186+ // headers and control payloads, but data for other message payloads is read directly into the buffers
187+ // passed into ReceiveAsync.
188+ const int ReceiveBufferMinLength = MaxControlPayloadLength ;
189+ _receiveBuffer = buffer . Length >= ReceiveBufferMinLength ? buffer : new byte [ ReceiveBufferMinLength ] ;
189190
190191 // Set up the abort source so that if it's triggered, we transition the instance appropriately.
191192 _abortSource . Token . Register ( s =>
@@ -697,32 +698,56 @@ private async ValueTask<TWebSocketReceiveResult> ReceiveAsyncPrivate<TWebSocketR
697698 }
698699
699700 // Otherwise, read as much of the payload as we can efficiently, and upate the header to reflect how much data
700- // remains for future reads.
701- int bytesToCopy = Math . Min ( payloadBuffer . Length , ( int ) Math . Min ( header . PayloadLength , _receiveBuffer . Length ) ) ;
702- Debug . Assert ( bytesToCopy > 0 , $ "Expected { nameof ( bytesToCopy ) } > 0") ;
703- if ( _receiveBufferCount < bytesToCopy )
701+ // remains for future reads. We first need to copy any data that may be lingering in the receive buffer
702+ // into the destination; then to minimize ReceiveAsync calls, we want to read as much as we can, stopping
703+ // only when we've either read the whole message or when we've filled the payload buffer.
704+
705+ // First copy any data lingering in the receive buffer.
706+ int totalBytesReceived = 0 ;
707+ if ( _receiveBufferCount > 0 )
704708 {
705- await EnsureBufferContainsAsync ( bytesToCopy , throwOnPrematureClosure : true ) . ConfigureAwait ( false ) ;
709+ int receiveBufferBytesToCopy = Math . Min ( payloadBuffer . Length , ( int ) Math . Min ( header . PayloadLength , _receiveBufferCount ) ) ;
710+ Debug . Assert ( receiveBufferBytesToCopy > 0 ) ;
711+ _receiveBuffer . Span . Slice ( _receiveBufferOffset , receiveBufferBytesToCopy ) . CopyTo ( payloadBuffer . Span ) ;
712+ ConsumeFromBuffer ( receiveBufferBytesToCopy ) ;
713+ totalBytesReceived += receiveBufferBytesToCopy ;
714+ Debug . Assert (
715+ _receiveBufferCount == 0 ||
716+ totalBytesReceived == payloadBuffer . Length ||
717+ totalBytesReceived == header . PayloadLength ) ;
718+ }
719+
720+ // Then read directly into the payload buffer until we've hit a limit.
721+ while ( totalBytesReceived < payloadBuffer . Length &&
722+ totalBytesReceived < header . PayloadLength )
723+ {
724+ int numBytesRead = await _stream . ReadAsync ( payloadBuffer . Slice (
725+ totalBytesReceived ,
726+ ( int ) Math . Min ( payloadBuffer . Length , header . PayloadLength ) - totalBytesReceived ) ) . ConfigureAwait ( false ) ;
727+ if ( numBytesRead <= 0 )
728+ {
729+ ThrowIfEOFUnexpected ( throwOnPrematureClosure : true ) ;
730+ break ;
731+ }
732+ totalBytesReceived += numBytesRead ;
706733 }
707734
708735 if ( _isServer )
709736 {
710- _receivedMaskOffsetOffset = ApplyMask ( _receiveBuffer . Span . Slice ( _receiveBufferOffset , bytesToCopy ) , header . Mask , _receivedMaskOffsetOffset ) ;
737+ _receivedMaskOffsetOffset = ApplyMask ( payloadBuffer . Span . Slice ( 0 , totalBytesReceived ) , header . Mask , _receivedMaskOffsetOffset ) ;
711738 }
712- _receiveBuffer . Span . Slice ( _receiveBufferOffset , bytesToCopy ) . CopyTo ( payloadBuffer . Span . Slice ( 0 , bytesToCopy ) ) ;
713- ConsumeFromBuffer ( bytesToCopy ) ;
714- header . PayloadLength -= bytesToCopy ;
739+ header . PayloadLength -= totalBytesReceived ;
715740
716741 // If this a text message, validate that it contains valid UTF8.
717742 if ( header . Opcode == MessageOpcode . Text &&
718- ! TryValidateUtf8 ( payloadBuffer . Span . Slice ( 0 , bytesToCopy ) , header . Fin , _utf8TextState ) )
743+ ! TryValidateUtf8 ( payloadBuffer . Span . Slice ( 0 , totalBytesReceived ) , header . Fin , _utf8TextState ) )
719744 {
720745 await CloseWithReceiveErrorAndThrowAsync ( WebSocketCloseStatus . InvalidPayloadData , WebSocketError . Faulted ) . ConfigureAwait ( false ) ;
721746 }
722747
723748 _lastReceiveHeader = header ;
724749 return resultGetter . GetResult (
725- bytesToCopy ,
750+ totalBytesReceived ,
726751 header . Opcode == MessageOpcode . Text ? WebSocketMessageType . Text : WebSocketMessageType . Binary ,
727752 header . Fin && header . PayloadLength == 0 ,
728753 null , null ) ;
@@ -1165,27 +1190,32 @@ private async Task EnsureBufferContainsAsync(int minimumRequiredBytes, bool thro
11651190 {
11661191 int numRead = await _stream . ReadAsync ( _receiveBuffer . Slice ( _receiveBufferCount , _receiveBuffer . Length - _receiveBufferCount ) , default ) . ConfigureAwait ( false ) ;
11671192 Debug . Assert ( numRead >= 0 , $ "Expected non-negative bytes read, got { numRead } ") ;
1168- _receiveBufferCount += numRead ;
1169- if ( numRead == 0 )
1193+ if ( numRead <= 0 )
11701194 {
1171- // The connection closed before we were able to read everything we needed.
1172- // If it was due to use being disposed, fail. If it was due to the connection
1173- // being closed and it wasn't expected, fail. If it was due to the connection
1174- // being closed and that was expected, exit gracefully.
1175- if ( _disposed )
1176- {
1177- throw new ObjectDisposedException ( nameof ( WebSocket ) ) ;
1178- }
1179- else if ( throwOnPrematureClosure )
1180- {
1181- throw new WebSocketException ( WebSocketError . ConnectionClosedPrematurely ) ;
1182- }
1195+ ThrowIfEOFUnexpected ( throwOnPrematureClosure ) ;
11831196 break ;
11841197 }
1198+ _receiveBufferCount += numRead ;
11851199 }
11861200 }
11871201 }
11881202
1203+ private void ThrowIfEOFUnexpected ( bool throwOnPrematureClosure )
1204+ {
1205+ // The connection closed before we were able to read everything we needed.
1206+ // If it was due to us being disposed, fail. If it was due to the connection
1207+ // being closed and it wasn't expected, fail. If it was due to the connection
1208+ // being closed and that was expected, exit gracefully.
1209+ if ( _disposed )
1210+ {
1211+ throw new ObjectDisposedException ( nameof ( WebSocket ) ) ;
1212+ }
1213+ if ( throwOnPrematureClosure )
1214+ {
1215+ throw new WebSocketException ( WebSocketError . ConnectionClosedPrematurely ) ;
1216+ }
1217+ }
1218+
11891219 /// <summary>Gets a send buffer from the pool.</summary>
11901220 private void AllocateSendBuffer ( int minLength )
11911221 {
0 commit comments