|
| 1 | +// Licensed to the .NET Foundation under one or more agreements. |
| 2 | +// The .NET Foundation licenses this file to you under the MIT license. |
| 3 | +// See the LICENSE file in the project root for more information. |
| 4 | + |
| 5 | +using System.Collections.Generic; |
| 6 | +using System.IO; |
| 7 | +using System.Net.Security; |
| 8 | +using System.Net.Sockets; |
| 9 | +using System.Text; |
| 10 | +using System.Threading; |
| 11 | +using System.Threading.Tasks; |
| 12 | +using Xunit; |
| 13 | + |
| 14 | +namespace System.Net.WebSockets.Tests |
| 15 | +{ |
| 16 | + public abstract class WebSocketCreateTest |
| 17 | + { |
| 18 | + protected abstract WebSocket CreateFromStream(Stream stream, bool isServer, string subProtocol, TimeSpan keepAliveInterval); |
| 19 | + |
| 20 | + [Fact] |
| 21 | + public void CreateFromStream_InvalidArguments_Throws() |
| 22 | + { |
| 23 | + AssertExtensions.Throws<ArgumentNullException>("stream", () => CreateFromStream(null, true, "subProtocol", TimeSpan.FromSeconds(30))); |
| 24 | + AssertExtensions.Throws<ArgumentException>("stream", () => CreateFromStream(new MemoryStream(new byte[100], writable: false), true, "subProtocol", TimeSpan.FromSeconds(30))); |
| 25 | + AssertExtensions.Throws<ArgumentException>("stream", () => CreateFromStream(new UnreadableStream(), true, "subProtocol", TimeSpan.FromSeconds(30))); |
| 26 | + |
| 27 | + AssertExtensions.Throws<ArgumentException>("subProtocol", () => CreateFromStream(new MemoryStream(), true, " ", TimeSpan.FromSeconds(30))); |
| 28 | + AssertExtensions.Throws<ArgumentException>("subProtocol", () => CreateFromStream(new MemoryStream(), true, "\xFF", TimeSpan.FromSeconds(30))); |
| 29 | + |
| 30 | + AssertExtensions.Throws<ArgumentOutOfRangeException>("keepAliveInterval", () => CreateFromStream(new MemoryStream(), true, "subProtocol", TimeSpan.FromSeconds(-2))); |
| 31 | + } |
| 32 | + |
| 33 | + [Theory] |
| 34 | + [InlineData(0)] |
| 35 | + [InlineData(1)] |
| 36 | + [InlineData(14)] |
| 37 | + [InlineData(4096)] |
| 38 | + public void CreateFromStream_ValidBufferSizes_CreatesWebSocket(int bufferSize) |
| 39 | + { |
| 40 | + Assert.NotNull(CreateFromStream(new MemoryStream(), false, null, Timeout.InfiniteTimeSpan)); |
| 41 | + Assert.NotNull(CreateFromStream(new MemoryStream(), true, null, Timeout.InfiniteTimeSpan)); |
| 42 | + } |
| 43 | + |
| 44 | + [OuterLoop] // Connects to external server. |
| 45 | + [Theory] |
| 46 | + [MemberData(nameof(EchoServers))] |
| 47 | + public async Task WebSocketProtocol_CreateFromConnectedStream_CanSendReceiveData(Uri echoUri) |
| 48 | + { |
| 49 | + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) |
| 50 | + { |
| 51 | + bool secure = echoUri.Scheme == "wss"; |
| 52 | + client.Connect(echoUri.Host, secure ? 443 : 80); |
| 53 | + |
| 54 | + Stream stream = new NetworkStream(client, ownsSocket: false); |
| 55 | + if (secure) |
| 56 | + { |
| 57 | + SslStream ssl = new SslStream(stream, leaveInnerStreamOpen: true, delegate { return true; }); |
| 58 | + await ssl.AuthenticateAsClientAsync(echoUri.Host); |
| 59 | + stream = ssl; |
| 60 | + } |
| 61 | + |
| 62 | + using (stream) |
| 63 | + { |
| 64 | + using (var writer = new StreamWriter(stream, Encoding.ASCII, bufferSize: 1, leaveOpen: true)) |
| 65 | + { |
| 66 | + await writer.WriteAsync($"GET {echoUri.PathAndQuery} HTTP/1.1\r\n"); |
| 67 | + await writer.WriteAsync($"Host: {echoUri.Host}\r\n"); |
| 68 | + await writer.WriteAsync($"Upgrade: websocket\r\n"); |
| 69 | + await writer.WriteAsync($"Connection: Upgrade\r\n"); |
| 70 | + await writer.WriteAsync($"Sec-WebSocket-Version: 13\r\n"); |
| 71 | + await writer.WriteAsync($"Sec-WebSocket-Key: {Convert.ToBase64String(Guid.NewGuid().ToByteArray())}\r\n"); |
| 72 | + await writer.WriteAsync($"\r\n"); |
| 73 | + } |
| 74 | + |
| 75 | + using (var reader = new StreamReader(stream, Encoding.ASCII, detectEncodingFromByteOrderMarks: false, bufferSize: 1, leaveOpen: true)) |
| 76 | + { |
| 77 | + string statusLine = await reader.ReadLineAsync(); |
| 78 | + Assert.NotEmpty(statusLine); |
| 79 | + Assert.Equal("HTTP/1.1 101 Switching Protocols", statusLine); |
| 80 | + while (!string.IsNullOrEmpty(await reader.ReadLineAsync())) ; |
| 81 | + } |
| 82 | + |
| 83 | + using (WebSocket socket = CreateFromStream(stream, false, null, TimeSpan.FromSeconds(10))) |
| 84 | + { |
| 85 | + Assert.NotNull(socket); |
| 86 | + Assert.Equal(WebSocketState.Open, socket.State); |
| 87 | + |
| 88 | + string expected = "Hello World!"; |
| 89 | + ArraySegment<byte> buffer = new ArraySegment<byte>(Encoding.UTF8.GetBytes(expected)); |
| 90 | + await socket.SendAsync(buffer, WebSocketMessageType.Text, true, CancellationToken.None); |
| 91 | + |
| 92 | + buffer = new ArraySegment<byte>(new byte[buffer.Count]); |
| 93 | + await socket.ReceiveAsync(buffer, CancellationToken.None); |
| 94 | + |
| 95 | + Assert.Equal(expected, Encoding.UTF8.GetString(buffer.Array)); |
| 96 | + } |
| 97 | + } |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + [Fact] |
| 102 | + public async Task ReceiveAsync_UTF8SplitAcrossMultipleBuffers_ValidDataReceived() |
| 103 | + { |
| 104 | + // 1 character - 2 bytes |
| 105 | + byte[] payload = Encoding.UTF8.GetBytes("\u00E6"); |
| 106 | + var frame = new byte[payload.Length + 2]; |
| 107 | + frame[0] = 0x81; // FIN = true, Opcode = Text |
| 108 | + frame[1] = (byte)payload.Length; |
| 109 | + Array.Copy(payload, 0, frame, 2, payload.Length); |
| 110 | + |
| 111 | + using (var stream = new MemoryStream(frame, writable: true)) |
| 112 | + { |
| 113 | + WebSocket websocket = CreateFromStream(stream, false, "null", Timeout.InfiniteTimeSpan); |
| 114 | + |
| 115 | + // read first half of the multi-byte character |
| 116 | + var recvBuffer = new byte[1]; |
| 117 | + WebSocketReceiveResult result = await websocket.ReceiveAsync(new ArraySegment<byte>(recvBuffer), CancellationToken.None); |
| 118 | + Assert.False(result.EndOfMessage); |
| 119 | + Assert.Equal(1, result.Count); |
| 120 | + Assert.Equal(0xc3, recvBuffer[0]); |
| 121 | + |
| 122 | + // read second half of the multi-byte character |
| 123 | + result = await websocket.ReceiveAsync(new ArraySegment<byte>(recvBuffer), CancellationToken.None); |
| 124 | + Assert.True(result.EndOfMessage); |
| 125 | + Assert.Equal(1, result.Count); |
| 126 | + Assert.Equal(0xa6, recvBuffer[0]); |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + [Fact] |
| 131 | + public async Task ReceiveAsync_ServerSplitHeader_ValidDataReceived() |
| 132 | + { |
| 133 | + using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) |
| 134 | + using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) |
| 135 | + { |
| 136 | + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); |
| 137 | + listener.Listen(1); |
| 138 | + |
| 139 | + await client.ConnectAsync(listener.LocalEndPoint); |
| 140 | + using (Socket server = await listener.AcceptAsync()) |
| 141 | + { |
| 142 | + WebSocket websocket = CreateFromStream(new NetworkStream(server, ownsSocket: false), isServer: true, null, Timeout.InfiniteTimeSpan); |
| 143 | + |
| 144 | + // Send a full packet and a partial packet |
| 145 | + var packets = new byte[7 + 11 + 4]; |
| 146 | + IList<byte> packet0 = new ArraySegment<byte>(packets, 0, 7); |
| 147 | + packet0[0] = 0x82; // fin, binary |
| 148 | + packet0[1] = 0x81; // masked, 1-byte length |
| 149 | + packet0[6] = 42; // content |
| 150 | + |
| 151 | + IList<byte> partialPacket1 = new ArraySegment<byte>(packets, 7, 11); |
| 152 | + partialPacket1[0] = 0x82; // fin, binary |
| 153 | + partialPacket1[1] = 0xFF; // masked, 8-byte length |
| 154 | + partialPacket1[9] = 1; // length == 1 |
| 155 | + |
| 156 | + IList<byte> remainderPacket1 = new ArraySegment<byte>(packets, 7 + 11, 4); |
| 157 | + remainderPacket1[3] = 84; // content |
| 158 | + |
| 159 | + await client.SendAsync(new ArraySegment<byte>(packets, 0, packet0.Count + partialPacket1.Count), SocketFlags.None); |
| 160 | + |
| 161 | + // Read the first packet |
| 162 | + byte[] received = new byte[1]; |
| 163 | + WebSocketReceiveResult r = await websocket.ReceiveAsync(new ArraySegment<byte>(received), default); |
| 164 | + Assert.True(r.EndOfMessage); |
| 165 | + Assert.Equal(1, r.Count); |
| 166 | + Assert.Equal(42, received[0]); |
| 167 | + |
| 168 | + // Read the next packet, which is partial, then complete it. |
| 169 | + // Partial read shouldn't cause a failure. |
| 170 | + Task<WebSocketReceiveResult> tr = websocket.ReceiveAsync(new ArraySegment<byte>(received), default); |
| 171 | + Assert.False(tr.IsCompleted); |
| 172 | + await client.SendAsync((ArraySegment<byte>)remainderPacket1, SocketFlags.None); |
| 173 | + r = await tr; |
| 174 | + Assert.True(r.EndOfMessage); |
| 175 | + Assert.Equal(1, r.Count); |
| 176 | + Assert.Equal(84, received[0]); |
| 177 | + } |
| 178 | + } |
| 179 | + } |
| 180 | + |
| 181 | + public static readonly object[][] EchoServers = System.Net.Test.Common.Configuration.WebSockets.EchoServers; |
| 182 | + |
| 183 | + protected sealed class UnreadableStream : Stream |
| 184 | + { |
| 185 | + public override bool CanRead => false; |
| 186 | + public override bool CanSeek => true; |
| 187 | + public override bool CanWrite => true; |
| 188 | + public override long Length => throw new NotImplementedException(); |
| 189 | + public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } |
| 190 | + public override void Flush() => throw new NotImplementedException(); |
| 191 | + public override int Read(byte[] buffer, int offset, int count) => throw new NotImplementedException(); |
| 192 | + public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException(); |
| 193 | + public override void SetLength(long value) => throw new NotImplementedException(); |
| 194 | + public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException(); |
| 195 | + } |
| 196 | + } |
| 197 | +} |
0 commit comments