Skip to content

Commit

Permalink
Fix access token behavior in connection pool (dotnet#443)
Browse files Browse the repository at this point in the history
* Initial test changes

* Fix access token behavior in connection pool

* Compare ordinals for strings

* Access token only
  • Loading branch information
cheenamalhotra committed Feb 25, 2020
1 parent 363902d commit 44b5867
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ internal string AccessToken
public override bool Equals(object obj)
{
SqlConnectionPoolKey key = obj as SqlConnectionPoolKey;
return (key != null && _credential == key._credential && ConnectionString == key.ConnectionString && Object.ReferenceEquals(_accessToken, key._accessToken));
return (key != null
&& _credential == key._credential
&& ConnectionString == key.ConnectionString
&& string.CompareOrdinal(_accessToken, key._accessToken) == 0);
}

public override int GetHashCode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public override bool Equals(object obj)
return (key != null &&
_credential == key._credential &&
ConnectionString == key.ConnectionString &&
Object.ReferenceEquals(_accessToken, key._accessToken) &&
string.CompareOrdinal(_accessToken, key._accessToken) == 0 &&
_serverCertificateValidationCallback == key._serverCertificateValidationCallback &&
_clientCertificateRetrievalCallback == key._clientCertificateRetrievalCallback &&
_originalNetworkAddressInfo == key._originalNetworkAddressInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
using System.IO;
using System.Linq;
using System.Reflection;
using System.Security;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client;
using Newtonsoft.Json;
using Xunit;

Expand All @@ -26,8 +28,9 @@ public static class DataTestUtility
public static readonly string TCPConnectionStringHGSVBS = null;
public static readonly string TCPConnectionStringAASVBS = null;
public static readonly string TCPConnectionStringAASSGX = null;
public static readonly string AADAccessToken = null;
public static readonly string AADAuthorityURL = null;
public static readonly string AADPasswordConnectionString = null;
public static readonly string AADAccessToken = null;
public static readonly string AKVBaseUrl = null;
public static readonly string AKVUrl = null;
public static readonly string AKVClientId = null;
Expand Down Expand Up @@ -60,7 +63,7 @@ private class Config
public string TCPConnectionStringHGSVBS = null;
public string TCPConnectionStringAASVBS = null;
public string TCPConnectionStringAASSGX = null;
public string AADAccessToken = null;
public string AADAuthorityURL = null;
public string AADPasswordConnectionString = null;
public string AzureKeyVaultURL = null;
public string AzureKeyVaultClientId = null;
Expand All @@ -83,13 +86,20 @@ static DataTestUtility()
TCPConnectionStringHGSVBS = c.TCPConnectionStringHGSVBS;
TCPConnectionStringAASVBS = c.TCPConnectionStringAASVBS;
TCPConnectionStringAASSGX = c.TCPConnectionStringAASSGX;
AADAccessToken = c.AADAccessToken;
AADAuthorityURL = c.AADAuthorityURL;
AADPasswordConnectionString = c.AADPasswordConnectionString;
SupportsLocalDb = c.SupportsLocalDb;
SupportsIntegratedSecurity = c.SupportsIntegratedSecurity;
SupportsFileStream = c.SupportsFileStream;
EnclaveEnabled = c.EnclaveEnabled;

if (IsAADPasswordConnStrSetup() && IsAADAuthorityURLSetup())
{
string username = RetrieveValueFromConnStr(AADPasswordConnectionString, new string[] { "User ID", "UID" });
string password = RetrieveValueFromConnStr(AADPasswordConnectionString, new string[] { "Password", "PWD" });
AADAccessToken = GenerateAccessToken(AADAuthorityURL, username, password);
}

string url = c.AzureKeyVaultURL;
Uri AKVBaseUri = null;
if (!string.IsNullOrEmpty(url) && Uri.TryCreate(url, UriKind.Absolute, out AKVBaseUri))
Expand Down Expand Up @@ -134,6 +144,41 @@ static DataTestUtility()
}
}

private static string GenerateAccessToken(string authorityURL, string aADAuthUserID, string aADAuthPassword)
{
return AcquireTokenAsync(authorityURL, aADAuthUserID, aADAuthPassword).Result;
}

private static Task<string> AcquireTokenAsync(string authorityURL, string userID, string password) => Task.Run(() =>
{
// The below properties are set specific to test configurations.
string scope = "https://database.windows.net//.default";
string applicationName = "Microsoft Data SqlClient Manual Tests";
string clientVersion = "1.0.0.0";
string adoClientId = "4d079b4c-cab7-4b7c-a115-8fd51b6f8239";
IPublicClientApplication app = PublicClientApplicationBuilder.Create(adoClientId)
.WithAuthority(authorityURL)
.WithClientName(applicationName)
.WithClientVersion(clientVersion)
.Build();
AuthenticationResult result;
string[] scopes = new string[] { scope };
// Note: CorrelationId, which existed in ADAL, can not be set in MSAL (yet?).
// parameter.ConnectionId was passed as the CorrelationId in ADAL to aid support in troubleshooting.
// If/When MSAL adds CorrelationId support, it should be passed from parameters here, too.
SecureString securePassword = new SecureString();
foreach (char c in password)
securePassword.AppendChar(c);
securePassword.MakeReadOnly();
result = app.AcquireTokenByUsernamePassword(scopes, userID, securePassword).ExecuteAsync().Result;
return result.AccessToken;
});

