Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix | Fix unit test for SPN to include port number with Managed SNI #2281

Merged
merged 21 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9c776f3
Remove ActiveIssue to test pipeline.
arellegue Dec 18, 2023
f84884c
Fix SPN port number Unit Test to use TCP and NP connection strings an…
arellegue Dec 20, 2023
36c235f
Try adding retry and delay in creating and using the connection objec…
arellegue Dec 20, 2023
0caa4ce
Added IsUsingManagedSNI annotation since reflection is using ManagedS…
arellegue Dec 20, 2023
c47a16e
Add wrapper for all required annotations for the unit test as it is t…
arellegue Dec 20, 2023
64c0103
Removed all Sleeps as they are not needed.
arellegue Dec 20, 2023
af685d3
Add annotation to skip .net framework for SPN port unit test.
arellegue Dec 20, 2023
6076e8e
Put [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework)] so …
arellegue Dec 20, 2023
bea5513
Removed SkipOnTargetFramework as it did not stop the unit test from r…
arellegue Dec 20, 2023
fa556f5
Add validation of InstanceName within the ConditionalThreory annotation.
arellegue Dec 21, 2023
6ace809
Removed and sorted using references.
arellegue Dec 21, 2023
660d992
Fix ParseDataSource named pipe data source detection to be case insen…
arellegue Dec 22, 2023
192a1e2
Applied PR review suggesstions.
arellegue Jan 5, 2024
8a93097
Moved named pipe protocol port usage assert message to line 138 and u…
arellegue Jan 11, 2024
00d9a3c
Removed all source codes that handle named pipe protocol.
arellegue Jan 19, 2024
41366f6
Removed unwanted comments as they cause code bloat.
arellegue Jan 24, 2024
baa3ab9
Add SPN pattern validation and use Assert.Equal to compare expected p…
arellegue Jan 24, 2024
53a13e6
Merge branch 'main' into FixUnitTestForSPNPortNumber
arellegue Feb 1, 2024
2bc3f7c
Added validation for DataSource.
arellegue Feb 1, 2024
c131562
Added assertion if DataSource is valid.
arellegue Feb 1, 2024
0c5ed6f
Replace If (port > 0) with Assert( port > 0, ...)
arellegue Feb 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -653,10 +653,10 @@ private bool InferConnectionDetails()

Port = port;
}
// Instance Name Handling. Only if we found a '\' and we did not find a port in the Data Source
else if (backSlashIndex > -1)
// Instance Name Handling.
if (backSlashIndex > -1)
{
// This means that there will not be any part separated by comma.
// This means that there is a part separated by '\'
InstanceName = tokensByCommaAndSlash[1].Trim();

if (string.IsNullOrWhiteSpace(InstanceName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -979,9 +979,6 @@ public static bool ParseDataSource(string dataSource, out string hostname, out i
port = -1;
instanceName = string.Empty;

if (dataSource.Contains(",") && dataSource.Contains("\\"))
return false;

if (dataSource.Contains(":"))
{
dataSource = dataSource.Substring(dataSource.IndexOf(":", StringComparison.Ordinal) + 1);
Expand All @@ -993,7 +990,8 @@ public static bool ParseDataSource(string dataSource, out string hostname, out i
{
return false;
}
dataSource = dataSource.Substring(0, dataSource.IndexOf(",", StringComparison.Ordinal) - 1);
// IndexOf is zero-based, no need to subtract one
dataSource = dataSource.Substring(0, dataSource.IndexOf(",", StringComparison.Ordinal));
arellegue marked this conversation as resolved.
Show resolved Hide resolved
}

if (dataSource.Contains("\\"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
public static class InstanceNameTest
{
private const char SemicolonSeparator = ';';

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.AreConnStringsSetup))]
public static void ConnectToSQLWithInstanceNameTest()
{
Expand Down Expand Up @@ -84,138 +86,135 @@ public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailove
}
}

// Note: This Unit test was tested in a domain-joined VM connecting to a remote
// SQL Server using Kerberos in the same domain.
[ActiveIssue("27824")] // When specifying instance name and port number, this method call always returns false
[ConditionalFact(nameof(IsKerberos))]
public static void PortNumberInSPNTest()
#if NETCOREAPP
[ConditionalFact(nameof(IsSPNPortNumberTestForTCP))]
public static void PortNumberInSPNTestForTCP()
{
string connectionString = DataTestUtility.TCPConnectionString;
SqlConnectionStringBuilder builder = new(connectionString);

int port = GetNamedInstancePortNumberFromSqlBrowser(connectionString);
Assert.True(port > 0, "Named instance must have a valid port number.");
builder.DataSource = $"{builder.DataSource},{port}";

PortNumberInSPNTest(builder.ConnectionString, port);
}
#endif

private static void PortNumberInSPNTest(string connectionString, int expectedPortNumber)
{
string connStr = DataTestUtility.TCPConnectionString;
// If config.json.SupportsIntegratedSecurity = true, replace all keys defined below with Integrated Security=true
if (DataTestUtility.IsIntegratedSecuritySetup())
{
string[] removeKeys = { "Authentication", "User ID", "Password", "UID", "PWD", "Trusted_Connection" };
connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.TCPConnectionString, removeKeys) + $"Integrated Security=true";
connectionString = DataTestUtility.RemoveKeysInConnStr(connectionString, removeKeys) + $"Integrated Security=true";
}

SqlConnectionStringBuilder builder = new(connStr);
SqlConnectionStringBuilder builder = new(connectionString);

string hostname = "";
string instanceName = "";

Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName), "Data source to be parsed must contain a host name and instance name");
DataTestUtility.ParseDataSource(builder.DataSource, out hostname, out _, out instanceName);

bool condition = IsBrowserAlive(hostname) && IsValidInstance(hostname, instanceName);
Assert.True(condition, "Browser service is not running or instance name is invalid");
Assert.False(string.IsNullOrEmpty(hostname), "Hostname must be included in the data source.");
Assert.False(string.IsNullOrEmpty(instanceName), "Instance name must be included in the data source.");

if (condition)
using (SqlConnection connection = new(builder.ConnectionString))
{
using SqlConnection connection = new(builder.ConnectionString);
connection.Open();
using SqlCommand command = new("SELECT auth_scheme, local_tcp_port from sys.dm_exec_connections where session_id = @@spid", connection);
using SqlDataReader reader = command.ExecuteReader();
Assert.True(reader.Read(), "Expected to receive one row data");
Assert.Equal("KERBEROS", reader.GetString(0));
int localTcpPort = reader.GetInt32(1);

int spnPort = -1;
string spnInfo = GetSPNInfo(builder.DataSource, out spnPort);

// sample output to validate = MSSQLSvc/machine.domain.tld:spnPort"
Assert.Contains($"MSSQLSvc/{hostname}", spnInfo);
// the local_tcp_port should be the same as the inferred SPN port from instance name
Assert.Equal(localTcpPort, spnPort);

string spnInfo = GetSPNInfo(builder.DataSource);
Assert.Matches(@"MSSQLSvc\/.*:[\d]", spnInfo);

string[] spnStrs = spnInfo.Split(':');
int portInSPN = 0;
if (spnStrs.Length > 1)
{
int.TryParse(spnStrs[1], out portInSPN);
}
Assert.Equal(expectedPortNumber, portInSPN);
}
}

private static string GetSPNInfo(string datasource, out int out_port)
private static string GetSPNInfo(string dataSource)
{
Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection));

