Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.

Commit c6f4eb6

Browse files
authored
Fix WebSocket server split header parsing with large payload (#30402) (#30407)
* Refactor WebSocket{Protocol}.CreateFromStream tests to be shared We have the WebSocket.CreateFromStream and WebSocketProtocol.CreateFromStream methods, which are identical, except that the former is netcoreapp-only and the latter is in a separate NuGet package for downlevel use. However, tests for them were separate, with some tests only for one and some tests for the other. This commit centralizes those tests so they're shared by and apply to both methods. There are no code changes to the actual bodies of tests, just moving code around to have it apply to both. * Fix WebSocket server split header parsing with large payload When ReceiveAsync is called, as a fast path it checks to see whether there's already enough data in the buffer to satisfy any possible header, skipping subsequent checks if there is. The max header size differs between client and server, though. The maximum size header a client can send to the server is 14 bytes, which includes a 4-byte masking value; the maximum size header a server can send to a client is 10 bytes, as it doesn't include a masking value. However, the code currently has those values reversed. If the code is running on the client, this means that we end up falling back to the slow-path unnecessarily if there are 10, 11, 12, or 13 bytes already in the buffer when ReceiveAsync is called. However, on the server, this means we end up potentially throwing an exception or misinterpreting the payload if 10, 11, 12, or 13 bytes are in the buffer and the packet contains a large payload (in which case it'll be using an 8-byte length and be the full 14 byte header), as we'll end up erroneously taking the fast path when we should have taken the slow path to read more data from the network. The fix is simply to swap the branches of the conditional. * Address PR feedback
1 parent 64635e4 commit c6f4eb6

File tree

6 files changed

+219
-173
lines changed

6 files changed

+219
-173
lines changed

