diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index 304f10a516d1b..1346aa06cb471 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -119,67 +119,40 @@ private Task AcceptAsyncApm(Socket acceptSocket) internal Task ConnectAsync(EndPoint remoteEP) { - var tcs = new TaskCompletionSource(this); - BeginConnect(remoteEP, iar => + // Use ValueTaskReceive so the AwaitableSocketAsyncEventArgs can be re-used later. + AwaitableSocketAsyncEventArgs saea = LazyInitializer.EnsureInitialized(ref EventArgs.ValueTaskReceive, () => new AwaitableSocketAsyncEventArgs()); + + // We don't expect concurrent users while calling ConnectAsync. + if (!saea.Reserve()) { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try - { - ((Socket)innerTcs.Task.AsyncState).EndConnect(iar); - innerTcs.TrySetResult(true); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; + throw new InvalidOperationException(SR.Format(SR.net_socketopinprogress)); + } + + saea.RemoteEndPoint = remoteEP; + return saea.ConnectAsync(this).AsTask(); } internal Task ConnectAsync(IPAddress address, int port) - { - var tcs = new TaskCompletionSource(this); - BeginConnect(address, port, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try - { - ((Socket)innerTcs.Task.AsyncState).EndConnect(iar); - innerTcs.TrySetResult(true); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } + => ConnectAsync(new IPEndPoint(address, port)); - internal Task ConnectAsync(IPAddress[] addresses, int port) + internal async Task ConnectAsync(IPAddress[] addresses, int port) { - var tcs = new TaskCompletionSource(this); - BeginConnect(addresses, port, iar => + if (addresses == null) { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try - { - ((Socket)innerTcs.Task.AsyncState).EndConnect(iar); - innerTcs.TrySetResult(true); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; + throw new ArgumentNullException(nameof(addresses)); + } + if (addresses.Length == 0) + { + throw new ArgumentException(SR.net_invalidAddressList, nameof(addresses)); + } + foreach (var address in addresses) + { + await ConnectAsync(address, port).ConfigureAwait(false); + } } internal Task ConnectAsync(string host, int port) - { - var tcs = new TaskCompletionSource(this); - BeginConnect(host, port, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try - { - ((Socket)innerTcs.Task.AsyncState).EndConnect(iar); - innerTcs.TrySetResult(true); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } + => ConnectAsync(new DnsEndPoint(host, port)); internal Task ReceiveAsync(ArraySegment buffer, SocketFlags socketFlags, bool fromNetworkStream) { @@ -187,9 +160,6 @@ internal Task ReceiveAsync(ArraySegment buffer, SocketFlags socketFla return ReceiveAsync((Memory)buffer, socketFlags, fromNetworkStream, default).AsTask(); } - // TODO https://github.com/dotnet/corefx/issues/24430: - // Fully plumb cancellation down into socket operations. - internal ValueTask ReceiveAsync(Memory buffer, SocketFlags socketFlags, bool fromNetworkStream, CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) @@ -949,6 +919,24 @@ public ValueTask SendAsyncForNetworkStream(Socket socket, CancellationToken canc new ValueTask(Task.FromException(CreateException(error))); } + public ValueTask ConnectAsync(Socket socket) + { + Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use"); + + if (socket.ConnectAsync(this)) + { + return new ValueTask(this, _token); + } + + SocketError error = SocketError; + + Release(); + + return error == SocketError.Success ? + default : + new ValueTask(Task.FromException(CreateException(error))); + } + /// Gets the status of the operation. public ValueTaskSourceStatus GetStatus(short token) { diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs index 4dd838598e3d9..fd8170f11dfda 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs @@ -91,8 +91,7 @@ public async Task Connect_AfterDisconnect_Fails() [PlatformSpecific(~(TestPlatforms.OSX | TestPlatforms.FreeBSD))] // Not supported on BSD like OSes. public async Task ConnectGetsCanceledByDispose() { - bool usesApm = UsesApm || - (this is ConnectTask); // .NET Core ConnectAsync Task API is implemented using Apm + bool usesApm = UsesApm; // We try this a couple of times to deal with a timing race: if the Dispose happens // before the operation is started, we won't see a SocketException.