Skip to content

Commit

Permalink
Ensure ManagedWebSocket.CloseOutputAsync closes the underlying stream (
Browse files Browse the repository at this point in the history
…dotnet#33473)

* Ensure CloseOutputAsync closes the underlying stream

When we issue a close, we then wait for the response close message, and after receiving it, dispose of the web socket instance, which then closes the stream.  We don't currently do the same when sending a close frame having already received a close frame, but we should.

* Address PR feedback

And fix break on netfx where ArraySegment doesn't have an implicit cast from byte[].
  • Loading branch information
stephentoub authored and jlennox committed Dec 16, 2018
1 parent 2fc6c81 commit b70ebe3
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 56 deletions.
27 changes: 15 additions & 12 deletions src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -328,17 +328,24 @@ public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusD
public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
{
WebSocketValidate.ValidateCloseStatus(closeStatus, statusDescription);
return CloseOutputAsyncCore(closeStatus, statusDescription, cancellationToken);
}

try
{
WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validCloseOutputStates);
}
catch (Exception exc)
private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
{
WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validCloseOutputStates);

await SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false);

// If we already received a close frame, since we've now also sent one, we're now closed.
lock (StateUpdateLock)
{
return Task.FromException(exc);
Debug.Assert(_sentCloseFrame);
if (_receivedCloseFrame)
{
DisposeCore();
}
}

return SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken);
}

public override void Abort()
Expand Down Expand Up @@ -1102,10 +1109,6 @@ private async Task CloseAsyncPrivate(WebSocketCloseStatus closeStatus, string st
lock (StateUpdateLock)
{
DisposeCore();
if (_state < WebSocketState.Closed)
{
_state = WebSocketState.Closed;
}
}
}

Expand Down
160 changes: 121 additions & 39 deletions src/Common/tests/System/Net/WebSockets/WebSocketCreateTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Security;
using System.Net.Sockets;
using System.Text;
Expand Down Expand Up @@ -51,49 +52,20 @@ public async Task WebSocketProtocol_CreateFromConnectedStream_CanSendReceiveData
bool secure = echoUri.Scheme == "wss";
client.Connect(echoUri.Host, secure ? 443 : 80);

Stream stream = new NetworkStream(client, ownsSocket: false);
if (secure)
using (Stream stream = await CreateWebSocketStream(echoUri, client, secure))
using (WebSocket socket = CreateFromStream(stream, false, null, TimeSpan.FromSeconds(10)))
{
SslStream ssl = new SslStream(stream, leaveInnerStreamOpen: true, delegate { return true; });
await ssl.AuthenticateAsClientAsync(echoUri.Host);
stream = ssl;
}
Assert.NotNull(socket);
Assert.Equal(WebSocketState.Open, socket.State);

using (stream)
{
using (var writer = new StreamWriter(stream, Encoding.ASCII, bufferSize: 1, leaveOpen: true))
{
await writer.WriteAsync($"GET {echoUri.PathAndQuery} HTTP/1.1\r\n");
await writer.WriteAsync($"Host: {echoUri.Host}\r\n");
await writer.WriteAsync($"Upgrade: websocket\r\n");
await writer.WriteAsync($"Connection: Upgrade\r\n");
await writer.WriteAsync($"Sec-WebSocket-Version: 13\r\n");
await writer.WriteAsync($"Sec-WebSocket-Key: {Convert.ToBase64String(Guid.NewGuid().ToByteArray())}\r\n");
await writer.WriteAsync($"\r\n");
}
string expected = "Hello World!";
ArraySegment<byte> buffer = new ArraySegment<byte>(Encoding.UTF8.GetBytes(expected));
await socket.SendAsync(buffer, WebSocketMessageType.Text, true, CancellationToken.None);

using (var reader = new StreamReader(stream, Encoding.ASCII, detectEncodingFromByteOrderMarks: false, bufferSize: 1, leaveOpen: true))
{
string statusLine = await reader.ReadLineAsync();
Assert.NotEmpty(statusLine);
Assert.Equal("HTTP/1.1 101 Switching Protocols", statusLine);
while (!string.IsNullOrEmpty(await reader.ReadLineAsync())) ;
}
buffer = new ArraySegment<byte>(new byte[buffer.Count]);
await socket.ReceiveAsync(buffer, CancellationToken.None);

using (WebSocket socket = CreateFromStream(stream, false, null, TimeSpan.FromSeconds(10)))
{
Assert.NotNull(socket);
Assert.Equal(WebSocketState.Open, socket.State);

string expected = "Hello World!";
ArraySegment<byte> buffer = new ArraySegment<byte>(Encoding.UTF8.GetBytes(expected));
await socket.SendAsync(buffer, WebSocketMessageType.Text, true, CancellationToken.None);

buffer = new ArraySegment<byte>(new byte[buffer.Count]);
await socket.ReceiveAsync(buffer, CancellationToken.None);

Assert.Equal(expected, Encoding.UTF8.GetString(buffer.Array));
}
Assert.Equal(expected, Encoding.UTF8.GetString(buffer.Array));
}
}
}
Expand Down Expand Up @@ -178,7 +150,117 @@ public async Task ReceiveAsync_ServerSplitHeader_ValidDataReceived()
}
}