public static bool IsDatabasePresent(string name)
{
AvailableDatabases = AvailableDatabases ?? new Dictionary<string, bool>();
Expand Down Expand Up @@ -171,6 +216,11 @@ public static bool IsAADPasswordConnStrSetup()
return !string.IsNullOrEmpty(AADPasswordConnectionString);
}

public static bool IsAADAuthorityURLSetup()
{
return !string.IsNullOrEmpty(AADAuthorityURL);
}

public static bool IsNotAzureServer()
{
return AreConnStringsSetup() ? !DataTestUtility.IsAzureSqlServer(new SqlConnectionStringBuilder((DataTestUtility.TCPConnectionString)).DataSource) : true;
Expand Down Expand Up @@ -248,10 +298,11 @@ public static string GetUniqueNameForSqlServer(string prefix)

public static string GetAccessToken()
{
return AADAccessToken;
// Creates a new Object Reference of Access Token - See GitHub Issue 438
return (null != AADAccessToken) ? new string(AADAccessToken.ToCharArray()) : null;
}

public static bool IsAccessTokenSetup() => string.IsNullOrEmpty(GetAccessToken()) ? false : true;
public static bool IsAccessTokenSetup() => !string.IsNullOrEmpty(GetAccessToken());

public static bool IsFileStreamSetup() => SupportsFileStream;

Expand Down Expand Up @@ -519,5 +570,54 @@ public static string GetValueString(object paramValue)

return paramValue.ToString();
}

public static string RemoveKeysInConnStr(string connStr, string[] keysToRemove)
{
// tokenize connection string and remove input keys.
string res = "";
string[] keys = connStr.Split(';');
foreach (var key in keys)
{
if (!string.IsNullOrEmpty(key.Trim()))
{
bool removeKey = false;
foreach (var keyToRemove in keysToRemove)
{
if (key.Trim().ToLower().StartsWith(keyToRemove.Trim().ToLower()))
{
removeKey = true;
break;
}
}
if (!removeKey)
{
res += key + ";";
}
}
}
return res;
}

public static string RetrieveValueFromConnStr(string connStr, string[] keywords)
{
// tokenize connection string and retrieve value for a specific key.
string res = "";
string[] keys = connStr.Split(';');
foreach (var key in keys)
{
foreach (var keyword in keywords)
{
if (!string.IsNullOrEmpty(key.Trim()))
{
if (key.Trim().ToLower().StartsWith(keyword.Trim().ToLower()))
{
res = key.Substring(key.IndexOf('=') + 1).Trim();
break;
}
}
}
}
return res;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,46 @@ private static void BasicConnectionPoolingTest(string connectionString)
connection3.Close();

connectionPool.Cleanup();

SqlConnection connection4 = new SqlConnection(connectionString);
connection4.Open();
Assert.True(internalConnection.IsInternalConnectionOf(connection4), "New connection does not use same internal connection");
Assert.True(connectionPool.ContainsConnection(connection4), "New connection is in a different pool");
connection4.Close();
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAADPasswordConnStrSetup), nameof(DataTestUtility.IsAADAuthorityURLSetup))]
public static void AccessTokenConnectionPoolingTest()
{
// Remove cred info and add invalid token
string[] credKeys = { "User ID", "Password", "UID", "PWD", "Authentication" };
string connectionString = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, credKeys);

SqlConnection connection = new SqlConnection(connectionString);
connection.AccessToken = DataTestUtility.GetAccessToken();
connection.Open();
InternalConnectionWrapper internalConnection = new InternalConnectionWrapper(connection);
ConnectionPoolWrapper connectionPool = new ConnectionPoolWrapper(connection);
connection.Close();

SqlConnection connection2 = new SqlConnection(connectionString);
connection2.AccessToken = DataTestUtility.GetAccessToken();
connection2.Open();
Assert.True(internalConnection.IsInternalConnectionOf(connection2), "New connection does not use same internal connection");
Assert.True(connectionPool.ContainsConnection(connection2), "New connection is in a different pool");
connection2.Close();

SqlConnection connection3 = new SqlConnection(connectionString + ";App=SqlConnectionPoolUnitTest;");
connection3.AccessToken = DataTestUtility.GetAccessToken();
connection3.Open();
Assert.False(internalConnection.IsInternalConnectionOf(connection3), "Connection with different connection string uses same internal connection");
Assert.False(connectionPool.ContainsConnection(connection3), "Connection with different connection string uses same connection pool");
connection3.Close();

connectionPool.Cleanup();

SqlConnection connection4 = new SqlConnection(connectionString);
connection4.AccessToken = DataTestUtility.GetAccessToken();
connection4.Open();
Assert.True(internalConnection.IsInternalConnectionOf(connection4), "New connection does not use same internal connection");
Assert.True(connectionPool.ContainsConnection(connection4), "New connection is in a different pool");
Expand Down

0 comments on commit 44b5867

Please sign in to comment.