diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 30b5c05c4979b..247b02a2431d9 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -52,9 +52,9 @@ public partial class Socket : IDisposable // to poll for the real state until we're done connecting. private bool _nonBlockingConnectInProgress; - // Keep track of the kind of endpoint used to do a non-blocking connect, so we can set - // it to _rightEndPoint when we discover we're connected. - private EndPoint? _nonBlockingConnectRightEndPoint; + // Keep track of the kind of endpoint used to do a connect, so we can set + // it to _rightEndPoint when we're connected. + private EndPoint? _pendingConnectRightEndPoint; // These are constants initialized by constructor. private AddressFamily _addressFamily; @@ -285,11 +285,8 @@ public int Available if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite)) { - // Update the state if we've become connected after a non-blocking connect. - _isConnected = true; - _rightEndPoint ??= _nonBlockingConnectRightEndPoint; - UpdateLocalEndPointOnConnect(); _nonBlockingConnectInProgress = false; + SetToConnected(); } if (_rightEndPoint == null) @@ -332,11 +329,9 @@ public int Available { if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite)) { - // Update the state if we've become connected after a non-blocking connect. - _isConnected = true; - _rightEndPoint ??= _nonBlockingConnectRightEndPoint; - UpdateLocalEndPointOnConnect(); _nonBlockingConnectInProgress = false; + // Update the state if we've become connected after a non-blocking connect. + SetToConnected(); } if (_rightEndPoint == null || !_isConnected) @@ -439,11 +434,9 @@ public bool Connected if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite)) { - // Update the state if we've become connected after a non-blocking connect. - _isConnected = true; - _rightEndPoint ??= _nonBlockingConnectRightEndPoint; - UpdateLocalEndPointOnConnect(); _nonBlockingConnectInProgress = false; + // Update the state if we've become connected after a non-blocking connect. + SetToConnected(); } return _isConnected; @@ -856,12 +849,8 @@ public void Connect(EndPoint remoteEP) ValidateForMultiConnect(isMultiEndpoint: false); Internals.SocketAddress socketAddress = Serialize(ref remoteEP); - - if (!Blocking) - { - _nonBlockingConnectRightEndPoint = remoteEP; - _nonBlockingConnectInProgress = true; - } + _pendingConnectRightEndPoint = remoteEP; + _nonBlockingConnectInProgress = !Blocking; DoConnect(remoteEP, socketAddress); } @@ -2768,13 +2757,11 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaCan } e._socketAddress = Serialize(ref endPointSnapshot); + _pendingConnectRightEndPoint = endPointSnapshot; + _nonBlockingConnectInProgress = false; WildcardBindForConnectIfNecessary(endPointSnapshot.AddressFamily); - // Save the old RightEndPoint and prep new RightEndPoint. - EndPoint? oldEndPoint = _rightEndPoint; - _rightEndPoint ??= endPointSnapshot; - if (SocketsTelemetry.Log.IsEnabled()) { SocketsTelemetry.Log.ConnectStart(e._socketAddress!); @@ -2801,7 +2788,6 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaCan SocketsTelemetry.Log.AfterConnect(SocketError.NotSocket, ex.Message); } - _rightEndPoint = oldEndPoint; _localEndPoint = null; // Clear in-use flag on event args object. @@ -3217,12 +3203,11 @@ private void DoConnect(EndPoint endPointSnapshot, Internals.SocketAddress socket if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterConnect(SocketError.Success); - // Save a copy of the EndPoint so we can use it for Create(). - _rightEndPoint ??= endPointSnapshot; - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"connection to:{endPointSnapshot}"); // Update state and performance counters. + _pendingConnectRightEndPoint = endPointSnapshot; + _nonBlockingConnectInProgress = false; SetToConnected(); if (NetEventSource.Log.IsEnabled()) NetEventSource.Connected(this, LocalEndPoint, RemoteEndPoint); } @@ -3659,10 +3644,14 @@ internal void SetToConnected() return; } + Debug.Assert(_nonBlockingConnectInProgress == false); + // Update the status: this socket was indeed connected at // some point in time update the perf counter as well. _isConnected = true; _isDisconnected = false; + _rightEndPoint ??= _pendingConnectRightEndPoint; + _pendingConnectRightEndPoint = null; UpdateLocalEndPointOnConnect(); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "now connected"); } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs index d31cf2d43c478..5f9730e8ab076 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs @@ -83,6 +83,60 @@ public async Task Connect_MultipleIPAddresses_Success(IPAddress listenAt) } } + [Fact] + public async Task Connect_DualMode_MultiAddressFamilyConnect_RetrievedEndPoints_Success() + { + if (!SupportsMultiConnect) + return; + + int port; + using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback, out port)) + using (Socket client = new Socket(SocketType.Stream, ProtocolType.Tcp)) + { + Assert.True(client.DualMode); + + Task connectTask = MultiConnectAsync(client, new IPAddress[] { IPAddress.IPv6Loopback, IPAddress.Loopback }, port); + await connectTask; + + var localEndPoint = client.LocalEndPoint as IPEndPoint; + Assert.NotNull(localEndPoint); + Assert.Equal(IPAddress.Loopback.MapToIPv6(), localEndPoint.Address); + + var remoteEndPoint = client.RemoteEndPoint as IPEndPoint; + Assert.NotNull(remoteEndPoint); + Assert.Equal(IPAddress.Loopback.MapToIPv6(), remoteEndPoint.Address); + } + } + + [Fact] + public async Task Connect_DualMode_DnsConnect_RetrievedEndPoints_Success() + { + var localhostAddresses = Dns.GetHostAddresses("localhost"); + if (Array.IndexOf(localhostAddresses, IPAddress.Loopback) == -1 || + Array.IndexOf(localhostAddresses, IPAddress.IPv6Loopback) == -1) + { + return; + } + + int port; + using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback, out port)) + using (Socket client = new Socket(SocketType.Stream, ProtocolType.Tcp)) + { + Assert.True(client.DualMode); + + Task connectTask = ConnectAsync(client, new DnsEndPoint("localhost", port)); + await connectTask; + + var localEndPoint = client.LocalEndPoint as IPEndPoint; + Assert.NotNull(localEndPoint); + Assert.Equal(IPAddress.Loopback.MapToIPv6(), localEndPoint.Address); + + var remoteEndPoint = client.RemoteEndPoint as IPEndPoint; + Assert.NotNull(remoteEndPoint); + Assert.Equal(IPAddress.Loopback.MapToIPv6(), remoteEndPoint.Address); + } + } + [Fact] public async Task Connect_OnConnectedSocket_Fails() { diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTaskExtensionsTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTaskExtensionsTest.cs index bdc7ec7e3ef0f..c53ab9583230f 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTaskExtensionsTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTaskExtensionsTest.cs @@ -23,6 +23,9 @@ public async Task EnsureMethodsAreCallable() await Assert.ThrowsAsync(async () => await SocketTaskExtensions.AcceptAsync(s)); await Assert.ThrowsAsync(async () => await SocketTaskExtensions.AcceptAsync(s, null)); + await Assert.ThrowsAsync(async () => await SocketTaskExtensions.ReceiveFromAsync(s, new ArraySegment(buffer), SocketFlags.None, badEndPoint)); + await Assert.ThrowsAsync(async () => await SocketTaskExtensions.ReceiveMessageFromAsync(s, new ArraySegment(buffer), SocketFlags.None, badEndPoint)); + await Assert.ThrowsAsync(async () => await SocketTaskExtensions.ConnectAsync(s, badEndPoint)); await Assert.ThrowsAsync(async () => await SocketTaskExtensions.ConnectAsync(s, badEndPoint, CancellationToken.None)); await Assert.ThrowsAsync(async () => await SocketTaskExtensions.ConnectAsync(s, badEndPoint.Address, badEndPoint.Port)); @@ -35,8 +38,6 @@ public async Task EnsureMethodsAreCallable() await Assert.ThrowsAsync(async () => await SocketTaskExtensions.ReceiveAsync(s, new ArraySegment(buffer), SocketFlags.None)); await Assert.ThrowsAsync(async () => await SocketTaskExtensions.ReceiveAsync(s, buffer.AsMemory(), SocketFlags.None)); await Assert.ThrowsAsync(async () => await SocketTaskExtensions.ReceiveAsync(s, new ArraySegment[] { new ArraySegment(buffer) }, SocketFlags.None)); - await Assert.ThrowsAsync(async () => await SocketTaskExtensions.ReceiveFromAsync(s, new ArraySegment(buffer), SocketFlags.None, badEndPoint)); - await Assert.ThrowsAsync(async () => await SocketTaskExtensions.ReceiveMessageFromAsync(s, new ArraySegment(buffer), SocketFlags.None, badEndPoint)); await Assert.ThrowsAsync(async () => await SocketTaskExtensions.SendAsync(s, new ArraySegment(buffer), SocketFlags.None)); await Assert.ThrowsAsync(async () => await SocketTaskExtensions.SendAsync(s, buffer.AsMemory(), SocketFlags.None));