[OuterLoop] // Connects to external server.
[Theory]
[MemberData(nameof(EchoServersAndBoolean))]
public async Task WebSocketProtocol_CreateFromConnectedStream_CloseAsyncClosesStream(Uri echoUri, bool explicitCloseAsync)
{
using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
bool secure = echoUri.Scheme == "wss";
client.Connect(echoUri.Host, secure ? 443 : 80);

using (Stream stream = await CreateWebSocketStream(echoUri, client, secure))
{
using (WebSocket socket = CreateFromStream(stream, false, null, TimeSpan.FromSeconds(10)))
{
Assert.NotNull(socket);
Assert.Equal(WebSocketState.Open, socket.State);

Assert.True(stream.CanRead);
Assert.True(stream.CanWrite);

if (explicitCloseAsync) // make sure CloseAsync ends up disposing the stream
{
await socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
Assert.False(stream.CanRead);
Assert.False(stream.CanWrite);
}
}

Assert.False(stream.CanRead);
Assert.False(stream.CanWrite);
}
}
}

[OuterLoop] // Connects to external server.
[Theory]
[MemberData(nameof(EchoServersAndBoolean))]
public async Task WebSocketProtocol_CreateFromConnectedStream_CloseAsyncAfterCloseReceivedClosesStream(Uri echoUri, bool useCloseOutputAsync)
{
using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
bool secure = echoUri.Scheme == "wss";
client.Connect(echoUri.Host, secure ? 443 : 80);

using (Stream stream = await CreateWebSocketStream(echoUri, client, secure))
using (WebSocket socket = CreateFromStream(stream, false, null, TimeSpan.FromSeconds(10)))
{
Assert.NotNull(socket);
Assert.Equal(WebSocketState.Open, socket.State);

// Ask server to send us a close
await socket.SendAsync(new ArraySegment<byte>(Encoding.UTF8.GetBytes(".close")), WebSocketMessageType.Text, true, default);

// Verify received server-initiated close message.
WebSocketReceiveResult recvResult = await socket.ReceiveAsync(new ArraySegment<byte>(new byte[256]), default);
Assert.Equal(WebSocketCloseStatus.NormalClosure, recvResult.CloseStatus);
Assert.Equal(WebSocketCloseStatus.NormalClosure, socket.CloseStatus);
Assert.Equal(WebSocketState.CloseReceived, socket.State);

Assert.True(stream.CanRead);
Assert.True(stream.CanWrite);

await (useCloseOutputAsync ?
socket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None) :
socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None));

Assert.False(stream.CanRead);
Assert.False(stream.CanWrite);
}
}
}

