Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Socket.Windows: support ConnectAsync(SocketAsyncEventArgs) for UDP, and Unix sockets #33674

Merged
merged 5 commits into from
Apr 3, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2088,17 +2088,12 @@ public IAsyncResult BeginConnect(EndPoint remoteEP, AsyncCallback? callback, obj
return UnsafeBeginConnect(remoteEP, callback, state, flowContext: true);
}

private bool CanUseConnectEx(EndPoint remoteEP)
{
return (_socketType == SocketType.Stream) &&
(_rightEndPoint != null || remoteEP.GetType() == typeof(IPEndPoint));
}


private static bool CanUseConnectEx(SocketType socketType, EndPoint remoteEP)
=> (socketType == SocketType.Stream) && (remoteEP.GetType() == typeof(IPEndPoint));

internal IAsyncResult UnsafeBeginConnect(EndPoint remoteEP, AsyncCallback? callback, object? state, bool flowContext = false)
{
if (CanUseConnectEx(remoteEP))
if (CanUseConnectEx(_socketType, remoteEP))
{
return BeginConnectEx(remoteEP, flowContext, callback, state);
}
Expand Down Expand Up @@ -3817,7 +3812,15 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket)
SocketError socketError = SocketError.Success;
try
{
socketError = e.DoOperationConnect(this, _handle);
if (CanUseConnectEx(_socketType, endPointSnapshot))
{
socketError = e.DoOperationConnectEx(this, _handle);
}
else
{
// For connectionless protocols, Connect is not an I/O call.
socketError = e.DoOperationConnect(this, _handle);
}
}
catch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,6 @@ private void SetupMultipleBuffers()

private void CompleteCore() { }

private void FinishOperationSync(SocketError socketError, int bytesTransferred, SocketFlags flags)
{
Debug.Assert(socketError != SocketError.IOPending);

if (socketError == SocketError.Success)
{
FinishOperationSyncSuccess(bytesTransferred, flags);
}
else
{
FinishOperationSyncFailure(socketError, bytesTransferred, flags);
}
}

private void AcceptCompletionCallback(IntPtr acceptedFileDescriptor, byte[] socketAddress, int socketAddressSize, SocketError socketError)
{
CompleteAcceptOperation(acceptedFileDescriptor, socketAddress, socketAddressSize, socketError);
Expand Down Expand Up @@ -95,6 +81,9 @@ private void ConnectCompletionCallback(SocketError socketError)
CompletionCallback(0, SocketFlags.None, socketError);
}

internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle handle)
=> DoOperationConnect(socket, handle);

internal unsafe SocketError DoOperationConnect(Socket socket, SafeSocketHandle handle)
{
SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, _socketAddress.Size, ConnectCompletionCallback);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@ internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle ha
}

internal unsafe SocketError DoOperationConnect(Socket socket, SafeSocketHandle handle)
{
// Called for connectionless protocols.
SocketError socketError = SocketPal.Connect(handle, _socketAddress!.Buffer, _socketAddress.Size);
FinishOperationSync(socketError, 0, SocketFlags.None);
return socketError;
}

internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle handle)
{
// ConnectEx uses a sockaddr buffer containing the remote address to which to connect.
// It can also optionally take a single buffer of data to send after the connection is complete.
Expand Down Expand Up @@ -1160,6 +1168,13 @@ private unsafe SocketError FinishOperationConnect()
{
try
{
if (_currentSocket!.SocketType != SocketType.Stream)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@antonfirsov @scalablecory also give this condition some thought. Personally, I'm fine if CI passes. You may have a better understanding of what SO_UPDATE_CONNECT_CONTEXT does to decide if this is the most appropriate check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should match the logic for CanUseConnectEx. We need to call SO_UPDATE_CONNECT_CONTEXT if ConnectEx is used.

Copy link
Member Author

@tmds tmds Mar 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the same thing, though it seems the stream unix socket, which doesn't support ConnectEx, doesn't give an error when we set this socket option.
We added the check because UDP socket gave an error.

{
// With connectionless sockets, regular connect is used instead of ConnectEx,
// attempting to set SO_UPDATE_CONNECT_CONTEXT will result in an error.
return SocketError.Success;
}

// Update the socket context.
SocketError socketError = Interop.Winsock.setsockopt(
_currentSocket!.SafeHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,5 +780,19 @@ internal void FinishOperationAsyncSuccess(int bytesTransferred, SocketFlags flag
ExecutionContext.Run(context, s_executionCallback, this);
}
}

private void FinishOperationSync(SocketError socketError, int bytesTransferred, SocketFlags flags)
{
Debug.Assert(socketError != SocketError.IOPending);

if (socketError == SocketError.Success)
{
FinishOperationSyncSuccess(bytesTransferred, flags);
}
else
{
FinishOperationSyncFailure(socketError, bytesTransferred, flags);
}
}
}
}
28 changes: 28 additions & 0 deletions src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,34 @@ public async Task Connect_Success(IPAddress listenAt)
}
}

[Theory]
[MemberData(nameof(Loopbacks))]
public async Task Connect_Udp_Success(IPAddress listenAt)
{
using Socket listener = new Socket(listenAt.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
using Socket client = new Socket(listenAt.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
listener.Bind(new IPEndPoint(listenAt, 0));

await ConnectAsync(client, new IPEndPoint(listenAt, ((IPEndPoint)listener.LocalEndPoint).Port));
Assert.True(client.Connected);
}

[Theory]
[MemberData(nameof(Loopbacks))]
public async Task Connect_Dns_Success(IPAddress listenAt)
{
int port;
using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, listenAt, out port))
{
using (Socket client = new Socket(listenAt.AddressFamily, SocketType.Stream, ProtocolType.Tcp))
{
Task connectTask = ConnectAsync(client, new DnsEndPoint("localhost", port));
await connectTask;
Assert.True(client.Connected);
}
}
}

[OuterLoop]
[Theory]
[MemberData(nameof(Loopbacks))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ public void OSSupportsUnixDomainSockets_ReturnsCorrectValue()
Assert.Equal(PlatformSupportsUnixDomainSockets, Socket.OSSupportsUnixDomainSockets);
}

[PlatformSpecific(~TestPlatforms.Windows)] // Windows doesn't currently support ConnectEx with domain sockets
[ConditionalFact(nameof(PlatformSupportsUnixDomainSockets))]
public async Task Socket_ConnectAsyncUnixDomainSocketEndPoint_Success()
{
Expand Down Expand Up @@ -100,7 +99,7 @@ public async Task Socket_ConnectAsyncUnixDomainSocketEndPoint_NotServer()
}

Assert.Equal(
RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? SocketError.InvalidArgument : SocketError.AddressNotAvailable,
RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? SocketError.ConnectionRefused : SocketError.AddressNotAvailable,
args.SocketError);
}
}
Expand Down