diff --git a/src/libraries/System.DirectoryServices.Protocols/src/System.DirectoryServices.Protocols.csproj b/src/libraries/System.DirectoryServices.Protocols/src/System.DirectoryServices.Protocols.csproj index 03d310243f3b35..1fdfb6b7abc10d 100644 --- a/src/libraries/System.DirectoryServices.Protocols/src/System.DirectoryServices.Protocols.csproj +++ b/src/libraries/System.DirectoryServices.Protocols/src/System.DirectoryServices.Protocols.csproj @@ -100,6 +100,8 @@ + + diff --git a/src/libraries/System.DirectoryServices.Protocols/src/System/DirectoryServices/Protocols/ldap/LdapConnection.cs b/src/libraries/System.DirectoryServices.Protocols/src/System/DirectoryServices/Protocols/ldap/LdapConnection.cs index c66383e6d8a97f..9850e342d468eb 100644 --- a/src/libraries/System.DirectoryServices.Protocols/src/System/DirectoryServices/Protocols/ldap/LdapConnection.cs +++ b/src/libraries/System.DirectoryServices.Protocols/src/System/DirectoryServices/Protocols/ldap/LdapConnection.cs @@ -2,17 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Globalization; -using System.Net; using System.Collections; using System.ComponentModel; -using System.Text; using System.Diagnostics; +using System.Net; using System.Runtime.InteropServices; -using System.Xml; -using System.Threading; using System.Security.Cryptography.X509Certificates; -using System.Diagnostics.CodeAnalysis; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using System.Xml; namespace System.DirectoryServices.Protocols { @@ -44,7 +43,6 @@ internal enum LdapResult private bool _needRebind = false; internal static Hashtable s_handleTable = null; internal static object s_objectLock = null; - private readonly GetLdapResponseCallback _fd = null; private static readonly Hashtable s_asyncResultTable = null; private static readonly LdapPartialResultsProcessor s_partialResultsProcessor = null; private static readonly ManualResetEvent s_waitHandle = null; @@ -84,7 +82,6 @@ public LdapConnection(LdapDirectoryIdentifier identifier, NetworkCredential cred public LdapConnection(LdapDirectoryIdentifier identifier, NetworkCredential credential, AuthType authType) { - _fd = new GetLdapResponseCallback(ConstructResponse); _directoryIdentifier = identifier; _directoryCredential = (credential != null) ? new NetworkCredential(credential.UserName, credential.Password, credential.Domain) : null; @@ -288,7 +285,9 @@ public DirectoryResponse SendRequest(DirectoryRequest request, TimeSpan requestT if (error == 0 && messageID != -1) { - return ConstructResponse(messageID, operation, ResultAll.LDAP_MSG_ALL, requestTimeout, true); + ValueTask vt = ConstructResponseAsync(messageID, operation, ResultAll.LDAP_MSG_ALL, requestTimeout, true, sync: true); + Debug.Assert(vt.IsCompleted); + return vt.GetAwaiter().GetResult(); } else { @@ -379,7 +378,30 @@ public IAsyncResult BeginSendRequest(DirectoryRequest request, TimeSpan requestT s_asyncResultTable.Add(asyncResult, messageID); - _fd.BeginInvoke(messageID, operation, ResultAll.LDAP_MSG_ALL, requestTimeout, true, new AsyncCallback(ResponseCallback), requestState); + _ = ResponseCallback(ConstructResponseAsync(messageID, operation, ResultAll.LDAP_MSG_ALL, requestTimeout, true, sync: false), requestState); + + static async Task ResponseCallback(ValueTask vt, LdapRequestState requestState) + { + try + { + DirectoryResponse response = await vt.ConfigureAwait(false); + requestState._response = response; + } + catch (Exception e) + { + requestState._exception = e; + requestState._response = null; + } + + // Signal waitable object, indicate operation completed and fire callback. + requestState._ldapAsync._manualResetEvent.Set(); + requestState._ldapAsync._completed = true; + + if (requestState._ldapAsync._callback != null && !requestState._abortCalled) + { + requestState._ldapAsync._callback(requestState._ldapAsync); + } + } return asyncResult; } @@ -404,31 +426,6 @@ public IAsyncResult BeginSendRequest(DirectoryRequest request, TimeSpan requestT throw ConstructException(error, operation); } - private void ResponseCallback(IAsyncResult asyncResult) - { - LdapRequestState requestState = (LdapRequestState)asyncResult.AsyncState; - - try - { - DirectoryResponse response = _fd.EndInvoke(asyncResult); - requestState._response = response; - } - catch (Exception e) - { - requestState._exception = e; - requestState._response = null; - } - - // Signal waitable object, indicate operation completed and fire callback. - requestState._ldapAsync._manualResetEvent.Set(); - requestState._ldapAsync._completed = true; - - if (requestState._ldapAsync._callback != null && !requestState._abortCalled) - { - requestState._ldapAsync._callback(requestState._ldapAsync); - } - } - public void Abort(IAsyncResult asyncResult) { if (_disposed) @@ -1404,7 +1401,7 @@ internal LdapMod[] BuildAttributes(CollectionBase directoryAttributes, ArrayList return attributes; } - internal DirectoryResponse ConstructResponse(int messageId, LdapOperation operation, ResultAll resultType, TimeSpan requestTimeOut, bool exceptionOnTimeOut) + internal async ValueTask ConstructResponseAsync(int messageId, LdapOperation operation, ResultAll resultType, TimeSpan requestTimeOut, bool exceptionOnTimeOut, bool sync) { var timeout = new LDAP_TIMEVAL() { @@ -1436,7 +1433,35 @@ internal DirectoryResponse ConstructResponse(int messageId, LdapOperation operat needAbandon = false; } - int error = LdapPal.GetResultFromAsyncOperation(_ldapHandle, messageId, (int)resultType, timeout, ref ldapResult); + int error; + if (sync) + { + error = LdapPal.GetResultFromAsyncOperation(_ldapHandle, messageId, (int)resultType, timeout, ref ldapResult); + } + else + { + timeout.tv_sec = 0; + timeout.tv_usec = 0; + int iterationDelay = 1; + // Underlying native libraries don't support callback-based function, so we will instead use polling and + // use a Stopwatch to track the timeout manually. + Stopwatch watch = Stopwatch.StartNew(); + while (true) + { + error = LdapPal.GetResultFromAsyncOperation(_ldapHandle, messageId, (int)resultType, timeout, ref ldapResult); + if (error != 0 || (requestTimeOut != Threading.Timeout.InfiniteTimeSpan && watch.Elapsed > requestTimeOut)) + { + break; + } + await Task.Delay(Math.Min(iterationDelay, 100)).ConfigureAwait(false); + if (iterationDelay < 100) + { + iterationDelay *= 2; + } + } + watch.Stop(); + } + if (error != -1 && error != 0) { // parsing the result diff --git a/src/libraries/System.DirectoryServices.Protocols/src/System/DirectoryServices/Protocols/ldap/LdapPartialResultsProcessor.cs b/src/libraries/System.DirectoryServices.Protocols/src/System/DirectoryServices/Protocols/ldap/LdapPartialResultsProcessor.cs index 3d6b85838ca013..b80b5995c38c5a 100644 --- a/src/libraries/System.DirectoryServices.Protocols/src/System/DirectoryServices/Protocols/ldap/LdapPartialResultsProcessor.cs +++ b/src/libraries/System.DirectoryServices.Protocols/src/System/DirectoryServices/Protocols/ldap/LdapPartialResultsProcessor.cs @@ -6,6 +6,7 @@ using System.Threading; using System.Collections; using System.Diagnostics; +using System.Threading.Tasks; namespace System.DirectoryServices.Protocols { @@ -135,7 +136,9 @@ private void GetResultsHelper(LdapPartialAsyncResult asyncResult) try { - SearchResponse response = (SearchResponse)connection.ConstructResponse(asyncResult._messageID, LdapOperation.LdapSearch, resultType, asyncResult._requestTimeout, false); + ValueTask vt = connection.ConstructResponseAsync(asyncResult._messageID, LdapOperation.LdapSearch, resultType, asyncResult._requestTimeout, false, sync: true); + Debug.Assert(vt.IsCompleted); + SearchResponse response = (SearchResponse)vt.GetAwaiter().GetResult(); // This should only happen in the polling thread case. if (response == null) diff --git a/src/libraries/System.DirectoryServices.Protocols/tests/DirectoryServicesProtocolsTests.cs b/src/libraries/System.DirectoryServices.Protocols/tests/DirectoryServicesProtocolsTests.cs index af065dfd6a393d..d5965e346f9eea 100644 --- a/src/libraries/System.DirectoryServices.Protocols/tests/DirectoryServicesProtocolsTests.cs +++ b/src/libraries/System.DirectoryServices.Protocols/tests/DirectoryServicesProtocolsTests.cs @@ -2,9 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Diagnostics; using System.DirectoryServices.Tests; using System.Globalization; using System.Net; +using System.Threading; using Xunit; namespace System.DirectoryServices.Protocols.Tests @@ -534,7 +536,8 @@ private SearchResultEntry SearchOrganizationalUnit(LdapConnection connection, st { string filter = $"(&(objectClass=organizationalUnit)(ou={ouName}))"; SearchRequest searchRequest = new SearchRequest(rootDn, filter, SearchScope.OneLevel, null); - SearchResponse searchResponse = (SearchResponse) connection.SendRequest(searchRequest); + IAsyncResult asyncResult = connection.BeginSendRequest(searchRequest, PartialResultProcessing.NoPartialResultSupport, null, null); + SearchResponse searchResponse = (SearchResponse)connection.EndSendRequest(asyncResult); if (searchResponse.Entries.Count > 0) return searchResponse.Entries[0];