Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Task-based DisconnectAsync and refactor APM methods on top of it #51213

Merged
merged 3 commits into from
Apr 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ public partial class Socket : System.IDisposable
public static bool ConnectAsync(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType, System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public void Disconnect(bool reuseSocket) { }
public bool DisconnectAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public System.Threading.Tasks.ValueTask DisconnectAsync(bool reuseSocket, System.Threading.CancellationToken cancellationToken = default) { throw null; }
public void Dispose() { }
protected virtual void Dispose(bool disposing) { }
[System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
<Compile Include="System\Net\Sockets\UdpReceiveResult.cs" />
<Compile Include="System\Net\Sockets\AcceptOverlappedAsyncResult.cs" />
<Compile Include="System\Net\Sockets\BaseOverlappedAsyncResult.cs" />
<Compile Include="System\Net\Sockets\DisconnectOverlappedAsyncResult.cs" />
<Compile Include="System\Net\Sockets\UnixDomainSocketEndPoint.cs" />
<!-- Logging -->
<Compile Include="$(CommonPath)System\Net\Logging\NetEventSource.Common.cs"
Expand Down Expand Up @@ -187,7 +186,6 @@
<ItemGroup Condition="'$(TargetsUnix)' == 'true'">
<Compile Include="System\Net\Sockets\AcceptOverlappedAsyncResult.Unix.cs" />
<Compile Include="System\Net\Sockets\BaseOverlappedAsyncResult.Unix.cs" />
<Compile Include="System\Net\Sockets\DisconnectOverlappedAsyncResult.Unix.cs" />
<Compile Include="System\Net\Sockets\SafeSocketHandle.Unix.cs" />
<Compile Include="System\Net\Sockets\Socket.Unix.cs" />
<Compile Include="System\Net\Sockets\SocketAsyncContext.Unix.cs" />
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,29 @@ public ValueTask ConnectAsync(string host, int port, CancellationToken cancellat
return ConnectAsync(ep, cancellationToken);
}

/// <summary>
/// Disconnects a connected socket from the remote host.
/// </summary>
/// <param name="reuseSocket">Indicates whether the socket should be available for reuse after disconnect.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <returns>An asynchronous task that completes when the socket is disconnected.</returns>
public ValueTask DisconnectAsync(bool reuseSocket, CancellationToken cancellationToken = default)
{
if (cancellationToken.IsCancellationRequested)
{
return ValueTask.FromCanceled(cancellationToken);
}

AwaitableSocketAsyncEventArgs saea =
Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ??
new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false);

saea.DisconnectReuseSocket = reuseSocket;
saea.WrapExceptionsForNetworkStream = false;

return saea.DisconnectAsync(this, cancellationToken);
}

/// <summary>
/// Receives data from a connected socket.
/// </summary>
Expand Down Expand Up @@ -1028,6 +1051,25 @@ public ValueTask ConnectAsync(Socket socket)
ValueTask.FromException(CreateException(error));
}

public ValueTask DisconnectAsync(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");

if (socket.DisconnectAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
return new ValueTask(this, _token);
}

SocketError error = SocketError;

Release();

return error == SocketError.Success ?
ValueTask.CompletedTask :
ValueTask.FromException(CreateException(error));
}

/// <summary>Gets the status of the operation.</summary>
public ValueTaskSourceStatus GetStatus(short token)
{
Expand Down
94 changes: 16 additions & 78 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2116,43 +2116,14 @@ public static void Select(IList? checkRead, IList? checkWrite, IList? checkError
public IAsyncResult BeginConnect(IPAddress[] addresses, int port, AsyncCallback? requestCallback, object? state) =>
TaskToApm.Begin(ConnectAsync(addresses, port), requestCallback, state);

public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state)
public void EndConnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();

// Start context-flowing op. No need to lock - we don't use the context till the callback.
DisconnectOverlappedAsyncResult asyncResult = new DisconnectOverlappedAsyncResult(this, state, callback);
asyncResult.StartPostingAsyncOp(false);

// Post the disconnect.
DoBeginDisconnect(reuseSocket, asyncResult);

// Finish flowing (or call the callback), and return.
asyncResult.FinishPostingAsyncOp();
return asyncResult;
TaskToApm.End(asyncResult);
}

private void DoBeginDisconnect(bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult)
{
SocketError errorCode = SocketError.Success;

errorCode = SocketPal.DisconnectAsync(this, _handle, reuseSocket, asyncResult);

if (errorCode == SocketError.Success)
{
SetToDisconnected();
_remoteEndPoint = null;
_localEndPoint = null;
}

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"UnsafeNclNativeMethods.OSSOCK.DisConnectEx returns:{errorCode}");

// If the call failed, update our status and throw
if (!CheckErrorAndUpdateStatus(errorCode))
{
throw new SocketException((int)errorCode);
}
}
public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state) =>
TaskToApm.Begin(DisconnectAsync(reuseSocket).AsTask(), callback, state);
geoffkizer marked this conversation as resolved.
Show resolved Hide resolved

public void Disconnect(bool reuseSocket)
{
Expand All @@ -2175,47 +2146,12 @@ public void Disconnect(bool reuseSocket)
_localEndPoint = null;
}

public void EndConnect(IAsyncResult asyncResult)
public void EndDisconnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();
TaskToApm.End(asyncResult);
}

public void EndDisconnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();

if (asyncResult == null)
{
throw new ArgumentNullException(nameof(asyncResult));
}

