diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 5304925959d84..5dc340ca92811 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -2090,11 +2090,17 @@ public IAsyncResult BeginConnect(EndPoint remoteEP, AsyncCallback? callback, obj private bool CanUseConnectEx(EndPoint remoteEP) { - return (_socketType == SocketType.Stream) && - (_rightEndPoint != null || remoteEP.GetType() == typeof(IPEndPoint)); - } + Debug.Assert(remoteEP.GetType() != typeof(DnsEndPoint)); + // ConnectEx supports connection-oriented sockets. + // The socket must be bound before calling ConnectEx. + // In case of IPEndPoint, the Socket will be bound using WildcardBindForConnectIfNecessary. + // Unix sockets are not supported by ConnectEx. + return (_socketType == SocketType.Stream) && + (_rightEndPoint != null || remoteEP.GetType() == typeof(IPEndPoint)) && + (remoteEP.AddressFamily != AddressFamily.Unix); + } internal IAsyncResult UnsafeBeginConnect(EndPoint remoteEP, AsyncCallback? callback, object? state, bool flowContext = false) { @@ -3817,7 +3823,15 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket) SocketError socketError = SocketError.Success; try { - socketError = e.DoOperationConnect(this, _handle); + if (CanUseConnectEx(endPointSnapshot)) + { + socketError = e.DoOperationConnectEx(this, _handle); + } + else + { + // For connectionless protocols, Connect is not an I/O call. + socketError = e.DoOperationConnect(this, _handle); + } } catch { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index b46cda4ea3065..d2ca491a52ae9 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -38,20 +38,6 @@ private void SetupMultipleBuffers() private void CompleteCore() { } - private void FinishOperationSync(SocketError socketError, int bytesTransferred, SocketFlags flags) - { - Debug.Assert(socketError != SocketError.IOPending); - - if (socketError == SocketError.Success) - { - FinishOperationSyncSuccess(bytesTransferred, flags); - } - else - { - FinishOperationSyncFailure(socketError, bytesTransferred, flags); - } - } - private void AcceptCompletionCallback(IntPtr acceptedFileDescriptor, byte[] socketAddress, int socketAddressSize, SocketError socketError) { CompleteAcceptOperation(acceptedFileDescriptor, socketAddress, socketAddressSize, socketError); @@ -95,6 +81,9 @@ private void ConnectCompletionCallback(SocketError socketError) CompletionCallback(0, SocketFlags.None, socketError); } + internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle handle) + => DoOperationConnect(socket, handle); + internal unsafe SocketError DoOperationConnect(Socket socket, SafeSocketHandle handle) { SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, _socketAddress.Size, ConnectCompletionCallback); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index e1f2fbb487133..ef5be1880d2d1 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -272,6 +272,14 @@ internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle ha } internal unsafe SocketError DoOperationConnect(Socket socket, SafeSocketHandle handle) + { + // Called for connectionless protocols. + SocketError socketError = SocketPal.Connect(handle, _socketAddress!.Buffer, _socketAddress.Size); + FinishOperationSync(socketError, 0, SocketFlags.None); + return socketError; + } + + internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle handle) { // ConnectEx uses a sockaddr buffer containing the remote address to which to connect. // It can also optionally take a single buffer of data to send after the connection is complete. @@ -1160,6 +1168,13 @@ private unsafe SocketError FinishOperationConnect() { try { + if (_currentSocket!.SocketType != SocketType.Stream) + { + // With connectionless sockets, regular connect is used instead of ConnectEx, + // attempting to set SO_UPDATE_CONNECT_CONTEXT will result in an error. + return SocketError.Success; + } + // Update the socket context. SocketError socketError = Interop.Winsock.setsockopt( _currentSocket!.SafeHandle, diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs index 9a38815d8b15b..dd3f8953e988b 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs @@ -780,5 +780,19 @@ internal void FinishOperationAsyncSuccess(int bytesTransferred, SocketFlags flag ExecutionContext.Run(context, s_executionCallback, this); } } + + private void FinishOperationSync(SocketError socketError, int bytesTransferred, SocketFlags flags) + { + Debug.Assert(socketError != SocketError.IOPending); + + if (socketError == SocketError.Success) + { + FinishOperationSyncSuccess(bytesTransferred, flags); + } + else + { + FinishOperationSyncFailure(socketError, bytesTransferred, flags); + } + } } } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs index 5fcd6f75f70b8..6f01d22ca57bb 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs @@ -30,6 +30,40 @@ public async Task Connect_Success(IPAddress listenAt) } } + [Theory] + [MemberData(nameof(Loopbacks))] + public async Task Connect_Udp_Success(IPAddress listenAt) + { + using Socket listener = new Socket(listenAt.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket client = new Socket(listenAt.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + listener.Bind(new IPEndPoint(listenAt, 0)); + + await ConnectAsync(client, new IPEndPoint(listenAt, ((IPEndPoint)listener.LocalEndPoint).Port)); + Assert.True(client.Connected); + } + + [Theory] + [MemberData(nameof(Loopbacks))] + public async Task Connect_Dns_Success(IPAddress listenAt) + { + // On some systems (like Ubuntu 16.04 and Ubuntu 18.04) "localhost" doesn't resolve to '::1'. + if (Array.IndexOf(Dns.GetHostAddresses("localhost"), listenAt) == -1) + { + return; + } + + int port; + using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, listenAt, out port)) + { + using (Socket client = new Socket(listenAt.AddressFamily, SocketType.Stream, ProtocolType.Tcp)) + { + Task connectTask = ConnectAsync(client, new DnsEndPoint("localhost", port)); + await connectTask; + Assert.True(client.Connected); + } + } + } + [OuterLoop] [Theory] [MemberData(nameof(Loopbacks))] diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs index 3fe045582c610..ddead9a754b4a 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs @@ -21,7 +21,6 @@ public void OSSupportsUnixDomainSockets_ReturnsCorrectValue() Assert.Equal(PlatformSupportsUnixDomainSockets, Socket.OSSupportsUnixDomainSockets); } - [PlatformSpecific(~TestPlatforms.Windows)] // Windows doesn't currently support ConnectEx with domain sockets [ConditionalFact(nameof(PlatformSupportsUnixDomainSockets))] public async Task Socket_ConnectAsyncUnixDomainSocketEndPoint_Success() { @@ -100,7 +99,7 @@ public async Task Socket_ConnectAsyncUnixDomainSocketEndPoint_NotServer() } Assert.Equal( - RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? SocketError.InvalidArgument : SocketError.AddressNotAvailable, + RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? SocketError.ConnectionRefused : SocketError.AddressNotAvailable, args.SocketError); } }