// Get all required types using reflection
Type sniProxyType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SNIProxy");
Type ssrpType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SSRP");
Type dataSourceType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.DataSource");
Type timeoutTimerType = sqlConnectionAssembly.GetType("Microsoft.Data.ProviderBase.TimeoutTimer");

// Used in Datasource constructor param type array
Type[] dataSourceConstructorTypesArray = new Type[] { typeof(string) };

// Used in GetSqlServerSPNs function param types array
Type[] getSqlServerSPNsTypesArray = new Type[] { dataSourceType, typeof(string) };

// GetPortByInstanceName parameters array
Type[] getPortByInstanceNameTypesArray = new Type[] { typeof(string), typeof(string), timeoutTimerType, typeof(bool), typeof(Microsoft.Data.SqlClient.SqlConnectionIPAddressPreference) };

// TimeoutTimer.StartSecondsTimeout params
Type[] startSecondsTimeoutTypesArray = new Type[] { typeof(int) };

// Get all types constructors
ConstructorInfo sniProxyCtor = sniProxyType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo SSRPCtor = ssrpType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo dataSourceCtor = dataSourceType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
ConstructorInfo timeoutTimerCtor = timeoutTimerType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo sniProxyConstructor = sniProxyType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo SSRPConstructor = ssrpType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo dataSourceConstructor = dataSourceType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
ConstructorInfo timeoutTimerConstructor = timeoutTimerType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);

// Instantiate SNIProxy
object sniProxy = sniProxyCtor.Invoke(new object[] { });
object sniProxyObj = sniProxyConstructor.Invoke(new object[] { });

// Instantiate datasource
object dataSourceObj = dataSourceCtor.Invoke(new object[] { datasource });
object dataSourceObj = dataSourceConstructor.Invoke(new object[] { dataSource });

// Instantiate SSRP
object ssrp = SSRPCtor.Invoke(new object[] { });
object ssrpObj = SSRPConstructor.Invoke(new object[] { });

// Instantiate TimeoutTimer
object timeoutTimer = timeoutTimerCtor.Invoke(new object[] { });
object timeoutTimerObj = timeoutTimerConstructor.Invoke(new object[] { });

// Get TimeoutTimer.StartSecondsTimeout Method
MethodInfo startSecondsTimeout = timeoutTimer.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, startSecondsTimeoutTypesArray, null);
// Create a timeoutTimer that expires in 30 seconds
timeoutTimer = startSecondsTimeout.Invoke(dataSourceObj, new object[] { 30 });
MethodInfo startSecondsTimeoutInfo = timeoutTimerObj.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, startSecondsTimeoutTypesArray, null);

