Skip to content

Commit

Permalink
SqlClient-826 Missed synchronization (dotnet#1029)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinek authored and kant2002 committed Jun 29, 2023
1 parent 0524837 commit 3e324e9
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 116 deletions.
Expand Up @@ -4,11 +4,20 @@

namespace System.Net
{
[Serializable]
internal class InternalException : Exception
{
internal InternalException()
public InternalException() : this("InternalException thrown.")
{
NetEventSource.Fail(this, "InternalException thrown.");
}

public InternalException(string message) : this(message, null)
{
}

public InternalException(string message, Exception innerException) : base(message, innerException)
{
NetEventSource.Fail(this, message);
}
}
}
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.Security;
Expand All @@ -14,6 +15,7 @@
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.Common;

namespace Microsoft.Data.SqlClient.SNI
{
Expand Down Expand Up @@ -347,149 +349,158 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
availableSocket = connectTask.Result;
return availableSocket;
}

/// <summary>
/// Returns array of IP addresses for the given server name, sorted according to the given preference.
/// </summary>
/// <exception cref="ArgumentOutOfRangeException">Thrown when ipPreference is not supported</exception>
private static IEnumerable<IPAddress> GetHostAddressesSortedByPreference(string serverName, SqlConnectionIPAddressPreference ipPreference)
{
IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName);
AddressFamily? prioritiesFamily = ipPreference switch
{
SqlConnectionIPAddressPreference.IPv4First => AddressFamily.InterNetwork,
SqlConnectionIPAddressPreference.IPv6First => AddressFamily.InterNetworkV6,
SqlConnectionIPAddressPreference.UsePlatformDefault => null,
_ => throw ADP.NotSupportedEnumerationValue(typeof(SqlConnectionIPAddressPreference), ipPreference.ToString(), nameof(GetHostAddressesSortedByPreference))
};

// Return addresses of the preferred family first
if (prioritiesFamily != null)
{
foreach (IPAddress ipAddress in ipAddresses)
{
if (ipAddress.AddressFamily == prioritiesFamily)
{
yield return ipAddress;
}
}
}

// Return addresses of the other family
foreach (IPAddress ipAddress in ipAddresses)
{
if (ipAddress.AddressFamily is AddressFamily.InterNetwork or AddressFamily.InterNetworkV6)
{
if (prioritiesFamily == null || ipAddress.AddressFamily != prioritiesFamily)
{
yield return ipAddress;
}
}
}
}

// Connect to server with hostName and port.
// The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point.
// Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server.
private static Socket Connect(string serverName, int port, TimeSpan timeout, bool isInfiniteTimeout, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference));

IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(serverName);
Stopwatch timeTaken = Stopwatch.StartNew();

string IPv4String = null;
string IPv6String = null;
IEnumerable<IPAddress> ipAddresses = GetHostAddressesSortedByPreference(serverName, ipPreference);

// Returning null socket is handled by the caller function.
if (ipAddresses == null || ipAddresses.Length == 0)
foreach (IPAddress ipAddress in ipAddresses)
{
return null;
}

Socket[] sockets = new Socket[ipAddresses.Length];
AddressFamily[] preferedIPFamilies = new AddressFamily[2];
bool isSocketSelected = false;
Socket socket = null;

if (ipPreference == SqlConnectionIPAddressPreference.IPv4First)
{
preferedIPFamilies[0] = AddressFamily.InterNetwork;
preferedIPFamilies[1] = AddressFamily.InterNetworkV6;
}
else if (ipPreference == SqlConnectionIPAddressPreference.IPv6First)
{
preferedIPFamilies[0] = AddressFamily.InterNetworkV6;
preferedIPFamilies[1] = AddressFamily.InterNetwork;
}
// else -> UsePlatformDefault

CancellationTokenSource cts = null;

if (!isInfiniteTimeout)
{
cts = new CancellationTokenSource(timeout);
cts.Token.Register(Cancel);
}

Socket availableSocket = null;
try
{
// We go through the IP list twice.
// In the first traversal, we only try to connect with the preferedIPFamilies[0].
// In the second traversal, we only try to connect with the preferedIPFamilies[1].
// For UsePlatformDefault preference, we do traversal once.
for (int i = 0; i < preferedIPFamilies.Length; ++i)
try
{
for (int n = 0; n < ipAddresses.Length; n++)
socket = new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp)
{
Blocking = isInfiniteTimeout
};

// enable keep-alive on socket
SetKeepAliveValues(ref socket);

SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO,
"Connecting to IP address {0} and port {1} using {2} address family. Is infinite timeout: {3}",
ipAddress,
port,
ipAddress.AddressFamily,
isInfiniteTimeout);