//get async result and check for errors
LazyAsyncResult? castedAsyncResult = asyncResult as LazyAsyncResult;
if (castedAsyncResult == null || castedAsyncResult.AsyncObject != this)
{
throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult));
}
if (castedAsyncResult.EndCalled)
{
throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, nameof(EndDisconnect)));
}

//wait for completion if it hasn't occurred
castedAsyncResult.InternalWaitForCompletion();
castedAsyncResult.EndCalled = true;

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this);

//
// if the asynchronous native call failed asynchronously
// we'll throw a SocketException
//
if ((SocketError)castedAsyncResult.ErrorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException((SocketError)castedAsyncResult.ErrorCode);
}
}

public IAsyncResult BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback? callback, object? state)
{
Expand Down Expand Up @@ -2668,7 +2604,7 @@ public void Shutdown(SocketShutdown how)
InternalSetBlocking(_willBlockInternal);
}

#region Async methods
#region Async methods
public bool AcceptAsync(SocketAsyncEventArgs e)
{
ThrowIfDisposed();
Expand Down Expand Up @@ -2889,7 +2825,9 @@ public static void CancelConnectAsync(SocketAsyncEventArgs e)
e.CancelConnectAsync();
}

public bool DisconnectAsync(SocketAsyncEventArgs e)
public bool DisconnectAsync(SocketAsyncEventArgs e) => DisconnectAsync(e, default);

private bool DisconnectAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
// Throw if socket disposed
ThrowIfDisposed();
Expand All @@ -2904,7 +2842,7 @@ public bool DisconnectAsync(SocketAsyncEventArgs e)
SocketError socketError = SocketError.Success;
try
{
socketError = e.DoOperationDisconnect(this, _handle);
socketError = e.DoOperationDisconnect(this, _handle, cancellationToken);
}
catch
{
Expand Down Expand Up @@ -3155,10 +3093,10 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT

return socketError == SocketError.IOPending;
}
#endregion
#endregion
#endregion
#endregion

#region Internal and private properties
#region Internal and private properties

private CacheSet Caches
{
Expand All @@ -3174,9 +3112,9 @@ private CacheSet Caches
}

internal bool Disposed => _disposed != 0;
#endregion
#endregion

#region Internal and private methods
#region Internal and private methods

internal static void GetIPProtocolInformation(AddressFamily addressFamily, Internals.SocketAddress socketAddress, out bool isIPv4, out bool isIPv6)
{
Expand Down Expand Up @@ -3889,6 +3827,6 @@ private static SocketError GetSocketErrorFromFaultedTask(Task t)
};
}

#endregion
#endregion
geoffkizer marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal unsafe SocketError DoOperationConnect(Socket socket, SafeSocketHandle h
return socketError;
}

internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle)
internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
{
SocketError socketError = SocketPal.Disconnect(socket, handle, _disconnectReuseSocket);
FinishOperationSync(socketError, 0, SocketFlags.None);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,10 @@ internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle
}
}

internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle)
internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
{
// Note: CancellationToken is ignored for now.
geoffkizer marked this conversation as resolved.
Show resolved Hide resolved

NativeOverlapped* overlapped = AllocateNativeOverlapped();
try
{
Expand Down Expand Up @@ -1188,6 +1190,7 @@ private unsafe SocketError FinishOperationConnect()
private void CompleteCore()
{
_strongThisRef.Value = null; // null out this reference from the overlapped so this isn't kept alive artificially

if (_singleBufferHandleState != SingleBufferHandleState.None)
{
// If the state isn't None, then either it's Set, in which case there's state to cleanup,
Expand All @@ -1213,6 +1216,8 @@ void CompleteCoreSpin()
sw.SpinOnce();
}

Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);

// Remove any cancellation registration. First dispose the registration
// to ensure that cancellation will either never fine or will have completed
// firing before we continue. Only then can we safely null out the overlapped.
Expand All @@ -1223,6 +1228,8 @@ void CompleteCoreSpin()
}

// Release any GC handles.
Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);

if (_singleBufferHandleState == SingleBufferHandleState.Set)
{
_singleBufferHandleState = SingleBufferHandleState.None;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1976,13 +1976,6 @@ public static SocketError AcceptAsync(Socket socket, SafeSocketHandle handle, Sa
return socketError;
}

internal static SocketError DisconnectAsync(Socket socket, SafeSocketHandle handle, bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult)
{
SocketError socketError = Disconnect(socket, handle, reuseSocket);
asyncResult.PostCompletion(socketError);
return socketError;
}

internal static SocketError Disconnect(Socket socket, SafeSocketHandle handle, bool reuseSocket)
{
handle.SetToDisconnected();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1137,27 +1137,6 @@ public static void CheckDualModeReceiveSupport(Socket socket)
// Dual-mode sockets support received packet info on Windows.
}

internal static unsafe SocketError DisconnectAsync(Socket socket, SafeSocketHandle handle, bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult)
{
asyncResult.SetUnmanagedStructures(null);
try
{
// This can throw ObjectDisposedException
bool success = socket.DisconnectEx(
handle,
asyncResult.DangerousOverlappedPointer, // SafeHandle was just created in SetUnmanagedStructures
(int)(reuseSocket ? TransmitFileOptions.ReuseSocket : 0),
0);

return asyncResult.ProcessOverlappedResult(success, 0);
}
catch
{
asyncResult.ReleaseUnmanagedStructures();
throw;
}
}

internal static SocketError Disconnect(Socket socket, SafeSocketHandle handle, bool reuseSocket)
{
SocketError errorCode = SocketError.Success;
Expand Down
Loading