Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Task Socket.ConnectAsync methods to use AwaitableSocketAsyncEventArgs #787

Merged
merged 2 commits into from
Apr 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,66 +119,70 @@ private Task<Socket> AcceptAsyncApm(Socket? acceptSocket)

internal Task ConnectAsync(EndPoint remoteEP)
{
var tcs = new TaskCompletionSource<bool>(this);
BeginConnect(remoteEP, iar =>
// Use ValueTaskReceive so the AwaitableSocketAsyncEventArgs can be re-used later.
AwaitableSocketAsyncEventArgs saea = LazyInitializer.EnsureInitialized(ref EventArgs.ValueTaskReceive, () => new AwaitableSocketAsyncEventArgs());

if (!saea.Reserve())
tmds marked this conversation as resolved.
Show resolved Hide resolved
{
var innerTcs = (TaskCompletionSource<bool>)iar.AsyncState!;
try
{
((Socket)innerTcs.Task.AsyncState!).EndConnect(iar);
innerTcs.TrySetResult(true);
}
catch (Exception e) { innerTcs.TrySetException(e); }
}, tcs);
return tcs.Task;
saea = new AwaitableSocketAsyncEventArgs();
saea.Reserve();
}

saea.RemoteEndPoint = remoteEP;
return saea.ConnectAsync(this).AsTask();
}

internal Task ConnectAsync(IPAddress address, int port)
=> ConnectAsync(new IPEndPoint(address, port));

internal Task ConnectAsync(IPAddress[] addresses, int port)
{
var tcs = new TaskCompletionSource<bool>(this);
BeginConnect(address, port, iar =>
if (addresses == null)
{
var innerTcs = (TaskCompletionSource<bool>)iar.AsyncState!;
try
{
((Socket)innerTcs.Task.AsyncState!).EndConnect(iar);
innerTcs.TrySetResult(true);
}
catch (Exception e) { innerTcs.TrySetException(e); }
}, tcs);
return tcs.Task;
throw new ArgumentNullException(nameof(addresses));
}
if (addresses.Length == 0)
{
throw new ArgumentException(SR.net_invalidAddressList, nameof(addresses));
}
tmds marked this conversation as resolved.
Show resolved Hide resolved

return DoConnectAsync(addresses, port);
}

internal Task ConnectAsync(IPAddress[] addresses, int port)
private async Task DoConnectAsync(IPAddress[] addresses, int port)
{
var tcs = new TaskCompletionSource<bool>(this);
BeginConnect(addresses, port, iar =>
Exception? lastException = null;
foreach (IPAddress address in addresses)
{
var innerTcs = (TaskCompletionSource<bool>)iar.AsyncState!;
try
{
((Socket)innerTcs.Task.AsyncState!).EndConnect(iar);
innerTcs.TrySetResult(true);
await ConnectAsync(address, port).ConfigureAwait(false);
return;
}
catch (Exception e) { innerTcs.TrySetException(e); }
}, tcs);
return tcs.Task;
catch (Exception ex)
{
lastException = ex;
}
}
Debug.Assert(lastException != null);
ExceptionDispatchInfo.Throw(lastException);
}

internal Task ConnectAsync(string host, int port)
{
var tcs = new TaskCompletionSource<bool>(this);
BeginConnect(host, port, iar =>
if (host == null)
{
var innerTcs = (TaskCompletionSource<bool>)iar.AsyncState!;
try
{
((Socket)innerTcs.Task.AsyncState!).EndConnect(iar);
innerTcs.TrySetResult(true);
}
catch (Exception e) { innerTcs.TrySetException(e); }
}, tcs);
return tcs.Task;
throw new ArgumentNullException(nameof(host));
}

if (IPAddress.TryParse(host, out IPAddress? parsedAddress))
{
return ConnectAsync(new IPEndPoint(parsedAddress, port));
}
else
{
tmds marked this conversation as resolved.
Show resolved Hide resolved
return ConnectAsync(new DnsEndPoint(host, port));
}
}

internal Task<int> ReceiveAsync(ArraySegment<byte> buffer, SocketFlags socketFlags, bool fromNetworkStream)
Expand Down Expand Up @@ -946,6 +950,32 @@ public ValueTask SendAsyncForNetworkStream(Socket socket, CancellationToken canc
new ValueTask(Task.FromException(CreateException(error)));
}

public ValueTask ConnectAsync(Socket socket)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");

try
{
if (socket.ConnectAsync(this))
{
return new ValueTask(this, _token);
}
}
catch
{
Release();
throw;
}

SocketError error = SocketError;

Release();

return error == SocketError.Success ?
default :
new ValueTask(Task.FromException(CreateException(error)));
}

/// <summary>Gets the status of the operation.</summary>
public ValueTaskSourceStatus GetStatus(short token)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,14 @@ public void BeginConnect_IPAddresses_EmptyIPAddresses_Throws_Argument()
public void BeginConnect_IPAddresses_InvalidPort_Throws_ArgumentOutOfRange(int port)
{
Assert.Throws<ArgumentOutOfRangeException>(() => GetSocket().BeginConnect(new[] { IPAddress.Loopback }, port, TheAsyncCallback, null));
Assert.Throws<ArgumentOutOfRangeException>(() => { GetSocket().ConnectAsync(new[] { IPAddress.Loopback }, port); });
}

[Theory]
[InlineData(-1)]
[InlineData(65536)]
public async Task ConnectAsync_IPAddresses_InvalidPort_Throws_ArgumentOutOfRange(int port)
{
await Assert.ThrowsAsync<ArgumentOutOfRangeException>(() => GetSocket().ConnectAsync(new[] { IPAddress.Loopback }, port));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires await to throw. The check could be added in the Task ConnectAsync directly.

}

[Fact]
Expand All @@ -1126,12 +1133,16 @@ public void BeginConnect_IPAddresses_ListeningSocket_Throws_InvalidOperation()
socket.Listen(1);
Assert.Throws<InvalidOperationException>(() => socket.BeginConnect(new[] { IPAddress.Loopback }, 1, TheAsyncCallback, null));
}
}

[Fact]
public async Task ConnectAsync_IPAddresses_ListeningSocket_Throws_InvalidOperation()
{
using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
socket.Bind(new IPEndPoint(IPAddress.Loopback, 0));
socket.Listen(1);
Assert.Throws<InvalidOperationException>(() => { socket.ConnectAsync(new[] { IPAddress.Loopback }, 1); });
await Assert.ThrowsAsync<InvalidOperationException>(() => socket.ConnectAsync(new[] { IPAddress.Loopback }, 1));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires await to throw. The check could be added in the Task ConnectAsync directly.

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,6 @@ public async Task Connect_AfterDisconnect_Fails()
[PlatformSpecific(~(TestPlatforms.OSX | TestPlatforms.FreeBSD))] // Not supported on BSD like OSes.
public async Task ConnectGetsCanceledByDispose()
{
bool usesApm = UsesApm ||
(this is ConnectTask); // .NET Core ConnectAsync Task API is implemented using Apm

// We try this a couple of times to deal with a timing race: if the Dispose happens
// before the operation is started, we won't see a SocketException.
int msDelay = 100;
Expand Down Expand Up @@ -167,7 +164,7 @@ public async Task ConnectGetsCanceledByDispose()
disposedException = true;
}

if (usesApm)
if (UsesApm)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exception thrown by the Task ConnectAsync methods has changed. The exceptions are now consistent with other Task-returning Socket methods.

{
Assert.Null(localSocketError);
Assert.True(disposedException);
Expand Down