Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Socket: don't assign right endpoint until the connect is successful. #53581

Merged
merged 4 commits into from
Jun 5, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 17 additions & 32 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ public partial class Socket : IDisposable
// Our internal state doesn't automatically get updated after a non-blocking connect
// completes. Keep track of whether we're doing a non-blocking connect, and make sure
// to poll for the real state until we're done connecting.
private bool _nonBlockingConnectInProgress;
private bool _pollPendingConnect;
tmds marked this conversation as resolved.
Show resolved Hide resolved

// Keep track of the kind of endpoint used to do a non-blocking connect, so we can set
// it to _rightEndPoint when we discover we're connected.
private EndPoint? _nonBlockingConnectRightEndPoint;
private EndPoint? _pendingConnectRightEndPoint;
tmds marked this conversation as resolved.
Show resolved Hide resolved

// These are constants initialized by constructor.
private AddressFamily _addressFamily;
Expand Down Expand Up @@ -283,13 +283,9 @@ public int Available
{
ThrowIfDisposed();

if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite))
if (_pollPendingConnect && Poll(0, SelectMode.SelectWrite))
{
// Update the state if we've become connected after a non-blocking connect.
_isConnected = true;
_rightEndPoint ??= _nonBlockingConnectRightEndPoint;
UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
SetToConnected();
}

if (_rightEndPoint == null)
Expand Down Expand Up @@ -330,13 +326,10 @@ public int Available

if (_remoteEndPoint == null)
{
if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite))
if (_pollPendingConnect && Poll(0, SelectMode.SelectWrite))
{
// Update the state if we've become connected after a non-blocking connect.
_isConnected = true;
_rightEndPoint ??= _nonBlockingConnectRightEndPoint;
UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
SetToConnected();
}

if (_rightEndPoint == null || !_isConnected)
Expand Down Expand Up @@ -437,13 +430,10 @@ public bool Connected
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"_isConnected:{_isConnected}");

if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite))
if (_pollPendingConnect && Poll(0, SelectMode.SelectWrite))
{
// Update the state if we've become connected after a non-blocking connect.
_isConnected = true;
_rightEndPoint ??= _nonBlockingConnectRightEndPoint;
UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
SetToConnected();
}

return _isConnected;
Expand Down Expand Up @@ -856,12 +846,8 @@ public void Connect(EndPoint remoteEP)
ValidateForMultiConnect(isMultiEndpoint: false);

Internals.SocketAddress socketAddress = Serialize(ref remoteEP);

if (!Blocking)
{
_nonBlockingConnectRightEndPoint = remoteEP;
_nonBlockingConnectInProgress = true;
}
_pendingConnectRightEndPoint = remoteEP;
_pollPendingConnect = !Blocking;

DoConnect(remoteEP, socketAddress);
}
Expand Down Expand Up @@ -2768,13 +2754,11 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaCan
}

e._socketAddress = Serialize(ref endPointSnapshot);
_pendingConnectRightEndPoint = endPointSnapshot;
_pollPendingConnect = false;

WildcardBindForConnectIfNecessary(endPointSnapshot.AddressFamily);

// Save the old RightEndPoint and prep new RightEndPoint.
EndPoint? oldEndPoint = _rightEndPoint;
_rightEndPoint ??= endPointSnapshot;

if (SocketsTelemetry.Log.IsEnabled())
{
SocketsTelemetry.Log.ConnectStart(e._socketAddress!);
Expand All @@ -2801,7 +2785,6 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaCan
SocketsTelemetry.Log.AfterConnect(SocketError.NotSocket, ex.Message);
}

_rightEndPoint = oldEndPoint;
_localEndPoint = null;

// Clear in-use flag on event args object.
Expand Down Expand Up @@ -3217,12 +3200,11 @@ private void DoConnect(EndPoint endPointSnapshot, Internals.SocketAddress socket

if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterConnect(SocketError.Success);

// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint ??= endPointSnapshot;

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"connection to:{endPointSnapshot}");

// Update state and performance counters.
_pendingConnectRightEndPoint = endPointSnapshot;
_pollPendingConnect = false;
SetToConnected();
if (NetEventSource.Log.IsEnabled()) NetEventSource.Connected(this, LocalEndPoint, RemoteEndPoint);
}
Expand Down Expand Up @@ -3663,6 +3645,9 @@ internal void SetToConnected()
// some point in time update the perf counter as well.
_isConnected = true;
_isDisconnected = false;
_pollPendingConnect = false;
tmds marked this conversation as resolved.
Show resolved Hide resolved
_rightEndPoint ??= _pendingConnectRightEndPoint;
tmds marked this conversation as resolved.
Show resolved Hide resolved
_pendingConnectRightEndPoint = null;
UpdateLocalEndPointOnConnect();
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "now connected");
}
Expand Down
54 changes: 54 additions & 0 deletions src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,60 @@ public async Task Connect_MultipleIPAddresses_Success(IPAddress listenAt)
}
}

[Fact]
public async Task Connect_DualMode_MultiAddressFamilyConnect_RetrievedEndPoints_Success()
{
if (!SupportsMultiConnect)
return;

int port;
using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback, out port))
using (Socket client = new Socket(SocketType.Stream, ProtocolType.Tcp))
{
Assert.True(client.DualMode);

Task connectTask = MultiConnectAsync(client, new IPAddress[] { IPAddress.IPv6Loopback, IPAddress.Loopback }, port);
await connectTask;

var localEndPoint = client.LocalEndPoint as IPEndPoint;
Assert.NotNull(localEndPoint);
Assert.Equal(IPAddress.Loopback.MapToIPv6(), localEndPoint.Address);

var remoteEndPoint = client.RemoteEndPoint as IPEndPoint;
Assert.NotNull(remoteEndPoint);
Assert.Equal(IPAddress.Loopback.MapToIPv6(), remoteEndPoint.Address);
}
}

[Fact]
public async Task Connect_DualMode_DnsConnect_RetrievedEndPoints_Success()
{
var localhostAddresses = Dns.GetHostAddresses("localhost");
if (Array.IndexOf(localhostAddresses, IPAddress.Loopback) == -1 ||
Array.IndexOf(localhostAddresses, IPAddress.IPv6Loopback) == -1)
{
return;
}

int port;
using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback, out port))
using (Socket client = new Socket(SocketType.Stream, ProtocolType.Tcp))
{
Assert.True(client.DualMode);

Task connectTask = ConnectAsync(client, new DnsEndPoint("localhost", port));
await connectTask;

var localEndPoint = client.LocalEndPoint as IPEndPoint;
Assert.NotNull(localEndPoint);
Assert.Equal(IPAddress.Loopback.MapToIPv6(), localEndPoint.Address);

var remoteEndPoint = client.RemoteEndPoint as IPEndPoint;
Assert.NotNull(remoteEndPoint);
Assert.Equal(IPAddress.Loopback.MapToIPv6(), remoteEndPoint.Address);
}
}

[Fact]
public async Task Connect_OnConnectedSocket_Fails()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ public async Task EnsureMethodsAreCallable()
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveAsync(s, buffer.AsMemory(), SocketFlags.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveAsync(s, new ArraySegment<byte>[] { new ArraySegment<byte>(buffer) }, SocketFlags.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveMessageFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.ReceiveFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
tmds marked this conversation as resolved.
Show resolved Hide resolved
await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.ReceiveMessageFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));

await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.SendAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.SendAsync(s, buffer.AsMemory(), SocketFlags.None));
Expand Down