src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ private async ValueTask<TWebSocketReceiveResult> ReceiveAsyncPrivate<TWebSocketR
625625
MessageHeader header = _lastReceiveHeader;
626626
if (header.PayloadLength == 0)
627627
{
628-
if (_receiveBufferCount < (_isServer ? (MaxMessageHeaderLength - MaskLength) : MaxMessageHeaderLength))
628+
if (_receiveBufferCount < (_isServer ? MaxMessageHeaderLength : (MaxMessageHeaderLength - MaskLength)))
629629
{
630630
// Make sure we have the first two bytes, which includes the start of the payload length.
631631
if (_receiveBufferCount < 2)
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using System.IO;
7+
using System.Net.Security;
8+
using System.Net.Sockets;
9+
using System.Text;
10+
using System.Threading;
11+
using System.Threading.Tasks;
12+
using Xunit;
13+
14+
namespace System.Net.WebSockets.Tests
15+
{
16+
public abstract class WebSocketCreateTest
17+
{
18+
protected abstract WebSocket CreateFromStream(Stream stream, bool isServer, string subProtocol, TimeSpan keepAliveInterval);
19+
20+
[Fact]
21+
public void CreateFromStream_InvalidArguments_Throws()
22+
{
23+
AssertExtensions.Throws<ArgumentNullException>("stream", () => CreateFromStream(null, true, "subProtocol", TimeSpan.FromSeconds(30)));
24+
AssertExtensions.Throws<ArgumentException>("stream", () => CreateFromStream(new MemoryStream(new byte[100], writable: false), true, "subProtocol", TimeSpan.FromSeconds(30)));
25+
AssertExtensions.Throws<ArgumentException>("stream", () => CreateFromStream(new UnreadableStream(), true, "subProtocol", TimeSpan.FromSeconds(30)));
26+
27+
AssertExtensions.Throws<ArgumentException>("subProtocol", () => CreateFromStream(new MemoryStream(), true, " ", TimeSpan.FromSeconds(30)));
28+
AssertExtensions.Throws<ArgumentException>("subProtocol", () => CreateFromStream(new MemoryStream(), true, "\xFF", TimeSpan.FromSeconds(30)));
29+
30+
AssertExtensions.Throws<ArgumentOutOfRangeException>("keepAliveInterval", () => CreateFromStream(new MemoryStream(), true, "subProtocol", TimeSpan.FromSeconds(-2)));
31+
}
32+
33+
[Theory]
34+
[InlineData(0)]
35+
[InlineData(1)]
36+
[InlineData(14)]
37+
[InlineData(4096)]
38+
public void CreateFromStream_ValidBufferSizes_CreatesWebSocket(int bufferSize)
39+
{
40+
Assert.NotNull(CreateFromStream(new MemoryStream(), false, null, Timeout.InfiniteTimeSpan));
41+
Assert.NotNull(CreateFromStream(new MemoryStream(), true, null, Timeout.InfiniteTimeSpan));
42+
}
43+
44+
[OuterLoop] // Connects to external server.
45+
[Theory]
46+
[MemberData(nameof(EchoServers))]
47+
public async Task WebSocketProtocol_CreateFromConnectedStream_CanSendReceiveData(Uri echoUri)
48+
{
49+
using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
50+
{
51+
bool secure = echoUri.Scheme == "wss";
52+
client.Connect(echoUri.Host, secure ? 443 : 80);
53+
54+
Stream stream = new NetworkStream(client, ownsSocket: false);
55+
if (secure)
56+
{
57+
SslStream ssl = new SslStream(stream, leaveInnerStreamOpen: true, delegate { return true; });
58+
await ssl.AuthenticateAsClientAsync(echoUri.Host);
59+
stream = ssl;
60+
}
61+
62+
using (stream)
63+
{
64+
using (var writer = new StreamWriter(stream, Encoding.ASCII, bufferSize: 1, leaveOpen: true))
65+
{
66+
await writer.WriteAsync($"GET {echoUri.PathAndQuery} HTTP/1.1\r\n");
67+
await writer.WriteAsync($"Host: {echoUri.Host}\r\n");
68+
await writer.WriteAsync($"Upgrade: websocket\r\n");
69+
await writer.WriteAsync($"Connection: Upgrade\r\n");
70+
await writer.WriteAsync($"Sec-WebSocket-Version: 13\r\n");
71+
await writer.WriteAsync($"Sec-WebSocket-Key: {Convert.ToBase64String(Guid.NewGuid().ToByteArray())}\r\n");
72+
await writer.WriteAsync($"\r\n");
73+
}
74+
75+
using (var reader = new StreamReader(stream, Encoding.ASCII, detectEncodingFromByteOrderMarks: false, bufferSize: 1, leaveOpen: true))
76+
{
77+
string statusLine = await reader.ReadLineAsync();
78+
Assert.NotEmpty(statusLine);
79+
Assert.Equal("HTTP/1.1 101 Switching Protocols", statusLine);
80+
while (!string.IsNullOrEmpty(await reader.ReadLineAsync())) ;
81+
}
82+
83+
using (WebSocket socket = CreateFromStream(stream, false, null, TimeSpan.FromSeconds(10)))
84+
{
85+
Assert.NotNull(socket);
86+
Assert.Equal(WebSocketState.Open, socket.State);
87+
88+
string expected = "Hello World!";
89+
ArraySegment<byte> buffer = new ArraySegment<byte>(Encoding.UTF8.GetBytes(expected));
90+
await socket.SendAsync(buffer, WebSocketMessageType.Text, true, CancellationToken.None);
91+
92+
buffer = new ArraySegment<byte>(new byte[buffer.Count]);
93+
await socket.ReceiveAsync(buffer, CancellationToken.None);
94+
95+
Assert.Equal(expected, Encoding.UTF8.GetString(buffer.Array));
96+
}
97+
}
98+
}
99+
}
100+
101+
[Fact]
102+
public async Task ReceiveAsync_UTF8SplitAcrossMultipleBuffers_ValidDataReceived()
103+
{
104+
// 1 character - 2 bytes
105+
byte[] payload = Encoding.UTF8.GetBytes("\u00E6");
106+
var frame = new byte[payload.Length + 2];
107+
frame[0] = 0x81; // FIN = true, Opcode = Text
108+
frame[1] = (byte)payload.Length;
109+
Array.Copy(payload, 0, frame, 2, payload.Length);
110+
111+
using (var stream = new MemoryStream(frame, writable: true))
112+
{
113+
WebSocket websocket = CreateFromStream(stream, false, "null", Timeout.InfiniteTimeSpan);
114+
115+
// read first half of the multi-byte character
116+
var recvBuffer = new byte[1];
117+
WebSocketReceiveResult result = await websocket.ReceiveAsync(new ArraySegment<byte>(recvBuffer), CancellationToken.None);
118+
Assert.False(result.EndOfMessage);
119+
Assert.Equal(1, result.Count);
120+
Assert.Equal(0xc3, recvBuffer[0]);
121+
122+
// read second half of the multi-byte character
123+
result = await websocket.ReceiveAsync(new ArraySegment<byte>(recvBuffer), CancellationToken.None);
124+
Assert.True(result.EndOfMessage);
125+
Assert.Equal(1, result.Count);
126+
Assert.Equal(0xa6, recvBuffer[0]);
127+
}
128+
}
129+
130+
[Fact]
131+
public async Task ReceiveAsync_ServerSplitHeader_ValidDataReceived()
132+
{
133+
using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
134+
using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
135+
{
136+
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
137+
listener.Listen(1);
138+
139+
await client.ConnectAsync(listener.LocalEndPoint);
140+
using (Socket server = await listener.AcceptAsync())
141+
{
142+
WebSocket websocket = CreateFromStream(new NetworkStream(server, ownsSocket: false), isServer: true, null, Timeout.InfiniteTimeSpan);
143+
144+
// Send a full packet and a partial packet
145+
var packets = new byte[7 + 11 + 4];
146+
IList<byte> packet0 = new ArraySegment<byte>(packets, 0, 7);
147+
packet0[0] = 0x82; // fin, binary
148+
packet0[1] = 0x81; // masked, 1-byte length
149+
packet0[6] = 42; // content
150+
151+
IList<byte> partialPacket1 = new ArraySegment<byte>(packets, 7, 11);
152+
partialPacket1[0] = 0x82; // fin, binary
153+
partialPacket1[1] = 0xFF; // masked, 8-byte length
154+
partialPacket1[9] = 1; // length == 1
155+
156+
IList<byte> remainderPacket1 = new ArraySegment<byte>(packets, 7 + 11, 4);
157+
remainderPacket1[3] = 84; // content
158+
159+
await client.SendAsync(new ArraySegment<byte>(packets, 0, packet0.Count + partialPacket1.Count), SocketFlags.None);
160+
161+
// Read the first packet
162+
byte[] received = new byte[1];
163+
WebSocketReceiveResult r = await websocket.ReceiveAsync(new ArraySegment<byte>(received), default);
164+
Assert.True(r.EndOfMessage);
165+
Assert.Equal(1, r.Count);
166+
Assert.Equal(42, received[0]);
167+
168+
// Read the next packet, which is partial, then complete it.
169+
// Partial read shouldn't cause a failure.
170+
Task<WebSocketReceiveResult> tr = websocket.ReceiveAsync(new ArraySegment<byte>(received), default);
171+
Assert.False(tr.IsCompleted);
172+
await client.SendAsync((ArraySegment<byte>)remainderPacket1, SocketFlags.None);
173+
r = await tr;
174+
Assert.True(r.EndOfMessage);
175+
Assert.Equal(1, r.Count);
176+
Assert.Equal(84, received[0]);
177+
}
178+
}
179+
}
180+
181+
public static readonly object[][] EchoServers = System.Net.Test.Common.Configuration.WebSockets.EchoServers;
182+
183+
protected sealed class UnreadableStream : Stream
184+
{
185+
public override bool CanRead => false;
186+
public override bool CanSeek => true;
187+
public override bool CanWrite => true;
188+
public override long Length => throw new NotImplementedException();
189+
public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
190+
public override void Flush() => throw new NotImplementedException();
191+
public override int Read(byte[] buffer, int offset, int count) => throw new NotImplementedException();
192+
public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException();
193+
public override void SetLength(long value) => throw new NotImplementedException();
194+
public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException();
195+
}
196+
}
197+
}

src/System.Net.WebSockets.WebSocketProtocol/tests/System.Net.WebSockets.WebSocketProtocol.Tests.csproj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
<ProjectGuid>{CF73547B-07D2-4290-A14A-CA2A354F4D21}</ProjectGuid>
1010
</PropertyGroup>
1111
<ItemGroup>
12+
<Compile Include="$(CommonTestPath)\System\Net\WebSockets\WebSocketCreateTest.cs">
13+
<Link>Common\System\Net\WebSockets\WebSocketCreateTest.cs</Link>
14+
</Compile>
1215
<Compile Include="$(CommonTestPath)\System\Net\Configuration.cs">
1316
<Link>Common\System\Net\Configuration.cs</Link>
1417
</Compile>

src/System.Net.WebSockets.WebSocketProtocol/tests/WebSocketProtocolTests.cs

Lines changed: 3 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -3,147 +3,12 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.IO;
6-
using System.Net.Security;
7-
using System.Net.Sockets;
8-
using System.Text;
9-
using System.Threading;
10-
using System.Threading.Tasks;
11-
using Xunit;
126

137
namespace System.Net.WebSockets.Tests
148
{
15-
public sealed class WebSocketProtocolTests
9+
public sealed class WebSocketProtocolCreateTests : WebSocketCreateTest
1610
{
17-
[Fact]
18-
public void CreateFromStream_InvalidArguments_Throws()
19-
{
20-
AssertExtensions.Throws<ArgumentNullException>("stream",
21-
() => WebSocketProtocol.CreateFromStream(null, true, "subProtocol", TimeSpan.FromSeconds(30)));
22-
AssertExtensions.Throws<ArgumentException>("stream",
23-
() => WebSocketProtocol.CreateFromStream(new MemoryStream(new byte[100], writable: false), true, "subProtocol", TimeSpan.FromSeconds(30)));
24-
AssertExtensions.Throws<ArgumentException>("stream",
25-
() => WebSocketProtocol.CreateFromStream(new UnreadableStream(), true, "subProtocol", TimeSpan.FromSeconds(30)));
26-
27-
AssertExtensions.Throws<ArgumentException>("subProtocol",
28-
() => WebSocketProtocol.CreateFromStream(new MemoryStream(), true, " ", TimeSpan.FromSeconds(30)));
29-
AssertExtensions.Throws<ArgumentException>("subProtocol",
30-
() => WebSocketProtocol.CreateFromStream(new MemoryStream(), true, "\xFF", TimeSpan.FromSeconds(30)));
31-
32-
AssertExtensions.Throws<ArgumentOutOfRangeException>("keepAliveInterval", () =>
33-
WebSocketProtocol.CreateFromStream(new MemoryStream(), true, "subProtocol", TimeSpan.FromSeconds(-2)));
34-
}
35-
36-
[Theory]
37-
[InlineData(0)]
38-
[InlineData(1)]
39-
[InlineData(14)]
40-
[InlineData(4096)]
41-
public void CreateFromStream_ValidBufferSizes_Succeed(int bufferSize)
42-
{
43-
Assert.NotNull(WebSocketProtocol.CreateFromStream(new MemoryStream(), false, null, Timeout.InfiniteTimeSpan));
44-
Assert.NotNull(WebSocketProtocol.CreateFromStream(new MemoryStream(), true, null, Timeout.InfiniteTimeSpan));
45-
}
46-
47-
[OuterLoop] // Connects to external server.
48-
[Theory]
49-
[MemberData(nameof(EchoServers))]
50-
public async Task WebSocketProtocol_CreateFromConnectedStream_Succeeds(Uri echoUri)
51-
{
52-
using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
53-
{
54-
bool secure = echoUri.Scheme == "wss";
55-
client.Connect(echoUri.Host, secure ? 443 : 80);
56-
57-
Stream stream = new NetworkStream(client, ownsSocket: false);
58-
if (secure)
59-
{
60-
SslStream ssl = new SslStream(stream, leaveInnerStreamOpen: true, delegate { return true; });
61-
await ssl.AuthenticateAsClientAsync(echoUri.Host);
62-
stream = ssl;
63-
}
64-
65-
using (stream)
66-
{
67-
using (var writer = new StreamWriter(stream, Encoding.ASCII, bufferSize: 1, leaveOpen: true))
68-
{
69-
await writer.WriteAsync($"GET {echoUri.PathAndQuery} HTTP/1.1\r\n");
70-
await writer.WriteAsync($"Host: {echoUri.Host}\r\n");
71-
await writer.WriteAsync($"Upgrade: websocket\r\n");
72-
await writer.WriteAsync($"Connection: Upgrade\r\n");
73-
await writer.WriteAsync($"Sec-WebSocket-Version: 13\r\n");
74-
await writer.WriteAsync($"Sec-WebSocket-Key: {Convert.ToBase64String(Guid.NewGuid().ToByteArray())}\r\n");
75-
await writer.WriteAsync($"\r\n");
76-
}
77-
78-
using (var reader = new StreamReader(stream, Encoding.ASCII, detectEncodingFromByteOrderMarks: false, bufferSize: 1, leaveOpen: true))
79-
{
80-
string statusLine = await reader.ReadLineAsync();
81-
Assert.NotEmpty(statusLine);
82-
Assert.Equal("HTTP/1.1 101 Switching Protocols", statusLine);
83-
while (!string.IsNullOrEmpty(await reader.ReadLineAsync()));
84-
}
85-
86-
using (WebSocket socket = WebSocketProtocol.CreateFromStream(stream, false, null, TimeSpan.FromSeconds(10)))
87-
{
88-
Assert.NotNull(socket);
89-
Assert.Equal(WebSocketState.Open, socket.State);
90-
91-
string expected = "Hello World!";
92-
ArraySegment<byte> buffer = new ArraySegment<byte>(Encoding.UTF8.GetBytes(expected));
93-
await socket.SendAsync(buffer, WebSocketMessageType.Text, true, CancellationToken.None);
94-
95-
buffer = new ArraySegment<byte>(new byte[buffer.Count]);
96-
await socket.ReceiveAsync(buffer, CancellationToken.None);
97-
98-
Assert.Equal(expected, Encoding.UTF8.GetString(buffer.Array));
99-
}
100-
}
101-
}
102-
}
103-
104-
[Fact]
105-
public static async Task ManagedWebSocket_ReceiveUTF8SplitAcrossMultipleBuffers()
106-
{
107-
// 1 character - 2 bytes
108-
byte[] payload = Encoding.UTF8.GetBytes("\u00E6");
109-
var frame = new byte[payload.Length + 2];
110-
frame[0] = 0x81; // FIN = true, Opcode = Text
111-
frame[1] = (byte)payload.Length;
112-
Array.Copy(payload, 0, frame, 2, payload.Length);
113-
114-
using (var stream = new MemoryStream(frame, writable: true))
115-
{
116-
WebSocket websocket = WebSocketProtocol.CreateFromStream(stream, false, "null", Timeout.InfiniteTimeSpan);
117-
118-
// read first half of the multi-byte character
119-
var recvBuffer = new byte[1];
120-
WebSocketReceiveResult result = await websocket.ReceiveAsync(new ArraySegment<byte>(recvBuffer), CancellationToken.None);
121-
Assert.False(result.EndOfMessage);
122-
Assert.Equal(1, result.Count);
123-
Assert.Equal(0xc3, recvBuffer[0]);
124-
125-
// read second half of the multi-byte character
126-
result = await websocket.ReceiveAsync(new ArraySegment<byte>(recvBuffer), CancellationToken.None);
127-
Assert.True(result.EndOfMessage);
128-
Assert.Equal(1, result.Count);
129-
Assert.Equal(0xa6, recvBuffer[0]);
130-
}
131-
}
132-
133-
public static readonly object[][] EchoServers = System.Net.Test.Common.Configuration.WebSockets.EchoServers;
134-
135-
private sealed class UnreadableStream : Stream
136-
{
137-
public override bool CanRead => false;
138-
public override bool CanSeek => true;
139-
public override bool CanWrite => true;
140-
public override long Length => throw new NotImplementedException();
141-
public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
142-
public override void Flush() => throw new NotImplementedException();
143-
public override int Read(byte[] buffer, int offset, int count) => throw new NotImplementedException();
144-
public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException();
145-
public override void SetLength(long value) => throw new NotImplementedException();
146-
public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException();
147-
}
11+
protected override WebSocket CreateFromStream(Stream stream, bool isServer, string subProtocol, TimeSpan keepAliveInterval) =>
12+
WebSocketProtocol.CreateFromStream(stream, isServer, subProtocol, keepAliveInterval);
14813
}
14914
}

0 commit comments

Comments
 (0)