Skip to content

Commit

Permalink
Refactor Task Socket.ConnectAsync methods to use AwaitableSocketAsync…
Browse files Browse the repository at this point in the history
…EventArgs (#787)

* Refactor Task Socket.ConnectAsync methods to use AwaitableSocketAsyncEventArgs

* Add back BeginConnect_IPAddresses_ListeningSocket_Throws_InvalidOperation
  • Loading branch information
tmds committed Apr 27, 2020
1 parent b4add76 commit eab1d28
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,66 +119,70 @@ private Task<Socket> AcceptAsyncApm(Socket? acceptSocket)

internal Task ConnectAsync(EndPoint remoteEP)
{
var tcs = new TaskCompletionSource<bool>(this);
BeginConnect(remoteEP, iar =>
// Use ValueTaskReceive so the AwaitableSocketAsyncEventArgs can be re-used later.
AwaitableSocketAsyncEventArgs saea = LazyInitializer.EnsureInitialized(ref EventArgs.ValueTaskReceive, () => new AwaitableSocketAsyncEventArgs());

if (!saea.Reserve())
{
var innerTcs = (TaskCompletionSource<bool>)iar.AsyncState!;
try
{
((Socket)innerTcs.Task.AsyncState!).EndConnect(iar);
innerTcs.TrySetResult(true);
}
catch (Exception e) { innerTcs.TrySetException(e); }
}, tcs);
return tcs.Task;
saea = new AwaitableSocketAsyncEventArgs();
saea.Reserve();
}

saea.RemoteEndPoint = remoteEP;
return saea.ConnectAsync(this).AsTask();
}

internal Task ConnectAsync(IPAddress address, int port)
=> ConnectAsync(new IPEndPoint(address, port));

internal Task ConnectAsync(IPAddress[] addresses, int port)
{
var tcs = new TaskCompletionSource<bool>(this);
BeginConnect(address, port, iar =>
if (addresses == null)
{
var innerTcs = (TaskCompletionSource<bool>)iar.AsyncState!;
try
{
((Socket)innerTcs.Task.AsyncState!).EndConnect(iar);
innerTcs.TrySetResult(true);
}
catch (Exception e) { innerTcs.TrySetException(e); }
}, tcs);
return tcs.Task;
throw new ArgumentNullException(nameof(addresses));
}
if (addresses.Length == 0)
{
throw new ArgumentException(SR.net_invalidAddressList, nameof(addresses));
}

return DoConnectAsync(addresses, port);
}

internal Task ConnectAsync(IPAddress[] addresses, int port)
private async Task DoConnectAsync(IPAddress[] addresses, int port)
{
var tcs = new TaskCompletionSource<bool>(this);
BeginConnect(addresses, port, iar =>
Exception? lastException = null;
foreach (IPAddress address in addresses)
{
var innerTcs = (TaskCompletionSource<bool>)iar.AsyncState!;
try
{
((Socket)innerTcs.Task.AsyncState!).EndConnect(iar);
innerTcs.TrySetResult(true);
await ConnectAsync(address, port).ConfigureAwait(false);
return;
}
catch (Exception e) { innerTcs.TrySetException(e); }
}, tcs);
return tcs.Task;
catch (Exception ex)
{
lastException = ex;
}
}
Debug.Assert(lastException != null);
ExceptionDispatchInfo.Throw(lastException);
}

internal Task ConnectAsync(string host, int port)
{
var tcs = new TaskCompletionSource<bool>(this);
BeginConnect(host, port, iar =>
if (host == null)
{
var innerTcs = (TaskCompletionSource<bool>)iar.AsyncState!;
try
{
((Socket)innerTcs.Task.AsyncState!).EndConnect(iar);
innerTcs.TrySetResult(true);
}
catch (Exception e) { innerTcs.TrySetException(e); }
}, tcs);
return tcs.Task;
throw new ArgumentNullException(nameof(host));
}

if (IPAddress.TryParse(host, out IPAddress? parsedAddress))
{
return ConnectAsync(new IPEndPoint(parsedAddress, port));
}
else
{
return ConnectAsync(new DnsEndPoint(host, port));
}
}

internal Task<int> ReceiveAsync(ArraySegment<byte> buffer, SocketFlags socketFlags, bool fromNetworkStream)
Expand Down Expand Up @@ -946,6 +950,32 @@ public ValueTask SendAsyncForNetworkStream(Socket socket, CancellationToken canc
new ValueTask(Task.FromException(CreateException(error)));
}

public ValueTask ConnectAsync(Socket socket)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");

try
{
if (socket.ConnectAsync(this))
{
return new ValueTask(this, _token);
}
}
catch
{
Release();
throw;
}

SocketError error = SocketError;

Release();

return error == SocketError.Success ?
default :
new ValueTask(Task.FromException(CreateException(error)));
}

/// <summary>Gets the status of the operation.</summary>
public ValueTaskSourceStatus GetStatus(short token)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,14 @@ public void BeginConnect_IPAddresses_EmptyIPAddresses_Throws_Argument()
public void BeginConnect_IPAddresses_InvalidPort_Throws_ArgumentOutOfRange(int port)
{
Assert.Throws<ArgumentOutOfRangeException>(() => GetSocket().BeginConnect(new[] { IPAddress.Loopback }, port, TheAsyncCallback, null));
Assert.Throws<ArgumentOutOfRangeException>(() => { GetSocket().ConnectAsync(new[] { IPAddress.Loopback }, port); });
}

[Theory]
[InlineData(-1)]
[InlineData(65536)]
public async Task ConnectAsync_IPAddresses_InvalidPort_Throws_ArgumentOutOfRange(int port)
{
await Assert.ThrowsAsync<ArgumentOutOfRangeException>(() => GetSocket().ConnectAsync(new[] { IPAddress.Loopback }, port));
}

[Fact]
Expand All @@ -1126,12 +1133,16 @@ public void BeginConnect_IPAddresses_ListeningSocket_Throws_InvalidOperation()
socket.Listen(1);
Assert.Throws<InvalidOperationException>(() => socket.BeginConnect(new[] { IPAddress.Loopback }, 1, TheAsyncCallback, null));
}
}

[Fact]
public async Task ConnectAsync_IPAddresses_ListeningSocket_Throws_InvalidOperation()
{
using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
socket.Bind(new IPEndPoint(IPAddress.Loopback, 0));
socket.Listen(1);
Assert.Throws<InvalidOperationException>(() => { socket.ConnectAsync(new[] { IPAddress.Loopback }, 1); });
await Assert.ThrowsAsync<InvalidOperationException>(() => socket.ConnectAsync(new[] { IPAddress.Loopback }, 1));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,6 @@ public async Task Connect_AfterDisconnect_Fails()
[PlatformSpecific(~(TestPlatforms.OSX | TestPlatforms.FreeBSD))] // Not supported on BSD like OSes.
public async Task ConnectGetsCanceledByDispose()
{
bool usesApm = UsesApm ||
(this is ConnectTask); // .NET Core ConnectAsync Task API is implemented using Apm

// We try this a couple of times to deal with a timing race: if the Dispose happens
// before the operation is started, we won't see a SocketException.
int msDelay = 100;
Expand Down Expand Up @@ -167,7 +164,7 @@ public async Task ConnectGetsCanceledByDispose()
disposedException = true;
}
if (usesApm)
if (UsesApm)
{
Assert.Null(localSocketError);
Assert.True(disposedException);
Expand Down

0 comments on commit eab1d28

Please sign in to comment.