private static async Task<Stream> CreateWebSocketStream(Uri echoUri, Socket client, bool secure)
{
Stream stream = new NetworkStream(client, ownsSocket: false);

if (secure)
{
var ssl = new SslStream(stream, leaveInnerStreamOpen: true, delegate { return true; });
await ssl.AuthenticateAsClientAsync(echoUri.Host);
stream = ssl;
}

using (var writer = new StreamWriter(stream, Encoding.ASCII, bufferSize: 1, leaveOpen: true))
{
await writer.WriteAsync($"GET {echoUri.PathAndQuery} HTTP/1.1\r\n");
await writer.WriteAsync($"Host: {echoUri.Host}\r\n");
await writer.WriteAsync($"Upgrade: websocket\r\n");
await writer.WriteAsync($"Connection: Upgrade\r\n");
await writer.WriteAsync($"Sec-WebSocket-Version: 13\r\n");
await writer.WriteAsync($"Sec-WebSocket-Key: {Convert.ToBase64String(Guid.NewGuid().ToByteArray())}\r\n");
await writer.WriteAsync($"\r\n");
}

using (var reader = new StreamReader(stream, Encoding.ASCII, detectEncodingFromByteOrderMarks: false, bufferSize: 1, leaveOpen: true))
{
string statusLine = await reader.ReadLineAsync();
Assert.NotEmpty(statusLine);
Assert.Equal("HTTP/1.1 101 Switching Protocols", statusLine);
while (!string.IsNullOrEmpty(await reader.ReadLineAsync()));
}

return stream;
}

public static readonly object[][] EchoServers = System.Net.Test.Common.Configuration.WebSockets.EchoServers;
public static readonly object[][] EchoServersAndBoolean = EchoServers.SelectMany(o => new object[][]
{
new object[] { o[0], false },
new object[] { o[0], true }
}).ToArray();

protected sealed class UnreadableStream : Stream
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System.Collections.Generic;
using System.Net.Test.Common;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

Expand All @@ -19,6 +20,11 @@ public class ClientWebSocketTestBase
{
public static readonly object[][] EchoServers = System.Net.Test.Common.Configuration.WebSockets.EchoServers;
public static readonly object[][] EchoHeadersServers = System.Net.Test.Common.Configuration.WebSockets.EchoHeadersServers;
public static readonly object[][] EchoServersAndBoolean = EchoServers.SelectMany(o => new object[][]
{
new object[] { o[0], false },
new object[] { o[0], true }
}).ToArray();

public const int TimeOutMilliseconds = 20000;
public const int CloseDescriptionMaxLength = 123;
Expand Down
12 changes: 7 additions & 5 deletions src/System.Net.WebSockets.Client/tests/CloseTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ public class CloseTest : ClientWebSocketTestBase
public CloseTest(ITestOutputHelper output) : base(output) { }

[OuterLoop] // TODO: Issue #11345
[ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))]
public async Task CloseAsync_ServerInitiatedClose_Success(Uri server)
[ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServersAndBoolean))]
public async Task CloseAsync_ServerInitiatedClose_Success(Uri server, bool useCloseOutputAsync)
{
const string closeWebSocketMetaCommand = ".close";

Expand Down Expand Up @@ -50,9 +50,11 @@ public async Task CloseAsync_ServerInitiatedClose_Success(Uri server)
Assert.Equal(closeWebSocketMetaCommand, cws.CloseStatusDescription);

// Send back close message to acknowledge server-initiated close.
_output.WriteLine("CloseAsync starting.");
await cws.CloseAsync(WebSocketCloseStatus.InvalidMessageType, string.Empty, cts.Token);
_output.WriteLine("CloseAsync done.");
_output.WriteLine("Close starting.");
await (useCloseOutputAsync ?
cws.CloseOutputAsync(WebSocketCloseStatus.InvalidMessageType, string.Empty, cts.Token) :
cws.CloseAsync(WebSocketCloseStatus.InvalidMessageType, string.Empty, cts.Token));
_output.WriteLine("Close done.");
Assert.Equal(WebSocketState.Closed, cws.State);

// Verify that there is no follow-up echo close message back from the server by
Expand Down

0 comments on commit b70ebe3

Please sign in to comment.