Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API | AccessTokenCallback support #1260

Merged
merged 41 commits into from Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
b974ffb
POC for TokenCredential support
christothes Sep 10, 2021
8f76542
POC 2 - callback abstraction
christothes Oct 18, 2021
7ffc10c
merge
christothes Oct 18, 2021
9ed36d7
POC 3
christothes Oct 22, 2021
8321090
fix
christothes Oct 25, 2021
2aee2b2
Merge remote-tracking branch 'upstream/main' into chriss/ADCreds
christothes Oct 27, 2021
ea47be2
Merge remote-tracking branch 'upstream/main' into chriss/ADCreds
christothes Nov 8, 2021
10f5a95
merge
christothes Nov 8, 2021
d281983
cleanups
christothes Nov 12, 2021
fcfb66e
formatting
christothes Nov 12, 2021
4c718b5
netfx consistency
christothes Nov 12, 2021
4ccc412
fix
christothes Nov 12, 2021
68bf511
cleanup
christothes Nov 15, 2021
0671d58
merge
christothes Mar 8, 2023
d9570b3
nuget
christothes Mar 8, 2023
aa517b6
source ref
christothes Mar 8, 2023
c5def10
revert nuget.config change
christothes Mar 10, 2023
727406f
Update src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlCli…
christothes Mar 10, 2023
5c1b5ef
cast timeout to int
christothes Mar 13, 2023
d82f3bb
PR feedback
christothes Apr 14, 2023
9a9cd87
tests
christothes Apr 18, 2023
677f729
fix resources
christothes Apr 19, 2023
afc5fff
fix error messages for fx
christothes Apr 19, 2023
2673bd4
fixes
christothes Apr 19, 2023
b08a251
rename AzureADTokenRequestContext
christothes Apr 20, 2023
be1263e
use SqlAuthenticationParameters in callback
christothes May 2, 2023
52a84f9
docs and simple sample
christothes May 2, 2023
9be9391
tests for password and userId with callback
christothes May 3, 2023
a2233f5
add to SqlClient.cs API listing
christothes May 3, 2023
c1df1f4
Allow credential with callback and pass to the callback, if available
David-Engel May 12, 2023
a4aca8f
Merge pull request #1 from David-Engel/ADCreds
christothes Jun 5, 2023
fc57e98
Merge remote-tracking branch 'upstream/main' into chriss/ADCreds
DavoudEshtehari Jun 5, 2023
2ab55b8
fb
christothes Jun 6, 2023
ada78c2
Apply suggestions from code review
christothes Jun 6, 2023
1e531dc
Merge branch 'chriss/ADCreds' of https://github.com/christothes/SqlCl…
christothes Jun 6, 2023
caa6c3e
fb
christothes Jun 6, 2023
36a59e8
Apply suggestions from code review
christothes Jun 6, 2023
cd922db
Merge branch 'chriss/ADCreds' of https://github.com/christothes/SqlCl…
christothes Jun 6, 2023
16693f6
fb
christothes Jun 6, 2023
78b9e4a
fb
christothes Jun 13, 2023
67bfd01
Add tests
DavoudEshtehari Jun 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -361,3 +361,5 @@ MigrationBackup/

# Config Json file
**/config.json

.idea/
christothes marked this conversation as resolved.
Show resolved Hide resolved
Expand Up @@ -971,6 +971,9 @@
<PackageReference Condition="$(TargetGroup) == 'netstandard'" Include="System.Security.Cryptography.Cng" Version="$(SystemSecurityCryptographyCngVersion)" />
<PackageReference Condition="$(BuildForRelease) == 'true'" Include="Microsoft.SourceLink.GitHub" Version="$(MicrosoftSourceLinkGitHubVersion)" PrivateAssets="All" />
</ItemGroup>
<ItemGroup>
<Compile Include="Microsoft\Data\SqlClient\AadTokenRequestContext.cs" />
</ItemGroup>
<Import Project="$(ToolsDir)targets\GenerateThisAssemblyCs.targets" />
<Import Project="$(ToolsDir)targets\ResolveContract.targets" Condition="'$(OSGroup)' == 'AnyOS'" />
<Import Project="$(ToolsDir)targets\NotSupported.targets" Condition="'$(OSGroup)' == 'AnyOS'" />
Expand Down
@@ -0,0 +1,25 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Threading;