bool isConnected;
try // catching SocketException with SocketErrorCode == WouldBlock to run Socket.Select
{
IPAddress ipAddress = ipAddresses[n];
try
socket.Connect(ipAddress, port);
if (!isInfiniteTimeout)
{
if (ipAddress != null)
{
if (ipAddress.AddressFamily != preferedIPFamilies[i] && ipPreference != SqlConnectionIPAddressPreference.UsePlatformDefault)
{
continue;
}
throw SQL.SocketDidNotThrow();
}

isConnected = true;
}
catch (SocketException socketException) when (!isInfiniteTimeout &&
socketException.SocketErrorCode ==
SocketError.WouldBlock)
{
// https://github.com/dotnet/SqlClient/issues/826#issuecomment-736224118
// Socket.Select is used because it supports timeouts, while Socket.Connect does not

sockets[n] = new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
List<Socket> checkReadLst; List<Socket> checkWriteLst; List<Socket> checkErrorLst;

// enable keep-alive on socket
SetKeepAliveValues(ref sockets[n]);
// Repeating Socket.Select several times if our timeout is greater
// than int.MaxValue microseconds because of
// https://github.com/dotnet/SqlClient/pull/1029#issuecomment-875364044
// which states that Socket.Select can't handle timeouts greater than int.MaxValue microseconds
do
{
TimeSpan timeLeft = timeout - timeTaken.Elapsed;

SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connecting to IP address {0} and port {1} using {2} address family.",
args0: ipAddress,
args1: port,
args2: ipAddress.AddressFamily);
sockets[n].Connect(ipAddress, port);
if (sockets[n] != null) // sockets[n] can be null if cancel callback is executed during connect()
{
if (sockets[n].Connected)
{
availableSocket = sockets[n];
if (ipAddress.AddressFamily == AddressFamily.InterNetwork)
{
IPv4String = ipAddress.ToString();
}
else if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6)
{
IPv6String = ipAddress.ToString();
}
if (timeLeft <= TimeSpan.Zero)
return null;

break;
}
else
{
sockets[n].Dispose();
sockets[n] = null;
}
}
}
}
catch (Exception e)
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message);
SqlClientEventSource.Log.TryAdvancedTraceEvent($"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {e}");
}
}
int socketSelectTimeout =
checked((int)(Math.Min(timeLeft.TotalMilliseconds, int.MaxValue / 1000) * 1000));

// If we have already got a valid Socket, or the platform default was prefered
// we won't do the second traversal.
if (availableSocket is not null || ipPreference == SqlConnectionIPAddressPreference.UsePlatformDefault)
{
break;
}
}
}
finally
{
cts?.Dispose();
}
checkReadLst = new List<Socket>(1) { socket };
checkWriteLst = new List<Socket>(1) { socket };
checkErrorLst = new List<Socket>(1) { socket };

// we only record the ip we can connect with successfully.
if (IPv4String != null || IPv6String != null)
{
pendingDNSInfo = new SQLDNSInfo(cachedFQDN, IPv4String, IPv6String, port.ToString());
}
Socket.Select(checkReadLst, checkWriteLst, checkErrorLst, socketSelectTimeout);
// nothing selected means timeout
} while (checkReadLst.Count == 0 && checkWriteLst.Count == 0 && checkErrorLst.Count == 0);

return availableSocket;
// workaround: false positive socket.Connected on linux: https://github.com/dotnet/runtime/issues/55538
isConnected = socket.Connected && checkErrorLst.Count == 0;
}

void Cancel()
{
for (int i = 0; i < sockets.Length; ++i)
{
try
if (isConnected)
{
if (sockets[i] != null && !sockets[i].Connected)
socket.Blocking = true;
string iPv4String = null;
string iPv6String = null;
if (socket.AddressFamily == AddressFamily.InterNetwork)
{
sockets[i].Dispose();
sockets[i] = null;
iPv4String = ipAddress.ToString();
}
else
{
iPv6String = ipAddress.ToString();
}
pendingDNSInfo = new SQLDNSInfo(cachedFQDN, iPv4String, iPv6String, port.ToString());
isSocketSelected = true;
return socket;
}
catch (Exception e)
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message);
}
}
catch (SocketException e)
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message);
SqlClientEventSource.Log.TryAdvancedTraceEvent(
$"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {e}");
}
finally
{
if (!isSocketSelected)
socket?.Dispose();
}
}

return null;
}

private static Task<Socket> ParallelConnectAsync(IPAddress[] serverAddresses, int port)
Expand Down
Expand Up @@ -8,6 +8,8 @@
using System.Data;
using System.Diagnostics;
using System.Globalization;
using System.Net;
using System.Net.Sockets;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
Expand Down Expand Up @@ -377,6 +379,10 @@ internal static Exception SynchronousCallMayNotPend()
{
return new Exception(StringsHelper.GetString(Strings.Sql_InternalError));
}
internal static Exception SocketDidNotThrow()
{
return new InternalException(StringsHelper.GetString(Strings.SQL_SocketDidNotThrow, nameof(SocketException), nameof(SocketError.WouldBlock)));
}
internal static Exception ConnectionLockedForBcpEvent()
{
return ADP.InvalidOperation(StringsHelper.GetString(Strings.SQL_ConnectionLockedForBcpEvent));
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -1941,4 +1941,7 @@
<data name="SQL_TDS8_NotSupported_Netstandard2.0" xml:space="preserve">
<value>Encrypt=Strict is not supported when targeting .NET Standard 2.0. Use .NET Standard 2.1, .NET Framework, or .NET.</value>
</data>
<data name="SQL_SocketDidNotThrow" xml:space="preserve">
<value>Socket did not throw expected '{0}' with error code '{1}'.</value>
</data>
</root>

0 comments on commit 3e324e9

Please sign in to comment.