// Parse the datasource to separate the server name and instance name
MethodInfo ParseServerName = dataSourceObj.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
object dataSrcInfo = ParseServerName.Invoke(dataSourceObj, new object[] { datasource });
timeoutTimerObj = startSecondsTimeoutInfo.Invoke(dataSourceObj, new object[] { 30 });

// Get the GetPortByInstanceName method of SSRP
MethodInfo getPortByInstanceName = ssrp.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getPortByInstanceNameTypesArray, null);
MethodInfo parseServerNameInfo = dataSourceObj.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
object dataSrcInfo = parseServerNameInfo.Invoke(dataSourceObj, new object[] { dataSource });

MethodInfo getPortByInstanceNameInfo = ssrpObj.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getPortByInstanceNameTypesArray, null);

// Get the server name
PropertyInfo serverInfo = dataSrcInfo.GetType().GetProperty("ServerName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
string serverName = serverInfo.GetValue(dataSrcInfo, null).ToString();

// Get the instance name
PropertyInfo instanceNameInfo = dataSrcInfo.GetType().GetProperty("InstanceName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
string instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString();

// Get the port number using the GetPortByInstanceName method of SSRP
object port = getPortByInstanceName.Invoke(ssrp, parameters: new object[] { serverName, instanceName, timeoutTimer, false, 0 });
object port = getPortByInstanceNameInfo.Invoke(ssrpObj, parameters: new object[] { serverName, instanceName, timeoutTimerObj, false, 0 });

// Set the resolved port property of datasource
PropertyInfo resolvedPortInfo = dataSrcInfo.GetType().GetProperty("ResolvedPort", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
resolvedPortInfo.SetValue(dataSrcInfo, (int)port, null);

// Prepare the GetSqlServerSPNs method
string serverSPN = "";
MethodInfo getSqlServerSPNs = sniProxy.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null);
MethodInfo getSqlServerSPNs = sniProxyObj.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null);

// Finally call GetSqlServerSPNs
byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxy, new object[] { dataSrcInfo, serverSPN });
byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxyObj, new object[] { dataSrcInfo, serverSPN });

// Example result: MSSQLSvc/machine.domain.tld:port"
string spnInfo = Encoding.Unicode.GetString(result[0]);

out_port = (int)port;

return spnInfo;
}

private static bool IsKerberos()
private static bool IsSPNPortNumberTestForTCP()
{
return (DataTestUtility.AreConnStringsSetup()
&& DataTestUtility.IsNotLocalhost()
&& DataTestUtility.IsKerberosTest
&& DataTestUtility.IsNotAzureServer()
return (IsInstanceNameValid(DataTestUtility.TCPConnectionString)
&& DataTestUtility.IsUsingManagedSNI()
&& DataTestUtility.IsNotAzureServer()
&& DataTestUtility.IsNotAzureSynapse());
}

private static bool IsInstanceNameValid(string connectionString)
{
string instanceName = "";

SqlConnectionStringBuilder builder = new(connectionString);

bool isDataSourceValid = DataTestUtility.ParseDataSource(builder.DataSource, out _, out _, out instanceName);

arellegue marked this conversation as resolved.
Show resolved Hide resolved
return isDataSourceValid && !string.IsNullOrWhiteSpace(instanceName);
}

private static bool IsBrowserAlive(string browserHostname)
{
const byte ClntUcastEx = 0x03;
Expand All @@ -231,6 +230,43 @@ private static bool IsValidInstance(string browserHostName, string instanceName)
return response != null && response.Length > 0;
}

private static int GetNamedInstancePortNumberFromSqlBrowser(string connectionString)
{
SqlConnectionStringBuilder builder = new(connectionString);

string hostname = "";
string instanceName = "";
int port = 0;

bool isDataSourceValid = DataTestUtility.ParseDataSource(builder.DataSource, out hostname, out _, out instanceName);
Assert.True(isDataSourceValid, "DataSource is invalid");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Optional] - Not a deal breaker for me. But this is missing a period for consistency.


bool isBrowserRunning = IsBrowserAlive(hostname);
Assert.True(isBrowserRunning, "Browser service is not running.");

bool isInstanceExisting = IsValidInstance(hostname, instanceName);
Assert.True(isInstanceExisting, "Instance name is invalid.");

if (isDataSourceValid && isBrowserRunning && isInstanceExisting)
{
byte[] request = CreateInstanceInfoRequest(instanceName);
byte[] response = QueryBrowser(hostname, request);

string serverMessage = Encoding.ASCII.GetString(response, 3, response.Length - 3);

string[] elements = serverMessage.Split(SemicolonSeparator);
int tcpIndex = Array.IndexOf(elements, "tcp");
if (tcpIndex < 0 || tcpIndex == elements.Length - 1)
{
throw new SocketException();
}

port = (int)ushort.Parse(elements[tcpIndex + 1]);
}

return port;
}

private static byte[] QueryBrowser(string browserHostname, byte[] requestPacket)
{
const int DefaultBrowserPort = 1434;
Expand Down
Loading