namespace Microsoft.Data.SqlClient
{
/// <summary>
///
/// </summary>
public class AadTokenRequestContext
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
{
/// <summary>
///
/// </summary>
/// <param name="resource"></param>
public AadTokenRequestContext(string resource) { Resource = resource; }

/// <summary>
///
/// </summary>
public string Resource { get; }
}
}
Expand Up @@ -89,6 +89,8 @@ private enum CultureCheckState : uint
/// Instance-level list of custom key store providers. It can be set more than once by the user.
private IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> _customColumnEncryptionKeyStoreProviders;

private Func<AadTokenRequestContext, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;

internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;

Expand Down Expand Up @@ -270,7 +272,7 @@ internal static List<string> GetColumnEncryptionSystemKeyStoreProvidersNames()
}

/// <summary>
/// This function returns a list of the names of the custom providers currently registered. If the
/// This function returns a list of the names of the custom providers currently registered. If the
/// instance-level cache is not empty, that cache is used, else the global cache is used.
/// </summary>
/// <returns>Combined list of provider names</returns>
Expand Down Expand Up @@ -342,7 +344,7 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary<st
new(customProviders, StringComparer.OrdinalIgnoreCase);

// Set the dictionary to the ReadOnly dictionary.
// This method can be called more than once. Re-registering a new collection will replace the
// This method can be called more than once. Re-registering a new collection will replace the
// old collection of providers.
_customColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders;
}
Expand Down Expand Up @@ -688,6 +690,35 @@ public string AccessToken
}
}

