diff --git a/src/Microsoft.Data.SqlClient/add-ons/Directory.Build.props b/src/Microsoft.Data.SqlClient/add-ons/Directory.Build.props
index 12a111dec0..762c5f9ed8 100644
--- a/src/Microsoft.Data.SqlClient/add-ons/Directory.Build.props
+++ b/src/Microsoft.Data.SqlClient/add-ons/Directory.Build.props
@@ -18,7 +18,7 @@
net462
netstandard2.0
- net6.0
+ net6.0
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
index 83f3e3a53d..af1b58d64a 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
@@ -394,6 +394,9 @@
Microsoft\Data\SqlClient\SqlMetadataFactory.cs
+
+ Microsoft\Data\SqlClient\SqlInternalConnection.cs
+
Microsoft\Data\SqlClient\SqlNotificationEventArgs.cs
@@ -628,7 +631,6 @@
-
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs
index 56e369593a..76710ff980 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs
@@ -17,38 +17,68 @@ namespace Microsoft.Data.SqlClient
// CONSIDER creating your own Set method that calls the base Set rather than providing a parameterized ctor, it is friendlier to caching
// DO NOT use this class' state after Dispose has been called. It will not throw ObjectDisposedException but it will be a cleared object
- internal abstract class AAsyncCallContext : IDisposable
+ internal abstract class AAsyncCallContext : AAsyncBaseCallContext
where TOwner : class
+ where TDisposable : IDisposable
{
- protected TOwner _owner;
- protected TaskCompletionSource _source;
- protected IDisposable _disposable;
+ protected TDisposable _disposable;
protected AAsyncCallContext()
{
}
- protected AAsyncCallContext(TOwner owner, TaskCompletionSource source, IDisposable disposable = null)
+ protected AAsyncCallContext(TOwner owner, TaskCompletionSource source, TDisposable disposable = default)
{
Set(owner, source, disposable);
}
- protected void Set(TOwner owner, TaskCompletionSource source, IDisposable disposable = null)
+ protected void Set(TOwner owner, TaskCompletionSource source, TDisposable disposable = default)
+ {
+ base.Set(owner, source);
+ _disposable = disposable;
+ }
+
+ protected override void DisposeCore()
+ {
+ TDisposable copyDisposable = _disposable;
+ _disposable = default;
+ copyDisposable?.Dispose();
+ }
+ }
+
+ internal abstract class AAsyncBaseCallContext
+ {
+ protected TOwner _owner;
+ protected TaskCompletionSource _source;
+ protected bool _isDisposed;
+
+ protected AAsyncBaseCallContext()
+ {
+ }
+
+ protected void Set(TOwner owner, TaskCompletionSource source)
{
_owner = owner ?? throw new ArgumentNullException(nameof(owner));
_source = source ?? throw new ArgumentNullException(nameof(source));
- _disposable = disposable;
+ _isDisposed = false;
}
protected void ClearCore()
{
_source = null;
_owner = default;
- IDisposable copyDisposable = _disposable;
- _disposable = null;
- copyDisposable?.Dispose();
+ try
+ {
+ DisposeCore();
+ }
+ finally
+ {
+ _isDisposed = true;
+ }
}
+ protected abstract void DisposeCore();
+
///
/// override this method to cleanup instance data before ClearCore is called which will blank the base data
///
@@ -65,16 +95,19 @@ protected virtual void AfterCleared(TOwner owner)
public void Dispose()
{
- TOwner owner = _owner;
- try
- {
- Clear();
- }
- finally
+ if (!_isDisposed)
{
- ClearCore();
+ TOwner owner = _owner;
+ try
+ {
+ Clear();
+ }
+ finally
+ {
+ ClearCore();
+ }
+ AfterCleared(owner);
}
- AfterCleared(owner);
}
}
}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs
index 6c55dcb541..4cd75ecb4d 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs
@@ -46,7 +46,7 @@ public sealed partial class SqlCommand : DbCommand, ICloneable
private static readonly Func s_beginExecuteXmlReaderInternal = BeginExecuteXmlReaderInternalCallback;
private static readonly Func s_beginExecuteNonQueryInternal = BeginExecuteNonQueryInternalCallback;
- internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext
+ internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext
{
public Guid OperationID;
public CommandBehavior CommandBehavior;
@@ -54,7 +54,7 @@ internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext _owner;
public TaskCompletionSource TaskCompletionSource => _source;
- public void Set(SqlCommand command, TaskCompletionSource source, IDisposable disposable, CommandBehavior behavior, Guid operationID)
+ public void Set(SqlCommand command, TaskCompletionSource source, CancellationTokenRegistration disposable, CommandBehavior behavior, Guid operationID)
{
base.Set(command, source, disposable);
CommandBehavior = behavior;
@@ -73,6 +73,31 @@ protected override void AfterCleared(SqlCommand owner)
}
}
+ internal sealed class ExecuteNonQueryAsyncCallContext : AAsyncCallContext
+ {
+ public Guid OperationID;
+
+ public SqlCommand Command => _owner;
+
+ public TaskCompletionSource TaskCompletionSource => _source;
+
+ public void Set(SqlCommand command, TaskCompletionSource source, CancellationTokenRegistration disposable, Guid operationID)
+ {
+ base.Set(command, source, disposable);
+ OperationID = operationID;
+ }
+
+ protected override void Clear()
+ {
+ OperationID = default;
+ }
+
+ protected override void AfterCleared(SqlCommand owner)
+ {
+
+ }
+ }
+
private CommandType _commandType;
private int? _commandTimeout;
private UpdateRowSource _updatedRowSource = UpdateRowSource.Both;
@@ -2540,23 +2565,37 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok
}
Task returnedTask = source.Task;
+ returnedTask = RegisterForConnectionCloseNotification(returnedTask);
+
+ ExecuteNonQueryAsyncCallContext context = new ExecuteNonQueryAsyncCallContext();
+ context.Set(this, source, registration, operationId);
try
{
- returnedTask = RegisterForConnectionCloseNotification(returnedTask);
+ Task.Factory.FromAsync(
+ static (AsyncCallback callback, object stateObject) => ((ExecuteNonQueryAsyncCallContext)stateObject).Command.BeginExecuteNonQueryAsync(callback, stateObject),
+ static (IAsyncResult result) => ((ExecuteNonQueryAsyncCallContext)result.AsyncState).Command.EndExecuteNonQueryAsync(result),
+ state: context
+ ).ContinueWith(
+ static (Task task, object state) =>
+ {
+ ExecuteNonQueryAsyncCallContext context = (ExecuteNonQueryAsyncCallContext)state;
+
+ Guid operationId = context.OperationID;
+ SqlCommand command = context.Command;
+ TaskCompletionSource source = context.TaskCompletionSource;
+
+ context.Dispose();
+ context = null;
- Task.Factory.FromAsync(BeginExecuteNonQueryAsync, EndExecuteNonQueryAsync, null)
- .ContinueWith((Task task) =>
- {
- registration.Dispose();
if (task.IsFaulted)
{
Exception e = task.Exception.InnerException;
- s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e);
+ s_diagnosticListener.WriteCommandError(operationId, command, command._transaction, e);
source.SetException(e);
}
else
{
- s_diagnosticListener.WriteCommandAfter(operationId, this, _transaction);
+ s_diagnosticListener.WriteCommandAfter(operationId, command, command._transaction);
if (task.IsCanceled)
{
source.SetCanceled();
@@ -2567,13 +2606,15 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok
}
}
},
- TaskScheduler.Default
+ state: context,
+ scheduler: TaskScheduler.Default
);
}
catch (Exception e)
{
s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e);
source.SetException(e);
+ context.Dispose();
}
return returnedTask;
@@ -2648,11 +2689,11 @@ private Task InternalExecuteReaderAsync(CommandBehavior behavior,
}
Task returnedTask = source.Task;
+ ExecuteReaderAsyncCallContext context = null;
try
{
returnedTask = RegisterForConnectionCloseNotification(returnedTask);
- ExecuteReaderAsyncCallContext context = null;
if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection)
{
context = Interlocked.Exchange(ref sqlInternalConnection.CachedCommandExecuteReaderAsyncContext, null);
@@ -2680,6 +2721,7 @@ private Task InternalExecuteReaderAsync(CommandBehavior behavior,
}
source.SetException(e);
+ context.Dispose();
}
return returnedTask;
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs
index cc028e44c7..af5e0f1c90 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs
@@ -4408,7 +4408,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken)
return source.Task;
}
- IDisposable registration = null;
+ CancellationTokenRegistration registration = default;
if (cancellationToken.CanBeCanceled)
{
if (cancellationToken.IsCancellationRequested)
@@ -4708,7 +4708,7 @@ out bytesRead
Debug.Assert(context.Source != null, "context._source should not be null when continuing");
// setup for cleanup/completing
retryTask.ContinueWith(
- continuationAction: SqlDataReaderAsyncCallContext.s_completeCallback,
+ continuationAction: SqlDataReaderBaseAsyncCallContext.s_completeCallback,
state: context,
TaskScheduler.Default
);
@@ -4735,6 +4735,13 @@ public override Task ReadAsync(CancellationToken cancellationToken)
return Task.FromException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed()));
}
+ // Register first to catch any already expired tokens to be able to trigger cancellation event.
+ CancellationTokenRegistration registration = default;
+ if (cancellationToken.CanBeCanceled)
+ {
+ registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
+ }
+
// If user's token is canceled, return a canceled task
if (cancellationToken.IsCancellationRequested)
{
@@ -4833,12 +4840,6 @@ public override Task ReadAsync(CancellationToken cancellationToken)
return source.Task;
}
- IDisposable registration = null;
- if (cancellationToken.CanBeCanceled)
- {
- registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
- }
-
ReadAsyncCallContext context = null;
if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection)
{
@@ -4849,7 +4850,7 @@ public override Task ReadAsync(CancellationToken cancellationToken)
context = new ReadAsyncCallContext();
}
- Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ReadAsyncCallContext was not properly disposed");
+ Debug.Assert(context.Reader == default && context.Source == null && context.Disposable == default, "cached ReadAsyncCallContext was not properly disposed");
context.Set(this, source, registration);
context._hasMoreData = more;
@@ -5007,7 +5008,7 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo
}
// Setup cancellations
- IDisposable registration = null;
+ CancellationTokenRegistration registration = default;
if (cancellationToken.CanBeCanceled)
{
registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
@@ -5023,7 +5024,7 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo
context = new IsDBNullAsyncCallContext();
}
- Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ISDBNullAsync context not properly disposed");
+ Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ISDBNullAsync context not properly disposed");
context.Set(this, source, registration);
context._columnIndex = i;
@@ -5154,7 +5155,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat
}
// Setup cancellations
- IDisposable registration = null;
+ CancellationTokenRegistration registration = default;
if (cancellationToken.CanBeCanceled)
{
registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
@@ -5218,49 +5219,63 @@ internal void CompletePendingReadWithFailure(int errorCode, bool resetForcePendi
}
#endif
-
- internal abstract class SqlDataReaderAsyncCallContext : AAsyncCallContext
+
+ internal abstract class SqlDataReaderBaseAsyncCallContext : AAsyncBaseCallContext
{
internal static readonly Action, object> s_completeCallback = CompleteAsyncCallCallback;
internal static readonly Func> s_executeCallback = ExecuteAsyncCallCallback;
- protected SqlDataReaderAsyncCallContext()
+ protected SqlDataReaderBaseAsyncCallContext()
{
}
- protected SqlDataReaderAsyncCallContext(SqlDataReader owner, TaskCompletionSource source, IDisposable disposable = null)
+ protected SqlDataReaderBaseAsyncCallContext(SqlDataReader owner, TaskCompletionSource source)
{
- Set(owner, source, disposable);
+ Set(owner, source);
}
internal abstract Func> Execute { get; }
internal SqlDataReader Reader { get => _owner; set => _owner = value; }
- public IDisposable Disposable { get => _disposable; set => _disposable = value; }
-
public TaskCompletionSource Source { get => _source; set => _source = value; }
- new public void Set(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable)
- {
- base.Set(reader, source, disposable);
- }
-
private static Task ExecuteAsyncCallCallback(Task task, object state)
{
- SqlDataReaderAsyncCallContext context = (SqlDataReaderAsyncCallContext)state;
+ SqlDataReaderBaseAsyncCallContext context = (SqlDataReaderBaseAsyncCallContext)state;
return context.Reader.ContinueAsyncCall(task, context);
}
private static void CompleteAsyncCallCallback(Task task, object state)
{
- SqlDataReaderAsyncCallContext context = (SqlDataReaderAsyncCallContext)state;
+ SqlDataReaderBaseAsyncCallContext context = (SqlDataReaderBaseAsyncCallContext)state;
context.Reader.CompleteAsyncCall(task, context);
}
}
- internal sealed class ReadAsyncCallContext : SqlDataReaderAsyncCallContext
+ internal abstract class SqlDataReaderAsyncCallContext : SqlDataReaderBaseAsyncCallContext
+ where TDisposable : IDisposable
+ {
+ private TDisposable _disposable;
+
+ public TDisposable Disposable { get => _disposable; set => _disposable = value; }
+
+ public void Set(SqlDataReader owner, TaskCompletionSource source, TDisposable disposable)
+ {
+ base.Set(owner, source);
+ _disposable = disposable;
+ }
+
+ protected override void DisposeCore()
+ {
+ TDisposable copy = _disposable;
+ _disposable = default;
+ copy.Dispose();
+ }
+ }
+
+ internal sealed class ReadAsyncCallContext : SqlDataReaderAsyncCallContext
{
internal static readonly Func> s_execute = SqlDataReader.ReadAsyncExecute;
@@ -5279,7 +5294,7 @@ protected override void AfterCleared(SqlDataReader owner)
}
}
- internal sealed class IsDBNullAsyncCallContext : SqlDataReaderAsyncCallContext
+ internal sealed class IsDBNullAsyncCallContext : SqlDataReaderAsyncCallContext
{
internal static readonly Func> s_execute = SqlDataReader.IsDBNullAsyncExecute;
@@ -5295,19 +5310,19 @@ protected override void AfterCleared(SqlDataReader owner)
}
}
- private sealed class HasNextResultAsyncCallContext : SqlDataReaderAsyncCallContext
+ private sealed class HasNextResultAsyncCallContext : SqlDataReaderAsyncCallContext
{
private static readonly Func> s_execute = SqlDataReader.NextResultAsyncExecute;
- public HasNextResultAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable)
- : base(reader, source, disposable)
+ public HasNextResultAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, CancellationTokenRegistration disposable)
{
+ Set(reader, source, disposable);
}
internal override Func> Execute => s_execute;
}
- private sealed class GetBytesAsyncCallContext : SqlDataReaderAsyncCallContext
+ private sealed class GetBytesAsyncCallContext : SqlDataReaderAsyncCallContext
{
internal enum OperationMode
{
@@ -5345,7 +5360,7 @@ protected override void Clear()
}
}
- private sealed class GetFieldValueAsyncCallContext : SqlDataReaderAsyncCallContext
+ private sealed class GetFieldValueAsyncCallContext : SqlDataReaderAsyncCallContext
{
private static readonly Func> s_execute = SqlDataReader.GetFieldValueAsyncExecute;
@@ -5353,9 +5368,9 @@ private sealed class GetFieldValueAsyncCallContext : SqlDataReaderAsyncCallCo
internal GetFieldValueAsyncCallContext() { }
- internal GetFieldValueAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable)
- : base(reader, source, disposable)
+ internal GetFieldValueAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, CancellationTokenRegistration disposable)
{
+ Set(reader, source, disposable);
}
protected override void Clear()
@@ -5375,7 +5390,7 @@ protected override void Clear()
///
///
///
- private Task InvokeAsyncCall(SqlDataReaderAsyncCallContext context)
+ private Task InvokeAsyncCall(SqlDataReaderBaseAsyncCallContext context)
{
TaskCompletionSource source = context.Source;
try
@@ -5397,7 +5412,7 @@ private Task InvokeAsyncCall(SqlDataReaderAsyncCallContext context)
else
{
task.ContinueWith(
- continuationAction: SqlDataReaderAsyncCallContext.s_completeCallback,
+ continuationAction: SqlDataReaderBaseAsyncCallContext.s_completeCallback,
state: context,
TaskScheduler.Default
);
@@ -5422,7 +5437,7 @@ private Task InvokeAsyncCall(SqlDataReaderAsyncCallContext context)
///
///
///
- private Task ExecuteAsyncCall(SqlDataReaderAsyncCallContext context)
+ private Task ExecuteAsyncCall(AAsyncBaseCallContext context)
{
// _networkPacketTaskSource could be null if the connection was closed
// while an async invocation was outstanding.
@@ -5435,7 +5450,7 @@ private Task ExecuteAsyncCall(SqlDataReaderAsyncCallContext context)
else
{
return completionSource.Task.ContinueWith(
- continuationFunction: SqlDataReaderAsyncCallContext.s_executeCallback,
+ continuationFunction: SqlDataReaderBaseAsyncCallContext.s_executeCallback,
state: context,
TaskScheduler.Default
).Unwrap();
@@ -5451,7 +5466,7 @@ private Task ExecuteAsyncCall(SqlDataReaderAsyncCallContext context)
///
///
///
- private Task ContinueAsyncCall(Task task, SqlDataReaderAsyncCallContext context)
+ private Task ContinueAsyncCall(Task task, SqlDataReaderBaseAsyncCallContext context)
{
// this function must be an instance function called from the static callback because otherwise a compiler error
// is caused by accessing the _cancelAsyncOnCloseToken field of a MarshalByRefObject derived class
@@ -5511,7 +5526,7 @@ private Task ContinueAsyncCall(Task task, SqlDataReaderAsyncCallContext
///
///
///
- private void CompleteAsyncCall(Task task, SqlDataReaderAsyncCallContext context)
+ private void CompleteAsyncCall(Task task, SqlDataReaderBaseAsyncCallContext context)
{
TaskCompletionSource source = context.Source;
context.Dispose();
diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs
index 5d78fb574e..9cc91a1c38 100644
--- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs
@@ -1043,6 +1043,10 @@ public sealed partial class SqlConnectionStringBuilder : System.Data.Common.DbCo
[System.ComponentModel.DisplayNameAttribute("Encrypt")]
[System.ComponentModel.RefreshPropertiesAttribute(System.ComponentModel.RefreshProperties.All)]
public SqlConnectionEncryptOption Encrypt { get { throw null; } set { } }
+ ///
+ [System.ComponentModel.DisplayNameAttribute("Host Name In Certificate")]
+ [System.ComponentModel.RefreshPropertiesAttribute(System.ComponentModel.RefreshProperties.All)]
+ public string HostNameInCertificate { get { throw null; } set { } }
///
[System.ComponentModel.DisplayNameAttribute("Enlist")]
[System.ComponentModel.RefreshPropertiesAttribute(System.ComponentModel.RefreshProperties.All)]
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj
index 12323365bf..bf6b3f3ff2 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj
@@ -482,6 +482,9 @@
Microsoft\Data\SqlClient\SqlInfoMessageEventHandler.cs
+
+ Microsoft\Data\SqlClient\SqlInternalConnection.cs
+
Microsoft\Data\SqlClient\SqlInternalTransaction.cs
@@ -554,12 +557,18 @@
Microsoft\Data\SqlClient\TdsParameterSetter.cs
+
+ Microsoft\Data\SqlClient\TdsParserSafeHandles.Windows.cs
+
Microsoft\Data\SqlClient\TdsParserStaticMethods.cs
Microsoft\Data\SqlClient\TdsRecordBufferSetter.cs
+
+ Microsoft\Data\SqlClient\TdsParserSessionPool.cs
+
Microsoft\Data\SqlClient\TdsValueSetter.cs
@@ -639,7 +648,6 @@
-
@@ -648,8 +656,6 @@
-
-
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs
index 09061277e0..c733b7fc8a 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs
@@ -329,8 +329,8 @@ internal virtual SmiExtendedMetaData[] GetInternalSmiMetaData()
if (null != metaData && 0 < metaData.Length)
{
- metaDataReturn = new SmiExtendedMetaData[metaData.visibleColumns];
-
+ metaDataReturn = new SmiExtendedMetaData[metaData.VisibleColumnCount];
+ int returnIndex = 0;
for (int index = 0; index < metaData.Length; index++)
{
_SqlMetaData colMetaData = metaData[index];
@@ -369,7 +369,7 @@ internal virtual SmiExtendedMetaData[] GetInternalSmiMetaData()
length /= ADP.CharSize;
}
- metaDataReturn[index] = new SmiQueryMetaData(
+ metaDataReturn[returnIndex] = new SmiQueryMetaData(
colMetaData.type,
length,
colMetaData.precision,
@@ -397,6 +397,7 @@ internal virtual SmiExtendedMetaData[] GetInternalSmiMetaData()
colMetaData.IsDifferentName,
colMetaData.IsHidden
);
+ returnIndex += 1;
}
}
}
@@ -458,7 +459,7 @@ override public int VisibleFieldCount
{
return 0;
}
- return (md.visibleColumns);
+ return md.VisibleColumnCount;
}
}
@@ -1352,31 +1353,6 @@ private bool TryConsumeMetaData()
Debug.Assert(!ignored, "Parser read a row token while trying to read metadata");
}
- // we hide hidden columns from the user so build an internal map
- // that compacts all hidden columns from the array
- if (null != _metaData)
- {
-
- if (_snapshot != null && object.ReferenceEquals(_snapshot._metadata, _metaData))
- {
- _metaData = (_SqlMetaDataSet)_metaData.Clone();
- }
-
- _metaData.visibleColumns = 0;
-
- Debug.Assert(null == _metaData.indexMap, "non-null metaData indexmap");
- int[] indexMap = new int[_metaData.Length];
- for (int i = 0; i < indexMap.Length; ++i)
- {
- indexMap[i] = _metaData.visibleColumns;
-
- if (!(_metaData[i].IsHidden))
- {
- _metaData.visibleColumns++;
- }
- }
- _metaData.indexMap = indexMap;
- }
return true;
}
@@ -1690,15 +1666,15 @@ override public DataTable GetSchemaTable()
try
{
statistics = SqlStatistics.StartTimer(Statistics);
- if (null == _metaData || null == _metaData.schemaTable)
+ if (null == _metaData || null == _metaData._schemaTable)
{
if (null != this.MetaData)
{
- _metaData.schemaTable = BuildSchemaTable();
- Debug.Assert(null != _metaData.schemaTable, "No schema information yet!");
+ _metaData._schemaTable = BuildSchemaTable();
+ Debug.Assert(null != _metaData._schemaTable, "No schema information yet!");
}
}
- return _metaData?.schemaTable;
+ return _metaData?._schemaTable;
}
finally
{
@@ -2994,11 +2970,11 @@ virtual public int GetSqlValues(object[] values)
SetTimeout(_defaultTimeoutMilliseconds);
- int copyLen = (values.Length < _metaData.visibleColumns) ? values.Length : _metaData.visibleColumns;
+ int copyLen = (values.Length < _metaData.VisibleColumnCount) ? values.Length : _metaData.VisibleColumnCount;
for (int i = 0; i < copyLen; i++)
{
- values[_metaData.indexMap[i]] = GetSqlValueInternal(i);
+ values[_metaData.GetVisibleColumnIndex(i)] = GetSqlValueInternal(i);
}
return copyLen;
}
@@ -3398,7 +3374,7 @@ override public int GetValues(object[] values)
CheckMetaDataIsReady();
- int copyLen = (values.Length < _metaData.visibleColumns) ? values.Length : _metaData.visibleColumns;
+ int copyLen = (values.Length < _metaData.VisibleColumnCount) ? values.Length : _metaData.VisibleColumnCount;
int maximumColumn = copyLen - 1;
SetTimeout(_defaultTimeoutMilliseconds);
@@ -3414,12 +3390,19 @@ override public int GetValues(object[] values)
for (int i = 0; i < copyLen; i++)
{
// Get the usable, TypeSystem-compatible value from the iternal buffer
- values[_metaData.indexMap[i]] = GetValueFromSqlBufferInternal(_data[i], _metaData[i]);
+ int fieldIndex = _metaData.GetVisibleColumnIndex(i);
+ values[i] = GetValueFromSqlBufferInternal(_data[fieldIndex], _metaData[fieldIndex]);
// If this is sequential access, then we need to wipe the internal buffer
if ((sequentialAccess) && (i < maximumColumn))
{
_data[i].Clear();
+ if (fieldIndex > i && fieldIndex > 0)
+ {
+ // if we jumped an index forward because of a hidden column see if the buffer before the
+ // current one was populated by the seek forward and clear it if it was
+ _data[fieldIndex - 1].Clear();
+ }
}
}
@@ -4767,7 +4750,7 @@ internal bool TrySetMetaData(_SqlMetaDataSet metaData, bool moreInfo)
_tableNames = null;
if (_metaData != null)
{
- _metaData.schemaTable = null;
+ _metaData._schemaTable = null;
_data = SqlBuffer.CreateBufferArray(metaData.Length);
}
@@ -5326,6 +5309,13 @@ public override Task ReadAsync(CancellationToken cancellationToken)
return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed("ReadAsync")));
}
+ // Register first to catch any already expired tokens to be able to trigger cancellation event.
+ IDisposable registration = null;
+ if (cancellationToken.CanBeCanceled)
+ {
+ registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
+ }
+
// If user's token is canceled, return a canceled task
if (cancellationToken.IsCancellationRequested)
{
@@ -5425,12 +5415,6 @@ public override Task ReadAsync(CancellationToken cancellationToken)
return source.Task;
}
- IDisposable registration = null;
- if (cancellationToken.CanBeCanceled)
- {
- registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
- }
-
var context = Interlocked.Exchange(ref _cachedReadAsyncContext, null) ?? new ReadAsyncCallContext();
Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ReadAsyncCallContext was not properly disposed");
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs
index 6daca4d771..1edad799ae 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs
@@ -523,7 +523,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
// Clean up IsSQLDNSCachingSupported flag from previous status
_connHandler.IsSQLDNSCachingSupported = false;
- UInt32 sniStatus = SNILoadHandle.SingletonInstance.SNIStatus;
+ UInt32 sniStatus = SNILoadHandle.SingletonInstance.Status;
if (sniStatus != TdsEnums.SNI_SUCCESS)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
@@ -5094,7 +5094,6 @@ internal bool TryProcessAltMetaData(int cColumns, TdsParserStateObject stateObj,
metaData = null;
_SqlMetaDataSet altMetaDataSet = new _SqlMetaDataSet(cColumns, null);
- int[] indexMap = new int[cColumns];
if (!stateObj.TryReadUInt16(out altMetaDataSet.id))
{
@@ -5191,12 +5190,8 @@ internal bool TryProcessAltMetaData(int cColumns, TdsParserStateObject stateObj,
break;
}
}
- indexMap[i] = i;
}
- altMetaDataSet.indexMap = indexMap;
- altMetaDataSet.visibleColumns = cColumns;
-
metaData = altMetaDataSet;
return true;
}
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserHelperClasses.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserHelperClasses.cs
index 21004f4be2..bf113efe3b 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserHelperClasses.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserHelperClasses.cs
@@ -511,51 +511,63 @@ public object Clone()
}
}
- sealed internal class _SqlMetaDataSet : ICloneable
+ sealed internal class _SqlMetaDataSet
{
internal ushort id; // for altrow-columns only
- internal int[] indexMap;
- internal int visibleColumns;
- internal DataTable schemaTable;
+ internal DataTable _schemaTable;
internal readonly SqlTceCipherInfoTable cekTable; // table of "column encryption keys" used for this metadataset
- internal readonly _SqlMetaData[] metaDataArray;
+ internal readonly _SqlMetaData[] _metaDataArray;
+ private int _hiddenColumnCount;
+ private int[] _visibleColumnMap;
internal _SqlMetaDataSet(int count, SqlTceCipherInfoTable cipherTable)
{
+ _hiddenColumnCount = -1;
cekTable = cipherTable;
- metaDataArray = new _SqlMetaData[count];
- for (int i = 0; i < metaDataArray.Length; ++i)
+ _metaDataArray = new _SqlMetaData[count];
+ for (int i = 0; i < _metaDataArray.Length; ++i)
{
- metaDataArray[i] = new _SqlMetaData(i);
+ _metaDataArray[i] = new _SqlMetaData(i);
}
}
private _SqlMetaDataSet(_SqlMetaDataSet original)
{
- this.id = original.id;
- // although indexMap is not immutable, in practice it is initialized once and then passed around
- this.indexMap = original.indexMap;
- this.visibleColumns = original.visibleColumns;
- this.schemaTable = original.schemaTable;
- if (original.metaDataArray == null)
+ id = original.id;
+ _hiddenColumnCount = original._hiddenColumnCount;
+ _visibleColumnMap = original._visibleColumnMap;
+ _schemaTable = original._schemaTable;
+ if (original._metaDataArray == null)
{
- metaDataArray = null;
+ _metaDataArray = null;
}
else
{
- metaDataArray = new _SqlMetaData[original.metaDataArray.Length];
- for (int idx = 0; idx < metaDataArray.Length; idx++)
+ _metaDataArray = new _SqlMetaData[original._metaDataArray.Length];
+ for (int idx = 0; idx < _metaDataArray.Length; idx++)
{
- metaDataArray[idx] = (_SqlMetaData)original.metaDataArray[idx].Clone();
+ _metaDataArray[idx] = (_SqlMetaData)original._metaDataArray[idx].Clone();
}
}
}
+ internal int VisibleColumnCount
+ {
+ get
+ {
+ if (_hiddenColumnCount == -1)
+ {
+ SetupHiddenColumns();
+ }
+ return Length - _hiddenColumnCount;
+ }
+ }
+
internal int Length
{
get
{
- return metaDataArray.Length;
+ return _metaDataArray.Length;
}
}
@@ -563,21 +575,66 @@ internal int Length
{
get
{
- return metaDataArray[index];
+ return _metaDataArray[index];
}
set
{
Debug.Assert(null == value, "used only by SqlBulkCopy");
- metaDataArray[index] = value;
+ _metaDataArray[index] = value;
}
}
- public object Clone()
+ public int GetVisibleColumnIndex(int index)
+ {
+ if (_hiddenColumnCount == -1)
+ {
+ SetupHiddenColumns();
+ }
+ if (_visibleColumnMap is null)
+ {
+ return index;
+ }
+ else
+ {
+ return _visibleColumnMap[index];
+ }
+ }
+
+ public _SqlMetaDataSet Clone()
{
return new _SqlMetaDataSet(this);
}
+
+ private void SetupHiddenColumns()
+ {
+ int hiddenColumnCount = 0;
+ for (int index = 0; index < Length; index++)
+ {
+ if (_metaDataArray[index].IsHidden)
+ {
+ hiddenColumnCount += 1;
+ }
+ }
+
+ if (hiddenColumnCount > 0)
+ {
+ int[] visibleColumnMap = new int[Length - hiddenColumnCount];
+ int mapIndex = 0;
+ for (int metaDataIndex = 0; metaDataIndex < Length; metaDataIndex++)
+ {
+ if (!_metaDataArray[metaDataIndex].IsHidden)
+ {
+ visibleColumnMap[mapIndex] = metaDataIndex;
+ mapIndex += 1;
+ }
+ }
+ _visibleColumnMap = visibleColumnMap;
+ }
+ _hiddenColumnCount = hiddenColumnCount;
+ }
}
+
sealed internal class _SqlMetaDataSetCollection : ICloneable
{
private readonly List<_SqlMetaDataSet> altMetaDataSetArray;
@@ -622,10 +679,10 @@ internal _SqlMetaDataSet GetAltMetaData(int id)
public object Clone()
{
_SqlMetaDataSetCollection result = new _SqlMetaDataSetCollection();
- result.metaDataSet = metaDataSet == null ? null : (_SqlMetaDataSet)metaDataSet.Clone();
+ result.metaDataSet = metaDataSet == null ? null : metaDataSet.Clone();
foreach (_SqlMetaDataSet set in altMetaDataSetArray)
{
- result.altMetaDataSetArray.Add((_SqlMetaDataSet)set.Clone());
+ result.altMetaDataSetArray.Add(set.Clone());
}
return result;
}
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs
index 6e8afce1ea..8d9057bc02 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs
@@ -2401,9 +2401,9 @@ private void OnTimeoutAsync(object state)
}
}
- private bool OnTimeoutSync()
+ private bool OnTimeoutSync(bool asyncClose = false)
{
- return OnTimeoutCore(TimeoutState.Running, TimeoutState.ExpiredSync);
+ return OnTimeoutCore(TimeoutState.Running, TimeoutState.ExpiredSync, asyncClose);
}
///
@@ -2412,8 +2412,9 @@ private bool OnTimeoutSync()
///
/// the state that is the expected current state, state will change only if this is correct
/// the state that will be changed to if the expected state is correct
+ /// any close action to be taken by an async task to avoid deadlock.
/// boolean value indicating whether the call changed the timeout state
- private bool OnTimeoutCore(int expectedState, int targetState)
+ private bool OnTimeoutCore(int expectedState, int targetState, bool asyncClose = false)
{
Debug.Assert(targetState == TimeoutState.ExpiredAsync || targetState == TimeoutState.ExpiredSync, "OnTimeoutCore must have an expiry state as the targetState");
@@ -2447,7 +2448,7 @@ private bool OnTimeoutCore(int expectedState, int targetState)
{
try
{
- SendAttention(mustTakeWriteLock: true);
+ SendAttention(mustTakeWriteLock: true, asyncClose);
}
catch (Exception e)
{
@@ -2988,7 +2989,7 @@ public void ReadAsyncCallback(IntPtr key, IntPtr packet, UInt32 error)
// synchrnously and then call OnTimeoutSync to force an atomic change of state.
if (TimeoutHasExpired)
{
- OnTimeoutSync();
+ OnTimeoutSync(asyncClose: true);
}
// try to change to the stopped state but only do so if currently in the running state
@@ -3475,7 +3476,7 @@ private void CancelWritePacket()
#pragma warning disable 420 // a reference to a volatile field will not be treated as volatile
- private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out UInt32 sniError, bool canAccumulate, bool callerHasConnectionLock)
+ private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out UInt32 sniError, bool canAccumulate, bool callerHasConnectionLock, bool asyncClose = false)
{
// Check for a stored exception
var delayedException = Interlocked.Exchange(ref _delayedWriteAsyncCallbackException, null);
@@ -3566,7 +3567,7 @@ private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out UInt32 sniEr
SqlClientEventSource.Log.TryTraceEvent(" write async returned error code {0}", (int)error);
AddError(_parser.ProcessSNIError(this));
- ThrowExceptionAndWarning();
+ ThrowExceptionAndWarning(false, asyncClose);
}
AssertValidState();
completion.SetResult(null);
@@ -3603,7 +3604,7 @@ private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out UInt32 sniEr
{
SqlClientEventSource.Log.TryTraceEvent(" write async returned error code {0}", (int)sniError);
AddError(_parser.ProcessSNIError(this));
- ThrowExceptionAndWarning(callerHasConnectionLock);
+ ThrowExceptionAndWarning(callerHasConnectionLock, false);
}
AssertValidState();
}
@@ -3613,7 +3614,7 @@ private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out UInt32 sniEr
#pragma warning restore 420
// Sends an attention signal - executing thread will consume attn.
- internal void SendAttention(bool mustTakeWriteLock = false)
+ internal void SendAttention(bool mustTakeWriteLock = false, bool asyncClose = false)
{
if (!_attentionSent)
{
@@ -3660,7 +3661,7 @@ internal void SendAttention(bool mustTakeWriteLock = false)
UInt32 sniError;
_parser._asyncWrite = false; // stop async write
- SNIWritePacket(Handle, attnPacket, out sniError, canAccumulate: false, callerHasConnectionLock: false);
+ SNIWritePacket(Handle, attnPacket, out sniError, canAccumulate: false, callerHasConnectionLock: false, asyncClose);
SqlClientEventSource.Log.TryTraceEvent(" Send Attention ASync.", "Info");
}
finally
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/ActivityCorrelator.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/ActivityCorrelator.cs
new file mode 100644
index 0000000000..ef6b9b6cd2
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/ActivityCorrelator.cs
@@ -0,0 +1,66 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Globalization;
+
+namespace Microsoft.Data.Common
+{
+ ///
+ /// This class defines the data structure for ActivityId used for correlated tracing between client (bid trace event) and server (XEvent).
+ /// It also includes all the APIs used to access the ActivityId. Note: ActivityId is thread based which is stored in TLS.
+ ///
+
+ internal static class ActivityCorrelator
+ {
+ internal sealed class ActivityId
+ {
+ internal readonly Guid Id;
+ internal readonly uint Sequence;
+
+ internal ActivityId(uint sequence)
+ {
+ this.Id = Guid.NewGuid();
+ this.Sequence = sequence;
+ }
+
+ public override string ToString()
+ {
+ return string.Format(CultureInfo.InvariantCulture, "{0}:{1}", this.Id, this.Sequence);
+ }
+ }
+
+ // Declare the ActivityId which will be stored in TLS. The Id is unique for each thread.
+ // The Sequence number will be incremented when each event happens.
+ // Correlation along threads is consistent with the current XEvent mechanism at server.
+ [ThreadStatic]
+ private static ActivityId t_tlsActivity;
+
+ ///
+ /// Get the current ActivityId
+ ///
+ internal static ActivityId Current
+ {
+ get
+ {
+ if (t_tlsActivity == null)
+ {
+ t_tlsActivity = new ActivityId(1);
+ }
+ return t_tlsActivity;
+ }
+ }
+
+ ///
+ /// Increment the sequence number and generate the new ActivityId
+ ///
+ /// ActivityId
+ internal static ActivityId Next()
+ {
+ t_tlsActivity = new ActivityId( (t_tlsActivity?.Sequence ?? 0) + 1);
+
+ return t_tlsActivity;
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Unix.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Unix.cs
new file mode 100644
index 0000000000..8b84feecfc
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Unix.cs
@@ -0,0 +1,24 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Data.Common
+{
+ ///
+ /// The class ADP defines the exceptions that are specific to the Adapters.
+ /// The class contains functions that take the proper informational variables and then construct
+ /// the appropriate exception with an error string obtained from the resource framework.
+ /// The exception is then returned to the caller, so that the caller may then throw from its
+ /// location so that the catcher of the exception will have the appropriate call stack.
+ /// This class is used so that there will be compile time checking of error messages.
+ /// The resource Framework.txt will ensure proper string text based on the appropriate locale.
+ ///
+ internal static partial class ADP
+ {
+ internal static object LocalMachineRegistryValue(string subkey, string queryvalue)
+ {
+ // No registry in non-Windows environments
+ return null;
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Windows.cs
new file mode 100644
index 0000000000..c9d0f8d91a
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Windows.cs
@@ -0,0 +1,49 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.InteropServices;
+using System.Runtime.Versioning;
+using System.Security;
+using System.Security.Permissions;
+using Microsoft.Win32;
+
+namespace Microsoft.Data.Common
+{
+ ///
+ /// The class ADP defines the exceptions that are specific to the Adapters.
+ /// The class contains functions that take the proper informational variables and then construct
+ /// the appropriate exception with an error string obtained from the resource framework.
+ /// The exception is then returned to the caller, so that the caller may then throw from its
+ /// location so that the catcher of the exception will have the appropriate call stack.
+ /// This class is used so that there will be compile time checking of error messages.
+ /// The resource Framework.txt will ensure proper string text based on the appropriate locale.
+ ///
+ internal static partial class ADP
+ {
+ [ResourceExposure(ResourceScope.Machine)]
+ [ResourceConsumption(ResourceScope.Machine)]
+ internal static object LocalMachineRegistryValue(string subkey, string queryvalue)
+ { // MDAC 77697
+ (new RegistryPermission(RegistryPermissionAccess.Read, "HKEY_LOCAL_MACHINE\\" + subkey)).Assert(); // MDAC 62028
+ try
+ {
+ using (RegistryKey key = Registry.LocalMachine.OpenSubKey(subkey, false))
+ {
+ return key?.GetValue(queryvalue);
+ }
+ }
+ catch (SecurityException e)
+ {
+ // Even though we assert permission - it's possible there are
+ // ACL's on registry that cause SecurityException to be thrown.
+ ADP.TraceExceptionWithoutRethrow(e);
+ return null;
+ }
+ finally
+ {
+ RegistryPermission.RevertAssert();
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.cs
new file mode 100644
index 0000000000..1866aa7fb3
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.cs
@@ -0,0 +1,1570 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using System.Data;
+using System.Data.Common;
+using System.Data.SqlTypes;
+using System.Diagnostics;
+using System.Globalization;
+using System.IO;
+using System.Runtime.CompilerServices;
+
+using System.Security;
+using System.Security.Permissions;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using System.Transactions;
+using Microsoft.Data.SqlClient;
+using IsolationLevel = System.Data.IsolationLevel;
+using Microsoft.Identity.Client;
+using Microsoft.SqlServer.Server;
+
+#if NETFRAMEWORK
+using Microsoft.Win32;
+using System.Reflection;
+using System.Runtime.ConstrainedExecution;
+using System.Runtime.InteropServices;
+using System.Runtime.Versioning;
+#endif
+
+namespace Microsoft.Data.Common
+{
+ ///
+ /// The class ADP defines the exceptions that are specific to the Adapters.
+ /// The class contains functions that take the proper informational variables and then construct
+ /// the appropriate exception with an error string obtained from the resource framework.
+ /// The exception is then returned to the caller, so that the caller may then throw from its
+ /// location so that the catcher of the exception will have the appropriate call stack.
+ /// This class is used so that there will be compile time checking of error messages.
+ /// The resource Framework.txt will ensure proper string text based on the appropriate locale.
+ ///
+ internal static partial class ADP
+ {
+ // NOTE: Initializing a Task in SQL CLR requires the "UNSAFE" permission set (http://msdn.microsoft.com/en-us/library/ms172338.aspx)
+ // Therefore we are lazily initializing these Tasks to avoid forcing customers to use the "UNSAFE" set when they are actually using no Async features
+ private static Task s_trueTask;
+ internal static Task TrueTask => s_trueTask ??= Task.FromResult(true);
+
+ private static Task s_falseTask;
+ internal static Task FalseTask => s_falseTask ??= Task.FromResult(false);
+
+ internal const CompareOptions DefaultCompareOptions = CompareOptions.IgnoreKanaType | CompareOptions.IgnoreWidth | CompareOptions.IgnoreCase;
+ internal const int DefaultConnectionTimeout = DbConnectionStringDefaults.ConnectTimeout;
+ ///
+ /// Infinite connection timeout identifier in seconds
+ ///
+ internal const int InfiniteConnectionTimeout = 0;
+ ///
+ /// Max duration for buffer in seconds
+ ///
+ internal const int MaxBufferAccessTokenExpiry = 600;
+
+ #region UDT
+#if NETFRAMEWORK
+ private static readonly MethodInfo s_method = typeof(InvalidUdtException).GetMethod("Create", BindingFlags.NonPublic | BindingFlags.Static);
+#endif
+ ///
+ /// Calls "InvalidUdtException.Create" method when an invalid UDT occurs.
+ ///
+ internal static InvalidUdtException CreateInvalidUdtException(Type udtType, string resourceReasonName)
+ {
+ InvalidUdtException e =
+#if NETFRAMEWORK
+ (InvalidUdtException)s_method.Invoke(null, new object[] { udtType, resourceReasonName });
+ ADP.TraceExceptionAsReturnValue(e);
+#else
+ InvalidUdtException.Create(udtType, resourceReasonName);
+#endif
+ return e;
+ }
+ #endregion
+
+ static private void TraceException(string trace, Exception e)
+ {
+ Debug.Assert(null != e, "TraceException: null Exception");
+ if (e is not null)
+ {
+ SqlClientEventSource.Log.TryTraceEvent(trace, e);
+ }
+ }
+
+ internal static void TraceExceptionAsReturnValue(Exception e)
+ {
+ TraceException(" '{0}'", e);
+ }
+
+ internal static void TraceExceptionWithoutRethrow(Exception e)
+ {
+ Debug.Assert(IsCatchableExceptionType(e), "Invalid exception type, should have been re-thrown!");
+ TraceException(" '{0}'", e);
+ }
+
+ internal static bool IsEmptyArray(string[] array) => (array is null) || (array.Length == 0);
+
+ internal static bool IsNull(object value)
+ {
+ if ((value is null) || (DBNull.Value == value))
+ {
+ return true;
+ }
+ INullable nullable = (value as INullable);
+ return ((nullable is not null) && nullable.IsNull);
+ }
+
+ internal static Exception ExceptionWithStackTrace(Exception e)
+ {
+ try
+ {
+ throw e;
+ }
+ catch (Exception caught)
+ {
+ return caught;
+ }
+ }
+
+#region COM+ exceptions
+ internal static ArgumentException Argument(string error)
+ {
+ ArgumentException e = new(error);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static ArgumentException Argument(string error, Exception inner)
+ {
+ ArgumentException e = new(error, inner);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static ArgumentException Argument(string error, string parameter)
+ {
+ ArgumentException e = new(error, parameter);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static ArgumentNullException ArgumentNull(string parameter)
+ {
+ ArgumentNullException e = new(parameter);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static ArgumentNullException ArgumentNull(string parameter, string error)
+ {
+ ArgumentNullException e = new(parameter, error);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static ArgumentOutOfRangeException ArgumentOutOfRange(string parameterName)
+ {
+ ArgumentOutOfRangeException e = new(parameterName);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static ArgumentOutOfRangeException ArgumentOutOfRange(string message, string parameterName)
+ {
+ ArgumentOutOfRangeException e = new(parameterName, message);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static IndexOutOfRangeException IndexOutOfRange(string error)
+ {
+ IndexOutOfRangeException e = new(error);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static IndexOutOfRangeException IndexOutOfRange(int value)
+ {
+ IndexOutOfRangeException e = new(value.ToString(CultureInfo.InvariantCulture));
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static IndexOutOfRangeException IndexOutOfRange()
+ {
+ IndexOutOfRangeException e = new();
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static InvalidOperationException InvalidOperation(string error, Exception inner)
+ {
+ InvalidOperationException e = new(error, inner);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static OverflowException Overflow(string error) => Overflow(error, null);
+
+ internal static OverflowException Overflow(string error, Exception inner)
+ {
+ OverflowException e = new(error, inner);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static TimeoutException TimeoutException(string error, Exception inner = null)
+ {
+ TimeoutException e = new(error, inner);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static TypeLoadException TypeLoad(string error)
+ {
+ TypeLoadException e = new(error);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static InvalidCastException InvalidCast()
+ {
+ InvalidCastException e = new();
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static InvalidCastException InvalidCast(string error)
+ {
+ return InvalidCast(error, null);
+ }
+
+ internal static InvalidCastException InvalidCast(string error, Exception inner)
+ {
+ InvalidCastException e = new(error, inner);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static InvalidOperationException InvalidOperation(string error)
+ {
+ InvalidOperationException e = new(error);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static IOException IO(string error)
+ {
+ IOException e = new(error);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+ internal static IOException IO(string error, Exception inner)
+ {
+ IOException e = new(error, inner);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static NotSupportedException NotSupported()
+ {
+ NotSupportedException e = new();
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static NotSupportedException NotSupported(string error)
+ {
+ NotSupportedException e = new(error);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static InvalidOperationException DataAdapter(string error) => InvalidOperation(error);
+
+ private static InvalidOperationException Provider(string error) => InvalidOperation(error);
+
+ internal static ArgumentException InvalidMultipartName(string property, string value)
+ {
+ ArgumentException e = new(StringsHelper.GetString(Strings.ADP_InvalidMultipartName, StringsHelper.GetString(property), value));
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static ArgumentException InvalidMultipartNameIncorrectUsageOfQuotes(string property, string value)
+ {
+ ArgumentException e = new(StringsHelper.GetString(Strings.ADP_InvalidMultipartNameQuoteUsage, StringsHelper.GetString(property), value));
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static ArgumentException InvalidMultipartNameToManyParts(string property, string value, int limit)
+ {
+ ArgumentException e = new(StringsHelper.GetString(Strings.ADP_InvalidMultipartNameToManyParts, StringsHelper.GetString(property), value, limit));
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static ObjectDisposedException ObjectDisposed(object instance)
+ {
+ ObjectDisposedException e = new(instance.GetType().Name);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static InvalidOperationException MethodCalledTwice(string method)
+ {
+ InvalidOperationException e = new(StringsHelper.GetString(Strings.ADP_CalledTwice, method));
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static ArgumentOutOfRangeException ArgumentOutOfRange(string message, string parameterName, object value)
+ {
+ ArgumentOutOfRangeException e = new(parameterName, value, message);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+#endregion
+
+#region Helper Functions
+ internal static ArgumentOutOfRangeException NotSupportedEnumerationValue(Type type, string value, string method)
+ => ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_NotSupportedEnumerationValue, type.Name, value, method), type.Name);
+
+ internal static void CheckArgumentNull(object value, string parameterName)
+ {
+ if (value is null)
+ {
+ throw ArgumentNull(parameterName);
+ }
+ }
+
+ internal static bool IsCatchableExceptionType(Exception e)
+ {
+ // only StackOverflowException & ThreadAbortException are sealed classes
+ // a 'catchable' exception is defined by what it is not.
+ Debug.Assert(e != null, "Unexpected null exception!");
+ Type type = e.GetType();
+
+ return ((type != typeof(StackOverflowException)) &&
+ (type != typeof(OutOfMemoryException)) &&
+ (type != typeof(ThreadAbortException)) &&
+ (type != typeof(NullReferenceException)) &&
+ (type != typeof(AccessViolationException)) &&
+ !typeof(SecurityException).IsAssignableFrom(type));
+ }
+
+ internal static bool IsCatchableOrSecurityExceptionType(Exception e)
+ {
+ // a 'catchable' exception is defined by what it is not.
+ // since IsCatchableExceptionType defined SecurityException as not 'catchable'
+ // this method will return true for SecurityException has being catchable.
+
+ // the other way to write this method is, but then SecurityException is checked twice
+ // return ((e is SecurityException) || IsCatchableExceptionType(e));
+
+ // only StackOverflowException & ThreadAbortException are sealed classes
+ Debug.Assert(e != null, "Unexpected null exception!");
+ Type type = e.GetType();
+
+ return ((type != typeof(StackOverflowException)) &&
+ (type != typeof(OutOfMemoryException)) &&
+ (type != typeof(ThreadAbortException)) &&
+ (type != typeof(NullReferenceException)) &&
+ (type != typeof(AccessViolationException)));
+ }
+
+ // Invalid Enumeration
+ internal static ArgumentOutOfRangeException InvalidEnumerationValue(Type type, int value)
+ => ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_InvalidEnumerationValue, type.Name, value.ToString(CultureInfo.InvariantCulture)), type.Name);
+
+ internal static ArgumentOutOfRangeException InvalidCommandBehavior(CommandBehavior value)
+ {
+ Debug.Assert((0 > (int)value) || ((int)value > 0x3F), "valid CommandType " + value.ToString());
+
+ return InvalidEnumerationValue(typeof(CommandBehavior), (int)value);
+ }
+
+ internal static void ValidateCommandBehavior(CommandBehavior value)
+ {
+ if (((int)value < 0) || (0x3F < (int)value))
+ {
+ throw InvalidCommandBehavior(value);
+ }
+ }
+
+ internal static ArgumentOutOfRangeException InvalidUserDefinedTypeSerializationFormat(Format value)
+ {
+#if DEBUG
+ switch (value)
+ {
+ case Format.Unknown:
+ case Format.Native:
+ case Format.UserDefined:
+ Debug.Assert(false, "valid UserDefinedTypeSerializationFormat " + value.ToString());
+ break;
+ }
+#endif
+ return InvalidEnumerationValue(typeof(Format), (int)value);
+ }
+
+ internal static ArgumentOutOfRangeException NotSupportedUserDefinedTypeSerializationFormat(Format value, string method)
+ => NotSupportedEnumerationValue(typeof(Format), value.ToString(), method);
+
+ internal static ArgumentException InvalidArgumentLength(string argumentName, int limit)
+ => Argument(StringsHelper.GetString(Strings.ADP_InvalidArgumentLength, argumentName, limit));
+
+ internal static ArgumentException MustBeReadOnly(string argumentName) => Argument(StringsHelper.GetString(Strings.ADP_MustBeReadOnly, argumentName));
+
+ internal static Exception CreateSqlException(MsalException msalException, SqlConnectionString connectionOptions, SqlInternalConnectionTds sender, string username)
+ {
+ // Error[0]
+ SqlErrorCollection sqlErs = new();
+
+ sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS,
+ connectionOptions.DataSource,
+ StringsHelper.GetString(Strings.SQL_MSALFailure, username, connectionOptions.Authentication.ToString("G")),
+ ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));
+
+ // Error[1]
+ string errorMessage1 = StringsHelper.GetString(Strings.SQL_MSALInnerException, msalException.ErrorCode);
+ sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS,
+ connectionOptions.DataSource, errorMessage1,
+ ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));
+
+ // Error[2]
+ if (!string.IsNullOrEmpty(msalException.Message))
+ {
+ sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS,
+ connectionOptions.DataSource, msalException.Message,
+ ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));
+ }
+ return SqlException.CreateException(sqlErs, "", sender);
+ }
+
+#endregion
+
+#region CommandBuilder, Command, BulkCopy
+ ///
+ /// This allows the caller to determine if it is an error or not for the quotedString to not be quoted
+ ///
+ /// The return value is true if the string was quoted and false if it was not
+ internal static bool RemoveStringQuotes(string quotePrefix, string quoteSuffix, string quotedString, out string unquotedString)
+ {
+ int prefixLength = quotePrefix is null ? 0 : quotePrefix.Length;
+ int suffixLength = quoteSuffix is null ? 0 : quoteSuffix.Length;
+
+ if ((suffixLength + prefixLength) == 0)
+ {
+ unquotedString = quotedString;
+ return true;
+ }
+
+ if (quotedString is null)
+ {
+ unquotedString = quotedString;
+ return false;
+ }
+
+ int quotedStringLength = quotedString.Length;
+
+ // is the source string too short to be quoted
+ if (quotedStringLength < prefixLength + suffixLength)
+ {
+ unquotedString = quotedString;
+ return false;
+ }
+
+ // is the prefix present?
+ if (prefixLength > 0)
+ {
+ if (!quotedString.StartsWith(quotePrefix, StringComparison.Ordinal))
+ {
+ unquotedString = quotedString;
+ return false;
+ }
+ }
+
+ // is the suffix present?
+ if (suffixLength > 0)
+ {
+ if (!quotedString.EndsWith(quoteSuffix, StringComparison.Ordinal))
+ {
+ unquotedString = quotedString;
+ return false;
+ }
+ unquotedString = quotedString.Substring(prefixLength, quotedStringLength - (prefixLength + suffixLength))
+ .Replace(quoteSuffix + quoteSuffix, quoteSuffix);
+ }
+ else
+ {
+ unquotedString = quotedString.Substring(prefixLength, quotedStringLength - prefixLength);
+ }
+ return true;
+ }
+
+ internal static string BuildQuotedString(string quotePrefix, string quoteSuffix, string unQuotedString)
+ {
+ var resultString = new StringBuilder(unQuotedString.Length + quoteSuffix.Length + quoteSuffix.Length);
+ AppendQuotedString(resultString, quotePrefix, quoteSuffix, unQuotedString);
+ return resultString.ToString();
+ }
+
+ internal static string AppendQuotedString(StringBuilder buffer, string quotePrefix, string quoteSuffix, string unQuotedString)
+ {
+ Debug.Assert(buffer is not null, "buffer parameter must be initialized!");
+
+ if (!string.IsNullOrEmpty(quotePrefix))
+ {
+ buffer.Append(quotePrefix);
+ }
+
+ // Assuming that the suffix is escaped by doubling it. i.e. foo"bar becomes "foo""bar".
+ if (!string.IsNullOrEmpty(quoteSuffix))
+ {
+ int start = buffer.Length;
+ buffer.Append(unQuotedString);
+ buffer.Replace(quoteSuffix, quoteSuffix + quoteSuffix, start, unQuotedString.Length);
+ buffer.Append(quoteSuffix);
+ }
+ else
+ {
+ buffer.Append(unQuotedString);
+ }
+
+ return buffer.ToString();
+ }
+
+ internal static string BuildMultiPartName(string[] strings)
+ {
+ StringBuilder bld = new();
+ // Assume we want to build a full multi-part name with all parts except trimming separators for
+ // leading empty names (null or empty strings, but not whitespace). Separators in the middle
+ // should be added, even if the name part is null/empty, to maintain proper location of the parts.
+ for (int i = 0; i < strings.Length; i++)
+ {
+ if (0 < bld.Length)
+ {
+ bld.Append('.');
+ }
+ if (strings[i] is not null && 0 != strings[i].Length)
+ {
+ bld.Append(BuildQuotedString("[", "]", strings[i]));
+ }
+ }
+ return bld.ToString();
+ }
+
+ // global constant strings
+ internal const string ColumnEncryptionSystemProviderNamePrefix = "MSSQL_";
+ internal const string Command = "Command";
+ internal const string Connection = "Connection";
+ internal const string Parameter = "Parameter";
+ internal const string ParameterName = "ParameterName";
+ internal const string ParameterSetPosition = "set_Position";
+
+ internal const int DefaultCommandTimeout = 30;
+ internal const float FailoverTimeoutStep = 0.08F; // fraction of timeout to use for fast failover connections
+
+ internal const int CharSize = UnicodeEncoding.CharSize;
+
+ internal static Delegate FindBuilder(MulticastDelegate mcd)
+ {
+ foreach (Delegate del in mcd?.GetInvocationList())
+ {
+ if (del.Target is DbCommandBuilder)
+ return del;
+ }
+
+ return null;
+ }
+
+ internal static long TimerCurrent() => DateTime.UtcNow.ToFileTimeUtc();
+
+ internal static long TimerFromSeconds(int seconds)
+ {
+ long result = checked((long)seconds * TimeSpan.TicksPerSecond);
+ return result;
+ }
+
+ internal static long TimerFromMilliseconds(long milliseconds)
+ {
+ long result = checked(milliseconds * TimeSpan.TicksPerMillisecond);
+ return result;
+ }
+
+ internal static bool TimerHasExpired(long timerExpire)
+ {
+ bool result = TimerCurrent() > timerExpire;
+ return result;
+ }
+
+ internal static long TimerRemaining(long timerExpire)
+ {
+ long timerNow = TimerCurrent();
+ long result = checked(timerExpire - timerNow);
+ return result;
+ }
+
+ internal static long TimerRemainingMilliseconds(long timerExpire)
+ {
+ long result = TimerToMilliseconds(TimerRemaining(timerExpire));
+ return result;
+ }
+
+ internal static long TimerRemainingSeconds(long timerExpire)
+ {
+ long result = TimerToSeconds(TimerRemaining(timerExpire));
+ return result;
+ }
+
+ internal static long TimerToMilliseconds(long timerValue)
+ {
+ long result = timerValue / TimeSpan.TicksPerMillisecond;
+ return result;
+ }
+
+ private static long TimerToSeconds(long timerValue)
+ {
+ long result = timerValue / TimeSpan.TicksPerSecond;
+ return result;
+ }
+
+ ///
+ /// Note: In Longhorn you'll be able to rename a machine without
+ /// rebooting. Therefore, don't cache this machine name.
+ ///
+ [EnvironmentPermission(SecurityAction.Assert, Read = "COMPUTERNAME")]
+ internal static string MachineName() => Environment.MachineName;
+
+ internal static Transaction GetCurrentTransaction()
+ {
+ Transaction transaction = Transaction.Current;
+ return transaction;
+ }
+
+ internal static bool IsDirection(DbParameter value, ParameterDirection condition)
+ {
+#if DEBUG
+ switch (condition)
+ { // @perfnote: Enum.IsDefined
+ case ParameterDirection.Input:
+ case ParameterDirection.Output:
+ case ParameterDirection.InputOutput:
+ case ParameterDirection.ReturnValue:
+ break;
+ default:
+ throw ADP.InvalidParameterDirection(condition);
+ }
+#endif
+ return (condition == (condition & value.Direction));
+ }
+
+ internal static void IsNullOrSqlType(object value, out bool isNull, out bool isSqlType)
+ {
+ if ((value is null) || (value == DBNull.Value))
+ {
+ isNull = true;
+ isSqlType = false;
+ }
+ else
+ {
+ if (value is INullable nullable)
+ {
+ isNull = nullable.IsNull;
+ // Duplicated from DataStorage.cs
+ // For back-compat, SqlXml is not in this list
+ isSqlType = ((value is SqlBinary) ||
+ (value is SqlBoolean) ||
+ (value is SqlByte) ||
+ (value is SqlBytes) ||
+ (value is SqlChars) ||
+ (value is SqlDateTime) ||
+ (value is SqlDecimal) ||
+ (value is SqlDouble) ||
+ (value is SqlGuid) ||
+ (value is SqlInt16) ||
+ (value is SqlInt32) ||
+ (value is SqlInt64) ||
+ (value is SqlMoney) ||
+ (value is SqlSingle) ||
+ (value is SqlString));
+ }
+ else
+ {
+ isNull = false;
+ isSqlType = false;
+ }
+ }
+ }
+
+ private static Version s_systemDataVersion;
+
+ internal static Version GetAssemblyVersion()
+ {
+ // NOTE: Using lazy thread-safety since we don't care if two threads both happen to update the value at the same time
+ if (s_systemDataVersion is null)
+ {
+ s_systemDataVersion = new Version(ThisAssembly.InformationalVersion);
+ }
+
+ return s_systemDataVersion;
+ }
+
+
+ private const string ONDEMAND_PREFIX = "-ondemand";
+ private const string AZURE_SYNAPSE = "-ondemand.sql.azuresynapse.";
+
+ internal static bool IsAzureSynapseOnDemandEndpoint(string dataSource)
+ {
+ return IsEndpoint(dataSource, ONDEMAND_PREFIX) || dataSource.Contains(AZURE_SYNAPSE);
+ }
+
+ internal static readonly string[] s_azureSqlServerEndpoints = { StringsHelper.GetString(Strings.AZURESQL_GenericEndpoint),
+ StringsHelper.GetString(Strings.AZURESQL_GermanEndpoint),
+ StringsHelper.GetString(Strings.AZURESQL_UsGovEndpoint),
+ StringsHelper.GetString(Strings.AZURESQL_ChinaEndpoint)};
+
+ internal static bool IsAzureSqlServerEndpoint(string dataSource)
+ {
+ return IsEndpoint(dataSource, null);
+ }
+
+ // This method assumes dataSource parameter is in TCP connection string format.
+ private static bool IsEndpoint(string dataSource, string prefix)
+ {
+ int length = dataSource.Length;
+ // remove server port
+ int foundIndex = dataSource.LastIndexOf(',');
+ if (foundIndex >= 0)
+ {
+ length = foundIndex;
+ }
+
+ // check for the instance name
+ foundIndex = dataSource.LastIndexOf('\\', length - 1, length - 1);
+ if (foundIndex > 0)
+ {
+ length = foundIndex;
+ }
+
+ // trim trailing whitespace
+ while (length > 0 && char.IsWhiteSpace(dataSource[length - 1]))
+ {
+ length -= 1;
+ }
+
+ // check if servername ends with any endpoints
+ for (int index = 0; index < s_azureSqlServerEndpoints.Length; index++)
+ {
+ string endpoint = string.IsNullOrEmpty(prefix) ? s_azureSqlServerEndpoints[index] : prefix + s_azureSqlServerEndpoints[index];
+ if (length > endpoint.Length)
+ {
+ if (string.Compare(dataSource, length - endpoint.Length, endpoint, 0, endpoint.Length, StringComparison.OrdinalIgnoreCase) == 0)
+ {
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ internal static ArgumentException SingleValuedProperty(string propertyName, string value)
+ {
+ ArgumentException e = new(StringsHelper.GetString(Strings.ADP_SingleValuedProperty, propertyName, value));
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static ArgumentException DoubleValuedProperty(string propertyName, string value1, string value2)
+ {
+ ArgumentException e = new(StringsHelper.GetString(Strings.ADP_DoubleValuedProperty, propertyName, value1, value2));
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static ArgumentException InvalidPrefixSuffix()
+ {
+ ArgumentException e = new(StringsHelper.GetString(Strings.ADP_InvalidPrefixSuffix));
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+#endregion
+
+#region DbConnectionOptions, DataAccess
+ internal static ArgumentException ConnectionStringSyntax(int index) => Argument(StringsHelper.GetString(Strings.ADP_ConnectionStringSyntax, index));
+
+ internal static ArgumentException KeywordNotSupported(string keyword) => Argument(StringsHelper.GetString(Strings.ADP_KeywordNotSupported, keyword));
+
+ internal static Exception InvalidConnectionOptionValue(string key) => InvalidConnectionOptionValue(key, null);
+
+ internal static Exception InvalidConnectionOptionValue(string key, Exception inner)
+ => Argument(StringsHelper.GetString(Strings.ADP_InvalidConnectionOptionValue, key), inner);
+
+ internal static Exception InvalidConnectionOptionValueLength(string key, int limit)
+ => Argument(StringsHelper.GetString(Strings.ADP_InvalidConnectionOptionValueLength, key, limit));
+
+ internal static Exception MissingConnectionOptionValue(string key, string requiredAdditionalKey)
+ => Argument(StringsHelper.GetString(Strings.ADP_MissingConnectionOptionValue, key, requiredAdditionalKey));
+
+ internal static InvalidOperationException InvalidDataDirectory() => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidDataDirectory));
+
+ internal static ArgumentException CollectionRemoveInvalidObject(Type itemType, ICollection collection)
+ => Argument(StringsHelper.GetString(Strings.ADP_CollectionRemoveInvalidObject, itemType.Name, collection.GetType().Name)); // MDAC 68201
+
+ internal static ArgumentNullException CollectionNullValue(string parameter, Type collection, Type itemType)
+ => ArgumentNull(parameter, StringsHelper.GetString(Strings.ADP_CollectionNullValue, collection.Name, itemType.Name));
+
+ internal static IndexOutOfRangeException CollectionIndexInt32(int index, Type collection, int count)
+ => IndexOutOfRange(StringsHelper.GetString(Strings.ADP_CollectionIndexInt32, index.ToString(CultureInfo.InvariantCulture), collection.Name, count.ToString(CultureInfo.InvariantCulture)));
+
+ internal static IndexOutOfRangeException CollectionIndexString(Type itemType, string propertyName, string propertyValue, Type collection)
+ => IndexOutOfRange(StringsHelper.GetString(Strings.ADP_CollectionIndexString, itemType.Name, propertyName, propertyValue, collection.Name));
+
+ internal static InvalidCastException CollectionInvalidType(Type collection, Type itemType, object invalidValue)
+ => InvalidCast(StringsHelper.GetString(Strings.ADP_CollectionInvalidType, collection.Name, itemType.FullName, invalidValue.GetType().FullName));
+
+ internal static ArgumentException ConvertFailed(Type fromType, Type toType, Exception innerException)
+ => ADP.Argument(StringsHelper.GetString(Strings.SqlConvert_ConvertFailed, fromType.FullName, toType.FullName), innerException);
+
+ internal static ArgumentException InvalidMinMaxPoolSizeValues()
+ => ADP.Argument(StringsHelper.GetString(Strings.ADP_InvalidMinMaxPoolSizeValues));
+#endregion
+
+#region DbConnection
+ private static string ConnectionStateMsg(ConnectionState state)
+ { // MDAC 82165, if the ConnectionState enum to msg the localization looks weird
+ return state switch
+ {
+ (ConnectionState.Closed) or (ConnectionState.Connecting | ConnectionState.Broken) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Closed),
+ (ConnectionState.Connecting) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Connecting),
+ (ConnectionState.Open) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Open),
+ (ConnectionState.Open | ConnectionState.Executing) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_OpenExecuting),
+ (ConnectionState.Open | ConnectionState.Fetching) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_OpenFetching),
+ _ => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg, state.ToString()),
+ };
+ }
+
+ internal static InvalidOperationException NoConnectionString()
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_NoConnectionString));
+
+ internal static NotImplementedException MethodNotImplemented([CallerMemberName] string methodName = "")
+ {
+ NotImplementedException e = new(methodName);
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+#endregion
+
+#region Stream
+ internal static Exception StreamClosed([CallerMemberName] string method = "") => InvalidOperation(StringsHelper.GetString(Strings.ADP_StreamClosed, method));
+
+ static internal Exception InvalidSeekOrigin(string parameterName) => ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_InvalidSeekOrigin), parameterName);
+
+ internal static IOException ErrorReadingFromStream(Exception internalException) => IO(StringsHelper.GetString(Strings.SqlMisc_StreamErrorMessage), internalException);
+#endregion
+
+#region Generic Data Provider Collection
+ internal static ArgumentException ParametersIsNotParent(Type parameterType, ICollection collection)
+ => Argument(StringsHelper.GetString(Strings.ADP_CollectionIsNotParent, parameterType.Name, collection.GetType().Name));
+
+ internal static ArgumentException ParametersIsParent(Type parameterType, ICollection collection)
+ => Argument(StringsHelper.GetString(Strings.ADP_CollectionIsNotParent, parameterType.Name, collection.GetType().Name));
+#endregion
+
+#region ConnectionUtil
+ internal enum InternalErrorCode
+ {
+ UnpooledObjectHasOwner = 0,
+ UnpooledObjectHasWrongOwner = 1,
+ PushingObjectSecondTime = 2,
+ PooledObjectHasOwner = 3,
+ PooledObjectInPoolMoreThanOnce = 4,
+ CreateObjectReturnedNull = 5,
+ NewObjectCannotBePooled = 6,
+ NonPooledObjectUsedMoreThanOnce = 7,
+ AttemptingToPoolOnRestrictedToken = 8,
+ // ConnectionOptionsInUse = 9,
+ ConvertSidToStringSidWReturnedNull = 10,
+ // UnexpectedTransactedObject = 11,
+ AttemptingToConstructReferenceCollectionOnStaticObject = 12,
+ AttemptingToEnlistTwice = 13,
+ CreateReferenceCollectionReturnedNull = 14,
+ PooledObjectWithoutPool = 15,
+ UnexpectedWaitAnyResult = 16,
+ SynchronousConnectReturnedPending = 17,
+ CompletedConnectReturnedPending = 18,
+
+ NameValuePairNext = 20,
+ InvalidParserState1 = 21,
+ InvalidParserState2 = 22,
+ InvalidParserState3 = 23,
+
+ InvalidBuffer = 30,
+
+ UnimplementedSMIMethod = 40,
+ InvalidSmiCall = 41,
+
+ SqlDependencyObtainProcessDispatcherFailureObjectHandle = 50,
+ SqlDependencyProcessDispatcherFailureCreateInstance = 51,
+ SqlDependencyProcessDispatcherFailureAppDomain = 52,
+ SqlDependencyCommandHashIsNotAssociatedWithNotification = 53,
+
+ UnknownTransactionFailure = 60,
+ }
+
+ internal static Exception InternalError(InternalErrorCode internalError)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_InternalProviderError, (int)internalError));
+
+ internal static Exception ClosedConnectionError() => InvalidOperation(StringsHelper.GetString(Strings.ADP_ClosedConnectionError));
+ internal static Exception ConnectionAlreadyOpen(ConnectionState state)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_ConnectionAlreadyOpen, ADP.ConnectionStateMsg(state)));
+
+ internal static Exception TransactionPresent() => InvalidOperation(StringsHelper.GetString(Strings.ADP_TransactionPresent));
+
+ internal static Exception LocalTransactionPresent() => InvalidOperation(StringsHelper.GetString(Strings.ADP_LocalTransactionPresent));
+
+ internal static Exception OpenConnectionPropertySet(string property, ConnectionState state)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_OpenConnectionPropertySet, property, ADP.ConnectionStateMsg(state)));
+
+ internal static Exception EmptyDatabaseName() => Argument(StringsHelper.GetString(Strings.ADP_EmptyDatabaseName));
+
+ internal enum ConnectionError
+ {
+ BeginGetConnectionReturnsNull,
+ GetConnectionReturnsNull,
+ ConnectionOptionsMissing,
+ CouldNotSwitchToClosedPreviouslyOpenedState,
+ }
+
+ internal static Exception InternalConnectionError(ConnectionError internalError)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_InternalConnectionError, (int)internalError));
+
+ internal static Exception InvalidConnectRetryCountValue() => Argument(StringsHelper.GetString(Strings.SQLCR_InvalidConnectRetryCountValue));
+
+ internal static Exception InvalidConnectRetryIntervalValue() => Argument(StringsHelper.GetString(Strings.SQLCR_InvalidConnectRetryIntervalValue));
+#endregion
+
+#region DbDataReader
+ internal static Exception DataReaderClosed([CallerMemberName] string method = "")
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_DataReaderClosed, method));
+
+ internal static ArgumentOutOfRangeException InvalidSourceBufferIndex(int maxLen, long srcOffset, string parameterName)
+ => ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_InvalidSourceBufferIndex,
+ maxLen.ToString(CultureInfo.InvariantCulture),
+ srcOffset.ToString(CultureInfo.InvariantCulture)), parameterName);
+
+ internal static ArgumentOutOfRangeException InvalidDestinationBufferIndex(int maxLen, int dstOffset, string parameterName)
+ => ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_InvalidDestinationBufferIndex,
+ maxLen.ToString(CultureInfo.InvariantCulture),
+ dstOffset.ToString(CultureInfo.InvariantCulture)), parameterName);
+
+ internal static IndexOutOfRangeException InvalidBufferSizeOrIndex(int numBytes, int bufferIndex)
+ => IndexOutOfRange(StringsHelper.GetString(Strings.SQL_InvalidBufferSizeOrIndex,
+ numBytes.ToString(CultureInfo.InvariantCulture),
+ bufferIndex.ToString(CultureInfo.InvariantCulture)));
+
+ internal static Exception InvalidDataLength(long length)
+ => IndexOutOfRange(StringsHelper.GetString(Strings.SQL_InvalidDataLength, length.ToString(CultureInfo.InvariantCulture)));
+
+ internal static bool CompareInsensitiveInvariant(string strvalue, string strconst)
+ => 0 == CultureInfo.InvariantCulture.CompareInfo.Compare(strvalue, strconst, CompareOptions.IgnoreCase);
+
+ internal static int DstCompare(string strA, string strB) // this is null safe
+ => CultureInfo.CurrentCulture.CompareInfo.Compare(strA, strB, ADP.DefaultCompareOptions);
+
+ internal static void SetCurrentTransaction(Transaction transaction) => Transaction.Current = transaction;
+
+ internal static Exception NonSeqByteAccess(long badIndex, long currIndex, string method)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_NonSeqByteAccess,
+ badIndex.ToString(CultureInfo.InvariantCulture),
+ currIndex.ToString(CultureInfo.InvariantCulture),
+ method));
+
+ internal static Exception NegativeParameter(string parameterName) => InvalidOperation(StringsHelper.GetString(Strings.ADP_NegativeParameter, parameterName));
+
+ internal static Exception InvalidXmlMissingColumn(string collectionName, string columnName)
+ => Argument(StringsHelper.GetString(Strings.MDF_InvalidXmlMissingColumn, collectionName, columnName));
+
+ internal static InvalidOperationException AsyncOperationPending() => InvalidOperation(StringsHelper.GetString(Strings.ADP_PendingAsyncOperation));
+#endregion
+
+#region IDbCommand
+ // IDbCommand.CommandType
+ static internal ArgumentOutOfRangeException InvalidCommandType(CommandType value)
+ {
+#if DEBUG
+ switch (value)
+ {
+ case CommandType.Text:
+ case CommandType.StoredProcedure:
+ case CommandType.TableDirect:
+ Debug.Assert(false, "valid CommandType " + value.ToString());
+ break;
+ }
+#endif
+ return InvalidEnumerationValue(typeof(CommandType), (int)value);
+ }
+
+ internal static Exception TooManyRestrictions(string collectionName)
+ => Argument(StringsHelper.GetString(Strings.MDF_TooManyRestrictions, collectionName));
+
+ internal static Exception CommandTextRequired(string method)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_CommandTextRequired, method));
+
+ internal static Exception UninitializedParameterSize(int index, Type dataType)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_UninitializedParameterSize, index.ToString(CultureInfo.InvariantCulture), dataType.Name));
+
+ internal static Exception PrepareParameterType(DbCommand cmd)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_PrepareParameterType, cmd.GetType().Name));
+
+ internal static Exception PrepareParameterSize(DbCommand cmd)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_PrepareParameterSize, cmd.GetType().Name));
+
+ internal static Exception PrepareParameterScale(DbCommand cmd, string type)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_PrepareParameterScale, cmd.GetType().Name, type));
+
+ internal static Exception MismatchedAsyncResult(string expectedMethod, string gotMethod)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_MismatchedAsyncResult, expectedMethod, gotMethod));
+
+ // IDataParameter.SourceVersion
+ internal static ArgumentOutOfRangeException InvalidDataRowVersion(DataRowVersion value)
+ {
+#if DEBUG
+ switch (value)
+ {
+ case DataRowVersion.Default:
+ case DataRowVersion.Current:
+ case DataRowVersion.Original:
+ case DataRowVersion.Proposed:
+ Debug.Fail($"Invalid DataRowVersion {value}");
+ break;
+ }
+#endif
+ return InvalidEnumerationValue(typeof(DataRowVersion), (int)value);
+ }
+
+ internal static ArgumentOutOfRangeException NotSupportedCommandBehavior(CommandBehavior value, string method)
+ => NotSupportedEnumerationValue(typeof(CommandBehavior), value.ToString(), method);
+
+ internal static ArgumentException BadParameterName(string parameterName)
+ {
+ ArgumentException e = new(StringsHelper.GetString(Strings.ADP_BadParameterName, parameterName));
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static Exception DeriveParametersNotSupported(IDbCommand value)
+ => DataAdapter(StringsHelper.GetString(Strings.ADP_DeriveParametersNotSupported, value.GetType().Name, value.CommandType.ToString()));
+
+ internal static Exception NoStoredProcedureExists(string sproc) => InvalidOperation(StringsHelper.GetString(Strings.ADP_NoStoredProcedureExists, sproc));
+#endregion
+
+#region DbMetaDataFactory
+ internal static Exception DataTableDoesNotExist(string collectionName)
+ => Argument(StringsHelper.GetString(Strings.MDF_DataTableDoesNotExist, collectionName));
+
+ // IDbCommand.UpdateRowSource
+ internal static ArgumentOutOfRangeException InvalidUpdateRowSource(UpdateRowSource value)
+ {
+#if DEBUG
+ switch (value)
+ {
+ case UpdateRowSource.None:
+ case UpdateRowSource.OutputParameters:
+ case UpdateRowSource.FirstReturnedRecord:
+ case UpdateRowSource.Both:
+ Debug.Fail("valid UpdateRowSource " + value.ToString());
+ break;
+ }
+#endif
+ return InvalidEnumerationValue(typeof(UpdateRowSource), (int)value);
+ }
+
+ internal static Exception QueryFailed(string collectionName, Exception e)
+ => InvalidOperation(StringsHelper.GetString(Strings.MDF_QueryFailed, collectionName), e);
+
+ internal static Exception NoColumns() => Argument(StringsHelper.GetString(Strings.MDF_NoColumns));
+
+ internal static InvalidOperationException ConnectionRequired(string method)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_ConnectionRequired, method));
+
+ internal static InvalidOperationException OpenConnectionRequired(string method, ConnectionState state)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_OpenConnectionRequired, method, ADP.ConnectionStateMsg(state)));
+
+ internal static Exception OpenReaderExists(bool marsOn) => OpenReaderExists(null, marsOn);
+
+ internal static Exception OpenReaderExists(Exception e, bool marsOn)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_OpenReaderExists, marsOn ? ADP.Command : ADP.Connection), e);
+
+ internal static Exception InvalidXml() => Argument(StringsHelper.GetString(Strings.MDF_InvalidXml));
+
+ internal static Exception InvalidXmlInvalidValue(string collectionName, string columnName)
+ => Argument(StringsHelper.GetString(Strings.MDF_InvalidXmlInvalidValue, collectionName, columnName));
+
+ internal static Exception CollectionNameIsNotUnique(string collectionName)
+ => Argument(StringsHelper.GetString(Strings.MDF_CollectionNameISNotUnique, collectionName));
+
+ internal static Exception UnableToBuildCollection(string collectionName)
+ => Argument(StringsHelper.GetString(Strings.MDF_UnableToBuildCollection, collectionName));
+
+ internal static Exception UndefinedCollection(string collectionName)
+ => Argument(StringsHelper.GetString(Strings.MDF_UndefinedCollection, collectionName));
+
+ internal static Exception UnsupportedVersion(string collectionName) => Argument(StringsHelper.GetString(Strings.MDF_UnsupportedVersion, collectionName));
+
+ internal static Exception AmbiguousCollectionName(string collectionName)
+ => Argument(StringsHelper.GetString(Strings.MDF_AmbiguousCollectionName, collectionName));
+
+ internal static Exception MissingDataSourceInformationColumn() => Argument(StringsHelper.GetString(Strings.MDF_MissingDataSourceInformationColumn));
+
+ internal static Exception IncorrectNumberOfDataSourceInformationRows()
+ => Argument(StringsHelper.GetString(Strings.MDF_IncorrectNumberOfDataSourceInformationRows));
+
+ internal static Exception MissingRestrictionColumn() => Argument(StringsHelper.GetString(Strings.MDF_MissingRestrictionColumn));
+
+ internal static Exception MissingRestrictionRow() => Argument(StringsHelper.GetString(Strings.MDF_MissingRestrictionRow));
+
+ internal static Exception UndefinedPopulationMechanism(string populationMechanism)
+#if NETFRAMEWORK
+ => Argument(StringsHelper.GetString(Strings.MDF_UndefinedPopulationMechanism, populationMechanism));
+#else
+ => throw new NotImplementedException();
+#endif
+#endregion
+
+#region DbConnectionPool and related
+ internal static Exception PooledOpenTimeout()
+ => ADP.InvalidOperation(StringsHelper.GetString(Strings.ADP_PooledOpenTimeout));
+
+ internal static Exception NonPooledOpenTimeout()
+ => ADP.TimeoutException(StringsHelper.GetString(Strings.ADP_NonPooledOpenTimeout));
+#endregion
+
+#region DbProviderException
+ internal static InvalidOperationException TransactionConnectionMismatch()
+ => Provider(StringsHelper.GetString(Strings.ADP_TransactionConnectionMismatch));
+
+ internal static InvalidOperationException TransactionRequired(string method)
+ => Provider(StringsHelper.GetString(Strings.ADP_TransactionRequired, method));
+
+ internal static InvalidOperationException TransactionCompletedButNotDisposed() => Provider(StringsHelper.GetString(Strings.ADP_TransactionCompletedButNotDisposed));
+
+#endregion
+
+#region SqlMetaData, SqlTypes
+ internal static Exception InvalidMetaDataValue() => ADP.Argument(StringsHelper.GetString(Strings.ADP_InvalidMetaDataValue));
+
+ internal static InvalidOperationException NonSequentialColumnAccess(int badCol, int currCol)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_NonSequentialColumnAccess,
+ badCol.ToString(CultureInfo.InvariantCulture),
+ currCol.ToString(CultureInfo.InvariantCulture)));
+#endregion
+
+#region IDataParameter
+ internal static ArgumentException InvalidDataType(TypeCode typecode) => Argument(StringsHelper.GetString(Strings.ADP_InvalidDataType, typecode.ToString()));
+
+ internal static ArgumentException UnknownDataType(Type dataType) => Argument(StringsHelper.GetString(Strings.ADP_UnknownDataType, dataType.FullName));
+
+ internal static ArgumentException DbTypeNotSupported(DbType type, Type enumtype)
+ => Argument(StringsHelper.GetString(Strings.ADP_DbTypeNotSupported, type.ToString(), enumtype.Name));
+
+ internal static ArgumentException UnknownDataTypeCode(Type dataType, TypeCode typeCode)
+ => Argument(StringsHelper.GetString(Strings.ADP_UnknownDataTypeCode, ((int)typeCode).ToString(CultureInfo.InvariantCulture), dataType.FullName));
+
+ internal static ArgumentException InvalidOffsetValue(int value)
+ => Argument(StringsHelper.GetString(Strings.ADP_InvalidOffsetValue, value.ToString(CultureInfo.InvariantCulture)));
+
+ internal static ArgumentException InvalidSizeValue(int value)
+ => Argument(StringsHelper.GetString(Strings.ADP_InvalidSizeValue, value.ToString(CultureInfo.InvariantCulture)));
+
+ internal static ArgumentException ParameterValueOutOfRange(decimal value)
+ => ADP.Argument(StringsHelper.GetString(Strings.ADP_ParameterValueOutOfRange, value.ToString((IFormatProvider)null)));
+
+ internal static ArgumentException ParameterValueOutOfRange(SqlDecimal value) => ADP.Argument(StringsHelper.GetString(Strings.ADP_ParameterValueOutOfRange, value.ToString()));
+
+ internal static ArgumentException ParameterValueOutOfRange(string value) => ADP.Argument(StringsHelper.GetString(Strings.ADP_ParameterValueOutOfRange, value));
+
+ internal static ArgumentException VersionDoesNotSupportDataType(string typeName) => Argument(StringsHelper.GetString(Strings.ADP_VersionDoesNotSupportDataType, typeName));
+
+ internal static Exception ParameterConversionFailed(object value, Type destType, Exception inner)
+ {
+ Debug.Assert(null != value, "null value on conversion failure");
+ Debug.Assert(null != inner, "null inner on conversion failure");
+
+ Exception e;
+ string message = StringsHelper.GetString(Strings.ADP_ParameterConversionFailed, value.GetType().Name, destType.Name);
+ if (inner is ArgumentException)
+ {
+ e = new ArgumentException(message, inner);
+ }
+ else if (inner is FormatException)
+ {
+ e = new FormatException(message, inner);
+ }
+ else if (inner is InvalidCastException)
+ {
+ e = new InvalidCastException(message, inner);
+ }
+ else if (inner is OverflowException)
+ {
+ e = new OverflowException(message, inner);
+ }
+ else
+ {
+ e = inner;
+ }
+ TraceExceptionAsReturnValue(e);
+ return e;
+ }
+#endregion
+
+#region IDataParameterCollection
+ internal static Exception ParametersMappingIndex(int index, DbParameterCollection collection) => CollectionIndexInt32(index, collection.GetType(), collection.Count);
+
+ internal static Exception ParametersSourceIndex(string parameterName, DbParameterCollection collection, Type parameterType)
+ => CollectionIndexString(parameterType, ADP.ParameterName, parameterName, collection.GetType());
+
+ internal static Exception ParameterNull(string parameter, DbParameterCollection collection, Type parameterType)
+ => CollectionNullValue(parameter, collection.GetType(), parameterType);
+
+ internal static Exception InvalidParameterType(DbParameterCollection collection, Type parameterType, object invalidValue)
+ => CollectionInvalidType(collection.GetType(), parameterType, invalidValue);
+#endregion
+
+#region IDbTransaction
+ internal static Exception ParallelTransactionsNotSupported(DbConnection obj)
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_ParallelTransactionsNotSupported, obj.GetType().Name));
+
+ internal static Exception TransactionZombied(DbTransaction obj) => InvalidOperation(StringsHelper.GetString(Strings.ADP_TransactionZombied, obj.GetType().Name));
+#endregion
+
+#region DbProviderConfigurationHandler
+ internal static InvalidOperationException InvalidMixedUsageOfSecureAndClearCredential()
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfSecureAndClearCredential));
+
+ internal static ArgumentException InvalidMixedArgumentOfSecureAndClearCredential()
+ => Argument(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfSecureAndClearCredential));
+
+ internal static InvalidOperationException InvalidMixedUsageOfSecureCredentialAndIntegratedSecurity()
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfSecureCredentialAndIntegratedSecurity));
+
+ internal static ArgumentException InvalidMixedArgumentOfSecureCredentialAndIntegratedSecurity()
+ => Argument(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfSecureCredentialAndIntegratedSecurity));
+
+ internal static InvalidOperationException InvalidMixedUsageOfAccessTokenAndIntegratedSecurity()
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenAndIntegratedSecurity));
+
+ static internal InvalidOperationException InvalidMixedUsageOfAccessTokenAndUserIDPassword()
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenAndUserIDPassword));
+
+ static internal InvalidOperationException InvalidMixedUsageOfAccessTokenAndAuthentication()
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenAndAuthentication));
+
+ static internal Exception InvalidMixedUsageOfCredentialAndAccessToken()
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfCredentialAndAccessToken));
+#endregion
+
+ internal static bool IsEmpty(string str) => string.IsNullOrEmpty(str);
+ internal static readonly IntPtr s_ptrZero = IntPtr.Zero;
+#if NETFRAMEWORK
+#region netfx project only
+ internal static Task CreatedTaskWithException(Exception ex)
+ {
+ TaskCompletionSource completion = new();
+ completion.SetException(ex);
+ return completion.Task;
+ }
+
+ internal static Task CreatedTaskWithCancellation()
+ {
+ TaskCompletionSource completion = new();
+ completion.SetCanceled();
+ return completion.Task;
+ }
+
+ internal static void TraceExceptionForCapture(Exception e)
+ {
+ Debug.Assert(ADP.IsCatchableExceptionType(e), "Invalid exception type, should have been re-thrown!");
+ TraceException(" '{0}'", e);
+ }
+
+ //
+ // Helper Functions
+ //
+ internal static void CheckArgumentLength(string value, string parameterName)
+ {
+ CheckArgumentNull(value, parameterName);
+ if (0 == value.Length)
+ {
+ throw Argument(StringsHelper.GetString(Strings.ADP_EmptyString, parameterName)); // MDAC 94859
+ }
+ }
+
+ // IDbConnection.BeginTransaction, OleDbTransaction.Begin
+ internal static ArgumentOutOfRangeException InvalidIsolationLevel(IsolationLevel value)
+ {
+#if DEBUG
+ switch (value)
+ {
+ case IsolationLevel.Unspecified:
+ case IsolationLevel.Chaos:
+ case IsolationLevel.ReadUncommitted:
+ case IsolationLevel.ReadCommitted:
+ case IsolationLevel.RepeatableRead:
+ case IsolationLevel.Serializable:
+ case IsolationLevel.Snapshot:
+ Debug.Assert(false, "valid IsolationLevel " + value.ToString());
+ break;
+ }
+#endif
+ return InvalidEnumerationValue(typeof(IsolationLevel), (int)value);
+ }
+
+ // DBDataPermissionAttribute.KeyRestrictionBehavior
+ internal static ArgumentOutOfRangeException InvalidKeyRestrictionBehavior(KeyRestrictionBehavior value)
+ {
+#if DEBUG
+ switch (value)
+ {
+ case KeyRestrictionBehavior.PreventUsage:
+ case KeyRestrictionBehavior.AllowOnly:
+ Debug.Assert(false, "valid KeyRestrictionBehavior " + value.ToString());
+ break;
+ }
+#endif
+ return InvalidEnumerationValue(typeof(KeyRestrictionBehavior), (int)value);
+ }
+
+ // IDataParameter.Direction
+ internal static ArgumentOutOfRangeException InvalidParameterDirection(ParameterDirection value)
+ {
+#if DEBUG
+ switch (value)
+ {
+ case ParameterDirection.Input:
+ case ParameterDirection.Output:
+ case ParameterDirection.InputOutput:
+ case ParameterDirection.ReturnValue:
+ Debug.Assert(false, "valid ParameterDirection " + value.ToString());
+ break;
+ }
+#endif
+ return InvalidEnumerationValue(typeof(ParameterDirection), (int)value);
+ }
+
+ //
+ // DbConnectionOptions, DataAccess
+ //
+ internal static ArgumentException InvalidKeyname(string parameterName)
+ {
+ return Argument(StringsHelper.GetString(Strings.ADP_InvalidKey), parameterName);
+ }
+ internal static ArgumentException InvalidValue(string parameterName)
+ {
+ return Argument(StringsHelper.GetString(Strings.ADP_InvalidValue), parameterName);
+ }
+ internal static ArgumentException InvalidMixedArgumentOfSecureCredentialAndContextConnection()
+ {
+ return ADP.Argument(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfSecureCredentialAndContextConnection));
+ }
+ internal static InvalidOperationException InvalidMixedUsageOfAccessTokenAndContextConnection()
+ {
+ return ADP.InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenAndContextConnection));
+ }
+ internal static Exception InvalidMixedUsageOfAccessTokenAndCredential()
+ {
+ return ADP.InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenAndCredential));
+ }
+
+ //
+ // DBDataPermission, DataAccess, Odbc
+ //
+ internal static Exception InvalidXMLBadVersion()
+ {
+ return Argument(StringsHelper.GetString(Strings.ADP_InvalidXMLBadVersion));
+ }
+ internal static Exception NotAPermissionElement()
+ {
+ return Argument(StringsHelper.GetString(Strings.ADP_NotAPermissionElement));
+ }
+ internal static Exception PermissionTypeMismatch()
+ {
+ return Argument(StringsHelper.GetString(Strings.ADP_PermissionTypeMismatch));
+ }
+
+ //
+ // DbDataReader
+ //
+ internal static Exception NumericToDecimalOverflow()
+ {
+ return InvalidCast(StringsHelper.GetString(Strings.ADP_NumericToDecimalOverflow));
+ }
+
+ //
+ // : IDbCommand
+ //
+ internal static Exception InvalidCommandTimeout(int value, string name)
+ {
+ return Argument(StringsHelper.GetString(Strings.ADP_InvalidCommandTimeout, value.ToString(CultureInfo.InvariantCulture)), name);
+ }
+
+ //
+ // : DbDataAdapter
+ //
+ internal static InvalidOperationException ComputerNameEx(int lastError)
+ {
+ return InvalidOperation(StringsHelper.GetString(Strings.ADP_ComputerNameEx, lastError));
+ }
+
+ // global constant strings
+ internal const float FailoverTimeoutStepForTnir = 0.125F; // Fraction of timeout to use in case of Transparent Network IP resolution.
+ internal const int MinimumTimeoutForTnirMs = 500; // The first login attempt in Transparent network IP Resolution
+
+ internal static readonly int s_ptrSize = IntPtr.Size;
+ internal static readonly IntPtr s_invalidPtr = new(-1); // use for INVALID_HANDLE
+
+ internal static readonly bool s_isWindowsNT = (PlatformID.Win32NT == Environment.OSVersion.Platform);
+ internal static readonly bool s_isPlatformNT5 = (ADP.s_isWindowsNT && (Environment.OSVersion.Version.Major >= 5));
+
+ [FileIOPermission(SecurityAction.Assert, AllFiles = FileIOPermissionAccess.PathDiscovery)]
+ [ResourceExposure(ResourceScope.Machine)]
+ [ResourceConsumption(ResourceScope.Machine)]
+ internal static string GetFullPath(string filename)
+ { // MDAC 77686
+ return Path.GetFullPath(filename);
+ }
+
+ // TODO: cache machine name and listen to longhorn event to reset it
+ internal static string GetComputerNameDnsFullyQualified()
+ {
+ const int ComputerNameDnsFullyQualified = 3; // winbase.h, enum COMPUTER_NAME_FORMAT
+ const int ERROR_MORE_DATA = 234; // winerror.h
+
+ string value;
+ if (s_isPlatformNT5)
+ {
+ int length = 0; // length parameter must be zero if buffer is null
+ // query for the required length
+ // VSTFDEVDIV 479551 - ensure that GetComputerNameEx does not fail with unexpected values and that the length is positive
+ int getComputerNameExError = 0;
+ if (0 == SafeNativeMethods.GetComputerNameEx(ComputerNameDnsFullyQualified, null, ref length))
+ {
+ getComputerNameExError = Marshal.GetLastWin32Error();
+ }
+ if ((getComputerNameExError != 0 && getComputerNameExError != ERROR_MORE_DATA) || length <= 0)
+ {
+ throw ADP.ComputerNameEx(getComputerNameExError);
+ }
+
+ StringBuilder buffer = new(length);
+ length = buffer.Capacity;
+ if (0 == SafeNativeMethods.GetComputerNameEx(ComputerNameDnsFullyQualified, buffer, ref length))
+ {
+ throw ADP.ComputerNameEx(Marshal.GetLastWin32Error());
+ }
+
+ // Note: In Longhorn you'll be able to rename a machine without
+ // rebooting. Therefore, don't cache this machine name.
+ value = buffer.ToString();
+ }
+ else
+ {
+ value = ADP.MachineName();
+ }
+ return value;
+ }
+
+ [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)]
+ internal static IntPtr IntPtrOffset(IntPtr pbase, int offset)
+ {
+ if (4 == ADP.s_ptrSize)
+ {
+ return (IntPtr)checked(pbase.ToInt32() + offset);
+ }
+ Debug.Assert(8 == ADP.s_ptrSize, "8 != IntPtr.Size"); // MDAC 73747
+ return (IntPtr)checked(pbase.ToInt64() + offset);
+ }
+
+#endregion
+#else
+#region netcore project only
+ internal static Timer UnsafeCreateTimer(TimerCallback callback, object state, int dueTime, int period)
+ {
+ // Don't capture the current ExecutionContext and its AsyncLocals onto
+ // a global timer causing them to live forever
+ bool restoreFlow = false;
+ try
+ {
+ if (!ExecutionContext.IsFlowSuppressed())
+ {
+ ExecutionContext.SuppressFlow();
+ restoreFlow = true;
+ }
+
+ return new Timer(callback, state, dueTime, period);
+ }
+ finally
+ {
+ // Restore the current ExecutionContext
+ if (restoreFlow)
+ ExecutionContext.RestoreFlow();
+ }
+ }
+
+ //
+ // COM+ exceptions
+ //
+ internal static PlatformNotSupportedException DbTypeNotSupported(string dbType) => new(StringsHelper.GetString(Strings.SQL_DbTypeNotSupportedOnThisPlatform, dbType));
+
+ // IDbConnection.BeginTransaction, OleDbTransaction.Begin
+ internal static ArgumentOutOfRangeException InvalidIsolationLevel(IsolationLevel value)
+ {
+#if DEBUG
+ switch (value)
+ {
+ case IsolationLevel.Unspecified:
+ case IsolationLevel.Chaos:
+ case IsolationLevel.ReadUncommitted:
+ case IsolationLevel.ReadCommitted:
+ case IsolationLevel.RepeatableRead:
+ case IsolationLevel.Serializable:
+ case IsolationLevel.Snapshot:
+ Debug.Fail("valid IsolationLevel " + value.ToString());
+ break;
+ }
+#endif
+ return InvalidEnumerationValue(typeof(IsolationLevel), (int)value);
+ }
+
+ // ConnectionUtil
+ internal static Exception IncorrectPhysicalConnectionType() => new ArgumentException(StringsHelper.GetString(StringsHelper.SNI_IncorrectPhysicalConnectionType));
+
+ // IDataParameter.Direction
+ internal static ArgumentOutOfRangeException InvalidParameterDirection(ParameterDirection value)
+ {
+#if DEBUG
+ switch (value)
+ {
+ case ParameterDirection.Input:
+ case ParameterDirection.Output:
+ case ParameterDirection.InputOutput:
+ case ParameterDirection.ReturnValue:
+ Debug.Fail("valid ParameterDirection " + value.ToString());
+ break;
+ }
+#endif
+ return InvalidEnumerationValue(typeof(ParameterDirection), (int)value);
+ }
+
+ //
+ // : IDbCommand
+ //
+ internal static Exception InvalidCommandTimeout(int value, [CallerMemberName] string property = "")
+ => Argument(StringsHelper.GetString(Strings.ADP_InvalidCommandTimeout, value.ToString(CultureInfo.InvariantCulture)), property);
+#endregion
+#endif
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionOptions.Common.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionOptions.Common.cs
new file mode 100644
index 0000000000..65e425590e
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionOptions.Common.cs
@@ -0,0 +1,770 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Globalization;
+using System.Text;
+using System.Text.RegularExpressions;
+using Microsoft.Data.SqlClient;
+
+namespace Microsoft.Data.Common
+{
+ partial class DbConnectionOptions
+ {
+ // instances of this class are intended to be immutable, i.e readonly
+ // used by pooling classes so it is much easier to verify correctness
+ // when not worried about the class being modified during execution
+
+ // connection string common keywords
+ private static class KEY
+ {
+ internal const string Integrated_Security = DbConnectionStringKeywords.IntegratedSecurity;
+ internal const string Password = DbConnectionStringKeywords.Password;
+ internal const string Persist_Security_Info = DbConnectionStringKeywords.PersistSecurityInfo;
+ internal const string User_ID = DbConnectionStringKeywords.UserID;
+ internal const string Encrypt = DbConnectionStringKeywords.Encrypt;
+ }
+
+ // known connection string common synonyms
+ private static class SYNONYM
+ {
+ internal const string Pwd = DbConnectionStringSynonyms.Pwd;
+ internal const string UID = DbConnectionStringSynonyms.UID;
+ }
+
+#if DEBUG
+ /*private const string ConnectionStringPatternV1 =
+ "[\\s;]*"
+ +"(?([^=\\s]|\\s+[^=\\s]|\\s+==|==)+)"
+ + "\\s*=(?!=)\\s*"
+ +"(?("
+ + "(" + "\"" + "([^\"]|\"\")*" + "\"" + ")"
+ + "|"
+ + "(" + "'" + "([^']|'')*" + "'" + ")"
+ + "|"
+ + "(" + "(?![\"'])" + "([^\\s;]|\\s+[^\\s;])*" + "(?([^=\\s\\p{Cc}]|\\s+[^=\\s\\p{Cc}]|\\s+==|==)+)" // allow any visible character for keyname except '=' which must quoted as '=='
+ + "\\s*=(?!=)\\s*" // the equal sign divides the key and value parts
+ + "(?"
+ + "(\"([^\"\u0000]|\"\")*\")" // double quoted string, " must be quoted as ""
+ + "|"
+ + "('([^'\u0000]|'')*')" // single quoted string, ' must be quoted as ''
+ + "|"
+ + "((?![\"'\\s])" // unquoted value must not start with " or ' or space, would also like = but too late to change
+ + "([^;\\s\\p{Cc}]|\\s+[^;\\s\\p{Cc}])*" // control characters must be quoted
+ + "(?([^=\\s\\p{Cc}]|\\s+[^=\\s\\p{Cc}])+)" // allow any visible character for keyname except '='
+ + "\\s*=\\s*" // the equal sign divides the key and value parts
+ + "(?"
+ + "(\\{([^\\}\u0000]|\\}\\})*\\})" // quoted string, starts with { and ends with }
+ + "|"
+ + "((?![\\{\\s])" // unquoted value must not start with { or space, would also like = but too late to change
+ + "([^;\\s\\p{Cc}]|\\s+[^;\\s\\p{Cc}])*" // control characters must be quoted
+
+ + ")" // although the spec does not allow {}
+ // embedded within a value, the retail code does.
+ + ")(\\s*)(;|[\u0000\\s]*$)" // whitespace after value up to semicolon or end-of-line
+ + ")*" // repeat the key-value pair
+ + "[\\s;]*[\u0000\\s]*" // trailing whitespace/semicolons (DataSourceLocator), embedded nulls are allowed only in the end
+ ;
+
+ private static readonly Regex s_connectionStringRegex = new Regex(ConnectionStringPattern, RegexOptions.ExplicitCapture | RegexOptions.Compiled);
+ private static readonly Regex s_connectionStringRegexOdbc = new Regex(ConnectionStringPatternOdbc, RegexOptions.ExplicitCapture | RegexOptions.Compiled);
+#endif
+ private const string ConnectionStringValidKeyPattern = "^(?![;\\s])[^\\p{Cc}]+(? _parsetable;
+
+ internal Dictionary Parsetable => _parsetable;
+ public bool IsEmpty => _keyChain == null;
+
+ public DbConnectionOptions(string connectionString, Dictionary synonyms)
+ {
+ _parsetable = new Dictionary(StringComparer.InvariantCultureIgnoreCase);
+ _usersConnectionString = ((null != connectionString) ? connectionString : "");
+
+ // first pass on parsing, initial syntax check
+ if (0 < _usersConnectionString.Length)
+ {
+ _keyChain = ParseInternal(_parsetable, _usersConnectionString, true, synonyms, false);
+ _hasPasswordKeyword = (_parsetable.ContainsKey(KEY.Password) || _parsetable.ContainsKey(SYNONYM.Pwd));
+ _hasUserIdKeyword = (_parsetable.ContainsKey(KEY.User_ID) || _parsetable.ContainsKey(SYNONYM.UID));
+ }
+ }
+
+ protected DbConnectionOptions(DbConnectionOptions connectionOptions)
+ { // Clone used by SqlConnectionString
+ _usersConnectionString = connectionOptions._usersConnectionString;
+ _parsetable = connectionOptions._parsetable;
+ _keyChain = connectionOptions._keyChain;
+ _hasPasswordKeyword = connectionOptions._hasPasswordKeyword;
+ _hasUserIdKeyword = connectionOptions._hasUserIdKeyword;
+ }
+
+ internal bool TryGetParsetableValue(string key, out string value) => _parsetable.TryGetValue(key, out value);
+
+ // same as Boolean, but with SSPI thrown in as valid yes
+ public bool ConvertValueToIntegratedSecurity()
+ {
+ return _parsetable.TryGetValue(KEY.Integrated_Security, out string value) && value != null ?
+ ConvertValueToIntegratedSecurityInternal(value) :
+ false;
+ }
+
+ internal bool ConvertValueToIntegratedSecurityInternal(string stringValue)
+ {
+ if (CompareInsensitiveInvariant(stringValue, "sspi") || CompareInsensitiveInvariant(stringValue, "true") || CompareInsensitiveInvariant(stringValue, "yes"))
+ return true;
+ else if (CompareInsensitiveInvariant(stringValue, "false") || CompareInsensitiveInvariant(stringValue, "no"))
+ return false;
+ else
+ {
+ string tmp = stringValue.Trim(); // Remove leading & trailing whitespace.
+ if (CompareInsensitiveInvariant(tmp, "sspi") || CompareInsensitiveInvariant(tmp, "true") || CompareInsensitiveInvariant(tmp, "yes"))
+ return true;
+ else if (CompareInsensitiveInvariant(tmp, "false") || CompareInsensitiveInvariant(tmp, "no"))
+ return false;
+ else
+ {
+ throw ADP.InvalidConnectionOptionValue(KEY.Integrated_Security);
+ }
+ }
+ }
+
+ public int ConvertValueToInt32(string keyName, int defaultValue)
+ {
+ return _parsetable.TryGetValue(keyName, out string value) && value != null ?
+ ConvertToInt32Internal(keyName, value) :
+ defaultValue;
+ }
+
+ internal static int ConvertToInt32Internal(string keyname, string stringValue)
+ {
+ try
+ {
+ return int.Parse(stringValue, System.Globalization.NumberStyles.Integer, CultureInfo.InvariantCulture);
+ }
+ catch (FormatException e)
+ {
+ throw ADP.InvalidConnectionOptionValue(keyname, e);
+ }
+ catch (OverflowException e)
+ {
+ throw ADP.InvalidConnectionOptionValue(keyname, e);
+ }
+ }
+
+ public string ConvertValueToString(string keyName, string defaultValue)
+ => _parsetable.TryGetValue(keyName, out string value) && value != null ? value : defaultValue;
+
+ public bool ContainsKey(string keyword) => _parsetable.ContainsKey(keyword);
+
+ protected internal virtual string Expand() => _usersConnectionString;
+
+ public string UsersConnectionString(bool hidePassword) => UsersConnectionString(hidePassword, false);
+
+ internal string UsersConnectionStringForTrace() => UsersConnectionString(true, true);
+
+ private string UsersConnectionString(bool hidePassword, bool forceHidePassword)
+ {
+ string connectionString = _usersConnectionString;
+ if (_hasPasswordKeyword && (forceHidePassword || (hidePassword && !HasPersistablePassword)))
+ {
+ ReplacePasswordPwd(out connectionString, false);
+ }
+ return connectionString ?? string.Empty;
+ }
+
+ internal bool HasPersistablePassword => _hasPasswordKeyword ?
+ ConvertValueToBoolean(KEY.Persist_Security_Info, DbConnectionStringDefaults.PersistSecurityInfo) :
+ true; // no password means persistable password so we don't have to munge
+
+ public bool ConvertValueToBoolean(string keyName, bool defaultValue)
+ {
+ string value;
+ return _parsetable.TryGetValue(keyName, out value) ?
+ ConvertValueToBooleanInternal(keyName, value) :
+ defaultValue;
+ }
+
+ internal static bool ConvertValueToBooleanInternal(string keyName, string stringValue)
+ {
+ if (CompareInsensitiveInvariant(stringValue, "true") || CompareInsensitiveInvariant(stringValue, "yes"))
+ return true;
+ else if (CompareInsensitiveInvariant(stringValue, "false") || CompareInsensitiveInvariant(stringValue, "no"))
+ return false;
+ else
+ {
+ string tmp = stringValue.Trim(); // Remove leading & trailing whitespace.
+ if (CompareInsensitiveInvariant(tmp, "true") || CompareInsensitiveInvariant(tmp, "yes"))
+ return true;
+ else if (CompareInsensitiveInvariant(tmp, "false") || CompareInsensitiveInvariant(tmp, "no"))
+ return false;
+ else
+ {
+ throw ADP.InvalidConnectionOptionValue(keyName);
+ }
+ }
+ }
+
+ private static bool CompareInsensitiveInvariant(string strvalue, string strconst)
+ => (0 == StringComparer.OrdinalIgnoreCase.Compare(strvalue, strconst));
+
+ [System.Diagnostics.Conditional("DEBUG")]
+ private static void DebugTraceKeyValuePair(string keyname, string keyvalue, Dictionary synonyms)
+ {
+ if (SqlClientEventSource.Log.IsAdvancedTraceOn())
+ {
+ Debug.Assert(string.Equals(keyname, keyname?.ToLower(), StringComparison.InvariantCulture), "missing ToLower");
+ string realkeyname = ((null != synonyms) ? synonyms[keyname] : keyname);
+
+ if (!string.Equals(KEY.Password, realkeyname, StringComparison.InvariantCultureIgnoreCase) &&
+ !string.Equals(SYNONYM.Pwd, realkeyname, StringComparison.InvariantCultureIgnoreCase))
+ {
+ // don't trace passwords ever!
+ if (null != keyvalue)
+ {
+ SqlClientEventSource.Log.AdvancedTraceEvent(" KeyName='{0}', KeyValue='{1}'", keyname, keyvalue);
+ }
+ else
+ {
+ SqlClientEventSource.Log.AdvancedTraceEvent(" KeyName='{0}'", keyname);
+ }
+ }
+ }
+ }
+
+ private static string GetKeyName(StringBuilder buffer)
+ {
+ int count = buffer.Length;
+ while ((0 < count) && char.IsWhiteSpace(buffer[count - 1]))
+ {
+ count--; // trailing whitespace
+ }
+ return buffer.ToString(0, count).ToLower(CultureInfo.InvariantCulture);
+ }
+
+ private static string GetKeyValue(StringBuilder buffer, bool trimWhitespace)
+ {
+ int count = buffer.Length;
+ int index = 0;
+ if (trimWhitespace)
+ {
+ while ((index < count) && char.IsWhiteSpace(buffer[index]))
+ {
+ index++; // leading whitespace
+ }
+ while ((0 < count) && char.IsWhiteSpace(buffer[count - 1]))
+ {
+ count--; // trailing whitespace
+ }
+ }
+ return buffer.ToString(index, count - index);
+ }
+
+ // transition states used for parsing
+ private enum ParserState
+ {
+ NothingYet = 1, //start point
+ Key,
+ KeyEqual,
+ KeyEnd,
+ UnquotedValue,
+ DoubleQuoteValue,
+ DoubleQuoteValueQuote,
+ SingleQuoteValue,
+ SingleQuoteValueQuote,
+ BraceQuoteValue,
+ BraceQuoteValueQuote,
+ QuotedValueEnd,
+ NullTermination,
+ };
+
+ internal static int GetKeyValuePair(string connectionString, int currentPosition, StringBuilder buffer, bool useOdbcRules, out string keyname, out string keyvalue)
+ {
+ int startposition = currentPosition;
+
+ buffer.Length = 0;
+ keyname = null;
+ keyvalue = null;
+
+ char currentChar = '\0';
+
+ ParserState parserState = ParserState.NothingYet;
+ int length = connectionString.Length;
+ for (; currentPosition < length; ++currentPosition)
+ {
+ currentChar = connectionString[currentPosition];
+
+ switch (parserState)
+ {
+ case ParserState.NothingYet: // [\\s;]*
+ if ((';' == currentChar) || char.IsWhiteSpace(currentChar))
+ {
+ continue;
+ }
+ if ('\0' == currentChar)
+ { parserState = ParserState.NullTermination; continue; }
+ if (char.IsControl(currentChar))
+ { throw ADP.ConnectionStringSyntax(startposition); }
+ startposition = currentPosition;
+ if ('=' != currentChar)
+ {
+ parserState = ParserState.Key;
+ break;
+ }
+ else
+ {
+ parserState = ParserState.KeyEqual;
+ continue;
+ }
+
+ case ParserState.Key: // (?([^=\\s\\p{Cc}]|\\s+[^=\\s\\p{Cc}]|\\s+==|==)+)
+ if ('=' == currentChar)
+ { parserState = ParserState.KeyEqual; continue; }
+ if (char.IsWhiteSpace(currentChar))
+ { break; }
+ if (char.IsControl(currentChar))
+ { throw ADP.ConnectionStringSyntax(startposition); }
+ break;
+
+ case ParserState.KeyEqual: // \\s*=(?!=)\\s*
+ if (!useOdbcRules && '=' == currentChar)
+ { parserState = ParserState.Key; break; }
+ keyname = GetKeyName(buffer);
+ if (string.IsNullOrEmpty(keyname))
+ { throw ADP.ConnectionStringSyntax(startposition); }
+ buffer.Length = 0;
+ parserState = ParserState.KeyEnd;
+ goto case ParserState.KeyEnd;
+
+ case ParserState.KeyEnd:
+ if (char.IsWhiteSpace(currentChar))
+ { continue; }
+ if (useOdbcRules)
+ {
+ if ('{' == currentChar)
+ { parserState = ParserState.BraceQuoteValue; break; }
+ }
+ else
+ {
+ if ('\'' == currentChar)
+ { parserState = ParserState.SingleQuoteValue; continue; }
+ if ('"' == currentChar)
+ { parserState = ParserState.DoubleQuoteValue; continue; }
+ }
+ if (';' == currentChar)
+ { goto ParserExit; }
+ if ('\0' == currentChar)
+ { goto ParserExit; }
+ if (char.IsControl(currentChar))
+ { throw ADP.ConnectionStringSyntax(startposition); }
+ parserState = ParserState.UnquotedValue;
+ break;
+
+ case ParserState.UnquotedValue: // "((?![\"'\\s])" + "([^;\\s\\p{Cc}]|\\s+[^;\\s\\p{Cc}])*" + "(? SplitConnectionString(string connectionString, Dictionary synonyms, bool firstKey)
+ {
+ var parsetable = new Dictionary();
+ Regex parser = (firstKey ? s_connectionStringRegexOdbc : s_connectionStringRegex);
+
+ const int KeyIndex = 1, ValueIndex = 2;
+ Debug.Assert(KeyIndex == parser.GroupNumberFromName("key"), "wrong key index");
+ Debug.Assert(ValueIndex == parser.GroupNumberFromName("value"), "wrong value index");
+
+ if (null != connectionString)
+ {
+ Match match = parser.Match(connectionString);
+ if (!match.Success || (match.Length != connectionString.Length))
+ {
+ throw ADP.ConnectionStringSyntax(match.Length);
+ }
+ int indexValue = 0;
+ CaptureCollection keyvalues = match.Groups[ValueIndex].Captures;
+ foreach (Capture keypair in match.Groups[KeyIndex].Captures)
+ {
+ string keyname = (firstKey ? keypair.Value : keypair.Value.Replace("==", "=")).ToLower(CultureInfo.InvariantCulture);
+ string keyvalue = keyvalues[indexValue++].Value;
+ if (0 < keyvalue.Length)
+ {
+ if (!firstKey)
+ {
+ switch (keyvalue[0])
+ {
+ case '\"':
+ keyvalue = keyvalue.Substring(1, keyvalue.Length - 2).Replace("\"\"", "\"");
+ break;
+ case '\'':
+ keyvalue = keyvalue.Substring(1, keyvalue.Length - 2).Replace("\'\'", "\'");
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ else
+ {
+ keyvalue = null;
+ }
+ DebugTraceKeyValuePair(keyname, keyvalue, synonyms);
+ string synonym;
+ string realkeyname = null != synonyms ?
+ (synonyms.TryGetValue(keyname, out synonym) ? synonym : null) : keyname;
+
+ if (!IsKeyNameValid(realkeyname))
+ {
+ throw ADP.KeywordNotSupported(keyname);
+ }
+ if (!firstKey || !parsetable.ContainsKey(realkeyname))
+ {
+ parsetable[realkeyname] = keyvalue; // last key-value pair wins (or first)
+ }
+ }
+ }
+ return parsetable;
+ }
+
+ private static void ParseComparison(Dictionary parsetable, string connectionString, Dictionary synonyms, bool firstKey, Exception e)
+ {
+ try
+ {
+ var parsedvalues = SplitConnectionString(connectionString, synonyms, firstKey);
+ foreach (var entry in parsedvalues)
+ {
+ string keyname = entry.Key;
+ string value1 = entry.Value;
+ string value2;
+ bool parsetableContainsKey = parsetable.TryGetValue(keyname, out value2);
+ Debug.Assert(parsetableContainsKey, $"{nameof(ParseInternal)} code vs. regex mismatch keyname <{keyname}>");
+ Debug.Assert(value1 == value2, $"{nameof(ParseInternal)} code vs. regex mismatch keyvalue <{value1}> <{value2}>");
+ }
+ }
+ catch (ArgumentException f)
+ {
+ if (null != e)
+ {
+ string msg1 = e.Message;
+ string msg2 = f.Message;
+
+ const string KeywordNotSupportedMessagePrefix = "Keyword not supported:";
+ const string WrongFormatMessagePrefix = "Format of the initialization string";
+ bool isEquivalent = (msg1 == msg2);
+ if (!isEquivalent)
+ {
+ // We also accept cases were Regex parser (debug only) reports "wrong format" and
+ // retail parsing code reports format exception in different location or "keyword not supported"
+ if (msg2.StartsWith(WrongFormatMessagePrefix, StringComparison.Ordinal))
+ {
+ if (msg1.StartsWith(KeywordNotSupportedMessagePrefix, StringComparison.Ordinal) || msg1.StartsWith(WrongFormatMessagePrefix, StringComparison.Ordinal))
+ {
+ isEquivalent = true;
+ }
+ }
+ }
+ Debug.Assert(isEquivalent, "ParseInternal code vs regex message mismatch: <" + msg1 + "> <" + msg2 + ">");
+ }
+ else
+ {
+ Debug.Fail("ParseInternal code vs regex throw mismatch " + f.Message);
+ }
+ e = null;
+ }
+ if (null != e)
+ {
+ Debug.Fail("ParseInternal code threw exception vs regex mismatch");
+ }
+ }
+#endif
+
+ private static NameValuePair ParseInternal(Dictionary parsetable, string connectionString, bool buildChain, Dictionary synonyms, bool firstKey)
+ {
+ Debug.Assert(null != connectionString, "null connectionstring");
+ StringBuilder buffer = new StringBuilder();
+ NameValuePair localKeychain = null, keychain = null;
+#if DEBUG
+ try
+ {
+#endif
+ int nextStartPosition = 0;
+ int endPosition = connectionString.Length;
+ while (nextStartPosition < endPosition)
+ {
+ int startPosition = nextStartPosition;
+
+ string keyname, keyvalue;
+ nextStartPosition = GetKeyValuePair(connectionString, startPosition, buffer, firstKey, out keyname, out keyvalue);
+ if (string.IsNullOrEmpty(keyname))
+ {
+ // if (nextStartPosition != endPosition) { throw; }
+ break;
+ }
+#if DEBUG
+ DebugTraceKeyValuePair(keyname, keyvalue, synonyms);
+#endif
+ Debug.Assert(IsKeyNameValid(keyname), "ParseFailure, invalid keyname");
+ Debug.Assert(IsValueValidInternal(keyvalue), "parse failure, invalid keyvalue");
+
+ string realkeyname = (synonyms is not null) ?
+ (synonyms.TryGetValue(keyname, out string synonym) ? synonym : null) :
+ keyname;
+
+ if (!IsKeyNameValid(realkeyname))
+ {
+ throw ADP.KeywordNotSupported(keyname);
+ }
+ if (!firstKey || !parsetable.ContainsKey(realkeyname))
+ {
+ parsetable[realkeyname] = keyvalue; // last key-value pair wins (or first)
+ }
+
+ if (null != localKeychain)
+ {
+ localKeychain = localKeychain.Next = new NameValuePair(realkeyname, keyvalue, nextStartPosition - startPosition);
+ }
+ else if (buildChain)
+ { // first time only - don't contain modified chain from UDL file
+ keychain = localKeychain = new NameValuePair(realkeyname, keyvalue, nextStartPosition - startPosition);
+ }
+ }
+#if DEBUG
+ }
+ catch (ArgumentException e)
+ {
+ ParseComparison(parsetable, connectionString, synonyms, firstKey, e);
+ throw;
+ }
+ ParseComparison(parsetable, connectionString, synonyms, firstKey, null);
+#endif
+ return keychain;
+ }
+
+ internal NameValuePair ReplacePasswordPwd(out string constr, bool fakePassword)
+ {
+ bool expanded = false;
+ int copyPosition = 0;
+ NameValuePair head = null, tail = null, next = null;
+ StringBuilder builder = new StringBuilder(_usersConnectionString.Length);
+ for (NameValuePair current = _keyChain; null != current; current = current.Next)
+ {
+ if (!string.Equals(KEY.Password, current.Name, StringComparison.InvariantCultureIgnoreCase) &&
+ !string.Equals(SYNONYM.Pwd, current.Name, StringComparison.InvariantCultureIgnoreCase))
+ {
+ builder.Append(_usersConnectionString, copyPosition, current.Length);
+ if (fakePassword)
+ {
+ next = new NameValuePair(current.Name, current.Value, current.Length);
+ }
+ }
+ else if (fakePassword)
+ {
+ // replace user password/pwd value with *
+ const string equalstar = "=*;";
+ builder.Append(current.Name).Append(equalstar);
+ next = new NameValuePair(current.Name, "*", current.Name.Length + equalstar.Length);
+ expanded = true;
+ }
+ else
+ {
+ // drop the password/pwd completely in returning for user
+ expanded = true;
+ }
+
+ if (fakePassword)
+ {
+ if (null != tail)
+ {
+ tail = tail.Next = next;
+ }
+ else
+ {
+ tail = head = next;
+ }
+ }
+ copyPosition += current.Length;
+ }
+ Debug.Assert(expanded, "password/pwd was not removed");
+ constr = builder.ToString();
+ return head;
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionPoolKey.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionPoolKey.cs
new file mode 100644
index 0000000000..7d2799289f
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionPoolKey.cs
@@ -0,0 +1,58 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.Data.Common
+{
+ // DbConnectionPoolKey: Base class implementation of a key to connection pool groups
+ // Only connection string is used as a key
+ internal class DbConnectionPoolKey : ICloneable
+ {
+ private string _connectionString;
+
+ internal DbConnectionPoolKey(string connectionString)
+ {
+ _connectionString = connectionString;
+ }
+
+ protected DbConnectionPoolKey(DbConnectionPoolKey key)
+ {
+ _connectionString = key.ConnectionString;
+ }
+
+ public virtual object Clone()
+ {
+ return new DbConnectionPoolKey(this);
+ }
+
+ internal virtual string ConnectionString
+ {
+ get
+ {
+ return _connectionString;
+ }
+
+ set
+ {
+ _connectionString = value;
+ }
+ }
+
+ public override bool Equals(object obj)
+ {
+ if (obj == null)
+ {
+ return false;
+ }
+
+ return (obj is DbConnectionPoolKey key && _connectionString == key._connectionString);
+ }
+
+ public override int GetHashCode()
+ {
+ return _connectionString == null ? 0 : _connectionString.GetHashCode();
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionStringCommon.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionStringCommon.cs
new file mode 100644
index 0000000000..ee6fafa0f0
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionStringCommon.cs
@@ -0,0 +1,1153 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Globalization;
+using Microsoft.Data.SqlClient;
+
+namespace Microsoft.Data.Common
+{
+ internal static class DbConnectionStringBuilderUtil
+ {
+ internal static bool ConvertToBoolean(object value)
+ {
+ Debug.Assert(null != value, "ConvertToBoolean(null)");
+ if (value is string svalue)
+ {
+ if (StringComparer.OrdinalIgnoreCase.Equals(svalue, "true") || StringComparer.OrdinalIgnoreCase.Equals(svalue, "yes"))
+ return true;
+ else if (StringComparer.OrdinalIgnoreCase.Equals(svalue, "false") || StringComparer.OrdinalIgnoreCase.Equals(svalue, "no"))
+ return false;
+ else
+ {
+ string tmp = svalue.Trim(); // Remove leading & trailing white space.
+ if (StringComparer.OrdinalIgnoreCase.Equals(tmp, "true") || StringComparer.OrdinalIgnoreCase.Equals(tmp, "yes"))
+ return true;
+ else if (StringComparer.OrdinalIgnoreCase.Equals(tmp, "false") || StringComparer.OrdinalIgnoreCase.Equals(tmp, "no"))
+ return false;
+ }
+ return bool.Parse(svalue);
+ }
+ try
+ {
+ return Convert.ToBoolean(value, CultureInfo.InvariantCulture);
+ }
+ catch (InvalidCastException e)
+ {
+ throw ADP.ConvertFailed(value.GetType(), typeof(bool), e);
+ }
+ }
+
+ internal static bool ConvertToIntegratedSecurity(object value)
+ {
+ Debug.Assert(null != value, "ConvertToIntegratedSecurity(null)");
+ if (value is string svalue)
+ {
+ if (StringComparer.OrdinalIgnoreCase.Equals(svalue, "sspi") || StringComparer.OrdinalIgnoreCase.Equals(svalue, "true") || StringComparer.OrdinalIgnoreCase.Equals(svalue, "yes"))
+ return true;
+ else if (StringComparer.OrdinalIgnoreCase.Equals(svalue, "false") || StringComparer.OrdinalIgnoreCase.Equals(svalue, "no"))
+ return false;
+ else
+ {
+ string tmp = svalue.Trim(); // Remove leading & trailing white space.
+ if (StringComparer.OrdinalIgnoreCase.Equals(tmp, "sspi") || StringComparer.OrdinalIgnoreCase.Equals(tmp, "true") || StringComparer.OrdinalIgnoreCase.Equals(tmp, "yes"))
+ return true;
+ else if (StringComparer.OrdinalIgnoreCase.Equals(tmp, "false") || StringComparer.OrdinalIgnoreCase.Equals(tmp, "no"))
+ return false;
+ }
+ return bool.Parse(svalue);
+ }
+ try
+ {
+ return Convert.ToBoolean(value, CultureInfo.InvariantCulture);
+ }
+ catch (InvalidCastException e)
+ {
+ throw ADP.ConvertFailed(value.GetType(), typeof(bool), e);
+ }
+ }
+
+ internal static int ConvertToInt32(object value)
+ {
+ try
+ {
+ return Convert.ToInt32(value, CultureInfo.InvariantCulture);
+ }
+ catch (InvalidCastException e)
+ {
+ throw ADP.ConvertFailed(value.GetType(), typeof(int), e);
+ }
+ }
+
+ internal static string ConvertToString(object value)
+ {
+ try
+ {
+ return Convert.ToString(value, CultureInfo.InvariantCulture);
+ }
+ catch (InvalidCastException e)
+ {
+ throw ADP.ConvertFailed(value.GetType(), typeof(string), e);
+ }
+ }
+
+ #region <>
+ internal static bool TryConvertToPoolBlockingPeriod(string value, out PoolBlockingPeriod result)
+ {
+ Debug.Assert(Enum.GetNames(typeof(PoolBlockingPeriod)).Length == 3, "PoolBlockingPeriod enum has changed, update needed");
+ Debug.Assert(null != value, "TryConvertToPoolBlockingPeriod(null,...)");
+
+ if (StringComparer.OrdinalIgnoreCase.Equals(value, nameof(PoolBlockingPeriod.Auto)))
+ {
+ result = PoolBlockingPeriod.Auto;
+ return true;
+ }
+ else if (StringComparer.OrdinalIgnoreCase.Equals(value, nameof(PoolBlockingPeriod.AlwaysBlock)))
+ {
+ result = PoolBlockingPeriod.AlwaysBlock;
+ return true;
+ }
+ else if (StringComparer.OrdinalIgnoreCase.Equals(value, nameof(PoolBlockingPeriod.NeverBlock)))
+ {
+ result = PoolBlockingPeriod.NeverBlock;
+ return true;
+ }
+ else
+ {
+ result = DbConnectionStringDefaults.PoolBlockingPeriod;
+ return false;
+ }
+ }
+
+ internal static bool IsValidPoolBlockingPeriodValue(PoolBlockingPeriod value)
+ {
+ Debug.Assert(Enum.GetNames(typeof(PoolBlockingPeriod)).Length == 3, "PoolBlockingPeriod enum has changed, update needed");
+ return value == PoolBlockingPeriod.Auto || value == PoolBlockingPeriod.AlwaysBlock || value == PoolBlockingPeriod.NeverBlock;
+ }
+
+ internal static string PoolBlockingPeriodToString(PoolBlockingPeriod value)
+ {
+ Debug.Assert(IsValidPoolBlockingPeriodValue(value));
+
+ return value switch
+ {
+ PoolBlockingPeriod.AlwaysBlock => nameof(PoolBlockingPeriod.AlwaysBlock),
+ PoolBlockingPeriod.NeverBlock => nameof(PoolBlockingPeriod.NeverBlock),
+ _ => nameof(PoolBlockingPeriod.Auto),
+ };
+ }
+
+ ///
+ /// This method attempts to convert the given value to a PoolBlockingPeriod enum. The algorithm is:
+ /// * if the value is from type string, it will be matched against PoolBlockingPeriod enum names only, using ordinal, case-insensitive comparer
+ /// * if the value is from type PoolBlockingPeriod, it will be used as is
+ /// * if the value is from integral type (SByte, Int16, Int32, Int64, Byte, UInt16, UInt32, or UInt64), it will be converted to enum
+ /// * if the value is another enum or any other type, it will be blocked with an appropriate ArgumentException
+ ///
+ /// in any case above, if the converted value is out of valid range, the method raises ArgumentOutOfRangeException.
+ ///
+ /// PoolBlockingPeriod value in the valid range
+ internal static PoolBlockingPeriod ConvertToPoolBlockingPeriod(string keyword, object value)
+ {
+ Debug.Assert(null != value, "ConvertToPoolBlockingPeriod(null)");
+ if (value is string sValue)
+ {
+ // We could use Enum.TryParse here, but it accepts value combinations like
+ // "ReadOnly, ReadWrite" which are unwelcome here
+ // Also, Enum.TryParse is 100x slower than plain StringComparer.OrdinalIgnoreCase.Equals method.
+
+ if (TryConvertToPoolBlockingPeriod(sValue, out PoolBlockingPeriod result))
+ {
+ return result;
+ }
+
+ // try again after remove leading & trailing whitespaces.
+ sValue = sValue.Trim();
+ if (TryConvertToPoolBlockingPeriod(sValue, out result))
+ {
+ return result;
+ }
+
+ // string values must be valid
+ throw ADP.InvalidConnectionOptionValue(keyword);
+ }
+ else
+ {
+ // the value is not string, try other options
+ PoolBlockingPeriod eValue;
+
+ if (value is PoolBlockingPeriod period)
+ {
+ // quick path for the most common case
+ eValue = period;
+ }
+ else if (value.GetType().IsEnum)
+ {
+ // explicitly block scenarios in which user tries to use wrong enum types, like:
+ // builder["PoolBlockingPeriod"] = EnvironmentVariableTarget.Process;
+ // workaround: explicitly cast non-PoolBlockingPeriod enums to int
+ throw ADP.ConvertFailed(value.GetType(), typeof(PoolBlockingPeriod), null);
+ }
+ else
+ {
+ try
+ {
+ // Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest
+ eValue = (PoolBlockingPeriod)Enum.ToObject(typeof(PoolBlockingPeriod), value);
+ }
+ catch (ArgumentException e)
+ {
+ // to be consistent with the messages we send in case of wrong type usage, replace
+ // the error with our exception, and keep the original one as inner one for troubleshooting
+ throw ADP.ConvertFailed(value.GetType(), typeof(PoolBlockingPeriod), e);
+ }
+ }
+
+ // ensure value is in valid range
+ if (IsValidPoolBlockingPeriodValue(eValue))
+ {
+ return eValue;
+ }
+ else
+ {
+ throw ADP.InvalidEnumerationValue(typeof(ApplicationIntent), (int)eValue);
+ }
+ }
+ }
+ #endregion
+
+ internal static bool TryConvertToApplicationIntent(string value, out ApplicationIntent result)
+ {
+ Debug.Assert(Enum.GetNames(typeof(ApplicationIntent)).Length == 2, "ApplicationIntent enum has changed, update needed");
+ Debug.Assert(null != value, "TryConvertToApplicationIntent(null,...)");
+
+ if (StringComparer.OrdinalIgnoreCase.Equals(value, nameof(ApplicationIntent.ReadOnly)))
+ {
+ result = ApplicationIntent.ReadOnly;
+ return true;
+ }
+ else if (StringComparer.OrdinalIgnoreCase.Equals(value, nameof(ApplicationIntent.ReadWrite)))
+ {
+ result = ApplicationIntent.ReadWrite;
+ return true;
+ }
+ else
+ {
+ result = DbConnectionStringDefaults.ApplicationIntent;
+ return false;
+ }
+ }
+
+ internal static bool IsValidApplicationIntentValue(ApplicationIntent value)
+ {
+ Debug.Assert(Enum.GetNames(typeof(ApplicationIntent)).Length == 2, "ApplicationIntent enum has changed, update needed");
+ return value == ApplicationIntent.ReadOnly || value == ApplicationIntent.ReadWrite;
+ }
+
+ internal static string ApplicationIntentToString(ApplicationIntent value)
+ {
+ Debug.Assert(IsValidApplicationIntentValue(value));
+ if (value == ApplicationIntent.ReadOnly)
+ {
+ return nameof(ApplicationIntent.ReadOnly);
+ }
+ else
+ {
+ return nameof(ApplicationIntent.ReadWrite);
+ }
+ }
+
+ ///
+ /// This method attempts to convert the given value tp ApplicationIntent enum. The algorithm is:
+ /// * if the value is from type string, it will be matched against ApplicationIntent enum names only, using ordinal, case-insensitive comparer
+ /// * if the value is from type ApplicationIntent, it will be used as is
+ /// * if the value is from integral type (SByte, Int16, Int32, Int64, Byte, UInt16, UInt32, or UInt64), it will be converted to enum
+ /// * if the value is another enum or any other type, it will be blocked with an appropriate ArgumentException
+ ///
+ /// in any case above, if the converted value is out of valid range, the method raises ArgumentOutOfRangeException.
+ ///
+ /// application intent value in the valid range
+ internal static ApplicationIntent ConvertToApplicationIntent(string keyword, object value)
+ {
+ Debug.Assert(null != value, "ConvertToApplicationIntent(null)");
+ if (value is string sValue)
+ {
+ // We could use Enum.TryParse here, but it accepts value combinations like
+ // "ReadOnly, ReadWrite" which are unwelcome here
+ // Also, Enum.TryParse is 100x slower than plain StringComparer.OrdinalIgnoreCase.Equals method.
+
+ if (TryConvertToApplicationIntent(sValue, out ApplicationIntent result))
+ {
+ return result;
+ }
+
+ // try again after remove leading & trailing whitespaces.
+ sValue = sValue.Trim();
+ if (TryConvertToApplicationIntent(sValue, out result))
+ {
+ return result;
+ }
+
+ // string values must be valid
+ throw ADP.InvalidConnectionOptionValue(keyword);
+ }
+ else
+ {
+ // the value is not string, try other options
+ ApplicationIntent eValue;
+
+ if (value is ApplicationIntent intent)
+ {
+ // quick path for the most common case
+ eValue = intent;
+ }
+ else if (value.GetType().IsEnum)
+ {
+ // explicitly block scenarios in which user tries to use wrong enum types, like:
+ // builder["ApplicationIntent"] = EnvironmentVariableTarget.Process;
+ // workaround: explicitly cast non-ApplicationIntent enums to int
+ throw ADP.ConvertFailed(value.GetType(), typeof(ApplicationIntent), null);
+ }
+ else
+ {
+ try
+ {
+ // Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest
+ eValue = (ApplicationIntent)Enum.ToObject(typeof(ApplicationIntent), value);
+ }
+ catch (ArgumentException e)
+ {
+ // to be consistent with the messages we send in case of wrong type usage, replace
+ // the error with our exception, and keep the original one as inner one for troubleshooting
+ throw ADP.ConvertFailed(value.GetType(), typeof(ApplicationIntent), e);
+ }
+ }
+
+ // ensure value is in valid range
+ if (IsValidApplicationIntentValue(eValue))
+ {
+ return eValue;
+ }
+ else
+ {
+ throw ADP.InvalidEnumerationValue(typeof(ApplicationIntent), (int)eValue);
+ }
+ }
+ }
+
+ const string SqlPasswordString = "Sql Password";
+ const string ActiveDirectoryPasswordString = "Active Directory Password";
+ const string ActiveDirectoryIntegratedString = "Active Directory Integrated";
+ const string ActiveDirectoryInteractiveString = "Active Directory Interactive";
+ const string ActiveDirectoryServicePrincipalString = "Active Directory Service Principal";
+ const string ActiveDirectoryDeviceCodeFlowString = "Active Directory Device Code Flow";
+ internal const string ActiveDirectoryManagedIdentityString = "Active Directory Managed Identity";
+ internal const string ActiveDirectoryMSIString = "Active Directory MSI";
+ internal const string ActiveDirectoryDefaultString = "Active Directory Default";
+ const string SqlCertificateString = "Sql Certificate";
+
+#if DEBUG
+ private static readonly string[] s_supportedAuthenticationModes =
+ {
+ "NotSpecified",
+ "SqlPassword",
+ "ActiveDirectoryPassword",
+ "ActiveDirectoryIntegrated",
+ "ActiveDirectoryInteractive",
+ "ActiveDirectoryServicePrincipal",
+ "ActiveDirectoryDeviceCodeFlow",
+ "ActiveDirectoryManagedIdentity",
+ "ActiveDirectoryMSI",
+ "ActiveDirectoryDefault"
+ };
+
+ private static bool IsValidAuthenticationMethodEnum()
+ {
+ string[] names = Enum.GetNames(typeof(SqlAuthenticationMethod));
+ int l = s_supportedAuthenticationModes.Length;
+ bool listValid;
+ if (listValid = names.Length == l)
+ {
+ for (int i = 0; i < l; i++)
+ {
+ if (s_supportedAuthenticationModes[i].CompareTo(names[i]) != 0)
+ {
+ listValid = false;
+ }
+ }
+ }
+ return listValid;
+ }
+#endif
+
+ internal static bool TryConvertToAuthenticationType(string value, out SqlAuthenticationMethod result)
+ {
+#if DEBUG
+ Debug.Assert(IsValidAuthenticationMethodEnum(), "SqlAuthenticationMethod enum has changed, update needed");
+#endif
+ bool isSuccess = false;
+
+ if (StringComparer.InvariantCultureIgnoreCase.Equals(value, SqlPasswordString)
+ || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.SqlPassword, CultureInfo.InvariantCulture)))
+ {
+ result = SqlAuthenticationMethod.SqlPassword;
+ isSuccess = true;
+ }
+ else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryPasswordString)
+ || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryPassword, CultureInfo.InvariantCulture)))
+ {
+ result = SqlAuthenticationMethod.ActiveDirectoryPassword;
+ isSuccess = true;
+ }
+ else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryIntegratedString)
+ || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryIntegrated, CultureInfo.InvariantCulture)))
+ {
+ result = SqlAuthenticationMethod.ActiveDirectoryIntegrated;
+ isSuccess = true;
+ }
+ else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryInteractiveString)
+ || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryInteractive, CultureInfo.InvariantCulture)))
+ {
+ result = SqlAuthenticationMethod.ActiveDirectoryInteractive;
+ isSuccess = true;
+ }
+ else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryServicePrincipalString)
+ || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryServicePrincipal, CultureInfo.InvariantCulture)))
+ {
+ result = SqlAuthenticationMethod.ActiveDirectoryServicePrincipal;
+ isSuccess = true;
+ }
+ else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryDeviceCodeFlowString)
+ || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow, CultureInfo.InvariantCulture)))
+ {
+ result = SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow;
+ isSuccess = true;
+ }
+ else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryManagedIdentityString)
+ || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryManagedIdentity, CultureInfo.InvariantCulture)))
+ {
+ result = SqlAuthenticationMethod.ActiveDirectoryManagedIdentity;
+ isSuccess = true;
+ }
+ else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryMSIString)
+ || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryMSI, CultureInfo.InvariantCulture)))
+ {
+ result = SqlAuthenticationMethod.ActiveDirectoryMSI;
+ isSuccess = true;
+ }
+ else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryDefaultString)
+ || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryDefault, CultureInfo.InvariantCulture)))
+ {
+ result = SqlAuthenticationMethod.ActiveDirectoryDefault;
+ isSuccess = true;
+ }
+#if ADONET_CERT_AUTH && NETFRAMEWORK
+ else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, SqlCertificateString)
+ || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.SqlCertificate, CultureInfo.InvariantCulture))) {
+ result = SqlAuthenticationMethod.SqlCertificate;
+ isSuccess = true;
+ }
+#endif
+ else
+ {
+ result = DbConnectionStringDefaults.Authentication;
+ }
+ return isSuccess;
+ }
+
+ ///
+ /// Convert a string value to the corresponding SqlConnectionColumnEncryptionSetting.
+ ///
+ ///
+ ///
+ ///
+ internal static bool TryConvertToColumnEncryptionSetting(string value, out SqlConnectionColumnEncryptionSetting result)
+ {
+ bool isSuccess = false;
+
+ if (StringComparer.InvariantCultureIgnoreCase.Equals(value, nameof(SqlConnectionColumnEncryptionSetting.Enabled)))
+ {
+ result = SqlConnectionColumnEncryptionSetting.Enabled;
+ isSuccess = true;
+ }
+ else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, nameof(SqlConnectionColumnEncryptionSetting.Disabled)))
+ {
+ result = SqlConnectionColumnEncryptionSetting.Disabled;
+ isSuccess = true;
+ }
+ else
+ {
+ result = DbConnectionStringDefaults.ColumnEncryptionSetting;
+ }
+
+ return isSuccess;
+ }
+
+ ///
+ /// Is it a valid connection level column encryption setting ?
+ ///
+ ///
+ ///
+ internal static bool IsValidColumnEncryptionSetting(SqlConnectionColumnEncryptionSetting value)
+ {
+ Debug.Assert(Enum.GetNames(typeof(SqlConnectionColumnEncryptionSetting)).Length == 2, "SqlConnectionColumnEncryptionSetting enum has changed, update needed");
+ return value == SqlConnectionColumnEncryptionSetting.Enabled || value == SqlConnectionColumnEncryptionSetting.Disabled;
+ }
+
+ ///
+ /// Convert connection level column encryption setting value to string.
+ ///
+ ///
+ ///
+ internal static string ColumnEncryptionSettingToString(SqlConnectionColumnEncryptionSetting value)
+ {
+ Debug.Assert(IsValidColumnEncryptionSetting(value), "value is not a valid connection level column encryption setting.");
+
+ return value switch
+ {
+ SqlConnectionColumnEncryptionSetting.Enabled => nameof(SqlConnectionColumnEncryptionSetting.Enabled),
+ SqlConnectionColumnEncryptionSetting.Disabled => nameof(SqlConnectionColumnEncryptionSetting.Disabled),
+ _ => null,
+ };
+ }
+
+ internal static bool IsValidAuthenticationTypeValue(SqlAuthenticationMethod value)
+ {
+ Debug.Assert(Enum.GetNames(typeof(SqlAuthenticationMethod)).Length == 10, "SqlAuthenticationMethod enum has changed, update needed");
+ return value == SqlAuthenticationMethod.SqlPassword
+ || value == SqlAuthenticationMethod.ActiveDirectoryPassword
+ || value == SqlAuthenticationMethod.ActiveDirectoryIntegrated
+ || value == SqlAuthenticationMethod.ActiveDirectoryInteractive
+ || value == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal
+ || value == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow
+ || value == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity
+ || value == SqlAuthenticationMethod.ActiveDirectoryMSI
+ || value == SqlAuthenticationMethod.ActiveDirectoryDefault
+#if ADONET_CERT_AUTH && NETFRAMEWORK
+ || value == SqlAuthenticationMethod.SqlCertificate
+#endif
+ || value == SqlAuthenticationMethod.NotSpecified;
+ }
+
+ internal static string AuthenticationTypeToString(SqlAuthenticationMethod value)
+ {
+ Debug.Assert(IsValidAuthenticationTypeValue(value));
+
+ return value switch
+ {
+ SqlAuthenticationMethod.SqlPassword => SqlPasswordString,
+ SqlAuthenticationMethod.ActiveDirectoryPassword => ActiveDirectoryPasswordString,
+ SqlAuthenticationMethod.ActiveDirectoryIntegrated => ActiveDirectoryIntegratedString,
+ SqlAuthenticationMethod.ActiveDirectoryInteractive => ActiveDirectoryInteractiveString,
+ SqlAuthenticationMethod.ActiveDirectoryServicePrincipal => ActiveDirectoryServicePrincipalString,
+ SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow => ActiveDirectoryDeviceCodeFlowString,
+ SqlAuthenticationMethod.ActiveDirectoryManagedIdentity => ActiveDirectoryManagedIdentityString,
+ SqlAuthenticationMethod.ActiveDirectoryMSI => ActiveDirectoryMSIString,
+ SqlAuthenticationMethod.ActiveDirectoryDefault => ActiveDirectoryDefaultString,
+#if ADONET_CERT_AUTH && NETFRAMEWORK
+ SqlAuthenticationMethod.SqlCertificate => SqlCertificateString,
+#endif
+ _ => null
+ };
+ }
+
+ internal static SqlAuthenticationMethod ConvertToAuthenticationType(string keyword, object value)
+ {
+ if (null == value)
+ {
+ return DbConnectionStringDefaults.Authentication;
+ }
+
+ if (value is string sValue)
+ {
+ if (TryConvertToAuthenticationType(sValue, out SqlAuthenticationMethod result))
+ {
+ return result;
+ }
+
+ // try again after remove leading & trailing whitespaces.
+ sValue = sValue.Trim();
+ if (TryConvertToAuthenticationType(sValue, out result))
+ {
+ return result;
+ }
+
+ // string values must be valid
+ throw ADP.InvalidConnectionOptionValue(keyword);
+ }
+ else
+ {
+ // the value is not string, try other options
+ SqlAuthenticationMethod eValue;
+
+ if (value is SqlAuthenticationMethod method)
+ {
+ // quick path for the most common case
+ eValue = method;
+ }
+ else if (value.GetType().IsEnum)
+ {
+ // explicitly block scenarios in which user tries to use wrong enum types, like:
+ // builder["ApplicationIntent"] = EnvironmentVariableTarget.Process;
+ // workaround: explicitly cast non-ApplicationIntent enums to int
+ throw ADP.ConvertFailed(value.GetType(), typeof(SqlAuthenticationMethod), null);
+ }
+ else
+ {
+ try
+ {
+ // Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest
+ eValue = (SqlAuthenticationMethod)Enum.ToObject(typeof(SqlAuthenticationMethod), value);
+ }
+ catch (ArgumentException e)
+ {
+ // to be consistent with the messages we send in case of wrong type usage, replace
+ // the error with our exception, and keep the original one as inner one for troubleshooting
+ throw ADP.ConvertFailed(value.GetType(), typeof(SqlAuthenticationMethod), e);
+ }
+ }
+
+ // ensure value is in valid range
+ if (IsValidAuthenticationTypeValue(eValue))
+ {
+ return eValue;
+ }
+ else
+ {
+ throw ADP.InvalidEnumerationValue(typeof(SqlAuthenticationMethod), (int)eValue);
+ }
+ }
+ }
+
+ ///
+ /// Convert the provided value to a SqlConnectionColumnEncryptionSetting.
+ ///
+ ///
+ ///
+ ///
+ internal static SqlConnectionColumnEncryptionSetting ConvertToColumnEncryptionSetting(string keyword, object value)
+ {
+ if (null == value)
+ {
+ return DbConnectionStringDefaults.ColumnEncryptionSetting;
+ }
+
+ if (value is string sValue)
+ {
+ if (TryConvertToColumnEncryptionSetting(sValue, out SqlConnectionColumnEncryptionSetting result))
+ {
+ return result;
+ }
+
+ // try again after remove leading & trailing whitespaces.
+ sValue = sValue.Trim();
+ if (TryConvertToColumnEncryptionSetting(sValue, out result))
+ {
+ return result;
+ }
+
+ // string values must be valid
+ throw ADP.InvalidConnectionOptionValue(keyword);
+ }
+ else
+ {
+ // the value is not string, try other options
+ SqlConnectionColumnEncryptionSetting eValue;
+
+ if (value is SqlConnectionColumnEncryptionSetting setting)
+ {
+ // quick path for the most common case
+ eValue = setting;
+ }
+ else if (value.GetType().IsEnum)
+ {
+ // explicitly block scenarios in which user tries to use wrong enum types, like:
+ // builder["SqlConnectionColumnEncryptionSetting"] = EnvironmentVariableTarget.Process;
+ // workaround: explicitly cast non-SqlConnectionColumnEncryptionSetting enums to int
+ throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionColumnEncryptionSetting), null);
+ }
+ else
+ {
+ try
+ {
+ // Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest
+ eValue = (SqlConnectionColumnEncryptionSetting)Enum.ToObject(typeof(SqlConnectionColumnEncryptionSetting), value);
+ }
+ catch (ArgumentException e)
+ {
+ // to be consistent with the messages we send in case of wrong type usage, replace
+ // the error with our exception, and keep the original one as inner one for troubleshooting
+ throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionColumnEncryptionSetting), e);
+ }
+ }
+
+ // ensure value is in valid range
+ if (IsValidColumnEncryptionSetting(eValue))
+ {
+ return eValue;
+ }
+ else
+ {
+ throw ADP.InvalidEnumerationValue(typeof(SqlConnectionColumnEncryptionSetting), (int)eValue);
+ }
+ }
+ }
+
+ #region <>
+ ///
+ /// Convert a string value to the corresponding SqlConnectionAttestationProtocol
+ ///
+ ///
+ ///
+ ///
+ internal static bool TryConvertToAttestationProtocol(string value, out SqlConnectionAttestationProtocol result)
+ {
+ if (StringComparer.InvariantCultureIgnoreCase.Equals(value, nameof(SqlConnectionAttestationProtocol.HGS)))
+ {
+ result = SqlConnectionAttestationProtocol.HGS;
+ return true;
+ }
+ else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, nameof(SqlConnectionAttestationProtocol.AAS)))
+ {
+ result = SqlConnectionAttestationProtocol.AAS;
+ return true;
+ }
+ else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, nameof(SqlConnectionAttestationProtocol.None)))
+ {
+ result = SqlConnectionAttestationProtocol.None;
+ return true;
+ }
+ else
+ {
+ result = DbConnectionStringDefaults.AttestationProtocol;
+ return false;
+ }
+ }
+
+ internal static bool IsValidAttestationProtocol(SqlConnectionAttestationProtocol value)
+ {
+ Debug.Assert(Enum.GetNames(typeof(SqlConnectionAttestationProtocol)).Length == 4, "SqlConnectionAttestationProtocol enum has changed, update needed");
+ return value == SqlConnectionAttestationProtocol.NotSpecified
+ || value == SqlConnectionAttestationProtocol.HGS
+ || value == SqlConnectionAttestationProtocol.AAS
+ || value == SqlConnectionAttestationProtocol.None;
+ }
+
+ internal static string AttestationProtocolToString(SqlConnectionAttestationProtocol value)
+ {
+ Debug.Assert(IsValidAttestationProtocol(value), "value is not a valid attestation protocol");
+
+ return value switch
+ {
+ SqlConnectionAttestationProtocol.AAS => nameof(SqlConnectionAttestationProtocol.AAS),
+ SqlConnectionAttestationProtocol.HGS => nameof(SqlConnectionAttestationProtocol.HGS),
+ SqlConnectionAttestationProtocol.None => nameof(SqlConnectionAttestationProtocol.None),
+ _ => null
+ };
+ }
+
+ internal static SqlConnectionAttestationProtocol ConvertToAttestationProtocol(string keyword, object value)
+ {
+ if (null == value)
+ {
+ return DbConnectionStringDefaults.AttestationProtocol;
+ }
+
+ if (value is string sValue)
+ {
+ // try again after remove leading & trailing whitespaces.
+ sValue = sValue.Trim();
+ if (TryConvertToAttestationProtocol(sValue, out SqlConnectionAttestationProtocol result))
+ {
+ return result;
+ }
+
+ // string values must be valid
+ throw ADP.InvalidConnectionOptionValue(keyword);
+ }
+ else
+ {
+ // the value is not string, try other options
+ SqlConnectionAttestationProtocol eValue;
+
+ if (value is SqlConnectionAttestationProtocol protocol)
+ {
+ eValue = protocol;
+ }
+ else if (value.GetType().IsEnum)
+ {
+ // explicitly block scenarios in which user tries to use wrong enum types, like:
+ // builder["SqlConnectionAttestationProtocol"] = EnvironmentVariableTarget.Process;
+ // workaround: explicitly cast non-SqlConnectionAttestationProtocol enums to int
+ throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionAttestationProtocol), null);
+ }
+ else
+ {
+ try
+ {
+ // Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest
+ eValue = (SqlConnectionAttestationProtocol)Enum.ToObject(typeof(SqlConnectionAttestationProtocol), value);
+ }
+ catch (ArgumentException e)
+ {
+ // to be consistent with the messages we send in case of wrong type usage, replace
+ // the error with our exception, and keep the original one as inner one for troubleshooting
+ throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionAttestationProtocol), e);
+ }
+ }
+
+ if (IsValidAttestationProtocol(eValue))
+ {
+ return eValue;
+ }
+ else
+ {
+ throw ADP.InvalidEnumerationValue(typeof(SqlConnectionAttestationProtocol), (int)eValue);
+ }
+ }
+ }
+
+ internal static SqlConnectionEncryptOption ConvertToSqlConnectionEncryptOption(string keyword, object value)
+ {
+ if (value is null)
+ {
+ return DbConnectionStringDefaults.Encrypt;
+ }
+ else if (value is string sValue)
+ {
+ return SqlConnectionEncryptOption.Parse(sValue);
+ }
+
+ throw ADP.InvalidConnectionOptionValue(keyword);
+ }
+
+ #endregion
+
+ #region <>
+ ///
+ /// IP Address Preference.
+ ///
+ private readonly static Dictionary s_preferenceNames = new(StringComparer.InvariantCultureIgnoreCase);
+
+ static DbConnectionStringBuilderUtil()
+ {
+ foreach (SqlConnectionIPAddressPreference item in Enum.GetValues(typeof(SqlConnectionIPAddressPreference)))
+ {
+ s_preferenceNames.Add(item.ToString(), item);
+ }
+ }
+
+ ///
+ /// Convert a string value to the corresponding IPAddressPreference.
+ ///
+ /// The string representation of the enumeration name to convert.
+ /// When this method returns, `result` contains an object of type `SqlConnectionIPAddressPreference` whose value is represented by `value` if the operation succeeds.
+ /// If the parse operation fails, `result` contains the default value of the `SqlConnectionIPAddressPreference` type.
+ /// `true` if the value parameter was converted successfully; otherwise, `false`.
+ internal static bool TryConvertToIPAddressPreference(string value, out SqlConnectionIPAddressPreference result)
+ {
+ if (!s_preferenceNames.TryGetValue(value, out result))
+ {
+ result = DbConnectionStringDefaults.IPAddressPreference;
+ return false;
+ }
+ return true;
+ }
+
+ ///
+ /// Verifies if the `value` is defined in the expected Enum.
+ ///
+ internal static bool IsValidIPAddressPreference(SqlConnectionIPAddressPreference value)
+ => value == SqlConnectionIPAddressPreference.IPv4First
+ || value == SqlConnectionIPAddressPreference.IPv6First
+ || value == SqlConnectionIPAddressPreference.UsePlatformDefault;
+
+ internal static string IPAddressPreferenceToString(SqlConnectionIPAddressPreference value)
+ => Enum.GetName(typeof(SqlConnectionIPAddressPreference), value);
+
+ internal static SqlConnectionIPAddressPreference ConvertToIPAddressPreference(string keyword, object value)
+ {
+ if (value is null)
+ {
+ return DbConnectionStringDefaults.IPAddressPreference; // IPv4First
+ }
+
+ if (value is string sValue)
+ {
+ // try again after remove leading & trailing whitespaces.
+ sValue = sValue.Trim();
+ if (TryConvertToIPAddressPreference(sValue, out SqlConnectionIPAddressPreference result))
+ {
+ return result;
+ }
+
+ // string values must be valid
+ throw ADP.InvalidConnectionOptionValue(keyword);
+ }
+ else
+ {
+ // the value is not string, try other options
+ SqlConnectionIPAddressPreference eValue;
+
+ if (value is SqlConnectionIPAddressPreference preference)
+ {
+ eValue = preference;
+ }
+ else if (value.GetType().IsEnum)
+ {
+ // explicitly block scenarios in which user tries to use wrong enum types, like:
+ // builder["SqlConnectionIPAddressPreference"] = EnvironmentVariableTarget.Process;
+ // workaround: explicitly cast non-SqlConnectionIPAddressPreference enums to int
+ throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionIPAddressPreference), null);
+ }
+ else
+ {
+ try
+ {
+ // Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest
+ eValue = (SqlConnectionIPAddressPreference)Enum.ToObject(typeof(SqlConnectionIPAddressPreference), value);
+ }
+ catch (ArgumentException e)
+ {
+ // to be consistent with the messages we send in case of wrong type usage, replace
+ // the error with our exception, and keep the original one as inner one for troubleshooting
+ throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionIPAddressPreference), e);
+ }
+ }
+
+ if (IsValidIPAddressPreference(eValue))
+ {
+ return eValue;
+ }
+ else
+ {
+ throw ADP.InvalidEnumerationValue(typeof(SqlConnectionIPAddressPreference), (int)eValue);
+ }
+ }
+ }
+ #endregion
+
+#if ADONET_CERT_AUTH && NETFRAMEWORK
+ internal static bool IsValidCertificateValue(string value) => string.IsNullOrEmpty(value)
+ || value.StartsWith("subject:", StringComparison.OrdinalIgnoreCase)
+ || value.StartsWith("sha1:", StringComparison.OrdinalIgnoreCase);
+#endif
+ }
+
+ internal static class DbConnectionStringDefaults
+ {
+ internal const ApplicationIntent ApplicationIntent = Microsoft.Data.SqlClient.ApplicationIntent.ReadWrite;
+ internal const string ApplicationName =
+#if NETFRAMEWORK
+ "Framework Microsoft SqlClient Data Provider";
+#else
+ "Core Microsoft SqlClient Data Provider";
+#endif
+ internal const string AttachDBFilename = "";
+ internal const int CommandTimeout = 30;
+ internal const int ConnectTimeout = 15;
+
+#if NETFRAMEWORK
+ internal const bool ConnectionReset = true;
+ internal const bool ContextConnection = false;
+ internal static readonly bool TransparentNetworkIPResolution = !LocalAppContextSwitches.DisableTNIRByDefault;
+ internal const string NetworkLibrary = "";
+#if ADONET_CERT_AUTH
+ internal const string Certificate = "";
+#endif
+#endif
+ internal const string CurrentLanguage = "";
+ internal const string DataSource = "";
+ internal static readonly SqlConnectionEncryptOption Encrypt = SqlConnectionEncryptOption.Mandatory;
+ internal const string HostNameInCertificate = "";
+ internal const bool Enlist = true;
+ internal const string FailoverPartner = "";
+ internal const string InitialCatalog = "";
+ internal const bool IntegratedSecurity = false;
+ internal const int LoadBalanceTimeout = 0; // default of 0 means don't use
+ internal const bool MultipleActiveResultSets = false;
+ internal const bool MultiSubnetFailover = false;
+ internal const int MaxPoolSize = 100;
+ internal const int MinPoolSize = 0;
+ internal const int PacketSize = 8000;
+ internal const string Password = "";
+ internal const bool PersistSecurityInfo = false;
+ internal const bool Pooling = true;
+ internal const bool TrustServerCertificate = false;
+ internal const string TypeSystemVersion = "Latest";
+ internal const string UserID = "";
+ internal const bool UserInstance = false;
+ internal const bool Replication = false;
+ internal const string WorkstationID = "";
+ internal const string TransactionBinding = "Implicit Unbind";
+ internal const int ConnectRetryCount = 1;
+ internal const int ConnectRetryInterval = 10;
+ internal static readonly SqlAuthenticationMethod Authentication = SqlAuthenticationMethod.NotSpecified;
+ internal const SqlConnectionColumnEncryptionSetting ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Disabled;
+ internal const string EnclaveAttestationUrl = "";
+ internal const SqlConnectionAttestationProtocol AttestationProtocol = SqlConnectionAttestationProtocol.NotSpecified;
+ internal const SqlConnectionIPAddressPreference IPAddressPreference = SqlConnectionIPAddressPreference.IPv4First;
+ internal const PoolBlockingPeriod PoolBlockingPeriod = SqlClient.PoolBlockingPeriod.Auto;
+ internal const string ServerSPN = "";
+ internal const string FailoverPartnerSPN = "";
+ }
+
+ internal static class DbConnectionStringKeywords
+ {
+#if NETFRAMEWORK
+ // Odbc
+ internal const string Driver = "Driver";
+ internal const string Dsn = "Dsn";
+ internal const string FileDsn = "FileDsn";
+ internal const string SaveFile = "SaveFile";
+
+ // OleDb
+ internal const string FileName = "File Name";
+ internal const string OleDbServices = "OLE DB Services";
+ internal const string Provider = "Provider";
+
+ // OracleClient
+ internal const string Unicode = "Unicode";
+ internal const string OmitOracleConnectionName = "Omit Oracle Connection Name";
+
+ // SqlClient
+ internal const string TransparentNetworkIPResolution = "Transparent Network IP Resolution";
+ internal const string Certificate = "Certificate";
+#endif
+ // SqlClient
+ internal const string ApplicationIntent = "Application Intent";
+ internal const string ApplicationName = "Application Name";
+ internal const string AttachDBFilename = "AttachDbFilename";
+ internal const string ConnectTimeout = "Connect Timeout";
+ internal const string CommandTimeout = "Command Timeout";
+ internal const string ConnectionReset = "Connection Reset";
+ internal const string ContextConnection = "Context Connection";
+ internal const string CurrentLanguage = "Current Language";
+ internal const string Encrypt = "Encrypt";
+ internal const string HostNameInCertificate = "Host Name In Certificate";
+ internal const string FailoverPartner = "Failover Partner";
+ internal const string InitialCatalog = "Initial Catalog";
+ internal const string MultipleActiveResultSets = "Multiple Active Result Sets";
+ internal const string MultiSubnetFailover = "Multi Subnet Failover";
+ internal const string NetworkLibrary = "Network Library";
+ internal const string PacketSize = "Packet Size";
+ internal const string Replication = "Replication";
+ internal const string TransactionBinding = "Transaction Binding";
+ internal const string TrustServerCertificate = "Trust Server Certificate";
+ internal const string TypeSystemVersion = "Type System Version";
+ internal const string UserInstance = "User Instance";
+ internal const string WorkstationID = "Workstation ID";
+ internal const string ConnectRetryCount = "Connect Retry Count";
+ internal const string ConnectRetryInterval = "Connect Retry Interval";
+ internal const string Authentication = "Authentication";
+ internal const string ColumnEncryptionSetting = "Column Encryption Setting";
+ internal const string EnclaveAttestationUrl = "Enclave Attestation Url";
+ internal const string AttestationProtocol = "Attestation Protocol";
+ internal const string IPAddressPreference = "IP Address Preference";
+ internal const string ServerSPN = "Server SPN";
+ internal const string FailoverPartnerSPN = "Failover Partner SPN";
+
+ // common keywords (OleDb, OracleClient, SqlClient)
+ internal const string DataSource = "Data Source";
+ internal const string IntegratedSecurity = "Integrated Security";
+ internal const string Password = "Password";
+ internal const string PersistSecurityInfo = "Persist Security Info";
+ internal const string UserID = "User ID";
+
+ // managed pooling (OracleClient, SqlClient)
+ internal const string Enlist = "Enlist";
+ internal const string LoadBalanceTimeout = "Load Balance Timeout";
+ internal const string MaxPoolSize = "Max Pool Size";
+ internal const string Pooling = "Pooling";
+ internal const string MinPoolSize = "Min Pool Size";
+ internal const string PoolBlockingPeriod = "Pool Blocking Period";
+ }
+
+ internal static class DbConnectionStringSynonyms
+ {
+#if NETFRAMEWORK
+ //internal const string TransparentNetworkIPResolution = TRANSPARENTNETWORKIPRESOLUTION;
+ internal const string TRANSPARENTNETWORKIPRESOLUTION = "transparentnetworkipresolution";
+#endif
+ //internal const string ApplicationName = APP;
+ internal const string APP = "app";
+
+ // internal const string IPAddressPreference = IPADDRESSPREFERENCE;
+ internal const string IPADDRESSPREFERENCE = "ipaddresspreference";
+
+ //internal const string ApplicationIntent = APPLICATIONINTENT;
+ internal const string APPLICATIONINTENT = "applicationintent";
+
+ //internal const string AttachDBFilename = EXTENDEDPROPERTIES+","+INITIALFILENAME;
+ internal const string EXTENDEDPROPERTIES = "extended properties";
+ internal const string INITIALFILENAME = "initial file name";
+
+ // internal const string HostNameInCertificate = HOSTNAMEINCERTIFICATE;
+ internal const string HOSTNAMEINCERTIFICATE = "hostnameincertificate";
+
+ //internal const string ConnectTimeout = CONNECTIONTIMEOUT+","+TIMEOUT;
+ internal const string CONNECTIONTIMEOUT = "connection timeout";
+ internal const string TIMEOUT = "timeout";
+
+ //internal const string ConnectRetryCount = CONNECTRETRYCOUNT;
+ internal const string CONNECTRETRYCOUNT = "connectretrycount";
+
+ //internal const string ConnectRetryInterval = CONNECTRETRYINTERVAL;
+ internal const string CONNECTRETRYINTERVAL = "connectretryinterval";
+
+ //internal const string CurrentLanguage = LANGUAGE;
+ internal const string LANGUAGE = "language";
+
+ //internal const string OraDataSource = SERVER;
+ //internal const string SqlDataSource = ADDR+","+ADDRESS+","+SERVER+","+NETWORKADDRESS;
+ internal const string ADDR = "addr";
+ internal const string ADDRESS = "address";
+ internal const string SERVER = "server";
+ internal const string NETWORKADDRESS = "network address";
+
+ //internal const string InitialCatalog = DATABASE;
+ internal const string DATABASE = "database";
+
+ //internal const string IntegratedSecurity = TRUSTEDCONNECTION;
+ internal const string TRUSTEDCONNECTION = "trusted_connection"; // underscore introduced in everett
+
+ //internal const string LoadBalanceTimeout = ConnectionLifetime;
+ internal const string ConnectionLifetime = "connection lifetime";
+
+ //internal const string MultipleActiveResultSets = MULTIPLEACTIVERESULTSETS;
+ internal const string MULTIPLEACTIVERESULTSETS = "multipleactiveresultsets";
+
+ //internal const string MultiSubnetFailover = MULTISUBNETFAILOVER;
+ internal const string MULTISUBNETFAILOVER = "multisubnetfailover";
+
+ //internal const string NetworkLibrary = NET+","+NETWORK;
+ internal const string NET = "net";
+ internal const string NETWORK = "network";
+
+ //internal const string PoolBlockingPeriod = POOLBLOCKINGPERIOD;
+ internal const string POOLBLOCKINGPERIOD = "poolblockingperiod";
+
+ //internal const string Password = Pwd;
+ internal const string Pwd = "pwd";
+
+ //internal const string PersistSecurityInfo = PERSISTSECURITYINFO;
+ internal const string PERSISTSECURITYINFO = "persistsecurityinfo";
+
+ //internal const string TrustServerCertificate = TRUSTSERVERCERTIFICATE;
+ internal const string TRUSTSERVERCERTIFICATE = "trustservercertificate";
+
+ //internal const string UserID = UID+","+User;
+ internal const string UID = "uid";
+ internal const string User = "user";
+
+ //internal const string WorkstationID = WSID;
+ internal const string WSID = "wsid";
+
+ //internal const string server SPNs
+ internal const string ServerSPN = "ServerSPN";
+ internal const string FailoverPartnerSPN = "FailoverPartnerSPN";
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/MultipartIdentifier.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/MultipartIdentifier.cs
new file mode 100644
index 0000000000..a30b462092
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/MultipartIdentifier.cs
@@ -0,0 +1,291 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Text;
+
+namespace Microsoft.Data.Common
+{
+ internal class MultipartIdentifier
+ {
+ private const int MaxParts = 4;
+ internal const int ServerIndex = 0;
+ internal const int CatalogIndex = 1;
+ internal const int SchemaIndex = 2;
+ internal const int TableIndex = 3;
+
+ /*
+ Left quote strings need to correspond 1 to 1 with the right quote strings
+ example: "ab" "cd", passed in for the left and the right quote
+ would set a or b as a starting quote character.
+ If a is the starting quote char then c would be the ending quote char
+ otherwise if b is the starting quote char then d would be the ending quote character.
+ */
+ internal static string[] ParseMultipartIdentifier(string name, string leftQuote, string rightQuote, string property, bool ThrowOnEmptyMultipartName)
+ {
+ return ParseMultipartIdentifier(name, leftQuote, rightQuote, '.', MaxParts, true, property, ThrowOnEmptyMultipartName);
+ }
+
+ private enum MPIState
+ {
+ MPI_Value,
+ MPI_ParseNonQuote,
+ MPI_LookForSeparator,
+ MPI_LookForNextCharOrSeparator,
+ MPI_ParseQuote,
+ MPI_RightQuote,
+ }
+
+ /* Core function for parsing the multipart identifier string.
+ * parameters: name - string to parse
+ * leftquote: set of characters which are valid quoting characters to initiate a quote
+ * rightquote: set of characters which are valid to stop a quote, array index's correspond to the leftquote array.
+ * separator: separator to use
+ * limit: number of names to parse out
+ * removequote:to remove the quotes on the returned string
+ */
+ private static void IncrementStringCount(string name, string[] ary, ref int position, string property)
+ {
+ ++position;
+ int limit = ary.Length;
+ if (position >= limit)
+ {
+ throw ADP.InvalidMultipartNameToManyParts(property, name, limit);
+ }
+ ary[position] = string.Empty;
+ }
+
+ private static bool IsWhitespace(char ch)
+ {
+ return char.IsWhiteSpace(ch);
+ }
+
+ internal static string[] ParseMultipartIdentifier(string name, string leftQuote, string rightQuote, char separator, int limit, bool removequotes, string property, bool ThrowOnEmptyMultipartName)
+ {
+ if (limit <= 0)
+ {
+ throw ADP.InvalidMultipartNameToManyParts(property, name, limit);
+ }
+
+ if (-1 != leftQuote.IndexOf(separator) || -1 != rightQuote.IndexOf(separator) || leftQuote.Length != rightQuote.Length)
+ {
+ throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name);
+ }
+
+ string[] parsedNames = new string[limit]; // return string array
+ int stringCount = 0; // index of current string in the buffer
+ MPIState state = MPIState.MPI_Value; // Initialize the starting state
+
+ StringBuilder sb = new StringBuilder(name.Length); // String buffer to hold the string being currently built, init the string builder so it will never be resized
+ StringBuilder whitespaceSB = null; // String buffer to hold whitespace used when parsing nonquoted strings 'a b . c d' = 'a b' and 'c d'
+ char rightQuoteChar = ' '; // Right quote character to use given the left quote character found.
+ for (int index = 0; index < name.Length; ++index)
+ {
+ char testchar = name[index];
+ switch (state)
+ {
+ case MPIState.MPI_Value:
+ {
+ int quoteIndex;
+ if (IsWhitespace(testchar))
+ { // Is White Space then skip the whitespace
+ continue;
+ }
+ else
+ if (testchar == separator)
+ { // If we found a separator, no string was found, initialize the string we are parsing to Empty and the next one to Empty.
+ // This is NOT a redundant setting of string.Empty it solves the case where we are parsing ".foo" and we should be returning null, null, empty, foo
+ parsedNames[stringCount] = string.Empty;
+ IncrementStringCount(name, parsedNames, ref stringCount, property);
+ }
+ else
+ if (-1 != (quoteIndex = leftQuote.IndexOf(testchar)))
+ { // If we are a left quote
+ rightQuoteChar = rightQuote[quoteIndex]; // record the corresponding right quote for the left quote
+ sb.Length = 0;
+ if (!removequotes)
+ {
+ sb.Append(testchar);
+ }
+ state = MPIState.MPI_ParseQuote;
+ }
+ else
+ if (-1 != rightQuote.IndexOf(testchar))
+ { // If we shouldn't see a right quote
+ throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name);
+ }
+ else
+ {
+ sb.Length = 0;
+ sb.Append(testchar);
+ state = MPIState.MPI_ParseNonQuote;
+ }
+ break;
+ }
+
+ case MPIState.MPI_ParseNonQuote:
+ {
+ if (testchar == separator)
+ {
+ parsedNames[stringCount] = sb.ToString(); // set the currently parsed string
+ IncrementStringCount(name, parsedNames, ref stringCount, property);
+ state = MPIState.MPI_Value;
+ }
+ else // Quotes are not valid inside a non-quoted name
+ if (-1 != rightQuote.IndexOf(testchar))
+ {
+ throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name);
+ }
+ else
+ if (-1 != leftQuote.IndexOf(testchar))
+ {
+ throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name);
+ }
+ else
+ if (IsWhitespace(testchar))
+ { // If it is Whitespace
+ parsedNames[stringCount] = sb.ToString(); // Set the currently parsed string
+ if (null == whitespaceSB)
+ {
+ whitespaceSB = new StringBuilder();
+ }
+ whitespaceSB.Length = 0;
+ whitespaceSB.Append(testchar); // start to record the whitespace, if we are parsing a name like "foo bar" we should return "foo bar"
+ state = MPIState.MPI_LookForNextCharOrSeparator;
+ }
+ else
+ {
+ sb.Append(testchar);
+ }
+ break;
+ }
+
+ case MPIState.MPI_LookForNextCharOrSeparator:
+ {
+ if (!IsWhitespace(testchar))
+ { // If it is not whitespace
+ if (testchar == separator)
+ {
+ IncrementStringCount(name, parsedNames, ref stringCount, property);
+ state = MPIState.MPI_Value;
+ }
+ else
+ { // If its not a separator and not whitespace
+ sb.Append(whitespaceSB);
+ sb.Append(testchar);
+ parsedNames[stringCount] = sb.ToString(); // Need to set the name here in case the string ends here.
+ state = MPIState.MPI_ParseNonQuote;
+ }
+ }
+ else
+ {
+ whitespaceSB.Append(testchar);
+ }
+ break;
+ }
+
+ case MPIState.MPI_ParseQuote:
+ {
+ if (testchar == rightQuoteChar)
+ { // if se are on a right quote see if we are escaping the right quote or ending the quoted string
+ if (!removequotes)
+ {
+ sb.Append(testchar);
+ }
+ state = MPIState.MPI_RightQuote;
+ }
+ else
+ {
+ sb.Append(testchar); // Append what we are currently parsing
+ }
+ break;
+ }
+
+ case MPIState.MPI_RightQuote:
+ {
+ if (testchar == rightQuoteChar)
+ { // If the next char is another right quote then we were escaping the right quote
+ sb.Append(testchar);
+ state = MPIState.MPI_ParseQuote;
+ }
+ else
+ if (testchar == separator)
+ { // If its a separator then record what we've parsed
+ parsedNames[stringCount] = sb.ToString();
+ IncrementStringCount(name, parsedNames, ref stringCount, property);
+ state = MPIState.MPI_Value;
+ }
+ else
+ if (!IsWhitespace(testchar))
+ { // If it is not whitespace we got problems
+ throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name);
+ }
+ else
+ { // It is a whitespace character so the following char should be whitespace, separator, or end of string anything else is bad
+ parsedNames[stringCount] = sb.ToString();
+ state = MPIState.MPI_LookForSeparator;
+ }
+ break;
+ }
+
+ case MPIState.MPI_LookForSeparator:
+ {
+ if (!IsWhitespace(testchar))
+ { // If it is not whitespace
+ if (testchar == separator)
+ { // If it is a separator
+ IncrementStringCount(name, parsedNames, ref stringCount, property);
+ state = MPIState.MPI_Value;
+ }
+ else
+ { // Otherwise not a separator
+ throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name);
+ }
+ }
+ break;
+ }
+ }
+ }
+
+ // Resolve final states after parsing the string
+ switch (state)
+ {
+ case MPIState.MPI_Value: // These states require no extra action
+ case MPIState.MPI_LookForSeparator:
+ case MPIState.MPI_LookForNextCharOrSeparator:
+ break;
+
+ case MPIState.MPI_ParseNonQuote: // Dump what ever was parsed
+ case MPIState.MPI_RightQuote:
+ parsedNames[stringCount] = sb.ToString();
+ break;
+
+ case MPIState.MPI_ParseQuote: // Invalid Ending States
+ default:
+ throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name);
+ }
+
+ if (parsedNames[0] == null)
+ {
+ if (ThrowOnEmptyMultipartName)
+ {
+ throw ADP.InvalidMultipartName(property, name); // Name is entirely made up of whitespace
+ }
+ }
+ else
+ {
+ // Shuffle the parsed name, from left justification to right justification, i.e. [a][b][null][null] goes to [null][null][a][b]
+ int offset = limit - stringCount - 1;
+ if (offset > 0)
+ {
+ for (int x = limit - 1; x >= offset; --x)
+ {
+ parsedNames[x] = parsedNames[x - offset];
+ parsedNames[x - offset] = null;
+ }
+ }
+ }
+ return parsedNames;
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/NameValuePair.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/NameValuePair.cs
new file mode 100644
index 0000000000..f0cfc71d53
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/NameValuePair.cs
@@ -0,0 +1,53 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Diagnostics;
+using System.Runtime.Serialization;
+
+namespace Microsoft.Data.Common
+{
+ [Serializable]
+ internal sealed class NameValuePair
+ {
+ readonly private string _name;
+ readonly private string _value;
+ [OptionalField(VersionAdded = 2)]
+ readonly private int _length;
+ private NameValuePair _next;
+
+ internal NameValuePair(string name, string value, int length)
+ {
+ Debug.Assert(!string.IsNullOrEmpty(name), "empty keyname");
+ _name = name;
+ _value = value;
+ _length = length;
+ }
+
+ internal int Length
+ {
+ get
+ {
+ Debug.Assert(0 < _length, "NameValuePair zero Length usage");
+ return _length;
+ }
+ }
+
+ internal string Name => _name;
+ internal string Value => _value;
+
+ internal NameValuePair Next
+ {
+ get => _next;
+ set
+ {
+ if ((null != _next) || (null == value))
+ {
+ throw ADP.InternalError(ADP.InternalErrorCode.NameValuePairNext);
+ }
+ _next = value;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/DataException.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/DataException.cs
new file mode 100644
index 0000000000..e8a49ffa33
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/DataException.cs
@@ -0,0 +1,56 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Diagnostics;
+using Microsoft.Data.SqlClient;
+
+namespace Microsoft.Data
+{
+ internal static class ExceptionBuilder
+ {
+ // The class defines the exceptions that are specific to the DataSet.
+ // The class contains functions that take the proper informational variables and then construct
+ // the appropriate exception with an error string obtained from the resource Data.txt.
+ // The exception is then returned to the caller, so that the caller may then throw from its
+ // location so that the catcher of the exception will have the appropriate call stack.
+ // This class is used so that there will be compile time checking of error messages.
+ // The resource Data.txt will ensure proper string text based on the appropriate
+ // locale.
+
+ private static void TraceException(string trace, Exception e)
+ {
+ Debug.Assert(null != e, "TraceException: null Exception");
+ if (null != e)
+ {
+ SqlClientEventSource.Log.TryAdvancedTraceEvent(trace, e.Message);
+ try
+ {
+ SqlClientEventSource.Log.TryAdvancedTraceEvent(" Environment StackTrace = '{0}'", Environment.StackTrace);
+ }
+ catch (System.Security.SecurityException)
+ {
+ // if you don't have permission - you don't get the stack trace
+ }
+ }
+ }
+
+ internal static void TraceExceptionAsReturnValue(Exception e)
+ {
+ TraceException(" Message='{0}'", e);
+ }
+
+ internal static ArgumentException _Argument(string error)
+ {
+ ArgumentException e = new ArgumentException(error);
+ ExceptionBuilder.TraceExceptionAsReturnValue(e);
+ return e;
+ }
+
+ internal static Exception InvalidOffsetLength()
+ {
+ return _Argument(StringsHelper.GetString(Strings.Data_InvalidOffsetLength));
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/OperationAbortedException.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/OperationAbortedException.cs
new file mode 100644
index 0000000000..537a4ac0a1
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/OperationAbortedException.cs
@@ -0,0 +1,40 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Runtime.Serialization;
+using Microsoft.Data.Common;
+
+namespace Microsoft.Data
+{
+ ///
+ [Serializable]
+ [System.Runtime.CompilerServices.TypeForwardedFrom("System.Data, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
+ public sealed class OperationAbortedException : SystemException
+ {
+ private OperationAbortedException(string message, Exception innerException) : base(message, innerException)
+ {
+ HResult = unchecked((int)0x80131936);
+ }
+
+ private OperationAbortedException(SerializationInfo info, StreamingContext context) : base(info, context)
+ {
+ }
+
+ internal static OperationAbortedException Aborted(Exception inner)
+ {
+ OperationAbortedException e;
+ if (inner == null)
+ {
+ e = new OperationAbortedException(Strings.ADP_OperationAborted, null);
+ }
+ else
+ {
+ e = new OperationAbortedException(Strings.ADP_OperationAbortedExceptionMessage, inner);
+ }
+ ADP.TraceExceptionAsReturnValue(e);
+ return e;
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContext.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContext.cs
new file mode 100644
index 0000000000..ec6b695429
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContext.cs
@@ -0,0 +1,112 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Diagnostics;
+using System.Runtime.ConstrainedExecution;
+using System.Threading;
+
+namespace Microsoft.Data.ProviderBase
+{
+ ///
+ /// Represents the context of an authentication attempt when using the new active directory based authentication mechanisms.
+ /// All data members, except_isUpdateInProgressCounter, should be immutable.
+ ///
+ sealed internal class DbConnectionPoolAuthenticationContext
+ {
+ ///
+ /// The value expected in _isUpdateInProgress if a thread has taken a lock on this context,
+ /// to perform the update on the context.
+ ///
+ private const int STATUS_LOCKED = 1;
+
+ ///
+ /// The value expected in _isUpdateInProgress if no thread has taken a lock on this context.
+ ///
+ private const int STATUS_UNLOCKED = 0;
+
+ ///
+ /// Access Token, which is obtained from Active Directory Authentication Library for SQL Server, and needs to be sent to SQL Server
+ /// as part of TDS Token type Federated Authentication Token.
+ ///
+ private readonly byte[] _accessToken;
+
+ ///
+ /// Expiration time of the above access token.
+ ///
+ private readonly DateTime _expirationTime;
+
+ ///
+ /// A member which is used to achieve a lock to control refresh attempt on this context.
+ ///
+ private int _isUpdateInProgress;
+
+ ///
+ /// Constructor.
+ ///
+ /// Access Token that will be used to connect to SQL Server. Carries identity information about a user.
+ /// The expiration time in UTC for the above accessToken.
+ internal DbConnectionPoolAuthenticationContext(byte[] accessToken, DateTime expirationTime)
+ {
+
+ Debug.Assert(accessToken != null && accessToken.Length > 0);
+ Debug.Assert(expirationTime > DateTime.MinValue && expirationTime < DateTime.MaxValue);
+
+ _accessToken = accessToken;
+ _expirationTime = expirationTime;
+ _isUpdateInProgress = STATUS_UNLOCKED;
+ }
+
+ ///
+ /// Static Method.
+ /// Given two contexts, choose one to update in the cache. Chooses based on expiration time.
+ ///
+ /// Context1
+ /// Context2
+ internal static DbConnectionPoolAuthenticationContext ChooseAuthenticationContextToUpdate(DbConnectionPoolAuthenticationContext context1, DbConnectionPoolAuthenticationContext context2)
+ {
+
+ Debug.Assert(context1 != null, "context1 should not be null.");
+ Debug.Assert(context2 != null, "context2 should not be null.");
+
+ return context1.ExpirationTime > context2.ExpirationTime ? context1 : context2;
+ }
+
+ internal byte[] AccessToken
+ {
+ get
+ {
+ return _accessToken;
+ }
+ }
+
+ internal DateTime ExpirationTime
+ {
+ get
+ {
+ return _expirationTime;
+ }
+ }
+
+ ///
+ /// Try locking the variable _isUpdateInProgressCounter and return if this thread got the lock to update.
+ /// Whichever thread got the chance to update this variable to 1 wins the lock.
+ ///
+ internal bool LockToUpdate()
+ {
+ int oldValue = Interlocked.CompareExchange(ref _isUpdateInProgress, STATUS_LOCKED, STATUS_UNLOCKED);
+ return (oldValue == STATUS_UNLOCKED);
+ }
+
+ ///
+ /// Release the lock which was obtained through LockToUpdate.
+ ///
+ [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)]
+ internal void ReleaseLockToUpdate()
+ {
+ int oldValue = Interlocked.CompareExchange(ref _isUpdateInProgress, STATUS_UNLOCKED, STATUS_LOCKED);
+ Debug.Assert(oldValue == STATUS_LOCKED);
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContextKey.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContextKey.cs
new file mode 100644
index 0000000000..a6f15ca999
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContextKey.cs
@@ -0,0 +1,112 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Diagnostics;
+
+namespace Microsoft.Data.ProviderBase
+{
+ ///
+ /// Represents the key of dbConnectionPoolAuthenticationContext.
+ /// All data members should be immutable and so, hashCode is pre-computed.
+ ///
+ sealed internal class DbConnectionPoolAuthenticationContextKey
+ {
+ ///
+ /// Security Token Service Authority.
+ ///
+ private readonly string _stsAuthority;
+
+ ///
+ /// Service Principal Name.
+ ///
+ private readonly string _servicePrincipalName;
+
+ ///
+ /// Pre-Computed Hash Code.
+ ///
+ private readonly int _hashCode;
+
+ internal string StsAuthority
+ {
+ get
+ {
+ return _stsAuthority;
+ }
+ }
+
+ internal string ServicePrincipalName
+ {
+ get
+ {
+ return _servicePrincipalName;
+ }
+ }
+
+ ///
+ /// Constructor for the type.
+ ///
+ /// Token Endpoint URL
+ /// SPN representing the SQL service in an active directory.
+ internal DbConnectionPoolAuthenticationContextKey(string stsAuthority, string servicePrincipalName)
+ {
+ Debug.Assert(!string.IsNullOrWhiteSpace(stsAuthority));
+ Debug.Assert(!string.IsNullOrWhiteSpace(servicePrincipalName));
+
+ _stsAuthority = stsAuthority;
+ _servicePrincipalName = servicePrincipalName;
+
+ // Pre-compute hash since data members are not going to change.
+ _hashCode = ComputeHashCode();
+ }
+
+ ///
+ /// Override the default Equals implementation.
+ ///
+ ///
+ ///
+ public override bool Equals(object obj)
+ {
+ if (obj == null)
+ {
+ return false;
+ }
+
+ DbConnectionPoolAuthenticationContextKey otherKey = obj as DbConnectionPoolAuthenticationContextKey;
+ if (otherKey == null)
+ {
+ return false;
+ }
+
+ return (String.Equals(StsAuthority, otherKey.StsAuthority, StringComparison.InvariantCultureIgnoreCase)
+ && String.Equals(ServicePrincipalName, otherKey.ServicePrincipalName, StringComparison.InvariantCultureIgnoreCase));
+ }
+
+ ///
+ /// Override the default GetHashCode implementation.
+ ///
+ ///
+ public override int GetHashCode()
+ {
+ return _hashCode;
+ }
+
+ ///
+ /// Compute the hash code for this object.
+ ///
+ ///
+ private int ComputeHashCode()
+ {
+ int hashCode = 33;
+
+ unchecked
+ {
+ hashCode = (hashCode * 17) + StsAuthority.GetHashCode();
+ hashCode = (hashCode * 17) + ServicePrincipalName.GetHashCode();
+ }
+
+ return hashCode;
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroup.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroup.cs
new file mode 100644
index 0000000000..7568340594
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroup.cs
@@ -0,0 +1,312 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+
+using Microsoft.Data.Common;
+using Microsoft.Data.SqlClient;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Threading;
+
+namespace Microsoft.Data.ProviderBase
+{
+ // set_ConnectionString calls DbConnectionFactory.GetConnectionPoolGroup
+ // when not found a new pool entry is created and potentially added
+ // DbConnectionPoolGroup starts in the Active state
+
+ // Open calls DbConnectionFactory.GetConnectionPool
+ // if the existing pool entry is Disabled, GetConnectionPoolGroup is called for a new entry
+ // DbConnectionFactory.GetConnectionPool calls DbConnectionPoolGroup.GetConnectionPool
+
+ // DbConnectionPoolGroup.GetConnectionPool will return pool for the current identity
+ // or null if identity is restricted or pooling is disabled or state is disabled at time of add
+ // state changes are Active->Active, Idle->Active
+
+ // DbConnectionFactory.PruneConnectionPoolGroups calls Prune
+ // which will QueuePoolForRelease on all empty pools
+ // and once no pools remain, change state from Active->Idle->Disabled
+ // Once Disabled, factory can remove its reference to the pool entry
+
+ sealed internal class DbConnectionPoolGroup
+ {
+ private readonly DbConnectionOptions _connectionOptions;
+ private readonly DbConnectionPoolKey _poolKey;
+ private readonly DbConnectionPoolGroupOptions _poolGroupOptions;
+ private ConcurrentDictionary _poolCollection;
+
+ private int _state; // see PoolGroupState* below
+
+ private DbConnectionPoolGroupProviderInfo _providerInfo;
+ private DbMetaDataFactory _metaDataFactory;
+
+ private static int s_objectTypeCount; // EventSource counter
+
+ // always lock this before changing _state, we don't want to move out of the 'Disabled' state
+ // PoolGroupStateUninitialized = 0;
+ private const int PoolGroupStateActive = 1; // initial state, GetPoolGroup from cache, connection Open
+ private const int PoolGroupStateIdle = 2; // all pools are pruned via Clear
+ private const int PoolGroupStateDisabled = 4; // factory pool entry pruning method
+
+ internal DbConnectionPoolGroup(DbConnectionOptions connectionOptions, DbConnectionPoolKey key, DbConnectionPoolGroupOptions poolGroupOptions)
+ {
+ Debug.Assert(null != connectionOptions, "null connection options");
+#if NETFRAMEWORK
+ Debug.Assert(null == poolGroupOptions || ADP.s_isWindowsNT, "should not have pooling options on Win9x");
+#endif
+
+ _connectionOptions = connectionOptions;
+ _poolKey = key;
+ _poolGroupOptions = poolGroupOptions;
+
+ // always lock this object before changing state
+ // HybridDictionary does not create any sub-objects until add
+ // so it is safe to use for non-pooled connection as long as
+ // we check _poolGroupOptions first
+ _poolCollection = new ConcurrentDictionary();
+ _state = PoolGroupStateActive;
+ }
+
+ internal DbConnectionOptions ConnectionOptions => _connectionOptions;
+
+ internal DbConnectionPoolKey PoolKey => _poolKey;
+
+ internal DbConnectionPoolGroupProviderInfo ProviderInfo
+ {
+ get
+ {
+ return _providerInfo;
+ }
+ set
+ {
+ _providerInfo = value;
+ if (null != value)
+ {
+ _providerInfo.PoolGroup = this;
+ }
+ }
+ }
+
+ internal bool IsDisabled => (PoolGroupStateDisabled == _state);
+
+ internal int ObjectID { get; } = Interlocked.Increment(ref s_objectTypeCount);
+
+ internal DbConnectionPoolGroupOptions PoolGroupOptions => _poolGroupOptions;
+
+ internal DbMetaDataFactory MetaDataFactory
+ {
+ get
+ {
+ return _metaDataFactory;
+ }
+
+ set
+ {
+ _metaDataFactory = value;
+ }
+ }
+
+ internal int Clear()
+ {
+ // must be multi-thread safe with competing calls by Clear and Prune via background thread
+ // will return the number of connections in the group after clearing has finished
+
+ // First, note the old collection and create a new collection to be used
+ ConcurrentDictionary oldPoolCollection = null;
+ lock (this)
+ {
+ if (_poolCollection.Count > 0)
+ {
+ oldPoolCollection = _poolCollection;
+ _poolCollection = new ConcurrentDictionary();
+ }
+ }
+
+ // Then, if a new collection was created, release the pools from the old collection
+ if (oldPoolCollection != null)
+ {
+ foreach (KeyValuePair entry in oldPoolCollection)
+ {
+ DbConnectionPool pool = entry.Value;
+ if (pool != null)
+ {
+ DbConnectionFactory connectionFactory = pool.ConnectionFactory;
+#if NETFRAMEWORK
+ connectionFactory.PerformanceCounters.NumberOfActiveConnectionPools.Decrement();
+#endif
+ connectionFactory.QueuePoolForRelease(pool, true);
+ }
+ }
+ }
+
+ // Finally, return the pool collection count - this may be non-zero if something was added while we were clearing
+ return _poolCollection.Count;
+ }
+
+ internal DbConnectionPool GetConnectionPool(DbConnectionFactory connectionFactory)
+ {
+ // When this method returns null it indicates that the connection
+ // factory should not use pooling.
+
+ // We don't support connection pooling on Win9x;
+ // PoolGroupOptions will only be null when we're not supposed to pool
+ // connections.
+ DbConnectionPool pool = null;
+ if (null != _poolGroupOptions)
+ {
+#if NETFRAMEWORK
+ Debug.Assert(ADP.s_isWindowsNT, "should not be pooling on Win9x");
+#endif
+
+ DbConnectionPoolIdentity currentIdentity = DbConnectionPoolIdentity.NoIdentity;
+
+ if (_poolGroupOptions.PoolByIdentity)
+ {
+ // if we're pooling by identity (because integrated security is
+ // being used for these connections) then we need to go out and
+ // search for the connectionPool that matches the current identity.
+
+ currentIdentity = DbConnectionPoolIdentity.GetCurrent();
+
+ // If the current token is restricted in some way, then we must
+ // not attempt to pool these connections.
+ if (currentIdentity.IsRestricted)
+ {
+ currentIdentity = null;
+ }
+ }
+
+ if (null != currentIdentity)
+ {
+ if (!_poolCollection.TryGetValue(currentIdentity, out pool)) // find the pool
+ {
+ lock (this)
+ {
+ // Did someone already add it to the list?
+ if (!_poolCollection.TryGetValue(currentIdentity, out pool))
+ {
+ DbConnectionPoolProviderInfo connectionPoolProviderInfo = connectionFactory.CreateConnectionPoolProviderInfo(ConnectionOptions);
+ DbConnectionPool newPool = new(connectionFactory, this, currentIdentity, connectionPoolProviderInfo);
+
+ if (MarkPoolGroupAsActive())
+ {
+ // If we get here, we know for certain that we there isn't
+ // a pool that matches the current identity, so we have to
+ // add the optimistically created one
+ newPool.Startup(); // must start pool before usage
+ bool addResult = _poolCollection.TryAdd(currentIdentity, newPool);
+ Debug.Assert(addResult, "No other pool with current identity should exist at this point");
+ SqlClientEventSource.Log.EnterActiveConnectionPool();
+#if NETFRAMEWORK
+ connectionFactory.PerformanceCounters.NumberOfActiveConnectionPools.Increment();
+#endif
+ pool = newPool;
+ }
+ else
+ {
+ // else pool entry has been disabled so don't create new pools
+ Debug.Assert(PoolGroupStateDisabled == _state, "state should be disabled");
+
+ // don't need to call connectionFactory.QueuePoolForRelease(newPool) because
+ // pool callbacks were delayed and no risk of connections being created
+ newPool.Shutdown();
+ }
+ }
+ else
+ {
+ // else found an existing pool to use instead
+ Debug.Assert(PoolGroupStateActive == _state, "state should be active since a pool exists and lock holds");
+ }
+ }
+ }
+ // the found pool could be in any state
+ }
+ }
+
+ if (null == pool)
+ {
+ lock (this)
+ {
+ // keep the pool entry state active when not pooling
+ MarkPoolGroupAsActive();
+ }
+ }
+ return pool;
+ }
+
+ private bool MarkPoolGroupAsActive()
+ {
+ // when getting a connection, make the entry active if it was idle (but not disabled)
+ // must always lock this before calling
+
+ if (PoolGroupStateIdle == _state)
+ {
+ _state = PoolGroupStateActive;
+ SqlClientEventSource.Log.TryTraceEvent(" {0}, Active", ObjectID);
+ }
+ return (PoolGroupStateActive == _state);
+ }
+
+ internal bool Prune()
+ {
+ // must only call from DbConnectionFactory.PruneConnectionPoolGroups on background timer thread
+ // must lock(DbConnectionFactory._connectionPoolGroups.SyncRoot) before calling ReadyToRemove
+ // to avoid conflict with DbConnectionFactory.CreateConnectionPoolGroup replacing pool entry
+ lock (this)
+ {
+ if (_poolCollection.Count > 0)
+ {
+ var newPoolCollection = new ConcurrentDictionary();
+
+ foreach (KeyValuePair entry in _poolCollection)
+ {
+ DbConnectionPool pool = entry.Value;
+ if (pool != null)
+ {
+ // Actually prune the pool if there are no connections in the pool and no errors occurred.
+ // Empty pool during pruning indicates zero or low activity, but
+ // an error state indicates the pool needs to stay around to
+ // throttle new connection attempts.
+ if ((!pool.ErrorOccurred) && (0 == pool.Count))
+ {
+ // Order is important here. First we remove the pool
+ // from the collection of pools so no one will try
+ // to use it while we're processing and finally we put the
+ // pool into a list of pools to be released when they
+ // are completely empty.
+ DbConnectionFactory connectionFactory = pool.ConnectionFactory;
+#if NETFRAMEWORK
+ connectionFactory.PerformanceCounters.NumberOfActiveConnectionPools.Decrement();
+#endif
+ connectionFactory.QueuePoolForRelease(pool, false);
+ }
+ else
+ {
+ newPoolCollection.TryAdd(entry.Key, entry.Value);
+ }
+ }
+ }
+ _poolCollection = newPoolCollection;
+ }
+
+ // must be pruning thread to change state and no connections
+ // otherwise pruning thread risks making entry disabled soon after user calls ClearPool
+ if (0 == _poolCollection.Count)
+ {
+ if (PoolGroupStateActive == _state)
+ {
+ _state = PoolGroupStateIdle;
+ SqlClientEventSource.Log.TryTraceEvent(" {0}, Idle", ObjectID);
+ }
+ else if (PoolGroupStateIdle == _state)
+ {
+ _state = PoolGroupStateDisabled;
+ SqlClientEventSource.Log.TryTraceEvent(" {0}, Disabled", ObjectID);
+ }
+ }
+ return (PoolGroupStateDisabled == _state);
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroupProviderInfo.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroupProviderInfo.cs
new file mode 100644
index 0000000000..3eceb6d3e3
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroupProviderInfo.cs
@@ -0,0 +1,23 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Data.ProviderBase
+{
+ internal class DbConnectionPoolGroupProviderInfo
+ {
+ private DbConnectionPoolGroup _poolGroup;
+
+ internal DbConnectionPoolGroup PoolGroup
+ {
+ get
+ {
+ return _poolGroup;
+ }
+ set
+ {
+ _poolGroup = value;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolOptions.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolOptions.cs
new file mode 100644
index 0000000000..866453432c
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolOptions.cs
@@ -0,0 +1,73 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.Data.ProviderBase
+{
+ internal sealed class DbConnectionPoolGroupOptions
+ {
+ private readonly bool _poolByIdentity;
+ private readonly int _minPoolSize;
+ private readonly int _maxPoolSize;
+ private readonly int _creationTimeout;
+ private readonly TimeSpan _loadBalanceTimeout;
+ private readonly bool _hasTransactionAffinity;
+ private readonly bool _useLoadBalancing;
+
+ public DbConnectionPoolGroupOptions(
+ bool poolByIdentity,
+ int minPoolSize,
+ int maxPoolSize,
+ int creationTimeout,
+ int loadBalanceTimeout,
+ bool hasTransactionAffinity
+ )
+ {
+ _poolByIdentity = poolByIdentity;
+ _minPoolSize = minPoolSize;
+ _maxPoolSize = maxPoolSize;
+ _creationTimeout = creationTimeout;
+
+ if (0 != loadBalanceTimeout)
+ {
+ _loadBalanceTimeout = new TimeSpan(0, 0, loadBalanceTimeout);
+ _useLoadBalancing = true;
+ }
+
+ _hasTransactionAffinity = hasTransactionAffinity;
+ }
+
+ public int CreationTimeout
+ {
+ get { return _creationTimeout; }
+ }
+ public bool HasTransactionAffinity
+ {
+ get { return _hasTransactionAffinity; }
+ }
+ public TimeSpan LoadBalanceTimeout
+ {
+ get { return _loadBalanceTimeout; }
+ }
+ public int MaxPoolSize
+ {
+ get { return _maxPoolSize; }
+ }
+ public int MinPoolSize
+ {
+ get { return _minPoolSize; }
+ }
+ public bool PoolByIdentity
+ {
+ get { return _poolByIdentity; }
+ }
+ public bool UseLoadBalancing
+ {
+ get { return _useLoadBalancing; }
+ }
+ }
+}
+
+
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolProviderInfo.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolProviderInfo.cs
new file mode 100644
index 0000000000..5392795dff
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolProviderInfo.cs
@@ -0,0 +1,10 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Data.ProviderBase
+{
+ internal class DbConnectionPoolProviderInfo
+ {
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbMetaDataFactory.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbMetaDataFactory.cs
new file mode 100644
index 0000000000..6e907d26e1
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbMetaDataFactory.cs
@@ -0,0 +1,558 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Data.Common;
+using System;
+using System.Data;
+using System.Data.Common;
+using System.Diagnostics;
+using System.Globalization;
+using System.IO;
+
+namespace Microsoft.Data.ProviderBase
+{
+ internal class DbMetaDataFactory
+ {
+
+ private DataSet _metaDataCollectionsDataSet;
+ private string _normalizedServerVersion;
+ private string _serverVersionString;
+ // well known column names
+ private const string CollectionNameKey = "CollectionName";
+ private const string PopulationMechanismKey = "PopulationMechanism";
+ private const string PopulationStringKey = "PopulationString";
+ private const string MaximumVersionKey = "MaximumVersion";
+ private const string MinimumVersionKey = "MinimumVersion";
+ private const string DataSourceProductVersionNormalizedKey = "DataSourceProductVersionNormalized";
+ private const string DataSourceProductVersionKey = "DataSourceProductVersion";
+ private const string RestrictionNumberKey = "RestrictionNumber";
+ private const string NumberOfRestrictionsKey = "NumberOfRestrictions";
+ private const string RestrictionNameKey = "RestrictionName";
+ private const string ParameterNameKey = "ParameterName";
+
+ // population mechanisms
+ private const string DataTableKey = "DataTable";
+ private const string SqlCommandKey = "SQLCommand";
+ private const string PrepareCollectionKey = "PrepareCollection";
+
+ public DbMetaDataFactory(Stream xmlStream, string serverVersion, string normalizedServerVersion)
+ {
+ ADP.CheckArgumentNull(xmlStream, nameof(xmlStream));
+ ADP.CheckArgumentNull(serverVersion, nameof(serverVersion));
+ ADP.CheckArgumentNull(normalizedServerVersion, nameof(normalizedServerVersion));
+
+ LoadDataSetFromXml(xmlStream);
+
+ _serverVersionString = serverVersion;
+ _normalizedServerVersion = normalizedServerVersion;
+ }
+
+ protected DataSet CollectionDataSet => _metaDataCollectionsDataSet;
+
+ protected string ServerVersion => _serverVersionString;
+
+ protected string ServerVersionNormalized => _normalizedServerVersion;
+
+ protected DataTable CloneAndFilterCollection(string collectionName, string[] hiddenColumnNames)
+ {
+ DataTable destinationTable;
+ DataColumn[] filteredSourceColumns;
+ DataColumnCollection destinationColumns;
+ DataRow newRow;
+
+ DataTable sourceTable = _metaDataCollectionsDataSet.Tables[collectionName];
+
+ if ((sourceTable == null) || (collectionName != sourceTable.TableName))
+ {
+ throw ADP.DataTableDoesNotExist(collectionName);
+ }
+
+ destinationTable = new DataTable(collectionName)
+ {
+ Locale = CultureInfo.InvariantCulture
+ };
+ destinationColumns = destinationTable.Columns;
+
+ filteredSourceColumns = FilterColumns(sourceTable, hiddenColumnNames, destinationColumns);
+
+ foreach (DataRow row in sourceTable.Rows)
+ {
+ if (SupportedByCurrentVersion(row))
+ {
+ newRow = destinationTable.NewRow();
+ for (int i = 0; i < destinationColumns.Count; i++)
+ {
+ newRow[destinationColumns[i]] = row[filteredSourceColumns[i], DataRowVersion.Current];
+ }
+ destinationTable.Rows.Add(newRow);
+ newRow.AcceptChanges();
+ }
+ }
+
+ return destinationTable;
+ }
+
+ public void Dispose() => Dispose(true);
+
+ protected virtual void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ _normalizedServerVersion = null;
+ _serverVersionString = null;
+ _metaDataCollectionsDataSet.Dispose();
+ }
+ }
+
+ private DataTable ExecuteCommand(DataRow requestedCollectionRow, string[] restrictions, DbConnection connection)
+ {
+ DataTable metaDataCollectionsTable = _metaDataCollectionsDataSet.Tables[DbMetaDataCollectionNames.MetaDataCollections];
+ DataColumn populationStringColumn = metaDataCollectionsTable.Columns[PopulationStringKey];
+ DataColumn numberOfRestrictionsColumn = metaDataCollectionsTable.Columns[NumberOfRestrictionsKey];
+ DataColumn collectionNameColumn = metaDataCollectionsTable.Columns[CollectionNameKey];
+
+ DataTable resultTable = null;
+
+ Debug.Assert(requestedCollectionRow != null);
+ string sqlCommand = requestedCollectionRow[populationStringColumn, DataRowVersion.Current] as string;
+ int numberOfRestrictions = (int)requestedCollectionRow[numberOfRestrictionsColumn, DataRowVersion.Current];
+ string collectionName = requestedCollectionRow[collectionNameColumn, DataRowVersion.Current] as string;
+
+ if ((restrictions != null) && (restrictions.Length > numberOfRestrictions))
+ {
+ throw ADP.TooManyRestrictions(collectionName);
+ }
+
+ DbCommand command = connection.CreateCommand();
+ command.CommandText = sqlCommand;
+ command.CommandTimeout = Math.Max(command.CommandTimeout, 180);
+
+ for (int i = 0; i < numberOfRestrictions; i++)
+ {
+
+ DbParameter restrictionParameter = command.CreateParameter();
+
+ if ((restrictions != null) && (restrictions.Length > i) && (restrictions[i] != null))
+ {
+ restrictionParameter.Value = restrictions[i];
+ }
+ else
+ {
+ // This is where we have to assign null to the value of the parameter.
+ restrictionParameter.Value = DBNull.Value;
+ }
+
+ restrictionParameter.ParameterName = GetParameterName(collectionName, i + 1);
+ restrictionParameter.Direction = ParameterDirection.Input;
+ command.Parameters.Add(restrictionParameter);
+ }
+
+ DbDataReader reader = null;
+ try
+ {
+ try
+ {
+ reader = command.ExecuteReader();
+ }
+ catch (Exception e)
+ {
+ if (!ADP.IsCatchableExceptionType(e))
+ {
+ throw;
+ }
+ throw ADP.QueryFailed(collectionName, e);
+ }
+
+ // Build a DataTable from the reader
+ resultTable = new DataTable(collectionName)
+ {
+ Locale = CultureInfo.InvariantCulture
+ };
+
+ DataTable schemaTable = reader.GetSchemaTable();
+ foreach (DataRow row in schemaTable.Rows)
+ {
+ resultTable.Columns.Add(row["ColumnName"] as string, (Type)row["DataType"] as Type);
+ }
+ object[] values = new object[resultTable.Columns.Count];
+ while (reader.Read())
+ {
+ reader.GetValues(values);
+ resultTable.Rows.Add(values);
+ }
+ }
+ finally
+ {
+ reader?.Dispose();
+ }
+ return resultTable;
+ }
+
+ private DataColumn[] FilterColumns(DataTable sourceTable, string[] hiddenColumnNames, DataColumnCollection destinationColumns)
+ {
+ int columnCount = 0;
+ foreach (DataColumn sourceColumn in sourceTable.Columns)
+ {
+ if (IncludeThisColumn(sourceColumn, hiddenColumnNames))
+ {
+ columnCount++;
+ }
+ }
+
+ if (columnCount == 0)
+ {
+ throw ADP.NoColumns();
+ }
+
+ int currentColumn = 0;
+ DataColumn[] filteredSourceColumns = new DataColumn[columnCount];
+
+ foreach (DataColumn sourceColumn in sourceTable.Columns)
+ {
+ if (IncludeThisColumn(sourceColumn, hiddenColumnNames))
+ {
+ DataColumn newDestinationColumn = new(sourceColumn.ColumnName, sourceColumn.DataType);
+ destinationColumns.Add(newDestinationColumn);
+ filteredSourceColumns[currentColumn] = sourceColumn;
+ currentColumn++;
+ }
+ }
+ return filteredSourceColumns;
+ }
+
+ internal DataRow FindMetaDataCollectionRow(string collectionName)
+ {
+ bool versionFailure;
+ bool haveExactMatch;
+ bool haveMultipleInexactMatches;
+ string candidateCollectionName;
+
+ DataTable metaDataCollectionsTable = _metaDataCollectionsDataSet.Tables[DbMetaDataCollectionNames.MetaDataCollections];
+ if (metaDataCollectionsTable == null)
+ {
+ throw ADP.InvalidXml();
+ }
+
+ DataColumn collectionNameColumn = metaDataCollectionsTable.Columns[DbMetaDataColumnNames.CollectionName];
+
+ if ((null == collectionNameColumn) || (typeof(string) != collectionNameColumn.DataType))
+ {
+ throw ADP.InvalidXmlMissingColumn(DbMetaDataCollectionNames.MetaDataCollections, DbMetaDataColumnNames.CollectionName);
+ }
+
+ DataRow requestedCollectionRow = null;
+ string exactCollectionName = null;
+
+ // find the requested collection
+ versionFailure = false;
+ haveExactMatch = false;
+ haveMultipleInexactMatches = false;
+
+ foreach (DataRow row in metaDataCollectionsTable.Rows)
+ {
+
+ candidateCollectionName = row[collectionNameColumn, DataRowVersion.Current] as string;
+ if (string.IsNullOrEmpty(candidateCollectionName))
+ {
+ throw ADP.InvalidXmlInvalidValue(DbMetaDataCollectionNames.MetaDataCollections, DbMetaDataColumnNames.CollectionName);
+ }
+
+ if (ADP.CompareInsensitiveInvariant(candidateCollectionName, collectionName))
+ {
+ if (!SupportedByCurrentVersion(row))
+ {
+ versionFailure = true;
+ }
+ else
+ {
+ if (collectionName == candidateCollectionName)
+ {
+ if (haveExactMatch)
+ {
+ throw ADP.CollectionNameIsNotUnique(collectionName);
+ }
+ requestedCollectionRow = row;
+ exactCollectionName = candidateCollectionName;
+ haveExactMatch = true;
+ }
+ else if (!haveExactMatch)
+ {
+ // have an inexact match - ok only if it is the only one
+ if (exactCollectionName != null)
+ {
+ // can't fail here becasue we may still find an exact match
+ haveMultipleInexactMatches = true;
+ }
+ requestedCollectionRow = row;
+ exactCollectionName = candidateCollectionName;
+ }
+ }
+ }
+ }
+
+ if (requestedCollectionRow == null)
+ {
+ if (!versionFailure)
+ {
+ throw ADP.UndefinedCollection(collectionName);
+ }
+ else
+ {
+ throw ADP.UnsupportedVersion(collectionName);
+ }
+ }
+
+ if (!haveExactMatch && haveMultipleInexactMatches)
+ {
+ throw ADP.AmbiguousCollectionName(collectionName);
+ }
+
+ return requestedCollectionRow;
+
+ }
+
+ private void FixUpVersion(DataTable dataSourceInfoTable)
+ {
+ Debug.Assert(dataSourceInfoTable.TableName == DbMetaDataCollectionNames.DataSourceInformation);
+ DataColumn versionColumn = dataSourceInfoTable.Columns[DataSourceProductVersionKey];
+ DataColumn normalizedVersionColumn = dataSourceInfoTable.Columns[DataSourceProductVersionNormalizedKey];
+
+ if ((versionColumn == null) || (normalizedVersionColumn == null))
+ {
+ throw ADP.MissingDataSourceInformationColumn();
+ }
+
+ if (dataSourceInfoTable.Rows.Count != 1)
+ {
+ throw ADP.IncorrectNumberOfDataSourceInformationRows();
+ }
+
+ DataRow dataSourceInfoRow = dataSourceInfoTable.Rows[0];
+
+ dataSourceInfoRow[versionColumn] = _serverVersionString;
+ dataSourceInfoRow[normalizedVersionColumn] = _normalizedServerVersion;
+ dataSourceInfoRow.AcceptChanges();
+ }
+
+
+ private string GetParameterName(string neededCollectionName, int neededRestrictionNumber)
+ {
+ DataColumn collectionName = null;
+ DataColumn parameterName = null;
+ DataColumn restrictionName = null;
+ DataColumn restrictionNumber = null;
+
+ string result = null;
+
+ DataTable restrictionsTable = _metaDataCollectionsDataSet.Tables[DbMetaDataCollectionNames.Restrictions];
+ if (restrictionsTable != null)
+ {
+ DataColumnCollection restrictionColumns = restrictionsTable.Columns;
+ if (restrictionColumns != null)
+ {
+ collectionName = restrictionColumns[DbMetaDataFactory.CollectionNameKey];
+ parameterName = restrictionColumns[ParameterNameKey];
+ restrictionName = restrictionColumns[RestrictionNameKey];
+ restrictionNumber = restrictionColumns[RestrictionNumberKey];
+ }
+ }
+
+ if ((parameterName == null) || (collectionName == null) || (restrictionName == null) || (restrictionNumber == null))
+ {
+ throw ADP.MissingRestrictionColumn();
+ }
+
+ foreach (DataRow restriction in restrictionsTable.Rows)
+ {
+
+ if (((string)restriction[collectionName] == neededCollectionName) &&
+ ((int)restriction[restrictionNumber] == neededRestrictionNumber) &&
+ (SupportedByCurrentVersion(restriction)))
+ {
+
+ result = (string)restriction[parameterName];
+ break;
+ }
+ }
+
+ if (result == null)
+ {
+ throw ADP.MissingRestrictionRow();
+ }
+
+ return result;
+ }
+
+ public virtual DataTable GetSchema(DbConnection connection, string collectionName, string[] restrictions)
+ {
+ Debug.Assert(_metaDataCollectionsDataSet != null);
+
+ DataTable metaDataCollectionsTable = _metaDataCollectionsDataSet.Tables[DbMetaDataCollectionNames.MetaDataCollections];
+ DataColumn populationMechanismColumn = metaDataCollectionsTable.Columns[PopulationMechanismKey];
+ DataColumn collectionNameColumn = metaDataCollectionsTable.Columns[DbMetaDataColumnNames.CollectionName];
+
+ string[] hiddenColumns;
+
+ DataRow requestedCollectionRow = FindMetaDataCollectionRow(collectionName);
+ string exactCollectionName = requestedCollectionRow[collectionNameColumn, DataRowVersion.Current] as string;
+
+ if (!ADP.IsEmptyArray(restrictions))
+ {
+
+ for (int i = 0; i < restrictions.Length; i++)
+ {
+ if ((restrictions[i] != null) && (restrictions[i].Length > 4096))
+ {
+ // use a non-specific error because no new beta 2 error messages are allowed
+ // TODO: will add a more descriptive error in RTM
+ throw ADP.NotSupported();
+ }
+ }
+ }
+
+ string populationMechanism = requestedCollectionRow[populationMechanismColumn, DataRowVersion.Current] as string;
+
+ DataTable requestedSchema;
+ switch (populationMechanism)
+ {
+
+ case DataTableKey:
+ if (exactCollectionName == DbMetaDataCollectionNames.MetaDataCollections)
+ {
+ hiddenColumns = new string[2];
+ hiddenColumns[0] = PopulationMechanismKey;
+ hiddenColumns[1] = PopulationStringKey;
+ }
+ else
+ {
+ hiddenColumns = null;
+ }
+ // none of the datatable collections support restrictions
+ if (!ADP.IsEmptyArray(restrictions))
+ {
+ throw ADP.TooManyRestrictions(exactCollectionName);
+ }
+
+
+ requestedSchema = CloneAndFilterCollection(exactCollectionName, hiddenColumns);
+
+ // TODO: Consider an alternate method that doesn't involve special casing -- perhaps _prepareCollection
+
+ // for the data source information table we need to fix up the version columns at run time
+ // since the version is determined at run time
+ if (exactCollectionName == DbMetaDataCollectionNames.DataSourceInformation)
+ {
+ FixUpVersion(requestedSchema);
+ }
+ break;
+
+ case SqlCommandKey:
+ requestedSchema = ExecuteCommand(requestedCollectionRow, restrictions, connection);
+ break;
+
+ case PrepareCollectionKey:
+ requestedSchema = PrepareCollection(exactCollectionName, restrictions, connection);
+ break;
+
+ default:
+ throw ADP.UndefinedPopulationMechanism(populationMechanism);
+ }
+
+ return requestedSchema;
+ }
+
+ private bool IncludeThisColumn(DataColumn sourceColumn, string[] hiddenColumnNames)
+ {
+
+ bool result = true;
+ string sourceColumnName = sourceColumn.ColumnName;
+
+ switch (sourceColumnName)
+ {
+
+ case MinimumVersionKey:
+ case MaximumVersionKey:
+ result = false;
+ break;
+
+ default:
+ if (hiddenColumnNames == null)
+ {
+ break;
+ }
+ for (int i = 0; i < hiddenColumnNames.Length; i++)
+ {
+ if (hiddenColumnNames[i] == sourceColumnName)
+ {
+ result = false;
+ break;
+ }
+ }
+ break;
+ }
+
+ return result;
+ }
+
+ private void LoadDataSetFromXml(Stream XmlStream)
+ {
+ _metaDataCollectionsDataSet = new DataSet
+ {
+ Locale = System.Globalization.CultureInfo.InvariantCulture
+ };
+ _metaDataCollectionsDataSet.ReadXml(XmlStream);
+ }
+
+ protected virtual DataTable PrepareCollection(string collectionName, string[] restrictions, DbConnection connection)
+ {
+ throw ADP.NotSupported();
+ }
+
+ private bool SupportedByCurrentVersion(DataRow requestedCollectionRow)
+ {
+ bool result = true;
+ DataColumnCollection tableColumns = requestedCollectionRow.Table.Columns;
+ DataColumn versionColumn;
+ object version;
+
+ // check the minimum version first
+ versionColumn = tableColumns[MinimumVersionKey];
+ if (versionColumn != null)
+ {
+ version = requestedCollectionRow[versionColumn];
+ if (version != null)
+ {
+ if (version != DBNull.Value)
+ {
+ if (0 > string.Compare(_normalizedServerVersion, (string)version, StringComparison.OrdinalIgnoreCase))
+ {
+ result = false;
+ }
+ }
+ }
+ }
+
+ // if the minimum version was ok what about the maximum version
+ if (result)
+ {
+ versionColumn = tableColumns[MaximumVersionKey];
+ if (versionColumn != null)
+ {
+ version = requestedCollectionRow[versionColumn];
+ if (version != null)
+ {
+ if (version != DBNull.Value)
+ {
+ if (0 < string.Compare(_normalizedServerVersion, (string)version, StringComparison.OrdinalIgnoreCase))
+ {
+ result = false;
+ }
+ }
+ }
+ }
+ }
+ return result;
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/FieldNameLookup.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/FieldNameLookup.cs
new file mode 100644
index 0000000000..41f67f9403
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/FieldNameLookup.cs
@@ -0,0 +1,117 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Data;
+using System.Globalization;
+using Microsoft.Data.Common;
+
+namespace Microsoft.Data.ProviderBase
+{
+ internal sealed class FieldNameLookup
+ {
+ private readonly string[] _fieldNames;
+ private readonly int _defaultLocaleID;
+
+ private Dictionary _fieldNameLookup;
+ private CompareInfo _compareInfo;
+
+ public FieldNameLookup(string[] fieldNames, int defaultLocaleID)
+ {
+ _defaultLocaleID = defaultLocaleID;
+ if (fieldNames == null)
+ {
+ throw ADP.ArgumentNull(nameof(fieldNames));
+ }
+ _fieldNames = fieldNames;
+ }
+
+ public FieldNameLookup(IDataReader reader, int defaultLocaleID)
+ {
+ _defaultLocaleID = defaultLocaleID;
+ string[] fieldNames = new string[reader.FieldCount];
+ for (int i = 0; i < fieldNames.Length; ++i)
+ {
+ fieldNames[i] = reader.GetName(i);
+ }
+ _fieldNames = fieldNames;
+ }
+
+ public int GetOrdinal(string fieldName)
+ {
+ if (fieldName == null)
+ {
+ throw ADP.ArgumentNull(nameof(fieldName));
+ }
+ int index = IndexOf(fieldName);
+ if (index == -1)
+ {
+ throw ADP.IndexOutOfRange(fieldName);
+ }
+ return index;
+ }
+
+ private int IndexOf(string fieldName)
+ {
+ if (_fieldNameLookup == null)
+ {
+ GenerateLookup();
+ }
+ if (!_fieldNameLookup.TryGetValue(fieldName, out int index))
+ {
+ index = LinearIndexOf(fieldName, CompareOptions.IgnoreCase);
+ if (index == -1)
+ {
+ // do the slow search now (kana, width insensitive comparison)
+ index = LinearIndexOf(fieldName, ADP.DefaultCompareOptions);
+ }
+ }
+
+ return index;
+ }
+
+ private CompareInfo GetCompareInfo()
+ {
+ if (_defaultLocaleID != -1)
+ {
+ return CompareInfo.GetCompareInfo(_defaultLocaleID);
+ }
+ return CultureInfo.InvariantCulture.CompareInfo;
+ }
+
+ private int LinearIndexOf(string fieldName, CompareOptions compareOptions)
+ {
+ if (_compareInfo == null)
+ {
+ _compareInfo = GetCompareInfo();
+ }
+
+ for (int index = 0; index < _fieldNames.Length; index++)
+ {
+ if (_compareInfo.Compare(fieldName, _fieldNames[index], compareOptions) == 0)
+ {
+ _fieldNameLookup[fieldName] = index;
+ return index;
+ }
+ }
+ return -1;
+ }
+
+ private void GenerateLookup()
+ {
+ int length = _fieldNames.Length;
+ Dictionary lookup = new Dictionary(length);
+
+ // walk the field names from the end to the beginning so that if a name exists
+ // multiple times the first (from beginning to end) index of it is stored
+ // in the hash table
+ for (int index = length - 1; 0 <= index; --index)
+ {
+ string fieldName = _fieldNames[index];
+ lookup[fieldName] = index;
+ }
+ _fieldNameLookup = lookup;
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/TimeoutTimer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/TimeoutTimer.cs
new file mode 100644
index 0000000000..9948b223d1
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/TimeoutTimer.cs
@@ -0,0 +1,185 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Data.Common;
+using System;
+using System.Diagnostics;
+
+namespace Microsoft.Data.ProviderBase
+{
+ // Purpose:
+ // Manages determining and tracking timeouts
+ //
+ // Intended use:
+ // Call StartXXXXTimeout() to get a timer with the given expiration point
+ // Get remaining time in appropriate format to pass to subsystem timeouts
+ // Check for timeout via IsExpired for checks in managed code.
+ // Simply abandon to GC when done.
+ internal class TimeoutTimer
+ {
+ //-------------------
+ // Fields
+ //-------------------
+ private long _timerExpire;
+ private bool _isInfiniteTimeout;
+ private long _originalTimerTicks;
+
+ //-------------------
+ // Timeout-setting methods
+ //-------------------
+
+ // Get a new timer that will expire in the given number of seconds
+ // For input, a value of zero seconds indicates infinite timeout
+ internal static TimeoutTimer StartSecondsTimeout(int seconds)
+ {
+ //--------------------
+ // Preconditions: None (seconds must conform to SetTimeoutSeconds requirements)
+
+ //--------------------
+ // Method body
+ var timeout = new TimeoutTimer();
+ timeout.SetTimeoutSeconds(seconds);
+
+ //---------------------
+ // Postconditions
+ Debug.Assert(timeout != null); // Need a valid timeouttimer if no error
+
+ return timeout;
+ }
+
+ // Get a new timer that will expire in the given number of milliseconds
+ // No current need to support infinite milliseconds timeout
+ internal static TimeoutTimer StartMillisecondsTimeout(long milliseconds)
+ {
+ //--------------------
+ // Preconditions
+ Debug.Assert(0 <= milliseconds);
+
+ //--------------------
+ // Method body
+ var timeout = new TimeoutTimer();
+ timeout._originalTimerTicks = milliseconds * TimeSpan.TicksPerMillisecond;
+ timeout._timerExpire = checked(ADP.TimerCurrent() + timeout._originalTimerTicks);
+ timeout._isInfiniteTimeout = false;
+
+ //---------------------
+ // Postconditions
+ Debug.Assert(timeout != null); // Need a valid timeouttimer if no error
+
+ return timeout;
+ }
+
+ //-------------------
+ // Methods for changing timeout
+ //-------------------
+
+ internal void SetTimeoutSeconds(int seconds)
+ {
+ //--------------------
+ // Preconditions
+ Debug.Assert(0 <= seconds || InfiniteTimeout == seconds); // no need to support negative seconds at present
+
+ //--------------------
+ // Method body
+ if (InfiniteTimeout == seconds)
+ {
+ _isInfiniteTimeout = true;
+ }
+ else
+ {
+ // Stash current time + timeout
+ _originalTimerTicks = ADP.TimerFromSeconds(seconds);
+ _timerExpire = checked(ADP.TimerCurrent() + _originalTimerTicks);
+ _isInfiniteTimeout = false;
+ }
+
+ //---------------------
+ // Postconditions:None
+ }
+
+ // Reset timer to original duration.
+ internal void Reset()
+ {
+ if (InfiniteTimeout == _originalTimerTicks)
+ {
+ _isInfiniteTimeout = true;
+ }
+ else
+ {
+ _timerExpire = checked(ADP.TimerCurrent() + _originalTimerTicks);
+ _isInfiniteTimeout = false;
+ }
+ }
+
+ //-------------------
+ // Timeout info properties
+ //-------------------
+
+ // Indicator for infinite timeout when starting a timer
+ internal static readonly long InfiniteTimeout = 0;
+
+ // Is this timer in an expired state?
+ internal bool IsExpired
+ {
+ get
+ {
+ return !IsInfinite && ADP.TimerHasExpired(_timerExpire);
+ }
+ }
+
+ // is this an infinite-timeout timer?
+ internal bool IsInfinite
+ {
+ get
+ {
+ return _isInfiniteTimeout;
+ }
+ }
+
+ // Special accessor for TimerExpire for use when thunking to legacy timeout methods.
+ internal long LegacyTimerExpire
+ {
+ get
+ {
+ return (_isInfiniteTimeout) ? long.MaxValue : _timerExpire;
+ }
+ }
+
+ // Returns milliseconds remaining trimmed to zero for none remaining
+ // and long.MaxValue for infinite
+ // This method should be preferred for internal calculations that are not
+ // yet common enough to code into the TimeoutTimer class itself.
+ internal long MillisecondsRemaining
+ {
+ get
+ {
+ //-------------------
+ // Preconditions: None
+
+ //-------------------
+ // Method Body
+ long milliseconds;
+ if (_isInfiniteTimeout)
+ {
+ milliseconds = long.MaxValue;
+ }
+ else
+ {
+ milliseconds = ADP.TimerRemainingMilliseconds(_timerExpire);
+ if (0 > milliseconds)
+ {
+ milliseconds = 0;
+ }
+ }
+
+ //--------------------
+ // Postconditions
+ Debug.Assert(0 <= milliseconds); // This property guarantees no negative return values
+
+ return milliseconds;
+ }
+ }
+ }
+}
+
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.Windows.cs
new file mode 100644
index 0000000000..83ce5085e7
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.Windows.cs
@@ -0,0 +1,22 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+using System.Data;
+using System.Data.Common;
+using Microsoft.Data.SqlClient.Server;
+
+namespace Microsoft.Data.Sql
+{
+ ///
+ public sealed partial class SqlDataSourceEnumerator : DbDataSourceEnumerator
+ {
+ private partial DataTable GetDataSourcesInternal()
+ {
+#if NETFRAMEWORK
+ return SqlDataSourceEnumeratorNativeHelper.GetDataSources();
+#else
+ return SqlClient.TdsParserStateObjectFactory.UseManagedSNI ? SqlDataSourceEnumeratorManagedHelper.GetDataSources() : SqlDataSourceEnumeratorNativeHelper.GetDataSources();
+#endif
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.cs
new file mode 100644
index 0000000000..e8f7aac29c
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.cs
@@ -0,0 +1,25 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+using System;
+using System.Data;
+using System.Data.Common;
+
+namespace Microsoft.Data.Sql
+{
+ ///
+ public sealed partial class SqlDataSourceEnumerator : DbDataSourceEnumerator
+ {
+ private static readonly Lazy s_singletonInstance = new(() => new SqlDataSourceEnumerator());
+
+ private SqlDataSourceEnumerator() : base(){}
+
+ ///
+ public static SqlDataSourceEnumerator Instance => s_singletonInstance.Value;
+
+ ///
+ override public DataTable GetDataSources() => GetDataSourcesInternal();
+
+ private partial DataTable GetDataSourcesInternal();
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorManagedHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorManagedHelper.cs
new file mode 100644
index 0000000000..43be666e0d
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorManagedHelper.cs
@@ -0,0 +1,75 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+using System.Collections.Generic;
+using System.Data;
+using Microsoft.Data.Sql;
+
+namespace Microsoft.Data.SqlClient.Server
+{
+ ///
+ /// Provides a mechanism for enumerating all available instances of SQL Server within the local network
+ ///
+ internal static class SqlDataSourceEnumeratorManagedHelper
+ {
+ ///
+ /// Provides a mechanism for enumerating all available instances of SQL Server within the local network.
+ ///
+ /// DataTable with ServerName,InstanceName,IsClustered and Version
+ internal static DataTable GetDataSources()
+ {
+ // TODO: Implement multicast request besides the implemented broadcast request.
+ throw new System.NotImplementedException(StringsHelper.net_MethodNotImplementedException);
+ }
+
+ private static DataTable ParseServerEnumString(string serverInstances)
+ {
+ DataTable dataTable = SqlDataSourceEnumeratorUtil.PrepareDataTable();
+ DataRow dataRow;
+
+ if (serverInstances.Length == 0)
+ {
+ return dataTable;
+ }
+
+ string[] numOfServerInstances = serverInstances.Split(SqlDataSourceEnumeratorUtil.s_endOfServerInstanceDelimiter_Managed, System.StringSplitOptions.None);
+ SqlClientEventSource.Log.TryTraceEvent(" Number of recieved server instances are {2}",
+ nameof(SqlDataSourceEnumeratorManagedHelper), nameof(ParseServerEnumString), numOfServerInstances.Length);
+
+ foreach (string currentServerInstance in numOfServerInstances)
+ {
+ Dictionary InstanceDetails = new();
+ string[] delimitedKeyValues = currentServerInstance.Split(SqlDataSourceEnumeratorUtil.InstanceKeysDelimiter);
+ string currentKey = string.Empty;
+
+ for (int keyvalue = 0; keyvalue < delimitedKeyValues.Length; keyvalue++)
+ {
+ if (keyvalue % 2 == 0)
+ {
+ currentKey = delimitedKeyValues[keyvalue];
+ }
+ else if (currentKey != string.Empty)
+ {
+ InstanceDetails.Add(currentKey, delimitedKeyValues[keyvalue]);
+ }
+ }
+
+ if (InstanceDetails.Count > 0)
+ {
+ dataRow = dataTable.NewRow();
+ dataRow[0] = InstanceDetails.ContainsKey(SqlDataSourceEnumeratorUtil.ServerNameCol) == true ?
+ InstanceDetails[SqlDataSourceEnumeratorUtil.ServerNameCol] : string.Empty;
+ dataRow[1] = InstanceDetails.ContainsKey(SqlDataSourceEnumeratorUtil.InstanceNameCol) == true ?
+ InstanceDetails[SqlDataSourceEnumeratorUtil.InstanceNameCol] : string.Empty;
+ dataRow[2] = InstanceDetails.ContainsKey(SqlDataSourceEnumeratorUtil.IsClusteredCol) == true ?
+ InstanceDetails[SqlDataSourceEnumeratorUtil.IsClusteredCol] : string.Empty;
+ dataRow[3] = InstanceDetails.ContainsKey(SqlDataSourceEnumeratorUtil.VersionNameCol) == true ?
+ InstanceDetails[SqlDataSourceEnumeratorUtil.VersionNameCol] : string.Empty;
+
+ dataTable.Rows.Add(dataRow);
+ }
+ }
+ return dataTable.SetColumnsReadOnly();
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorNativeHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorNativeHelper.cs
new file mode 100644
index 0000000000..f6ebfc4b8f
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorNativeHelper.cs
@@ -0,0 +1,179 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Data;
+using System.Diagnostics;
+using System.Runtime.CompilerServices;
+using System.Security;
+using System.Text;
+using Microsoft.Data.Common;
+using Microsoft.Data.SqlClient;
+using static Microsoft.Data.Sql.SqlDataSourceEnumeratorUtil;
+
+namespace Microsoft.Data.Sql
+{
+ ///
+ /// Provides a mechanism for enumerating all available instances of SQL Server within the local network
+ ///
+ internal static class SqlDataSourceEnumeratorNativeHelper
+ {
+ ///
+ /// Retrieves a DataTable containing information about all visible SQL Server instances
+ ///
+ ///
+ internal static DataTable GetDataSources()
+ {
+ (new NamedPermissionSet("FullTrust")).Demand(); // SQLBUDT 244304
+ char[] buffer = null;
+ StringBuilder strbldr = new();
+
+ int bufferSize = 1024;
+ int readLength = 0;
+ buffer = new char[bufferSize];
+ bool more = true;
+ bool failure = false;
+ IntPtr handle = ADP.s_ptrZero;
+
+ RuntimeHelpers.PrepareConstrainedRegions();
+ try
+ {
+ long s_timeoutTime = TdsParserStaticMethods.GetTimeoutSeconds(ADP.DefaultCommandTimeout);
+ RuntimeHelpers.PrepareConstrainedRegions();
+ try
+ { }
+ finally
+ {
+ handle = SNINativeMethodWrapper.SNIServerEnumOpen();
+ SqlClientEventSource.Log.TryTraceEvent(" {2} returned handle = {3}.",
+ nameof(SqlDataSourceEnumeratorNativeHelper),
+ nameof(GetDataSources),
+ nameof(SNINativeMethodWrapper.SNIServerEnumOpen), handle);
+ }
+
+ if (handle != ADP.s_ptrZero)
+ {
+ while (more && !TdsParserStaticMethods.TimeoutHasExpired(s_timeoutTime))
+ {
+ readLength = SNINativeMethodWrapper.SNIServerEnumRead(handle, buffer, bufferSize, out more);
+
+ SqlClientEventSource.Log.TryTraceEvent(" {2} returned 'readlength':{3}, and 'more':{4} with 'bufferSize' of {5}",
+ nameof(SqlDataSourceEnumeratorNativeHelper),
+ nameof(GetDataSources),
+ nameof(SNINativeMethodWrapper.SNIServerEnumRead),
+ readLength, more, bufferSize);
+ if (readLength > bufferSize)
+ {
+ failure = true;
+ more = false;
+ }
+ else if (readLength > 0)
+ {
+ strbldr.Append(buffer, 0, readLength);
+ }
+ }
+ }
+ }
+ finally
+ {
+ if (handle != ADP.s_ptrZero)
+ {
+ SNINativeMethodWrapper.SNIServerEnumClose(handle);
+ SqlClientEventSource.Log.TryTraceEvent(" {2} called.",
+ nameof(SqlDataSourceEnumeratorNativeHelper),
+ nameof(GetDataSources),
+ nameof(SNINativeMethodWrapper.SNIServerEnumClose));
+ }
+ }
+
+ if (failure)
+ {
+ Debug.Assert(false, $"{nameof(GetDataSources)}:{nameof(SNINativeMethodWrapper.SNIServerEnumRead)} returned bad length");
+ SqlClientEventSource.Log.TryTraceEvent(" {2} returned bad length, requested buffer {3}, received {4}",
+ nameof(SqlDataSourceEnumeratorNativeHelper),
+ nameof(GetDataSources),
+ nameof(SNINativeMethodWrapper.SNIServerEnumRead),
+ bufferSize, readLength);
+
+ throw ADP.ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_ParameterValueOutOfRange, readLength), nameof(readLength));
+ }
+ return ParseServerEnumString(strbldr.ToString());
+ }
+
+ private static DataTable ParseServerEnumString(string serverInstances)
+ {
+ DataTable dataTable = PrepareDataTable();
+ string serverName = null;
+ string instanceName = null;
+ string isClustered = null;
+ string version = null;
+ string[] serverinstanceslist = serverInstances.Split(EndOfServerInstanceDelimiter_Native);
+ SqlClientEventSource.Log.TryTraceEvent(" Number of recieved server instances are {2}",
+ nameof(SqlDataSourceEnumeratorNativeHelper), nameof(ParseServerEnumString), serverinstanceslist.Length);
+
+ // Every row comes in the format "serverName\instanceName;Clustered:[Yes|No];Version:.."
+ // Every row is terminated by a null character.
+ // Process one row at a time
+ foreach (string instance in serverinstanceslist)
+ {
+ string value = instance.Trim(EndOfServerInstanceDelimiter_Native); // MDAC 91934
+ if (value.Length == 0)
+ {
+ continue;
+ }
+ foreach (string instance2 in value.Split(InstanceKeysDelimiter))
+ {
+ if (serverName == null)
+ {
+ foreach (string instance3 in instance2.Split(ServerNamesAndInstanceDelimiter))
+ {
+ if (serverName == null)
+ {
+ serverName = instance3;
+ continue;
+ }
+ Debug.Assert(instanceName == null, $"{nameof(instanceName)}({instanceName}) is not null.");
+ instanceName = instance3;
+ }
+ continue;
+ }
+ if (isClustered == null)
+ {
+ Debug.Assert(string.Compare(Clustered, 0, instance2, 0, s_clusteredLength, StringComparison.OrdinalIgnoreCase) == 0,
+ $"{nameof(Clustered)} ({Clustered}) doesn't equal {nameof(instance2)} ({instance2})");
+ isClustered = instance2.Substring(s_clusteredLength);
+ continue;
+ }
+ Debug.Assert(version == null, $"{nameof(version)}({version}) is not null.");
+ Debug.Assert(string.Compare(SqlDataSourceEnumeratorUtil.Version, 0, instance2, 0, s_versionLength, StringComparison.OrdinalIgnoreCase) == 0,
+ $"{nameof(SqlDataSourceEnumeratorUtil.Version)} ({SqlDataSourceEnumeratorUtil.Version}) doesn't equal {nameof(instance2)} ({instance2})");
+ version = instance2.Substring(s_versionLength);
+ }
+
+ string query = "ServerName='" + serverName + "'";
+
+ if (!ADP.IsEmpty(instanceName))
+ { // SQL BU DT 20006584: only append instanceName if present.
+ query += " AND InstanceName='" + instanceName + "'";
+ }
+
+ // SNI returns dupes - do not add them. SQL BU DT 290323
+ if (dataTable.Select(query).Length == 0)
+ {
+ DataRow dataRow = dataTable.NewRow();
+ dataRow[0] = serverName;
+ dataRow[1] = instanceName;
+ dataRow[2] = isClustered;
+ dataRow[3] = version;
+ dataTable.Rows.Add(dataRow);
+ }
+ serverName = null;
+ instanceName = null;
+ isClustered = null;
+ version = null;
+ }
+ return dataTable.SetColumnsReadOnly();
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorUtil.cs
new file mode 100644
index 0000000000..fb6972d8cf
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorUtil.cs
@@ -0,0 +1,54 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Data;
+using System.Globalization;
+
+namespace Microsoft.Data.Sql
+{
+ ///
+ /// const values for SqlDataSourceEnumerator
+ ///
+ internal static class SqlDataSourceEnumeratorUtil
+ {
+ internal const string ServerNameCol = "ServerName";
+ internal const string InstanceNameCol = "InstanceName";
+ internal const string IsClusteredCol = "IsClustered";
+ internal const string VersionNameCol = "Version";
+
+ internal const string Version = "Version:";
+ internal const string Clustered = "Clustered:";
+ internal static readonly int s_versionLength = Version.Length;
+ internal static readonly int s_clusteredLength = Clustered.Length;
+
+ internal static readonly string[] s_endOfServerInstanceDelimiter_Managed = new[] { ";;" };
+ internal const char EndOfServerInstanceDelimiter_Native = '\0';
+ internal const char InstanceKeysDelimiter = ';';
+ internal const char ServerNamesAndInstanceDelimiter = '\\';
+
+ internal static DataTable PrepareDataTable()
+ {
+ DataTable dataTable = new("SqlDataSources");
+ dataTable.Locale = CultureInfo.InvariantCulture;
+ dataTable.Columns.Add(ServerNameCol, typeof(string));
+ dataTable.Columns.Add(InstanceNameCol, typeof(string));
+ dataTable.Columns.Add(IsClusteredCol, typeof(string));
+ dataTable.Columns.Add(VersionNameCol, typeof(string));
+
+ return dataTable;
+ }
+
+ ///
+ /// Sets all columns read-only.
+ ///
+ internal static DataTable SetColumnsReadOnly(this DataTable dataTable)
+ {
+ foreach (DataColumn column in dataTable.Columns)
+ {
+ column.ReadOnly = true;
+ }
+ return dataTable;
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlNotificationRequest.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlNotificationRequest.cs
new file mode 100644
index 0000000000..ccbff8fc0f
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlNotificationRequest.cs
@@ -0,0 +1,80 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Data.Common;
+using Microsoft.Data.SqlClient;
+
+namespace Microsoft.Data.Sql
+{
+ ///
+ public sealed class SqlNotificationRequest
+ {
+ private string _userData;
+ private string _options;
+ private int _timeout;
+
+ ///
+ public SqlNotificationRequest()
+ : this(null, null, SQL.SqlDependencyTimeoutDefault) { }
+
+ ///
+ public SqlNotificationRequest(string userData, string options, int timeout)
+ {
+ UserData = userData;
+ Timeout = timeout;
+ Options = options;
+ }
+
+ ///
+ public string Options
+ {
+ get
+ {
+ return _options;
+ }
+ set
+ {
+ if ((null != value) && (ushort.MaxValue < value.Length))
+ {
+ throw ADP.ArgumentOutOfRange(string.Empty, nameof(Options));
+ }
+ _options = value;
+ }
+ }
+
+ ///
+ public int Timeout
+ {
+ get
+ {
+ return _timeout;
+ }
+ set
+ {
+ if (0 > value)
+ {
+ throw ADP.ArgumentOutOfRange(string.Empty, nameof(Timeout));
+ }
+ _timeout = value;
+ }
+ }
+
+ ///
+ public string UserData
+ {
+ get
+ {
+ return _userData;
+ }
+ set
+ {
+ if ((null != value) && (ushort.MaxValue < value.Length))
+ {
+ throw ADP.ArgumentOutOfRange(string.Empty, nameof(UserData));
+ }
+ _userData = value;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs
new file mode 100644
index 0000000000..a8fdf219d3
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs
@@ -0,0 +1,516 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Concurrent;
+using System.Security;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure.Core;
+using Azure.Identity;
+using Microsoft.Identity.Client;
+using Microsoft.Identity.Client.Extensibility;
+
+namespace Microsoft.Data.SqlClient
+{
+ ///
+ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider
+ {
+ ///
+ /// This is a static cache instance meant to hold instances of "PublicClientApplication" mapping to information available in PublicClientAppKey.
+ /// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode
+ /// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache.
+ ///
+ private static ConcurrentDictionary s_pcaMap
+ = new ConcurrentDictionary();
+ private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
+ private static readonly string s_defaultScopeSuffix = "/.default";
+ private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
+ private readonly SqlClientLogger _logger = new SqlClientLogger();
+ private Func _deviceCodeFlowCallback;
+ private ICustomWebUi _customWebUI = null;
+ private readonly string _applicationClientId = ActiveDirectoryAuthentication.AdoClientId;
+
+ ///
+ public ActiveDirectoryAuthenticationProvider()
+ : this(DefaultDeviceFlowCallback)
+ {
+ }
+
+ ///
+ public ActiveDirectoryAuthenticationProvider(string applicationClientId)
+ : this(DefaultDeviceFlowCallback, applicationClientId)
+ {
+ }
+
+ ///
+ public ActiveDirectoryAuthenticationProvider(Func deviceCodeFlowCallbackMethod, string applicationClientId = null)
+ {
+ if (applicationClientId != null)
+ {
+ _applicationClientId = applicationClientId;
+ }
+ SetDeviceCodeFlowCallback(deviceCodeFlowCallbackMethod);
+ }
+
+ ///
+ public static void ClearUserTokenCache()
+ {
+ if (!s_pcaMap.IsEmpty)
+ {
+ s_pcaMap.Clear();
+ }
+ }
+
+ ///
+ public void SetDeviceCodeFlowCallback(Func deviceCodeFlowCallbackMethod) => _deviceCodeFlowCallback = deviceCodeFlowCallbackMethod;
+
+ ///
+ public void SetAcquireAuthorizationCodeAsyncCallback(Func> acquireAuthorizationCodeAsyncCallback) => _customWebUI = new CustomWebUi(acquireAuthorizationCodeAsyncCallback);
+
+ ///
+ public override bool IsSupported(SqlAuthenticationMethod authentication)
+ {
+ return authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated
+ || authentication == SqlAuthenticationMethod.ActiveDirectoryPassword
+ || authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive
+ || authentication == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal
+ || authentication == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow
+ || authentication == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity
+ || authentication == SqlAuthenticationMethod.ActiveDirectoryMSI
+ || authentication == SqlAuthenticationMethod.ActiveDirectoryDefault;
+ }
+
+ ///
+ public override void BeforeLoad(SqlAuthenticationMethod authentication)
+ {
+ _logger.LogInfo(_type, "BeforeLoad", $"being loaded into SqlAuthProviders for {authentication}.");
+ }
+
+ ///
+ public override void BeforeUnload(SqlAuthenticationMethod authentication)
+ {
+ _logger.LogInfo(_type, "BeforeUnload", $"being unloaded from SqlAuthProviders for {authentication}.");
+ }
+
+#if NETSTANDARD
+ private Func