From 3bce74bd59609297ed3a4b4ea4b511c0909bb5f6 Mon Sep 17 00:00:00 2001 From: Davoud Date: Thu, 20 Jul 2023 15:54:13 -0700 Subject: [PATCH 1/5] improve tests --- .../SqlConnectionBasicTests.cs | 51 +++++++++---------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs index 66dc223c4b..71f2de778c 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs @@ -5,6 +5,7 @@ using System; using System.Data; using System.Data.Common; +using System.Diagnostics; using System.Globalization; using System.Reflection; using System.Security; @@ -269,28 +270,26 @@ public void ConnectionTimeoutTest(int timeout) server.Dispose(); // Measure the actual time it took to timeout and compare it with configured timeout - var start = DateTime.Now; - var end = start; + Stopwatch timer = new(); + Exception ex = null; // Open a connection with the server disposed. try { + timer.Start(); connection.Open(); } - catch (Exception) + catch (Exception e) { - end = DateTime.Now; + timer.Stop(); + ex = e; } - // Calculate actual duration of timeout - TimeSpan s = end - start; - // Did not time out? - if (s.TotalSeconds == 0) - Assert.True(s.TotalSeconds == 0); - - // Is actual time out the same as configured timeout or within an additional 3 second threshold because of overhead? - if (s.TotalSeconds > 0) - Assert.True(s.TotalSeconds <= timeout + 3); + Assert.False(timer.IsRunning, "Timer must stopped."); + Assert.NotNull(ex); + Assert.True(timer.Elapsed.TotalSeconds <= timeout + 3, + $"The actual timeout {timer.Elapsed.TotalSeconds} is expected to be less than {timeout} plus 3 seconds additional threshold." + + $"{Environment.NewLine}{ex}"); } [Theory] @@ -310,28 +309,26 @@ public async void ConnectionTimeoutTestAsync(int timeout) server.Dispose(); // Measure the actual time it took to timeout and compare it with configured timeout - var start = DateTime.Now; - var end = start; + Stopwatch timer = new(); + Exception ex = null; // Open a connection with the server disposed. try { - await connection.OpenAsync(); + timer.Start(); + await connection.OpenAsync(); } - catch (Exception) + catch (Exception e) { - end = DateTime.Now; + timer.Stop(); + ex = e; } - // Calculate actual duration of timeout - TimeSpan s = end - start; - // Did not time out? - if (s.TotalSeconds == 0) - Assert.True(s.TotalSeconds == 0); - - // Is actual time out the same as configured timeout or within an additional 3 second threshold because of overhead? - if (s.TotalSeconds > 0) - Assert.True(s.TotalSeconds <= timeout + 3); + Assert.False(timer.IsRunning, "Timer must stopped."); + Assert.NotNull(ex); + Assert.True(timer.Elapsed.TotalSeconds <= timeout + 3, + $"The actual timeout {timer.Elapsed.TotalSeconds} is expected to be less than {timeout} plus 3 seconds additional threshold." + + $"{Environment.NewLine}{ex}"); } [Fact] From dcccdd06997e4ce3f75fc6e57480e8d2845b76f1 Mon Sep 17 00:00:00 2001 From: Davoud Date: Fri, 28 Jul 2023 15:51:08 -0700 Subject: [PATCH 2/5] update test --- .../tests/FunctionalTests/SqlConnectionBasicTests.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs index 71f2de778c..03a02fd3fe 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs @@ -315,8 +315,10 @@ public async void ConnectionTimeoutTestAsync(int timeout) // Open a connection with the server disposed. try { + //an asyn call with a timeout token to cancel the operation after the specific time + using CancellationTokenSource cts = new CancellationTokenSource(timeout * 1000); timer.Start(); - await connection.OpenAsync(); + await connection.OpenAsync(cts.Token).ConfigureAwait(false); } catch (Exception e) { From 1670faa4bfa41f7b830c016b512869e635d5e682 Mon Sep 17 00:00:00 2001 From: Davoud Date: Fri, 28 Jul 2023 15:56:38 -0700 Subject: [PATCH 3/5] update connect timeout --- .../Microsoft/Data/SqlClient/SNI/SNICommon.cs | 22 ++++- .../Data/SqlClient/SNI/SNITcpHandle.cs | 86 ++++++++++++++----- .../src/Microsoft/Data/SqlClient/SNI/SSRP.cs | 8 +- 3 files changed, 91 insertions(+), 25 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs index b92746098f..3349913335 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs @@ -3,9 +3,12 @@ // See the LICENSE file in the project root for more information. using System; +using System.Diagnostics; using System.Net; using System.Net.Security; using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; namespace Microsoft.Data.SqlClient.SNI { @@ -194,7 +197,7 @@ internal static bool ValidateSslServerCertificate(string targetServerName, X509C return true; } } - + /// /// We validate the provided certificate provided by the client with the one from the server to see if it matches. /// Certificate validation and chain trust validations are done by SSLStream class [System.Net.Security.SecureChannel.VerifyRemoteCertificate method] @@ -239,6 +242,23 @@ internal static bool ValidateSslServerCertificate(X509Certificate clientCert, X5 } } + internal static IPAddress[] GetDnsIpAddresses(string serverName, ref TimeSpan timeout) + { + using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) + { + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0} with {1} timeout.", args0: serverName, args1: timeout); + using CancellationTokenSource cts = new CancellationTokenSource(timeout); + Stopwatch stopwatch = Stopwatch.StartNew(); + // using this overload to support netstandard + Task task = Dns.GetHostAddressesAsync(serverName); + task.ConfigureAwait(false); + task.Wait(cts.Token); + timeout -= stopwatch.Elapsed; + stopwatch.Stop(); + return task.Result; + } + } + internal static IPAddress[] GetDnsIpAddresses(string serverName) { using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index 14243e98d3..de7c589cfa 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -164,6 +164,8 @@ public SNITCPHandle( ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts; } + Stopwatch stopwatch = Stopwatch.StartNew(); + bool reportError = true; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Connecting to serverName {1} and port {2}", args0: _connectionId, args1: serverName, args2: port); @@ -183,6 +185,11 @@ public SNITCPHandle( } catch (Exception ex) { + TimeSpan timeLeft = ts - stopwatch.Elapsed; + if (!isInfiniteTimeOut && timeLeft <= TimeSpan.Zero) + { + throw; + } // Retry with cached IP address if (ex is SocketException || ex is ArgumentException || ex is AggregateException) { @@ -214,26 +221,31 @@ public SNITCPHandle( { if (parallel) { - _socket = TryConnectParallel(firstCachedIP, portRetry, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); + _socket = TryConnectParallel(firstCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); } else { - _socket = Connect(firstCachedIP, portRetry, ts, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo); + _socket = Connect(firstCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo); } } catch (Exception exRetry) { + timeLeft = ts - stopwatch.Elapsed; + if (!isInfiniteTimeOut && timeLeft <= TimeSpan.Zero) + { + throw; + } if (exRetry is SocketException || exRetry is ArgumentNullException || exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException) { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Retrying exception {1}", args0: _connectionId, args1: exRetry?.Message); if (parallel) { - _socket = TryConnectParallel(secondCachedIP, portRetry, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); + _socket = TryConnectParallel(secondCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); } else { - _socket = Connect(secondCachedIP, portRetry, ts, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo); + _socket = Connect(secondCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo); } } else @@ -249,6 +261,10 @@ public SNITCPHandle( throw; } } + finally + { + stopwatch.Stop(); + } if (_socket == null || !_socket.Connected) { @@ -304,8 +320,11 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i { Socket availableSocket = null; Task connectTask; + TimeSpan timeout = ts; - IPAddress[] serverAddresses = SNICommon.GetDnsIpAddresses(hostName); + IPAddress[] serverAddresses = isInfiniteTimeOut + ? SNICommon.GetDnsIpAddresses(hostName) + : SNICommon.GetDnsIpAddresses(hostName, ref timeout); if (serverAddresses.Length > MaxParallelIpAddresses) { @@ -338,7 +357,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i connectTask = ParallelConnectAsync(serverAddresses, port); - if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(ts))) + if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(timeout))) { callerReportError = false; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "Connection Id {0} Connection timed out, Exception: {1}", args0: _connectionId, args1: Strings.SNI_ERROR_40); @@ -349,7 +368,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i availableSocket = connectTask.Result; return availableSocket; } - + /// /// Returns array of IP addresses for the given server name, sorted according to the given preference. /// @@ -389,7 +408,7 @@ private static IEnumerable GetHostAddressesSortedByPreference(string } } } - + // Connect to server with hostName and port. // The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point. // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. @@ -422,26 +441,44 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo port, ipAddress.AddressFamily, isInfiniteTimeout); - + bool isConnected; try // catching SocketException with SocketErrorCode == WouldBlock to run Socket.Select { - socket.Connect(ipAddress, port); - if (!isInfiniteTimeout) + if (isInfiniteTimeout) + { + socket.Connect(ipAddress, port); + } + else { + TimeSpan timeLeft = timeout - timeTaken.Elapsed; + if (timeLeft <= TimeSpan.Zero) + { + return null; + } + // Socket.Connect does not support infinite timeouts, so we use Task to simulate it + Task socketConnectTask = new Task(() => socket.Connect(ipAddress, port)); + socketConnectTask.ConfigureAwait(false); + socketConnectTask.Start(); + if (!socketConnectTask.Wait(timeLeft)) + { + throw ADP.TimeoutException($"The socket couldn't connect during the expected {timeLeft} remaining time to connect."); + } throw SQL.SocketDidNotThrow(); } - + isConnected = true; } - catch (SocketException socketException) when (!isInfiniteTimeout && - socketException.SocketErrorCode == - SocketError.WouldBlock) + catch (AggregateException aggregateException) when (!isInfiniteTimeout + && aggregateException.InnerException is SocketException socketException + && socketException.SocketErrorCode == SocketError.WouldBlock) { // https://github.com/dotnet/SqlClient/issues/826#issuecomment-736224118 // Socket.Select is used because it supports timeouts, while Socket.Connect does not - List checkReadLst; List checkWriteLst; List checkErrorLst; + List checkReadLst; + List checkWriteLst; + List checkErrorLst; // Repeating Socket.Select several times if our timeout is greater // than int.MaxValue microseconds because of @@ -450,9 +487,10 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo do { TimeSpan timeLeft = timeout - timeTaken.Elapsed; - - if (timeLeft <= TimeSpan.Zero) + if (!isInfiniteTimeout && timeLeft <= TimeSpan.Zero) + { return null; + } int socketSelectTimeout = checked((int)(Math.Min(timeLeft.TotalMilliseconds, int.MaxValue / 1000) * 1000)); @@ -487,11 +525,15 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo return socket; } } - catch (SocketException e) + catch (AggregateException aggregateException) when (aggregateException.InnerException is SocketException socketException) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: socketException?.Message); SqlClientEventSource.Log.TryAdvancedTraceEvent( - $"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {e}"); + $"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {socketException}"); + } + catch (AggregateException aggregateException) when (aggregateException.InnerException is TimeoutException timeoutException) + { + Console.WriteLine(timeoutException); // temporary for testing } finally { @@ -675,7 +717,7 @@ private bool ValidateServerCertificate(object sender, X509Certificate serverCert SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Certificate will not be validated.", args0: _connectionId); return true; } - + string serverNameToValidate; if (!string.IsNullOrEmpty(_hostNameInCertificate)) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index 0348c227e0..1a3dd06638 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -189,13 +189,17 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re TimeSpan ts = default; // In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count // The infinite Timeout is a function of ConnectionString Timeout=0 - if (long.MaxValue != timerExpire) + bool isInfiniteTimeout = long.MaxValue == timerExpire; + if (!isInfiniteTimeout) { ts = DateTime.FromFileTime(timerExpire) - DateTime.Now; ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts; } - IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(browserHostname); + IPAddress[] ipAddresses = isInfiniteTimeout + ? SNICommon.GetDnsIpAddresses(browserHostname) + : SNICommon.GetDnsIpAddresses(browserHostname, ref ts); + Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve"); IPAddress[] ipv4Addresses = null; IPAddress[] ipv6Addresses = null; From fe44d0e2bf4fdb1e77b0515f4e61bc0e3b72dbea Mon Sep 17 00:00:00 2001 From: Davoud Date: Wed, 2 Aug 2023 14:58:51 -0700 Subject: [PATCH 4/5] Use TimeoutTimer --- .../Microsoft/Data/SqlClient/SNI/SNICommon.cs | 14 ++-- .../Data/SqlClient/SNI/SNINpHandle.cs | 23 ++++-- .../Microsoft/Data/SqlClient/SNI/SNIProxy.cs | 25 ++++--- .../Data/SqlClient/SNI/SNITcpHandle.cs | 74 +++++++------------ .../src/Microsoft/Data/SqlClient/SNI/SSRP.cs | 31 +++----- .../SqlClient/SqlInternalConnectionTds.cs | 2 +- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 9 ++- .../Data/SqlClient/TdsParserStateObject.cs | 3 +- .../SqlClient/TdsParserStateObjectManaged.cs | 5 +- .../SqlClient/TdsParserStateObjectNative.cs | 25 +------ .../SqlClient/SqlInternalConnectionTds.cs | 2 +- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 9 ++- .../Data/SqlClient/TdsParserStateObject.cs | 24 +----- .../Data/ProviderBase/TimeoutTimer.cs | 40 +++++++++- 14 files changed, 136 insertions(+), 150 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs index 3349913335..ad61422ee2 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs @@ -9,6 +9,7 @@ using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { @@ -242,19 +243,20 @@ internal static bool ValidateSslServerCertificate(X509Certificate clientCert, X5 } } - internal static IPAddress[] GetDnsIpAddresses(string serverName, ref TimeSpan timeout) + internal static IPAddress[] GetDnsIpAddresses(string serverName, TimeoutTimer timeout) { using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0} with {1} timeout.", args0: serverName, args1: timeout); - using CancellationTokenSource cts = new CancellationTokenSource(timeout); - Stopwatch stopwatch = Stopwatch.StartNew(); + int remainingTimeout = timeout.MillisecondsRemainingInt; + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, + "Getting DNS host entries for serverName {0} within {1} milliseconds.", + args0: serverName, + args1: remainingTimeout); + using CancellationTokenSource cts = new CancellationTokenSource(remainingTimeout); // using this overload to support netstandard Task task = Dns.GetHostAddressesAsync(serverName); task.ConfigureAwait(false); task.Wait(cts.Token); - timeout -= stopwatch.Elapsed; - stopwatch.Stop(); return task.Result; } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs index b48ea36958..2c3c2aeaf3 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs @@ -10,6 +10,7 @@ using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Threading; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { @@ -37,7 +38,7 @@ internal sealed class SNINpHandle : SNIPhysicalHandle private int _bufferSize = TdsEnums.DEFAULT_LOGIN_PACKET_SIZE; private readonly Guid _connectionId = Guid.NewGuid(); - public SNINpHandle(string serverName, string pipeName, long timerExpire, bool tlsFirst) + public SNINpHandle(string serverName, string pipeName, TimeoutTimer timeout, bool tlsFirst) { using (TrySNIEventScope.Create(nameof(SNINpHandle))) { @@ -54,17 +55,25 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, bool tl PipeDirection.InOut, PipeOptions.Asynchronous | PipeOptions.WriteThrough); - bool isInfiniteTimeOut = long.MaxValue == timerExpire; - if (isInfiniteTimeOut) + if (timeout.IsInfinite) { + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNINpHandle), EventType.INFO, + "Connection Id {0}, Setting server name = {1}, pipe name = {2}. Connecting with infinite timeout.", + args0: _connectionId, + args1: serverName, + args2: pipeName); _pipeStream.Connect(Timeout.Infinite); } else { - TimeSpan ts = DateTime.FromFileTime(timerExpire) - DateTime.Now; - ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts; - - _pipeStream.Connect((int)ts.TotalMilliseconds); + int timeoutMilliseconds = timeout.MillisecondsRemainingInt; + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNINpHandle), EventType.INFO, + "Connection Id {0}, Setting server name = {1}, pipe name = {2}. Connecting within the {3} sepecified milliseconds.", + args0: _connectionId, + args1: serverName, + args2: pipeName, + args3: timeoutMilliseconds); + _pipeStream.Connect(timeoutMilliseconds); } } catch (TimeoutException te) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index 8f101b7bdf..b4b3a37222 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -9,6 +9,7 @@ using System.Net.Security; using System.Net.Sockets; using System.Text; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { @@ -130,7 +131,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) /// Create a SNI connection handle /// /// Full server name from connection string - /// Timer expiration + /// Timer expiration /// Instance name /// SPN /// pre-defined SPN @@ -147,7 +148,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) /// SNI handle internal static SNIHandle CreateConnectionHandle( string fullServerName, - long timerExpire, + TimeoutTimer timeout, out byte[] instanceName, ref byte[][] spnBuffer, string serverSPN, @@ -186,11 +187,11 @@ internal static SNIHandle CreateConnectionHandle( case DataSource.Protocol.Admin: case DataSource.Protocol.None: // default to using tcp if no protocol is provided case DataSource.Protocol.TCP: - sniHandle = CreateTcpHandle(details, timerExpire, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo, + sniHandle = CreateTcpHandle(details, timeout, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); break; case DataSource.Protocol.NP: - sniHandle = CreateNpHandle(details, timerExpire, parallel, tlsFirst); + sniHandle = CreateNpHandle(details, timeout, parallel, tlsFirst); break; default: Debug.Fail($"Unexpected connection protocol: {details._connectionProtocol}"); @@ -279,7 +280,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr /// Creates an SNITCPHandle object /// /// Data source - /// Timer expiration + /// Timer expiration /// Should MultiSubnetFailover be used /// IP address preference /// Key for DNS Cache @@ -290,7 +291,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr /// SNITCPHandle private static SNITCPHandle CreateTcpHandle( DataSource details, - long timerExpire, + TimeoutTimer timeout, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, @@ -317,8 +318,8 @@ private static SNITCPHandle CreateTcpHandle( try { port = isAdminConnection ? - SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference) : - SSRP.GetPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference); + SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference) : + SSRP.GetPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference); } catch (SocketException se) { @@ -335,7 +336,7 @@ private static SNITCPHandle CreateTcpHandle( port = isAdminConnection ? DefaultSqlServerDacPort : DefaultSqlServerPort; } - return new SNITCPHandle(hostName, port, timerExpire, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo, + return new SNITCPHandle(hostName, port, timeout, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); } @@ -343,11 +344,11 @@ private static SNITCPHandle CreateTcpHandle( /// Creates an SNINpHandle object /// /// Data source - /// Timer expiration + /// Timer expiration /// Should MultiSubnetFailover be used. Only returns an error for named pipes. /// /// SNINpHandle - private static SNINpHandle CreateNpHandle(DataSource details, long timerExpire, bool parallel, bool tlsFirst) + private static SNINpHandle CreateNpHandle(DataSource details, TimeoutTimer timeout, bool parallel, bool tlsFirst) { if (parallel) { @@ -355,7 +356,7 @@ private static SNINpHandle CreateNpHandle(DataSource details, long timerExpire, SNICommon.ReportSNIError(SNIProviders.NP_PROV, 0, SNICommon.MultiSubnetFailoverWithNonTcpProtocol, Strings.SNI_ERROR_49); return null; } - return new SNINpHandle(details.PipeHostName, details.PipeName, timerExpire, tlsFirst); + return new SNINpHandle(details.PipeHostName, details.PipeName, timeout, tlsFirst); } /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index de7c589cfa..d12e91ad62 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -16,6 +16,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { @@ -118,7 +119,7 @@ public override int ProtocolVersion /// /// Server name /// TCP port number - /// Connection timer expiration + /// Connection timer expiration /// Parallel executions /// IP address preference /// Key for DNS Cache @@ -129,7 +130,7 @@ public override int ProtocolVersion public SNITCPHandle( string serverName, int port, - long timerExpire, + TimeoutTimer timeout, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, @@ -153,19 +154,6 @@ public SNITCPHandle( try { - TimeSpan ts = default(TimeSpan); - - // In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count - // The infinite Timeout is a function of ConnectionString Timeout=0 - bool isInfiniteTimeOut = long.MaxValue == timerExpire; - if (!isInfiniteTimeOut) - { - ts = DateTime.FromFileTime(timerExpire) - DateTime.Now; - ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts; - } - - Stopwatch stopwatch = Stopwatch.StartNew(); - bool reportError = true; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Connecting to serverName {1} and port {2}", args0: _connectionId, args1: serverName, args2: port); @@ -176,17 +164,16 @@ public SNITCPHandle( { if (parallel) { - _socket = TryConnectParallel(serverName, port, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); + _socket = TryConnectParallel(serverName, port, timeout, ref reportError, cachedFQDN, ref pendingDNSInfo); } else { - _socket = Connect(serverName, port, ts, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo); + _socket = Connect(serverName, port, timeout, ipPreference, cachedFQDN, ref pendingDNSInfo); } } catch (Exception ex) { - TimeSpan timeLeft = ts - stopwatch.Elapsed; - if (!isInfiniteTimeOut && timeLeft <= TimeSpan.Zero) + if (timeout.IsExpired) { throw; } @@ -221,17 +208,16 @@ public SNITCPHandle( { if (parallel) { - _socket = TryConnectParallel(firstCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); + _socket = TryConnectParallel(firstCachedIP, portRetry, timeout, ref reportError, cachedFQDN, ref pendingDNSInfo); } else { - _socket = Connect(firstCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo); + _socket = Connect(firstCachedIP, portRetry, timeout, ipPreference, cachedFQDN, ref pendingDNSInfo); } } catch (Exception exRetry) { - timeLeft = ts - stopwatch.Elapsed; - if (!isInfiniteTimeOut && timeLeft <= TimeSpan.Zero) + if (timeout.IsExpired) { throw; } @@ -241,11 +227,11 @@ public SNITCPHandle( SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Retrying exception {1}", args0: _connectionId, args1: exRetry?.Message); if (parallel) { - _socket = TryConnectParallel(secondCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); + _socket = TryConnectParallel(secondCachedIP, portRetry, timeout, ref reportError, cachedFQDN, ref pendingDNSInfo); } else { - _socket = Connect(secondCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo); + _socket = Connect(secondCachedIP, portRetry, timeout, ipPreference, cachedFQDN, ref pendingDNSInfo); } } else @@ -261,10 +247,6 @@ public SNITCPHandle( throw; } } - finally - { - stopwatch.Stop(); - } if (_socket == null || !_socket.Connected) { @@ -316,15 +298,15 @@ public SNITCPHandle( // Connect to server with hostName and port in parellel mode. // The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point. // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. - private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool isInfiniteTimeOut, ref bool callerReportError, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) + private Socket TryConnectParallel(string hostName, int port, TimeoutTimer timeout, ref bool callerReportError, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { Socket availableSocket = null; Task connectTask; - TimeSpan timeout = ts; + bool isInfiniteTimeOut = timeout.IsInfinite; IPAddress[] serverAddresses = isInfiniteTimeOut ? SNICommon.GetDnsIpAddresses(hostName) - : SNICommon.GetDnsIpAddresses(hostName, ref timeout); + : SNICommon.GetDnsIpAddresses(hostName, timeout); if (serverAddresses.Length > MaxParallelIpAddresses) { @@ -357,7 +339,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i connectTask = ParallelConnectAsync(serverAddresses, port); - if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(timeout))) + if (!(connectTask.Wait(isInfiniteTimeOut ? -1: timeout.MillisecondsRemainingInt))) { callerReportError = false; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "Connection Id {0} Connection timed out, Exception: {1}", args0: _connectionId, args1: Strings.SNI_ERROR_40); @@ -412,11 +394,10 @@ private static IEnumerable GetHostAddressesSortedByPreference(string // Connect to server with hostName and port. // The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point. // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. - private static Socket Connect(string serverName, int port, TimeSpan timeout, bool isInfiniteTimeout, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) + private static Socket Connect(string serverName, int port, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference)); - - Stopwatch timeTaken = Stopwatch.StartNew(); + bool isInfiniteTimeout = timeout.IsInfinite; IEnumerable ipAddresses = GetHostAddressesSortedByPreference(serverName, ipPreference); @@ -451,8 +432,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo } else { - TimeSpan timeLeft = timeout - timeTaken.Elapsed; - if (timeLeft <= TimeSpan.Zero) + if (timeout.IsExpired) { return null; } @@ -460,9 +440,10 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo Task socketConnectTask = new Task(() => socket.Connect(ipAddress, port)); socketConnectTask.ConfigureAwait(false); socketConnectTask.Start(); - if (!socketConnectTask.Wait(timeLeft)) + int remainingTimeout = timeout.MillisecondsRemainingInt; + if (!socketConnectTask.Wait(remainingTimeout)) { - throw ADP.TimeoutException($"The socket couldn't connect during the expected {timeLeft} remaining time to connect."); + throw ADP.TimeoutException($"The socket couldn't connect during the expected {remainingTimeout} remaining time."); } throw SQL.SocketDidNotThrow(); } @@ -486,19 +467,22 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo // which states that Socket.Select can't handle timeouts greater than int.MaxValue microseconds do { - TimeSpan timeLeft = timeout - timeTaken.Elapsed; - if (!isInfiniteTimeout && timeLeft <= TimeSpan.Zero) + if (timeout.IsExpired) { return null; } int socketSelectTimeout = - checked((int)(Math.Min(timeLeft.TotalMilliseconds, int.MaxValue / 1000) * 1000)); + checked((int)(Math.Min(timeout.MillisecondsRemainingInt, int.MaxValue / 1000) * 1000)); checkReadLst = new List(1) { socket }; checkWriteLst = new List(1) { socket }; checkErrorLst = new List(1) { socket }; + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, + "Determining the status of the socket during the remaining timeout of {0} microseconds.", + socketSelectTimeout); + Socket.Select(checkReadLst, checkWriteLst, checkErrorLst, socketSelectTimeout); // nothing selected means timeout } while (checkReadLst.Count == 0 && checkWriteLst.Count == 0 && checkErrorLst.Count == 0); @@ -531,10 +515,6 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo SqlClientEventSource.Log.TryAdvancedTraceEvent( $"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {socketException}"); } - catch (AggregateException aggregateException) when (aggregateException.InnerException is TimeoutException timeoutException) - { - Console.WriteLine(timeoutException); // temporary for testing - } finally { if (!isSocketSelected) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index 1a3dd06638..3cad605caa 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -10,6 +10,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { @@ -29,11 +30,11 @@ internal sealed class SSRP /// /// SQL Sever Browser hostname /// instance name to find port number - /// Connection timer expiration + /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference /// port number for given instance name - internal static int GetPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + internal static int GetPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) { Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace"); Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace"); @@ -43,7 +44,7 @@ internal static int GetPortByInstanceName(string browserHostName, string instanc byte[] responsePacket = null; try { - responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timerExpire, allIPsInParallel, ipPreference); + responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timeout, allIPsInParallel, ipPreference); } catch (SocketException se) { @@ -104,17 +105,17 @@ private static byte[] CreateInstanceInfoRequest(string instanceName) /// /// SQL Sever Browser hostname /// instance name to lookup DAC port - /// Connection timer expiration + /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference /// DAC port for given instance name - internal static int GetDacPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + internal static int GetDacPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) { Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace"); Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace"); byte[] dacPortInfoRequest = CreateDacPortInfoRequest(instanceName); - byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timerExpire, allIPsInParallel, ipPreference); + byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timeout, allIPsInParallel, ipPreference); const byte SvrResp = 0x05; const byte ProtocolVersion = 0x01; @@ -163,11 +164,11 @@ private class SsrpResult /// UDP server hostname /// UDP server port /// request packet - /// Connection timer expiration + /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference /// response packet from UDP server - private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) { using (TrySNIEventScope.Create(nameof(SSRP))) { @@ -186,19 +187,9 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re return null; } - TimeSpan ts = default; - // In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count - // The infinite Timeout is a function of ConnectionString Timeout=0 - bool isInfiniteTimeout = long.MaxValue == timerExpire; - if (!isInfiniteTimeout) - { - ts = DateTime.FromFileTime(timerExpire) - DateTime.Now; - ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts; - } - - IPAddress[] ipAddresses = isInfiniteTimeout + IPAddress[] ipAddresses = timeout.IsInfinite ? SNICommon.GetDnsIpAddresses(browserHostname) - : SNICommon.GetDnsIpAddresses(browserHostname, ref ts); + : SNICommon.GetDnsIpAddresses(browserHostname, timeout); Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve"); IPAddress[] ipv4Addresses = null; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 0b745ae01e..6a0ee2e0e0 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -1918,7 +1918,7 @@ private void AttemptOneLogin( _parser.Connect(serverInfo, this, - timeout.LegacyTimerExpire, + timeout, ConnectionOptions, withFailover); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 8c16690149..ca2e9bc88a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -16,6 +16,7 @@ using System.Threading.Tasks; using System.Xml; using Microsoft.Data.Common; +using Microsoft.Data.ProviderBase; using Microsoft.Data.Sql; using Microsoft.Data.SqlClient.DataClassification; using Microsoft.Data.SqlClient.Server; @@ -361,7 +362,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) internal void Connect( ServerInfo serverInfo, SqlInternalConnectionTds connHandler, - long timerExpire, + TimeoutTimer timeout, SqlConnectionString connectionOptions, bool withFailover) { @@ -444,7 +445,7 @@ internal void Connect( // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server _physicalStateObj.CreatePhysicalSNIHandle( serverInfo.ExtendedServerName, - timerExpire, + timeout, out instanceName, ref _sniSpnBuffer, false, @@ -487,7 +488,7 @@ internal void Connect( } _state = TdsParserState.OpenNotLoggedIn; _physicalStateObj.SniContext = SniContext.Snix_PreLoginBeforeSuccessfulWrite; - _physicalStateObj.TimeoutTime = timerExpire; + _physicalStateObj.TimeoutTime = timeout.LegacyTimerExpire; bool marsCapable = false; @@ -542,7 +543,7 @@ internal void Connect( _physicalStateObj.CreatePhysicalSNIHandle( serverInfo.ExtendedServerName, - timerExpire, out instanceName, + timeout, out instanceName, ref _sniSpnBuffer, true, true, diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 1f15345167..303ee41134 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient { @@ -198,7 +199,7 @@ private void ResetCancelAndProcessAttention() internal abstract void CreatePhysicalSNIHandle( string serverName, - long timerExpire, + TimeoutTimer timeout, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 9665d8f188..1e0141dd58 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -13,6 +13,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { @@ -82,7 +83,7 @@ protected override uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref internal override void CreatePhysicalSNIHandle( string serverName, - long timerExpire, + TimeoutTimer timeout, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, @@ -97,7 +98,7 @@ internal override void CreatePhysicalSNIHandle( string hostNameInCertificate, string serverCertificateFilename) { - SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timerExpire, out instanceName, ref spnBuffer, serverSPN, + SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN, flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index bf8337cacb..59776956a1 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -11,6 +11,7 @@ using Microsoft.Data.Common; using System.Net; using System.Text; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient { @@ -140,7 +141,7 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) internal override void CreatePhysicalSNIHandle( string serverName, - long timerExpire, + TimeoutTimer timeout, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, @@ -175,30 +176,10 @@ internal override void CreatePhysicalSNIHandle( } SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async); - - // Translate to SNI timeout values (Int32 milliseconds) - long timeout; - if (long.MaxValue == timerExpire) - { - timeout = int.MaxValue; - } - else - { - timeout = ADP.TimerRemainingMilliseconds(timerExpire); - if (timeout > int.MaxValue) - { - timeout = int.MaxValue; - } - else if (0 > timeout) - { - timeout = 0; - } - } - SQLDNSInfo cachedDNSInfo; bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo); - _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], checked((int)timeout), out instanceName, + _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName, flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate); } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index aa6a670021..068b37dc71 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -2300,7 +2300,7 @@ private void AttemptOneLogin(ServerInfo serverInfo, string newPassword, SecureSt _parser.Connect(serverInfo, this, - timeout.LegacyTimerExpire, + timeout, ConnectionOptions, withFailover, isFirstTransparentAttempt, diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 21abbab757..2bdca1a9e0 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -24,6 +24,7 @@ using Microsoft.Data.SqlClient.Server; using Microsoft.Data.SqlTypes; using Microsoft.SqlServer.Server; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient { @@ -494,7 +495,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) internal void Connect(ServerInfo serverInfo, SqlInternalConnectionTds connHandler, - long timerExpire, + TimeoutTimer timeout, SqlConnectionString connectionOptions, bool withFailover, bool isFirstTransparentAttempt, @@ -639,7 +640,7 @@ internal void Connect(ServerInfo serverInfo, _physicalStateObj.CreatePhysicalSNIHandle( serverInfo.ExtendedServerName, - timerExpire, + timeout, out instanceName, _sniSpnBuffer, false, @@ -679,7 +680,7 @@ internal void Connect(ServerInfo serverInfo, } _state = TdsParserState.OpenNotLoggedIn; _physicalStateObj.SniContext = SniContext.Snix_PreLoginBeforeSuccessfulWrite; // SQL BU DT 376766 - _physicalStateObj.TimeoutTime = timerExpire; + _physicalStateObj.TimeoutTime = timeout.LegacyTimerExpire; bool marsCapable = false; @@ -744,7 +745,7 @@ internal void Connect(ServerInfo serverInfo, _physicalStateObj.SniContext = SniContext.Snix_Connect; _physicalStateObj.CreatePhysicalSNIHandle( serverInfo.ExtendedServerName, - timerExpire, + timeout, out instanceName, _sniSpnBuffer, true, diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 5d5ade91f4..1c68ad8a54 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -12,6 +12,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient { @@ -279,7 +280,7 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) internal void CreatePhysicalSNIHandle( string serverName, - long timerExpire, + TimeoutTimer timeout, out byte[] instanceName, byte[] spnBuffer, bool flushCache, @@ -293,31 +294,12 @@ internal void CreatePhysicalSNIHandle( { SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async); - // Translate to SNI timeout values (Int32 milliseconds) - long timeout; - if (long.MaxValue == timerExpire) - { - timeout = int.MaxValue; - } - else - { - timeout = ADP.TimerRemainingMilliseconds(timerExpire); - if (timeout > int.MaxValue) - { - timeout = int.MaxValue; - } - else if (0 > timeout) - { - timeout = 0; - } - } - // serverName : serverInfo.ExtendedServerName // may not use this serverName as key _ = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out SQLDNSInfo cachedDNSInfo); - _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, checked((int)timeout), + _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, timeout.MillisecondsRemainingInt, out instanceName, flushCache, !async, fParallel, transparentNetworkResolutionState, totalTimeout, ipPreference, cachedDNSInfo, hostNameInCertificate); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/ProviderBase/TimeoutTimer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/ProviderBase/TimeoutTimer.cs index 9948b223d1..37c94fe355 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/ProviderBase/TimeoutTimer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/ProviderBase/TimeoutTimer.cs @@ -138,7 +138,7 @@ internal bool IsInfinite } // Special accessor for TimerExpire for use when thunking to legacy timeout methods. - internal long LegacyTimerExpire + public long LegacyTimerExpire { get { @@ -180,6 +180,42 @@ internal long MillisecondsRemaining return milliseconds; } } + + // Returns milliseconds remaining trimmed to zero for none remaining + internal int MillisecondsRemainingInt + { + get + { + //------------------- + // Method Body + int milliseconds; + if (_isInfiniteTimeout) + { + milliseconds = int.MaxValue; + } + else + { + long longMilliseconds = ADP.TimerRemainingMilliseconds(_timerExpire); + if (0 > longMilliseconds) + { + milliseconds = 0; + } + else if (longMilliseconds > int.MaxValue) + { + milliseconds = int.MaxValue; + } + else + { + milliseconds = checked((int)longMilliseconds); + } + } + + //-------------------- + // Postconditions + Debug.Assert(0 <= milliseconds); + + return milliseconds; + } + } } } - From 71a627887fc9b0f67df81040d74f33169105b1e5 Mon Sep 17 00:00:00 2001 From: DavoudEshtehari <61173489+DavoudEshtehari@users.noreply.github.com> Date: Wed, 2 Aug 2023 22:14:00 -0700 Subject: [PATCH 5/5] Apply suggestions from code review Co-authored-by: David Engel --- .../tests/FunctionalTests/SqlConnectionBasicTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs index 03a02fd3fe..01c2f2c050 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs @@ -285,7 +285,7 @@ public void ConnectionTimeoutTest(int timeout) ex = e; } - Assert.False(timer.IsRunning, "Timer must stopped."); + Assert.False(timer.IsRunning, "Timer must be stopped."); Assert.NotNull(ex); Assert.True(timer.Elapsed.TotalSeconds <= timeout + 3, $"The actual timeout {timer.Elapsed.TotalSeconds} is expected to be less than {timeout} plus 3 seconds additional threshold." + @@ -326,7 +326,7 @@ public async void ConnectionTimeoutTestAsync(int timeout) ex = e; } - Assert.False(timer.IsRunning, "Timer must stopped."); + Assert.False(timer.IsRunning, "Timer must be stopped."); Assert.NotNull(ex); Assert.True(timer.Elapsed.TotalSeconds <= timeout + 3, $"The actual timeout {timer.Elapsed.TotalSeconds} is expected to be less than {timeout} plus 3 seconds additional threshold." +