/// <summary>
///
/// </summary>
public Func<AadTokenRequestContext, CancellationToken, Task<SqlAuthenticationToken>> AccessTokenCallback
{
get { return _accessTokenCallback; }
set
{
// If a connection is connecting or is ever opened, AccessToken callback cannot be set
if (!InnerConnection.AllowSetConnectionString)
{
throw ADP.OpenConnectionPropertySet(nameof(AccessTokenCallback), InnerConnection.State);
}

if (value != null)
{
// Check if the usage of AccessToken has any conflict with the keys used in connection string and credential
// CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken((SqlConnectionString)ConnectionOptions);
christothes marked this conversation as resolved.
Show resolved Hide resolved
}
if (value == null)
{
throw new ArgumentNullException(nameof(AccessTokenCallback), "Callback cannot be null.");
christothes marked this conversation as resolved.
Show resolved Hide resolved
}

ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, null, value));
_accessTokenCallback = value;
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/Database/*' />
[ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)]
[ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)]
Expand Down
Expand Up @@ -133,7 +133,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null);
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
}
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool, key.AccessTokenCallback);
}

protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)
Expand Down
Expand Up @@ -130,6 +130,7 @@ internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken _fedAuthToken = null;
internal byte[] _accessTokenInBytes;
internal readonly Func<AadTokenRequestContext,CancellationToken,Task<SqlAuthenticationToken>> _accessTokenCallback;

private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
private readonly SqlAuthenticationProviderManager _sqlAuthenticationProviderManager;
Expand Down Expand Up @@ -434,19 +435,19 @@ internal SqlConnectionTimeoutErrorInternal TimeoutErrorInternal
// the new Login7 packet will always write out the new password (or a length of zero and no bytes if not present)
//
internal SqlInternalConnectionTds(
DbConnectionPoolIdentity identity,
SqlConnectionString connectionOptions,
SqlCredential credential,
object providerInfo,
string newPassword,
SecureString newSecurePassword,
bool redirectedUserInstance,
SqlConnectionString userConnectionOptions = null, // NOTE: userConnectionOptions may be different to connectionOptions if the connection string has been expanded (see SqlConnectionString.Expand)
SessionData reconnectSessionData = null,
bool applyTransientFaultHandling = false,
string accessToken = null,
DbConnectionPool pool = null
) : base(connectionOptions)
DbConnectionPoolIdentity identity,
SqlConnectionString connectionOptions,
SqlCredential credential,
object providerInfo,
string newPassword,
SecureString newSecurePassword,
bool redirectedUserInstance,
SqlConnectionString userConnectionOptions = null, // NOTE: userConnectionOptions may be different to connectionOptions if the connection string has been expanded (see SqlConnectionString.Expand)
SessionData reconnectSessionData = null,
bool applyTransientFaultHandling = false,
string accessToken = null,
DbConnectionPool pool = null,
Func<AadTokenRequestContext, CancellationToken, Task<SqlAuthenticationToken>> accessTokenCallback = null) : base(connectionOptions)

{
#if DEBUG
Expand Down Expand Up @@ -479,6 +480,11 @@ internal SqlConnectionTimeoutErrorInternal TimeoutErrorInternal
_accessTokenInBytes = System.Text.Encoding.Unicode.GetBytes(accessToken);
}

if (accessTokenCallback != null)
{
_accessTokenCallback = accessTokenCallback;
}
christothes marked this conversation as resolved.
Show resolved Hide resolved

_activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper();
_sqlAuthenticationProviderManager = SqlAuthenticationProviderManager.Instance;

Expand Down Expand Up @@ -1345,6 +1351,18 @@ private void Login(ServerInfo server, TimeoutTimer timeout, string newPassword,
_federatedAuthenticationRequested = true;
}

if (_accessTokenCallback != null)
{
requestedFeatures |= TdsEnums.FeatureExtension.FedAuth;
_fedAuthFeatureExtensionData = new FederatedAuthenticationFeatureExtensionData
{
libraryType = TdsEnums.FedAuthLibrary.SecurityTokenCallback,
fedAuthRequiredPreLoginResponse = _fedAuthRequired,
};
// No need any further info from the server for token based authentication. So set _federatedAuthenticationRequested to true
_federatedAuthenticationInfoRequested = true;
}

// The GLOBALTRANSACTIONS, DATACLASSIFICATION, TCE, and UTF8 support features are implicitly requested
requestedFeatures |= TdsEnums.FeatureExtension.GlobalTransactions | TdsEnums.FeatureExtension.DataClassification | TdsEnums.FeatureExtension.Tce | TdsEnums.FeatureExtension.UTF8Support;

Expand Down Expand Up @@ -2144,6 +2162,7 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
Debug.Assert((ConnectionOptions._hasUserIdKeyword && ConnectionOptions._hasPasswordKeyword)
|| _credential != null
|| _accessTokenCallback != null
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity
Expand Down Expand Up @@ -2346,7 +2365,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
string username = null;

var authProvider = _sqlAuthenticationProviderManager.GetProvider(ConnectionOptions.Authentication);
if (authProvider == null)
if (authProvider == null && _accessTokenCallback == null)
throw SQL.CannotFindAuthProvider(ConnectionOptions.Authentication.ToString());

// retry getting access token once if MsalException.error_code is unknown_error.
Expand All @@ -2357,13 +2376,14 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
try
{
var authParamsBuilder = new SqlAuthenticationParameters.Builder(
authenticationMethod: ConnectionOptions.Authentication,
resource: fedAuthInfo.spn,
authority: fedAuthInfo.stsurl,
serverName: ConnectionOptions.DataSource,
databaseName: ConnectionOptions.InitialCatalog)
authenticationMethod: ConnectionOptions.Authentication,
resource: fedAuthInfo.spn,
authority: fedAuthInfo.stsurl,
serverName: ConnectionOptions.DataSource,
databaseName: ConnectionOptions.InitialCatalog)
.WithConnectionId(_clientConnectionId)
.WithConnectionTimeout(ConnectionOptions.ConnectTimeout);

switch (ConnectionOptions.Authentication)
{
case SqlAuthenticationMethod.ActiveDirectoryIntegrated:
Expand Down Expand Up @@ -2431,7 +2451,26 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
}
break;
default:
throw SQL.UnsupportedAuthenticationSpecified(ConnectionOptions.Authentication);
if (_accessTokenCallback == null)
{
throw SQL.UnsupportedAuthenticationSpecified(ConnectionOptions.Authentication);
}

if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
{
_fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
}
else
{
authParamsBuilder.WithUserId(ConnectionOptions.UserID);
SqlAuthenticationParameters parameters = authParamsBuilder;
CancellationTokenSource cts = new();
// Use Connection timeout value to cancel token acquire request after certain period of time.
cts.CancelAfter(parameters.ConnectionTimeout * 1000); // Convert to milliseconds
christothes marked this conversation as resolved.
Show resolved Hide resolved
_fedAuthToken = Task.Run(async () => await _accessTokenCallback(new AadTokenRequestContext(parameters.Resource), cts.Token)).GetAwaiter().GetResult().ToSqlFedAuthToken();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
break;
}

Debug.Assert(_fedAuthToken.accessToken != null, "AccessToken should not be null.");
Expand Down Expand Up @@ -2480,25 +2519,43 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
// Deal with normal MsalExceptions.
catch (MsalException msalException)
{
if (MsalError.UnknownError != msalException.ErrorCode
|| _timeout.IsExpired
|| _timeout.MillisecondsRemaining <= sleepInterval)
if (MsalError.UnknownError != msalException.ErrorCode || _timeout.IsExpired || _timeout.MillisecondsRemaining <= sleepInterval)
{
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MSALException error:> {0}", msalException.ErrorCode);

throw ADP.CreateSqlException(msalException, ConnectionOptions, this, username);
}

SqlClientEventSource.Log.TryAdvancedTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken|ADV> {0}, sleeping {1}[Milliseconds]", ObjectID, sleepInterval);
SqlClientEventSource.Log.TryAdvancedTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken|ADV> {0}, remaining {1}[Milliseconds]", ObjectID, _timeout.MillisecondsRemaining);
SqlClientEventSource.Log.TryAdvancedTraceEvent(
"<sc.SqlInternalConnectionTds.GetFedAuthToken|ADV> {0}, sleeping {1}[Milliseconds]",
ObjectID,
sleepInterval);
SqlClientEventSource.Log.TryAdvancedTraceEvent(
"<sc.SqlInternalConnectionTds.GetFedAuthToken|ADV> {0}, remaining {1}[Milliseconds]",
ObjectID,
_timeout.MillisecondsRemaining);

Thread.Sleep(sleepInterval);
sleepInterval *= 2;
}
// All other exceptions from MSAL/Azure Identity APIs
catch (Exception e)
{
throw SqlException.CreateException(new() { new(0, (byte)0x00, (byte)TdsEnums.FATAL_ERROR_CLASS, ConnectionOptions.DataSource, e.Message, ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0) }, "", this, e);
throw SqlException.CreateException(
new()
{
new(
0,
(byte)0x00,
(byte)TdsEnums.FATAL_ERROR_CLASS,
ConnectionOptions.DataSource,
e.Message,
ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName,
0)
},
"",
this,
e);
}
}

Expand Down Expand Up @@ -2603,6 +2660,7 @@ internal void OnFeatureExtAck(int featureId, byte[] data)

switch (_fedAuthFeatureExtensionData.libraryType)
{
case TdsEnums.FedAuthLibrary.SecurityTokenCallback:
case TdsEnums.FedAuthLibrary.MSAL:
case TdsEnums.FedAuthLibrary.SecurityToken:
// The server shouldn't have sent any additional data with the ack (like a nonce)
Expand Down