From 05db2283c7cb48abaa3ec6372dd65ec6a97d992b Mon Sep 17 00:00:00 2001 From: JRahnama Date: Thu, 20 Oct 2022 01:30:06 -0700 Subject: [PATCH] addresing the changes during conflict resolving --- .../add-ons/Directory.Build.props | 2 +- .../src/Microsoft.Data.SqlClient.csproj | 4 +- .../Data/SqlClient/AAsyncCallContext.cs | 69 +- .../Microsoft/Data/SqlClient/SqlCommand.cs | 64 +- .../Microsoft/Data/SqlClient/SqlDataReader.cs | 97 +- .../netfx/ref/Microsoft.Data.SqlClient.cs | 4 + .../netfx/src/Microsoft.Data.SqlClient.csproj | 12 +- .../Microsoft/Data/SqlClient/SqlDataReader.cs | 72 +- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 7 +- .../Data/SqlClient/TdsParserHelperClasses.cs | 105 +- .../Data/SqlClient/TdsParserStateObject.cs | 21 +- .../Data/Data/Common/ActivityCorrelator.cs | 66 + .../Data/Data/Common/AdapterUtil.Unix.cs | 24 + .../Data/Data/Common/AdapterUtil.Windows.cs | 49 + .../Microsoft/Data/Data/Common/AdapterUtil.cs | 1570 +++++++ .../Data/Common/DbConnectionOptions.Common.cs | 770 ++++ .../Data/Data/Common/DbConnectionPoolKey.cs | 58 + .../Data/Common/DbConnectionStringCommon.cs | 1153 +++++ .../Data/Data/Common/MultipartIdentifier.cs | 291 ++ .../Data/Data/Common/NameValuePair.cs | 53 + .../src/Microsoft/Data/Data/DataException.cs | 56 + .../Data/Data/OperationAbortedException.cs | 40 + .../DbConnectionPoolAuthenticationContext.cs | 112 + ...bConnectionPoolAuthenticationContextKey.cs | 112 + .../ProviderBase/DbConnectionPoolGroup.cs | 312 ++ .../DbConnectionPoolGroupProviderInfo.cs | 23 + .../ProviderBase/DbConnectionPoolOptions.cs | 73 + .../DbConnectionPoolProviderInfo.cs | 10 + .../Data/ProviderBase/DbMetaDataFactory.cs | 558 +++ .../Data/Data/ProviderBase/FieldNameLookup.cs | 117 + .../Data/Data/ProviderBase/TimeoutTimer.cs | 185 + .../Sql/SqlDataSourceEnumerator.Windows.cs | 22 + .../Data/Data/Sql/SqlDataSourceEnumerator.cs | 25 + .../SqlDataSourceEnumeratorManagedHelper.cs | 75 + .../SqlDataSourceEnumeratorNativeHelper.cs | 179 + .../Data/Sql/SqlDataSourceEnumeratorUtil.cs | 54 + .../Data/Data/Sql/SqlNotificationRequest.cs | 80 + .../ActiveDirectoryAuthenticationProvider.cs | 516 +++ ...rectoryAuthenticationTimeoutRetryHelper.cs | 139 + .../AlwaysEncryptedEnclaveProviderUtils.cs | 53 + .../Data/Data/SqlClient/ApplicationIntent.cs | 19 + .../Data/Data/SqlClient/AssemblyRef.cs | 22 + .../AzureAttestationBasedEnclaveProvider.cs | 544 +++ .../Data/SqlClient/ColumnEncryptionKeyInfo.cs | 125 + .../SensitivityClassification.cs | 119 + .../Data/SqlClient/EnclaveDelegate.Crypto.cs | 210 + .../SqlClient/EnclaveDelegate.NotSupported.cs | 55 + .../Data/Data/SqlClient/EnclaveDelegate.cs | 235 + .../Data/Data/SqlClient/EnclavePackage.cs | 27 + .../Data/SqlClient/EnclaveProviderBase.cs | 215 + .../Data/SqlClient/EnclaveSessionCache.cs | 75 + .../Data/SqlClient/LocalAppContextSwitches.cs | 99 + .../NoneAttestationEnclaveProvider.cs | 106 + .../Data/SqlClient/OnChangedEventHandler.cs | 10 + .../Data/SqlClient/ParameterPeekAheadValue.cs | 27 + .../Data/Data/SqlClient/PoolBlockingPeriod.cs | 22 + .../SqlClient/Reliability/AppConfigManager.cs | 136 + .../Common/SqlRetryIntervalBaseEnumerator.cs | 110 + .../Reliability/Common/SqlRetryLogic.cs | 107 + .../Reliability/Common/SqlRetryLogicBase.cs | 39 + .../Common/SqlRetryLogicBaseProvider.cs | 29 + .../Common/SqlRetryLogicProvider.cs | 221 + .../Common/SqlRetryingEventArgs.cs | 33 + .../SqlConfigurableRetryFactory.cs | 134 + .../SqlConfigurableRetryLogicLoader.cs | 277 ++ .../SqlConfigurableRetryLogicManager.cs | 110 + .../SqlRetryIntervalEnumerators.cs | 120 + .../Data/SqlClient/RowsCopiedEventArgs.cs | 41 + .../Data/SqlClient/RowsCopiedEventHandler.cs | 10 + .../Data/SqlClient/SQLFallbackDNSCache.cs | 86 + .../SqlClient/Server/ExtendedClrTypeCode.cs | 59 + .../Data/SqlClient/Server/ITypedGetters.cs | 99 + .../Data/SqlClient/Server/ITypedGettersV3.cs | 68 + .../Data/SqlClient/Server/ITypedSetters.cs | 96 + .../Data/SqlClient/Server/ITypedSettersV3.cs | 78 + .../SqlClient/Server/MemoryRecordBuffer.cs | 262 ++ .../Data/SqlClient/Server/MetadataUtilsSmi.cs | 984 +++++ .../Data/SqlClient/Server/SmiEventSink.cs | 130 + .../SqlClient/Server/SmiEventSink_Default.cs | 149 + .../Server/SmiEventSink_Default.netfx.cs | 261 ++ .../Data/SqlClient/Server/SmiGettersStream.cs | 106 + .../Data/Data/SqlClient/Server/SmiMetaData.cs | 1436 ++++++ .../SqlClient/Server/SmiMetaDataProperty.cs | 278 ++ .../Data/SqlClient/Server/SmiRecordBuffer.cs | 861 ++++ .../Data/SqlClient/Server/SmiSettersStream.cs | 107 + .../SqlClient/Server/SmiTypedGetterSetter.cs | 574 +++ .../Server/SmiXetterAccessMap.Common.cs | 105 + .../SqlClient/Server/SmiXetterTypeCode.cs | 29 + .../Data/SqlClient/Server/SqlDataRecord.cs | 422 ++ .../SqlClient/Server/SqlDataRecord.netcore.cs | 93 + .../SqlClient/Server/SqlDataRecord.netfx.cs | 135 + .../Data/Data/SqlClient/Server/SqlMetaData.cs | 2131 +++++++++ .../Data/SqlClient/Server/SqlNormalizer.cs | 617 +++ .../Data/SqlClient/Server/SqlRecordBuffer.cs | 625 +++ .../Data/Data/SqlClient/Server/SqlSer.cs | 229 + .../Data/SqlClient/Server/ValueUtilsSmi.cs | 3934 +++++++++++++++++ .../SqlClient/Server/ValueUtilsSmi.netfx.cs | 508 +++ .../SqlClient/SignatureVerificationCache.cs | 153 + .../Data/Data/SqlClient/SortOrder.cs | 20 + .../SqlAeadAes256CbcHmac256Algorithm.cs | 441 ++ .../SqlAeadAes256CbcHmac256EncryptionKey.cs | 124 + .../SqlAeadAes256CbcHmac256Factory.cs | 81 + .../SqlClient/SqlAuthenticationParameters.cs | 162 + .../SqlClient/SqlAuthenticationProvider.cs | 37 + .../Data/SqlClient/SqlAuthenticationToken.cs | 65 + .../Data/Data/SqlClient/SqlBuffer.cs | 1392 ++++++ .../SqlClient/SqlBulkCopyColumnMapping.cs | 134 + .../SqlBulkCopyColumnMappingCollection.cs | 148 + .../SqlClient/SqlBulkCopyColumnOrderHint.cs | 66 + .../SqlBulkCopyColumnOrderHintCollection.cs | 133 + .../Data/Data/SqlClient/SqlBulkCopyOptions.cs | 41 + .../Data/Data/SqlClient/SqlCachedBuffer.cs | 145 + .../SqlClient/SqlClientEncryptionAlgorithm.cs | 28 + .../SqlClientEncryptionAlgorithmFactory.cs | 23 + ...SqlClientEncryptionAlgorithmFactoryList.cs | 84 + .../Data/SqlClient/SqlClientEncryptionType.cs | 16 + .../Data/SqlClient/SqlClientEventSource.cs | 1167 +++++ .../Data/Data/SqlClient/SqlClientLogger.cs | 49 + .../SqlClientMetaDataCollectionNames.cs | 56 + .../Data/SqlClient/SqlClientSymmetricKey.cs | 65 + .../Data/Data/SqlClient/SqlCollation.cs | 230 + .../SqlColumnEncryptionKeyStoreProvider.cs | 34 + .../Data/Data/SqlClient/SqlCommandBuilder.cs | 371 ++ .../Data/Data/SqlClient/SqlCommandSet.cs | 330 ++ .../SqlClient/SqlConnectionEncryptOption.cs | 113 + .../SqlConnectionPoolGroupProviderInfo.cs | 137 + .../Data/SqlClient/SqlConnectionPoolKey.cs | 139 + .../SqlConnectionPoolProviderInfo.cs | 25 + .../Data/SqlClient/SqlConnectionString.cs | 1238 ++++++ .../SqlClient/SqlConnectionStringBuilder.cs | 1925 ++++++++ .../SqlConnectionTimeoutErrorInternal.cs | 239 + .../Data/Data/SqlClient/SqlCredential.cs | 55 + .../Data/Data/SqlClient/SqlDataAdapter.cs | 308 ++ .../Data/Data/SqlClient/SqlDependency.cs | 1352 ++++++ .../Data/SqlClient/SqlDependencyListener.cs | 1746 ++++++++ .../SqlClient/SqlDependencyUtils.AppDomain.cs | 20 + .../SqlDependencyUtils.AssemblyLoadContext.cs | 26 + .../Data/Data/SqlClient/SqlDependencyUtils.cs | 657 +++ .../SqlEnclaveAttestationParameters.Crypto.cs | 52 + ...claveAttestationParameters.NotSupported.cs | 19 + .../Data/Data/SqlClient/SqlEnclaveSession.cs | 73 + .../Microsoft/Data/Data/SqlClient/SqlEnums.cs | 1042 +++++ .../Data/Data/SqlClient/SqlEnvChange.cs | 72 + .../Microsoft/Data/Data/SqlClient/SqlError.cs | 91 + .../Data/Data/SqlClient/SqlErrorCollection.cs | 50 + .../Data/Data/SqlClient/SqlException.cs | 212 + .../Data/SqlClient/SqlInfoMessageEvent.cs | 34 + .../SqlClient/SqlInfoMessageEventHandler.cs | 9 + .../Data/SqlClient/SqlInternalConnection.cs | 255 +- .../Data/SqlClient/SqlInternalTransaction.cs | 501 +++ .../Data/Data/SqlClient/SqlMetadataFactory.cs | 288 ++ .../SqlClient/SqlNotificationEventArgs.cs | 35 + .../Data/SqlClient/SqlNotificationInfo.cs | 50 + .../Data/SqlClient/SqlNotificationSource.cs | 34 + .../Data/SqlClient/SqlNotificationType.cs | 18 + .../Data/Data/SqlClient/SqlObjectPool.cs | 67 + .../Data/Data/SqlClient/SqlParameter.cs | 2387 ++++++++++ .../Data/SqlClient/SqlParameterCollection.cs | 399 ++ .../Data/SqlClient/SqlQueryMetadataCache.cs | 320 ++ .../Data/SqlClient/SqlReferenceCollection.cs | 105 + .../Data/Data/SqlClient/SqlRowUpdatedEvent.cs | 28 + .../SqlClient/SqlRowUpdatedEventHandler.cs | 9 + .../Data/SqlClient/SqlRowUpdatingEvent.cs | 33 + .../SqlClient/SqlRowUpdatingEventHandler.cs | 9 + .../Data/Data/SqlClient/SqlSecurityUtility.cs | 461 ++ .../Data/SqlClient/SqlSequentialStream.cs | 338 ++ .../Data/SqlClient/SqlSequentialTextReader.cs | 521 +++ .../Data/Data/SqlClient/SqlStatistics.cs | 352 ++ .../Data/Data/SqlClient/SqlStream.cs | 647 +++ .../Data/SqlClient/SqlSymmetricKeyCache.cs | 103 + .../Data/SqlClient/SqlTransaction.Common.cs | 145 + .../Data/Data/SqlClient/SqlUdtInfo.cs | 66 + .../Microsoft/Data/Data/SqlClient/SqlUtil.cs | 32 + .../Microsoft/Data/Data/SqlClient/TdsEnums.cs | 1264 ++++++ .../Data/Data/SqlClient/TdsParameterSetter.cs | 66 + .../SqlClient/TdsParserSafeHandles.Windows.cs | 329 ++ .../Data/SqlClient/TdsParserSessionPool.cs | 241 + .../Data/SqlClient/TdsParserStaticMethods.cs | 305 ++ .../Data/SqlClient/TdsRecordBufferSetter.cs | 306 ++ .../Data/Data/SqlClient/TdsValueSetter.cs | 713 +++ .../VirtualSecureModeEnclaveProvider.cs | 500 +++ .../VirtualSecureModeEnclaveProviderBase.cs | 364 ++ .../Data/Data/SqlTypes/SQLResource.cs | 77 + .../Data/Data/SqlTypes/SqlTypeWorkarounds.cs | 94 + .../SqlClient/SqlConnectionStringBuilder.cs | 5 +- .../Data/SqlClient/SqlInternalConnection.cs | 809 ++++ .../SqlClient/TdsParserSafeHandles.Windows.cs | 329 ++ .../Data/SqlClient/TdsParserSessionPool.cs | 241 + 188 files changed, 52850 insertions(+), 196 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/ActivityCorrelator.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Unix.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Windows.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionOptions.Common.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionPoolKey.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionStringCommon.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/MultipartIdentifier.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/NameValuePair.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/DataException.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/OperationAbortedException.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContext.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContextKey.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroup.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroupProviderInfo.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolOptions.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolProviderInfo.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbMetaDataFactory.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/FieldNameLookup.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/TimeoutTimer.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.Windows.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorManagedHelper.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorNativeHelper.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorUtil.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlNotificationRequest.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ActiveDirectoryAuthenticationTimeoutRetryHelper.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/AlwaysEncryptedEnclaveProviderUtils.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ApplicationIntent.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/AssemblyRef.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/AzureAttestationBasedEnclaveProvider.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ColumnEncryptionKeyInfo.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/DataClassification/SensitivityClassification.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/EnclaveDelegate.Crypto.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/EnclaveDelegate.NotSupported.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/EnclaveDelegate.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/EnclavePackage.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/EnclaveProviderBase.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/EnclaveSessionCache.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/LocalAppContextSwitches.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/NoneAttestationEnclaveProvider.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/OnChangedEventHandler.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ParameterPeekAheadValue.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/PoolBlockingPeriod.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Reliability/AppConfigManager.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Reliability/Common/SqlRetryIntervalBaseEnumerator.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Reliability/Common/SqlRetryLogic.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Reliability/Common/SqlRetryLogicBase.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Reliability/Common/SqlRetryLogicBaseProvider.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Reliability/Common/SqlRetryLogicProvider.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Reliability/Common/SqlRetryingEventArgs.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Reliability/SqlConfigurableRetryFactory.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Reliability/SqlConfigurableRetryLogicLoader.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Reliability/SqlConfigurableRetryLogicManager.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Reliability/SqlRetryIntervalEnumerators.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/RowsCopiedEventArgs.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/RowsCopiedEventHandler.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SQLFallbackDNSCache.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/ExtendedClrTypeCode.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/ITypedGetters.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/ITypedGettersV3.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/ITypedSetters.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/ITypedSettersV3.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/MemoryRecordBuffer.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/MetadataUtilsSmi.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SmiEventSink.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SmiEventSink_Default.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SmiEventSink_Default.netfx.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SmiGettersStream.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SmiMetaData.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SmiMetaDataProperty.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SmiRecordBuffer.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SmiSettersStream.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SmiTypedGetterSetter.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SmiXetterAccessMap.Common.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SmiXetterTypeCode.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SqlDataRecord.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SqlDataRecord.netcore.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SqlDataRecord.netfx.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SqlMetaData.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SqlNormalizer.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SqlRecordBuffer.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/SqlSer.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/ValueUtilsSmi.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/Server/ValueUtilsSmi.netfx.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SignatureVerificationCache.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SortOrder.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlAeadAes256CbcHmac256Algorithm.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlAeadAes256CbcHmac256EncryptionKey.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlAeadAes256CbcHmac256Factory.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlAuthenticationParameters.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlAuthenticationProvider.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlAuthenticationToken.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlBuffer.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlBulkCopyColumnMapping.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlBulkCopyColumnMappingCollection.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlBulkCopyColumnOrderHint.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlBulkCopyColumnOrderHintCollection.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlBulkCopyOptions.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlCachedBuffer.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlClientEncryptionAlgorithm.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlClientEncryptionAlgorithmFactory.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlClientEncryptionAlgorithmFactoryList.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlClientEncryptionType.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlClientEventSource.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlClientLogger.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlClientMetaDataCollectionNames.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlClientSymmetricKey.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlCollation.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlColumnEncryptionKeyStoreProvider.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlCommandBuilder.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlCommandSet.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlConnectionEncryptOption.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlConnectionPoolGroupProviderInfo.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlConnectionPoolKey.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlConnectionPoolProviderInfo.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlConnectionString.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlConnectionStringBuilder.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlConnectionTimeoutErrorInternal.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlCredential.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlDataAdapter.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlDependency.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlDependencyListener.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlDependencyUtils.AppDomain.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlDependencyUtils.AssemblyLoadContext.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlDependencyUtils.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlEnclaveAttestationParameters.Crypto.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlEnclaveAttestationParameters.NotSupported.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlEnclaveSession.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlEnums.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlEnvChange.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlError.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlErrorCollection.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlException.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlInfoMessageEvent.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlInfoMessageEventHandler.cs rename src/Microsoft.Data.SqlClient/{netcore/src/Microsoft => src/Microsoft/Data}/Data/SqlClient/SqlInternalConnection.cs (76%) create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlInternalTransaction.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlMetadataFactory.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlNotificationEventArgs.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlNotificationInfo.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlNotificationSource.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlNotificationType.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlObjectPool.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlParameter.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlParameterCollection.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlQueryMetadataCache.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlReferenceCollection.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlRowUpdatedEvent.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlRowUpdatedEventHandler.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlRowUpdatingEvent.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlRowUpdatingEventHandler.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlSecurityUtility.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlSequentialStream.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlSequentialTextReader.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlStatistics.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlStream.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlSymmetricKeyCache.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlTransaction.Common.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlUdtInfo.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/SqlUtil.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/TdsEnums.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/TdsParameterSetter.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/TdsParserSafeHandles.Windows.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/TdsParserSessionPool.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/TdsParserStaticMethods.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/TdsRecordBufferSetter.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/TdsValueSetter.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/VirtualSecureModeEnclaveProvider.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/VirtualSecureModeEnclaveProviderBase.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlTypes/SQLResource.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlTypes/SqlTypeWorkarounds.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSessionPool.cs diff --git a/src/Microsoft.Data.SqlClient/add-ons/Directory.Build.props b/src/Microsoft.Data.SqlClient/add-ons/Directory.Build.props index 12a111dec0..762c5f9ed8 100644 --- a/src/Microsoft.Data.SqlClient/add-ons/Directory.Build.props +++ b/src/Microsoft.Data.SqlClient/add-ons/Directory.Build.props @@ -18,7 +18,7 @@ net462 netstandard2.0 - net6.0 + net6.0 diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 83f3e3a53d..af1b58d64a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -394,6 +394,9 @@ Microsoft\Data\SqlClient\SqlMetadataFactory.cs + + Microsoft\Data\SqlClient\SqlInternalConnection.cs + Microsoft\Data\SqlClient\SqlNotificationEventArgs.cs @@ -628,7 +631,6 @@ - diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs index 56e369593a..76710ff980 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs @@ -17,38 +17,68 @@ namespace Microsoft.Data.SqlClient // CONSIDER creating your own Set method that calls the base Set rather than providing a parameterized ctor, it is friendlier to caching // DO NOT use this class' state after Dispose has been called. It will not throw ObjectDisposedException but it will be a cleared object - internal abstract class AAsyncCallContext : IDisposable + internal abstract class AAsyncCallContext : AAsyncBaseCallContext where TOwner : class + where TDisposable : IDisposable { - protected TOwner _owner; - protected TaskCompletionSource _source; - protected IDisposable _disposable; + protected TDisposable _disposable; protected AAsyncCallContext() { } - protected AAsyncCallContext(TOwner owner, TaskCompletionSource source, IDisposable disposable = null) + protected AAsyncCallContext(TOwner owner, TaskCompletionSource source, TDisposable disposable = default) { Set(owner, source, disposable); } - protected void Set(TOwner owner, TaskCompletionSource source, IDisposable disposable = null) + protected void Set(TOwner owner, TaskCompletionSource source, TDisposable disposable = default) + { + base.Set(owner, source); + _disposable = disposable; + } + + protected override void DisposeCore() + { + TDisposable copyDisposable = _disposable; + _disposable = default; + copyDisposable?.Dispose(); + } + } + + internal abstract class AAsyncBaseCallContext + { + protected TOwner _owner; + protected TaskCompletionSource _source; + protected bool _isDisposed; + + protected AAsyncBaseCallContext() + { + } + + protected void Set(TOwner owner, TaskCompletionSource source) { _owner = owner ?? throw new ArgumentNullException(nameof(owner)); _source = source ?? throw new ArgumentNullException(nameof(source)); - _disposable = disposable; + _isDisposed = false; } protected void ClearCore() { _source = null; _owner = default; - IDisposable copyDisposable = _disposable; - _disposable = null; - copyDisposable?.Dispose(); + try + { + DisposeCore(); + } + finally + { + _isDisposed = true; + } } + protected abstract void DisposeCore(); + /// /// override this method to cleanup instance data before ClearCore is called which will blank the base data /// @@ -65,16 +95,19 @@ protected virtual void AfterCleared(TOwner owner) public void Dispose() { - TOwner owner = _owner; - try - { - Clear(); - } - finally + if (!_isDisposed) { - ClearCore(); + TOwner owner = _owner; + try + { + Clear(); + } + finally + { + ClearCore(); + } + AfterCleared(owner); } - AfterCleared(owner); } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index 6c55dcb541..4cd75ecb4d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -46,7 +46,7 @@ public sealed partial class SqlCommand : DbCommand, ICloneable private static readonly Func s_beginExecuteXmlReaderInternal = BeginExecuteXmlReaderInternalCallback; private static readonly Func s_beginExecuteNonQueryInternal = BeginExecuteNonQueryInternalCallback; - internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext + internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext { public Guid OperationID; public CommandBehavior CommandBehavior; @@ -54,7 +54,7 @@ internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext _owner; public TaskCompletionSource TaskCompletionSource => _source; - public void Set(SqlCommand command, TaskCompletionSource source, IDisposable disposable, CommandBehavior behavior, Guid operationID) + public void Set(SqlCommand command, TaskCompletionSource source, CancellationTokenRegistration disposable, CommandBehavior behavior, Guid operationID) { base.Set(command, source, disposable); CommandBehavior = behavior; @@ -73,6 +73,31 @@ protected override void AfterCleared(SqlCommand owner) } } + internal sealed class ExecuteNonQueryAsyncCallContext : AAsyncCallContext + { + public Guid OperationID; + + public SqlCommand Command => _owner; + + public TaskCompletionSource TaskCompletionSource => _source; + + public void Set(SqlCommand command, TaskCompletionSource source, CancellationTokenRegistration disposable, Guid operationID) + { + base.Set(command, source, disposable); + OperationID = operationID; + } + + protected override void Clear() + { + OperationID = default; + } + + protected override void AfterCleared(SqlCommand owner) + { + + } + } + private CommandType _commandType; private int? _commandTimeout; private UpdateRowSource _updatedRowSource = UpdateRowSource.Both; @@ -2540,23 +2565,37 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok } Task returnedTask = source.Task; + returnedTask = RegisterForConnectionCloseNotification(returnedTask); + + ExecuteNonQueryAsyncCallContext context = new ExecuteNonQueryAsyncCallContext(); + context.Set(this, source, registration, operationId); try { - returnedTask = RegisterForConnectionCloseNotification(returnedTask); + Task.Factory.FromAsync( + static (AsyncCallback callback, object stateObject) => ((ExecuteNonQueryAsyncCallContext)stateObject).Command.BeginExecuteNonQueryAsync(callback, stateObject), + static (IAsyncResult result) => ((ExecuteNonQueryAsyncCallContext)result.AsyncState).Command.EndExecuteNonQueryAsync(result), + state: context + ).ContinueWith( + static (Task task, object state) => + { + ExecuteNonQueryAsyncCallContext context = (ExecuteNonQueryAsyncCallContext)state; + + Guid operationId = context.OperationID; + SqlCommand command = context.Command; + TaskCompletionSource source = context.TaskCompletionSource; + + context.Dispose(); + context = null; - Task.Factory.FromAsync(BeginExecuteNonQueryAsync, EndExecuteNonQueryAsync, null) - .ContinueWith((Task task) => - { - registration.Dispose(); if (task.IsFaulted) { Exception e = task.Exception.InnerException; - s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e); + s_diagnosticListener.WriteCommandError(operationId, command, command._transaction, e); source.SetException(e); } else { - s_diagnosticListener.WriteCommandAfter(operationId, this, _transaction); + s_diagnosticListener.WriteCommandAfter(operationId, command, command._transaction); if (task.IsCanceled) { source.SetCanceled(); @@ -2567,13 +2606,15 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok } } }, - TaskScheduler.Default + state: context, + scheduler: TaskScheduler.Default ); } catch (Exception e) { s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e); source.SetException(e); + context.Dispose(); } return returnedTask; @@ -2648,11 +2689,11 @@ private Task InternalExecuteReaderAsync(CommandBehavior behavior, } Task returnedTask = source.Task; + ExecuteReaderAsyncCallContext context = null; try { returnedTask = RegisterForConnectionCloseNotification(returnedTask); - ExecuteReaderAsyncCallContext context = null; if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection) { context = Interlocked.Exchange(ref sqlInternalConnection.CachedCommandExecuteReaderAsyncContext, null); @@ -2680,6 +2721,7 @@ private Task InternalExecuteReaderAsync(CommandBehavior behavior, } source.SetException(e); + context.Dispose(); } return returnedTask; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index cc028e44c7..af5e0f1c90 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -4408,7 +4408,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken) return source.Task; } - IDisposable registration = null; + CancellationTokenRegistration registration = default; if (cancellationToken.CanBeCanceled) { if (cancellationToken.IsCancellationRequested) @@ -4708,7 +4708,7 @@ out bytesRead Debug.Assert(context.Source != null, "context._source should not be null when continuing"); // setup for cleanup/completing retryTask.ContinueWith( - continuationAction: SqlDataReaderAsyncCallContext.s_completeCallback, + continuationAction: SqlDataReaderBaseAsyncCallContext.s_completeCallback, state: context, TaskScheduler.Default ); @@ -4735,6 +4735,13 @@ public override Task ReadAsync(CancellationToken cancellationToken) return Task.FromException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed())); } + // Register first to catch any already expired tokens to be able to trigger cancellation event. + CancellationTokenRegistration registration = default; + if (cancellationToken.CanBeCanceled) + { + registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); + } + // If user's token is canceled, return a canceled task if (cancellationToken.IsCancellationRequested) { @@ -4833,12 +4840,6 @@ public override Task ReadAsync(CancellationToken cancellationToken) return source.Task; } - IDisposable registration = null; - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); - } - ReadAsyncCallContext context = null; if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection) { @@ -4849,7 +4850,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) context = new ReadAsyncCallContext(); } - Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ReadAsyncCallContext was not properly disposed"); + Debug.Assert(context.Reader == default && context.Source == null && context.Disposable == default, "cached ReadAsyncCallContext was not properly disposed"); context.Set(this, source, registration); context._hasMoreData = more; @@ -5007,7 +5008,7 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo } // Setup cancellations - IDisposable registration = null; + CancellationTokenRegistration registration = default; if (cancellationToken.CanBeCanceled) { registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); @@ -5023,7 +5024,7 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo context = new IsDBNullAsyncCallContext(); } - Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ISDBNullAsync context not properly disposed"); + Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ISDBNullAsync context not properly disposed"); context.Set(this, source, registration); context._columnIndex = i; @@ -5154,7 +5155,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat } // Setup cancellations - IDisposable registration = null; + CancellationTokenRegistration registration = default; if (cancellationToken.CanBeCanceled) { registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); @@ -5218,49 +5219,63 @@ internal void CompletePendingReadWithFailure(int errorCode, bool resetForcePendi } #endif - - internal abstract class SqlDataReaderAsyncCallContext : AAsyncCallContext + + internal abstract class SqlDataReaderBaseAsyncCallContext : AAsyncBaseCallContext { internal static readonly Action, object> s_completeCallback = CompleteAsyncCallCallback; internal static readonly Func> s_executeCallback = ExecuteAsyncCallCallback; - protected SqlDataReaderAsyncCallContext() + protected SqlDataReaderBaseAsyncCallContext() { } - protected SqlDataReaderAsyncCallContext(SqlDataReader owner, TaskCompletionSource source, IDisposable disposable = null) + protected SqlDataReaderBaseAsyncCallContext(SqlDataReader owner, TaskCompletionSource source) { - Set(owner, source, disposable); + Set(owner, source); } internal abstract Func> Execute { get; } internal SqlDataReader Reader { get => _owner; set => _owner = value; } - public IDisposable Disposable { get => _disposable; set => _disposable = value; } - public TaskCompletionSource Source { get => _source; set => _source = value; } - new public void Set(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable) - { - base.Set(reader, source, disposable); - } - private static Task ExecuteAsyncCallCallback(Task task, object state) { - SqlDataReaderAsyncCallContext context = (SqlDataReaderAsyncCallContext)state; + SqlDataReaderBaseAsyncCallContext context = (SqlDataReaderBaseAsyncCallContext)state; return context.Reader.ContinueAsyncCall(task, context); } private static void CompleteAsyncCallCallback(Task task, object state) { - SqlDataReaderAsyncCallContext context = (SqlDataReaderAsyncCallContext)state; + SqlDataReaderBaseAsyncCallContext context = (SqlDataReaderBaseAsyncCallContext)state; context.Reader.CompleteAsyncCall(task, context); } } - internal sealed class ReadAsyncCallContext : SqlDataReaderAsyncCallContext + internal abstract class SqlDataReaderAsyncCallContext : SqlDataReaderBaseAsyncCallContext + where TDisposable : IDisposable + { + private TDisposable _disposable; + + public TDisposable Disposable { get => _disposable; set => _disposable = value; } + + public void Set(SqlDataReader owner, TaskCompletionSource source, TDisposable disposable) + { + base.Set(owner, source); + _disposable = disposable; + } + + protected override void DisposeCore() + { + TDisposable copy = _disposable; + _disposable = default; + copy.Dispose(); + } + } + + internal sealed class ReadAsyncCallContext : SqlDataReaderAsyncCallContext { internal static readonly Func> s_execute = SqlDataReader.ReadAsyncExecute; @@ -5279,7 +5294,7 @@ protected override void AfterCleared(SqlDataReader owner) } } - internal sealed class IsDBNullAsyncCallContext : SqlDataReaderAsyncCallContext + internal sealed class IsDBNullAsyncCallContext : SqlDataReaderAsyncCallContext { internal static readonly Func> s_execute = SqlDataReader.IsDBNullAsyncExecute; @@ -5295,19 +5310,19 @@ protected override void AfterCleared(SqlDataReader owner) } } - private sealed class HasNextResultAsyncCallContext : SqlDataReaderAsyncCallContext + private sealed class HasNextResultAsyncCallContext : SqlDataReaderAsyncCallContext { private static readonly Func> s_execute = SqlDataReader.NextResultAsyncExecute; - public HasNextResultAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable) - : base(reader, source, disposable) + public HasNextResultAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, CancellationTokenRegistration disposable) { + Set(reader, source, disposable); } internal override Func> Execute => s_execute; } - private sealed class GetBytesAsyncCallContext : SqlDataReaderAsyncCallContext + private sealed class GetBytesAsyncCallContext : SqlDataReaderAsyncCallContext { internal enum OperationMode { @@ -5345,7 +5360,7 @@ protected override void Clear() } } - private sealed class GetFieldValueAsyncCallContext : SqlDataReaderAsyncCallContext + private sealed class GetFieldValueAsyncCallContext : SqlDataReaderAsyncCallContext { private static readonly Func> s_execute = SqlDataReader.GetFieldValueAsyncExecute; @@ -5353,9 +5368,9 @@ private sealed class GetFieldValueAsyncCallContext : SqlDataReaderAsyncCallCo internal GetFieldValueAsyncCallContext() { } - internal GetFieldValueAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable) - : base(reader, source, disposable) + internal GetFieldValueAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, CancellationTokenRegistration disposable) { + Set(reader, source, disposable); } protected override void Clear() @@ -5375,7 +5390,7 @@ protected override void Clear() /// /// /// - private Task InvokeAsyncCall(SqlDataReaderAsyncCallContext context) + private Task InvokeAsyncCall(SqlDataReaderBaseAsyncCallContext context) { TaskCompletionSource source = context.Source; try @@ -5397,7 +5412,7 @@ private Task InvokeAsyncCall(SqlDataReaderAsyncCallContext context) else { task.ContinueWith( - continuationAction: SqlDataReaderAsyncCallContext.s_completeCallback, + continuationAction: SqlDataReaderBaseAsyncCallContext.s_completeCallback, state: context, TaskScheduler.Default ); @@ -5422,7 +5437,7 @@ private Task InvokeAsyncCall(SqlDataReaderAsyncCallContext context) /// /// /// - private Task ExecuteAsyncCall(SqlDataReaderAsyncCallContext context) + private Task ExecuteAsyncCall(AAsyncBaseCallContext context) { // _networkPacketTaskSource could be null if the connection was closed // while an async invocation was outstanding. @@ -5435,7 +5450,7 @@ private Task ExecuteAsyncCall(SqlDataReaderAsyncCallContext context) else { return completionSource.Task.ContinueWith( - continuationFunction: SqlDataReaderAsyncCallContext.s_executeCallback, + continuationFunction: SqlDataReaderBaseAsyncCallContext.s_executeCallback, state: context, TaskScheduler.Default ).Unwrap(); @@ -5451,7 +5466,7 @@ private Task ExecuteAsyncCall(SqlDataReaderAsyncCallContext context) /// /// /// - private Task ContinueAsyncCall(Task task, SqlDataReaderAsyncCallContext context) + private Task ContinueAsyncCall(Task task, SqlDataReaderBaseAsyncCallContext context) { // this function must be an instance function called from the static callback because otherwise a compiler error // is caused by accessing the _cancelAsyncOnCloseToken field of a MarshalByRefObject derived class @@ -5511,7 +5526,7 @@ private Task ContinueAsyncCall(Task task, SqlDataReaderAsyncCallContext /// /// /// - private void CompleteAsyncCall(Task task, SqlDataReaderAsyncCallContext context) + private void CompleteAsyncCall(Task task, SqlDataReaderBaseAsyncCallContext context) { TaskCompletionSource source = context.Source; context.Dispose(); diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs index 5d78fb574e..9cc91a1c38 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs @@ -1043,6 +1043,10 @@ public sealed partial class SqlConnectionStringBuilder : System.Data.Common.DbCo [System.ComponentModel.DisplayNameAttribute("Encrypt")] [System.ComponentModel.RefreshPropertiesAttribute(System.ComponentModel.RefreshProperties.All)] public SqlConnectionEncryptOption Encrypt { get { throw null; } set { } } + /// + [System.ComponentModel.DisplayNameAttribute("Host Name In Certificate")] + [System.ComponentModel.RefreshPropertiesAttribute(System.ComponentModel.RefreshProperties.All)] + public string HostNameInCertificate { get { throw null; } set { } } /// [System.ComponentModel.DisplayNameAttribute("Enlist")] [System.ComponentModel.RefreshPropertiesAttribute(System.ComponentModel.RefreshProperties.All)] diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 12323365bf..bf6b3f3ff2 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -482,6 +482,9 @@ Microsoft\Data\SqlClient\SqlInfoMessageEventHandler.cs + + Microsoft\Data\SqlClient\SqlInternalConnection.cs + Microsoft\Data\SqlClient\SqlInternalTransaction.cs @@ -554,12 +557,18 @@ Microsoft\Data\SqlClient\TdsParameterSetter.cs + + Microsoft\Data\SqlClient\TdsParserSafeHandles.Windows.cs + Microsoft\Data\SqlClient\TdsParserStaticMethods.cs Microsoft\Data\SqlClient\TdsRecordBufferSetter.cs + + Microsoft\Data\SqlClient\TdsParserSessionPool.cs + Microsoft\Data\SqlClient\TdsValueSetter.cs @@ -639,7 +648,6 @@ - @@ -648,8 +656,6 @@ - - diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 09061277e0..c733b7fc8a 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -329,8 +329,8 @@ internal virtual SmiExtendedMetaData[] GetInternalSmiMetaData() if (null != metaData && 0 < metaData.Length) { - metaDataReturn = new SmiExtendedMetaData[metaData.visibleColumns]; - + metaDataReturn = new SmiExtendedMetaData[metaData.VisibleColumnCount]; + int returnIndex = 0; for (int index = 0; index < metaData.Length; index++) { _SqlMetaData colMetaData = metaData[index]; @@ -369,7 +369,7 @@ internal virtual SmiExtendedMetaData[] GetInternalSmiMetaData() length /= ADP.CharSize; } - metaDataReturn[index] = new SmiQueryMetaData( + metaDataReturn[returnIndex] = new SmiQueryMetaData( colMetaData.type, length, colMetaData.precision, @@ -397,6 +397,7 @@ internal virtual SmiExtendedMetaData[] GetInternalSmiMetaData() colMetaData.IsDifferentName, colMetaData.IsHidden ); + returnIndex += 1; } } } @@ -458,7 +459,7 @@ override public int VisibleFieldCount { return 0; } - return (md.visibleColumns); + return md.VisibleColumnCount; } } @@ -1352,31 +1353,6 @@ private bool TryConsumeMetaData() Debug.Assert(!ignored, "Parser read a row token while trying to read metadata"); } - // we hide hidden columns from the user so build an internal map - // that compacts all hidden columns from the array - if (null != _metaData) - { - - if (_snapshot != null && object.ReferenceEquals(_snapshot._metadata, _metaData)) - { - _metaData = (_SqlMetaDataSet)_metaData.Clone(); - } - - _metaData.visibleColumns = 0; - - Debug.Assert(null == _metaData.indexMap, "non-null metaData indexmap"); - int[] indexMap = new int[_metaData.Length]; - for (int i = 0; i < indexMap.Length; ++i) - { - indexMap[i] = _metaData.visibleColumns; - - if (!(_metaData[i].IsHidden)) - { - _metaData.visibleColumns++; - } - } - _metaData.indexMap = indexMap; - } return true; } @@ -1690,15 +1666,15 @@ override public DataTable GetSchemaTable() try { statistics = SqlStatistics.StartTimer(Statistics); - if (null == _metaData || null == _metaData.schemaTable) + if (null == _metaData || null == _metaData._schemaTable) { if (null != this.MetaData) { - _metaData.schemaTable = BuildSchemaTable(); - Debug.Assert(null != _metaData.schemaTable, "No schema information yet!"); + _metaData._schemaTable = BuildSchemaTable(); + Debug.Assert(null != _metaData._schemaTable, "No schema information yet!"); } } - return _metaData?.schemaTable; + return _metaData?._schemaTable; } finally { @@ -2994,11 +2970,11 @@ virtual public int GetSqlValues(object[] values) SetTimeout(_defaultTimeoutMilliseconds); - int copyLen = (values.Length < _metaData.visibleColumns) ? values.Length : _metaData.visibleColumns; + int copyLen = (values.Length < _metaData.VisibleColumnCount) ? values.Length : _metaData.VisibleColumnCount; for (int i = 0; i < copyLen; i++) { - values[_metaData.indexMap[i]] = GetSqlValueInternal(i); + values[_metaData.GetVisibleColumnIndex(i)] = GetSqlValueInternal(i); } return copyLen; } @@ -3398,7 +3374,7 @@ override public int GetValues(object[] values) CheckMetaDataIsReady(); - int copyLen = (values.Length < _metaData.visibleColumns) ? values.Length : _metaData.visibleColumns; + int copyLen = (values.Length < _metaData.VisibleColumnCount) ? values.Length : _metaData.VisibleColumnCount; int maximumColumn = copyLen - 1; SetTimeout(_defaultTimeoutMilliseconds); @@ -3414,12 +3390,19 @@ override public int GetValues(object[] values) for (int i = 0; i < copyLen; i++) { // Get the usable, TypeSystem-compatible value from the iternal buffer - values[_metaData.indexMap[i]] = GetValueFromSqlBufferInternal(_data[i], _metaData[i]); + int fieldIndex = _metaData.GetVisibleColumnIndex(i); + values[i] = GetValueFromSqlBufferInternal(_data[fieldIndex], _metaData[fieldIndex]); // If this is sequential access, then we need to wipe the internal buffer if ((sequentialAccess) && (i < maximumColumn)) { _data[i].Clear(); + if (fieldIndex > i && fieldIndex > 0) + { + // if we jumped an index forward because of a hidden column see if the buffer before the + // current one was populated by the seek forward and clear it if it was + _data[fieldIndex - 1].Clear(); + } } } @@ -4767,7 +4750,7 @@ internal bool TrySetMetaData(_SqlMetaDataSet metaData, bool moreInfo) _tableNames = null; if (_metaData != null) { - _metaData.schemaTable = null; + _metaData._schemaTable = null; _data = SqlBuffer.CreateBufferArray(metaData.Length); } @@ -5326,6 +5309,13 @@ public override Task ReadAsync(CancellationToken cancellationToken) return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed("ReadAsync"))); } + // Register first to catch any already expired tokens to be able to trigger cancellation event. + IDisposable registration = null; + if (cancellationToken.CanBeCanceled) + { + registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); + } + // If user's token is canceled, return a canceled task if (cancellationToken.IsCancellationRequested) { @@ -5425,12 +5415,6 @@ public override Task ReadAsync(CancellationToken cancellationToken) return source.Task; } - IDisposable registration = null; - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); - } - var context = Interlocked.Exchange(ref _cachedReadAsyncContext, null) ?? new ReadAsyncCallContext(); Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ReadAsyncCallContext was not properly disposed"); diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 6daca4d771..1edad799ae 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -523,7 +523,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) // Clean up IsSQLDNSCachingSupported flag from previous status _connHandler.IsSQLDNSCachingSupported = false; - UInt32 sniStatus = SNILoadHandle.SingletonInstance.SNIStatus; + UInt32 sniStatus = SNILoadHandle.SingletonInstance.Status; if (sniStatus != TdsEnums.SNI_SUCCESS) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -5094,7 +5094,6 @@ internal bool TryProcessAltMetaData(int cColumns, TdsParserStateObject stateObj, metaData = null; _SqlMetaDataSet altMetaDataSet = new _SqlMetaDataSet(cColumns, null); - int[] indexMap = new int[cColumns]; if (!stateObj.TryReadUInt16(out altMetaDataSet.id)) { @@ -5191,12 +5190,8 @@ internal bool TryProcessAltMetaData(int cColumns, TdsParserStateObject stateObj, break; } } - indexMap[i] = i; } - altMetaDataSet.indexMap = indexMap; - altMetaDataSet.visibleColumns = cColumns; - metaData = altMetaDataSet; return true; } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserHelperClasses.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserHelperClasses.cs index 21004f4be2..bf113efe3b 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserHelperClasses.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserHelperClasses.cs @@ -511,51 +511,63 @@ public object Clone() } } - sealed internal class _SqlMetaDataSet : ICloneable + sealed internal class _SqlMetaDataSet { internal ushort id; // for altrow-columns only - internal int[] indexMap; - internal int visibleColumns; - internal DataTable schemaTable; + internal DataTable _schemaTable; internal readonly SqlTceCipherInfoTable cekTable; // table of "column encryption keys" used for this metadataset - internal readonly _SqlMetaData[] metaDataArray; + internal readonly _SqlMetaData[] _metaDataArray; + private int _hiddenColumnCount; + private int[] _visibleColumnMap; internal _SqlMetaDataSet(int count, SqlTceCipherInfoTable cipherTable) { + _hiddenColumnCount = -1; cekTable = cipherTable; - metaDataArray = new _SqlMetaData[count]; - for (int i = 0; i < metaDataArray.Length; ++i) + _metaDataArray = new _SqlMetaData[count]; + for (int i = 0; i < _metaDataArray.Length; ++i) { - metaDataArray[i] = new _SqlMetaData(i); + _metaDataArray[i] = new _SqlMetaData(i); } } private _SqlMetaDataSet(_SqlMetaDataSet original) { - this.id = original.id; - // although indexMap is not immutable, in practice it is initialized once and then passed around - this.indexMap = original.indexMap; - this.visibleColumns = original.visibleColumns; - this.schemaTable = original.schemaTable; - if (original.metaDataArray == null) + id = original.id; + _hiddenColumnCount = original._hiddenColumnCount; + _visibleColumnMap = original._visibleColumnMap; + _schemaTable = original._schemaTable; + if (original._metaDataArray == null) { - metaDataArray = null; + _metaDataArray = null; } else { - metaDataArray = new _SqlMetaData[original.metaDataArray.Length]; - for (int idx = 0; idx < metaDataArray.Length; idx++) + _metaDataArray = new _SqlMetaData[original._metaDataArray.Length]; + for (int idx = 0; idx < _metaDataArray.Length; idx++) { - metaDataArray[idx] = (_SqlMetaData)original.metaDataArray[idx].Clone(); + _metaDataArray[idx] = (_SqlMetaData)original._metaDataArray[idx].Clone(); } } } + internal int VisibleColumnCount + { + get + { + if (_hiddenColumnCount == -1) + { + SetupHiddenColumns(); + } + return Length - _hiddenColumnCount; + } + } + internal int Length { get { - return metaDataArray.Length; + return _metaDataArray.Length; } } @@ -563,21 +575,66 @@ internal int Length { get { - return metaDataArray[index]; + return _metaDataArray[index]; } set { Debug.Assert(null == value, "used only by SqlBulkCopy"); - metaDataArray[index] = value; + _metaDataArray[index] = value; } } - public object Clone() + public int GetVisibleColumnIndex(int index) + { + if (_hiddenColumnCount == -1) + { + SetupHiddenColumns(); + } + if (_visibleColumnMap is null) + { + return index; + } + else + { + return _visibleColumnMap[index]; + } + } + + public _SqlMetaDataSet Clone() { return new _SqlMetaDataSet(this); } + + private void SetupHiddenColumns() + { + int hiddenColumnCount = 0; + for (int index = 0; index < Length; index++) + { + if (_metaDataArray[index].IsHidden) + { + hiddenColumnCount += 1; + } + } + + if (hiddenColumnCount > 0) + { + int[] visibleColumnMap = new int[Length - hiddenColumnCount]; + int mapIndex = 0; + for (int metaDataIndex = 0; metaDataIndex < Length; metaDataIndex++) + { + if (!_metaDataArray[metaDataIndex].IsHidden) + { + visibleColumnMap[mapIndex] = metaDataIndex; + mapIndex += 1; + } + } + _visibleColumnMap = visibleColumnMap; + } + _hiddenColumnCount = hiddenColumnCount; + } } + sealed internal class _SqlMetaDataSetCollection : ICloneable { private readonly List<_SqlMetaDataSet> altMetaDataSetArray; @@ -622,10 +679,10 @@ internal _SqlMetaDataSet GetAltMetaData(int id) public object Clone() { _SqlMetaDataSetCollection result = new _SqlMetaDataSetCollection(); - result.metaDataSet = metaDataSet == null ? null : (_SqlMetaDataSet)metaDataSet.Clone(); + result.metaDataSet = metaDataSet == null ? null : metaDataSet.Clone(); foreach (_SqlMetaDataSet set in altMetaDataSetArray) { - result.altMetaDataSetArray.Add((_SqlMetaDataSet)set.Clone()); + result.altMetaDataSetArray.Add(set.Clone()); } return result; } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 6e8afce1ea..8d9057bc02 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -2401,9 +2401,9 @@ private void OnTimeoutAsync(object state) } } - private bool OnTimeoutSync() + private bool OnTimeoutSync(bool asyncClose = false) { - return OnTimeoutCore(TimeoutState.Running, TimeoutState.ExpiredSync); + return OnTimeoutCore(TimeoutState.Running, TimeoutState.ExpiredSync, asyncClose); } /// @@ -2412,8 +2412,9 @@ private bool OnTimeoutSync() /// /// the state that is the expected current state, state will change only if this is correct /// the state that will be changed to if the expected state is correct + /// any close action to be taken by an async task to avoid deadlock. /// boolean value indicating whether the call changed the timeout state - private bool OnTimeoutCore(int expectedState, int targetState) + private bool OnTimeoutCore(int expectedState, int targetState, bool asyncClose = false) { Debug.Assert(targetState == TimeoutState.ExpiredAsync || targetState == TimeoutState.ExpiredSync, "OnTimeoutCore must have an expiry state as the targetState"); @@ -2447,7 +2448,7 @@ private bool OnTimeoutCore(int expectedState, int targetState) { try { - SendAttention(mustTakeWriteLock: true); + SendAttention(mustTakeWriteLock: true, asyncClose); } catch (Exception e) { @@ -2988,7 +2989,7 @@ public void ReadAsyncCallback(IntPtr key, IntPtr packet, UInt32 error) // synchrnously and then call OnTimeoutSync to force an atomic change of state. if (TimeoutHasExpired) { - OnTimeoutSync(); + OnTimeoutSync(asyncClose: true); } // try to change to the stopped state but only do so if currently in the running state @@ -3475,7 +3476,7 @@ private void CancelWritePacket() #pragma warning disable 420 // a reference to a volatile field will not be treated as volatile - private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out UInt32 sniError, bool canAccumulate, bool callerHasConnectionLock) + private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out UInt32 sniError, bool canAccumulate, bool callerHasConnectionLock, bool asyncClose = false) { // Check for a stored exception var delayedException = Interlocked.Exchange(ref _delayedWriteAsyncCallbackException, null); @@ -3566,7 +3567,7 @@ private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out UInt32 sniEr SqlClientEventSource.Log.TryTraceEvent(" write async returned error code {0}", (int)error); AddError(_parser.ProcessSNIError(this)); - ThrowExceptionAndWarning(); + ThrowExceptionAndWarning(false, asyncClose); } AssertValidState(); completion.SetResult(null); @@ -3603,7 +3604,7 @@ private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out UInt32 sniEr { SqlClientEventSource.Log.TryTraceEvent(" write async returned error code {0}", (int)sniError); AddError(_parser.ProcessSNIError(this)); - ThrowExceptionAndWarning(callerHasConnectionLock); + ThrowExceptionAndWarning(callerHasConnectionLock, false); } AssertValidState(); } @@ -3613,7 +3614,7 @@ private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out UInt32 sniEr #pragma warning restore 420 // Sends an attention signal - executing thread will consume attn. - internal void SendAttention(bool mustTakeWriteLock = false) + internal void SendAttention(bool mustTakeWriteLock = false, bool asyncClose = false) { if (!_attentionSent) { @@ -3660,7 +3661,7 @@ internal void SendAttention(bool mustTakeWriteLock = false) UInt32 sniError; _parser._asyncWrite = false; // stop async write - SNIWritePacket(Handle, attnPacket, out sniError, canAccumulate: false, callerHasConnectionLock: false); + SNIWritePacket(Handle, attnPacket, out sniError, canAccumulate: false, callerHasConnectionLock: false, asyncClose); SqlClientEventSource.Log.TryTraceEvent(" Send Attention ASync.", "Info"); } finally diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/ActivityCorrelator.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/ActivityCorrelator.cs new file mode 100644 index 0000000000..ef6b9b6cd2 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/ActivityCorrelator.cs @@ -0,0 +1,66 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Globalization; + +namespace Microsoft.Data.Common +{ + /// + /// This class defines the data structure for ActivityId used for correlated tracing between client (bid trace event) and server (XEvent). + /// It also includes all the APIs used to access the ActivityId. Note: ActivityId is thread based which is stored in TLS. + /// + + internal static class ActivityCorrelator + { + internal sealed class ActivityId + { + internal readonly Guid Id; + internal readonly uint Sequence; + + internal ActivityId(uint sequence) + { + this.Id = Guid.NewGuid(); + this.Sequence = sequence; + } + + public override string ToString() + { + return string.Format(CultureInfo.InvariantCulture, "{0}:{1}", this.Id, this.Sequence); + } + } + + // Declare the ActivityId which will be stored in TLS. The Id is unique for each thread. + // The Sequence number will be incremented when each event happens. + // Correlation along threads is consistent with the current XEvent mechanism at server. + [ThreadStatic] + private static ActivityId t_tlsActivity; + + /// + /// Get the current ActivityId + /// + internal static ActivityId Current + { + get + { + if (t_tlsActivity == null) + { + t_tlsActivity = new ActivityId(1); + } + return t_tlsActivity; + } + } + + /// + /// Increment the sequence number and generate the new ActivityId + /// + /// ActivityId + internal static ActivityId Next() + { + t_tlsActivity = new ActivityId( (t_tlsActivity?.Sequence ?? 0) + 1); + + return t_tlsActivity; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Unix.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Unix.cs new file mode 100644 index 0000000000..8b84feecfc --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Unix.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Data.Common +{ + /// + /// The class ADP defines the exceptions that are specific to the Adapters. + /// The class contains functions that take the proper informational variables and then construct + /// the appropriate exception with an error string obtained from the resource framework. + /// The exception is then returned to the caller, so that the caller may then throw from its + /// location so that the catcher of the exception will have the appropriate call stack. + /// This class is used so that there will be compile time checking of error messages. + /// The resource Framework.txt will ensure proper string text based on the appropriate locale. + /// + internal static partial class ADP + { + internal static object LocalMachineRegistryValue(string subkey, string queryvalue) + { + // No registry in non-Windows environments + return null; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Windows.cs new file mode 100644 index 0000000000..c9d0f8d91a --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.Windows.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.InteropServices; +using System.Runtime.Versioning; +using System.Security; +using System.Security.Permissions; +using Microsoft.Win32; + +namespace Microsoft.Data.Common +{ + /// + /// The class ADP defines the exceptions that are specific to the Adapters. + /// The class contains functions that take the proper informational variables and then construct + /// the appropriate exception with an error string obtained from the resource framework. + /// The exception is then returned to the caller, so that the caller may then throw from its + /// location so that the catcher of the exception will have the appropriate call stack. + /// This class is used so that there will be compile time checking of error messages. + /// The resource Framework.txt will ensure proper string text based on the appropriate locale. + /// + internal static partial class ADP + { + [ResourceExposure(ResourceScope.Machine)] + [ResourceConsumption(ResourceScope.Machine)] + internal static object LocalMachineRegistryValue(string subkey, string queryvalue) + { // MDAC 77697 + (new RegistryPermission(RegistryPermissionAccess.Read, "HKEY_LOCAL_MACHINE\\" + subkey)).Assert(); // MDAC 62028 + try + { + using (RegistryKey key = Registry.LocalMachine.OpenSubKey(subkey, false)) + { + return key?.GetValue(queryvalue); + } + } + catch (SecurityException e) + { + // Even though we assert permission - it's possible there are + // ACL's on registry that cause SecurityException to be thrown. + ADP.TraceExceptionWithoutRethrow(e); + return null; + } + finally + { + RegistryPermission.RevertAssert(); + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.cs new file mode 100644 index 0000000000..1866aa7fb3 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/AdapterUtil.cs @@ -0,0 +1,1570 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Data; +using System.Data.Common; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; + +using System.Security; +using System.Security.Permissions; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using System.Transactions; +using Microsoft.Data.SqlClient; +using IsolationLevel = System.Data.IsolationLevel; +using Microsoft.Identity.Client; +using Microsoft.SqlServer.Server; + +#if NETFRAMEWORK +using Microsoft.Win32; +using System.Reflection; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Runtime.Versioning; +#endif + +namespace Microsoft.Data.Common +{ + /// + /// The class ADP defines the exceptions that are specific to the Adapters. + /// The class contains functions that take the proper informational variables and then construct + /// the appropriate exception with an error string obtained from the resource framework. + /// The exception is then returned to the caller, so that the caller may then throw from its + /// location so that the catcher of the exception will have the appropriate call stack. + /// This class is used so that there will be compile time checking of error messages. + /// The resource Framework.txt will ensure proper string text based on the appropriate locale. + /// + internal static partial class ADP + { + // NOTE: Initializing a Task in SQL CLR requires the "UNSAFE" permission set (http://msdn.microsoft.com/en-us/library/ms172338.aspx) + // Therefore we are lazily initializing these Tasks to avoid forcing customers to use the "UNSAFE" set when they are actually using no Async features + private static Task s_trueTask; + internal static Task TrueTask => s_trueTask ??= Task.FromResult(true); + + private static Task s_falseTask; + internal static Task FalseTask => s_falseTask ??= Task.FromResult(false); + + internal const CompareOptions DefaultCompareOptions = CompareOptions.IgnoreKanaType | CompareOptions.IgnoreWidth | CompareOptions.IgnoreCase; + internal const int DefaultConnectionTimeout = DbConnectionStringDefaults.ConnectTimeout; + /// + /// Infinite connection timeout identifier in seconds + /// + internal const int InfiniteConnectionTimeout = 0; + /// + /// Max duration for buffer in seconds + /// + internal const int MaxBufferAccessTokenExpiry = 600; + + #region UDT +#if NETFRAMEWORK + private static readonly MethodInfo s_method = typeof(InvalidUdtException).GetMethod("Create", BindingFlags.NonPublic | BindingFlags.Static); +#endif + /// + /// Calls "InvalidUdtException.Create" method when an invalid UDT occurs. + /// + internal static InvalidUdtException CreateInvalidUdtException(Type udtType, string resourceReasonName) + { + InvalidUdtException e = +#if NETFRAMEWORK + (InvalidUdtException)s_method.Invoke(null, new object[] { udtType, resourceReasonName }); + ADP.TraceExceptionAsReturnValue(e); +#else + InvalidUdtException.Create(udtType, resourceReasonName); +#endif + return e; + } + #endregion + + static private void TraceException(string trace, Exception e) + { + Debug.Assert(null != e, "TraceException: null Exception"); + if (e is not null) + { + SqlClientEventSource.Log.TryTraceEvent(trace, e); + } + } + + internal static void TraceExceptionAsReturnValue(Exception e) + { + TraceException(" '{0}'", e); + } + + internal static void TraceExceptionWithoutRethrow(Exception e) + { + Debug.Assert(IsCatchableExceptionType(e), "Invalid exception type, should have been re-thrown!"); + TraceException(" '{0}'", e); + } + + internal static bool IsEmptyArray(string[] array) => (array is null) || (array.Length == 0); + + internal static bool IsNull(object value) + { + if ((value is null) || (DBNull.Value == value)) + { + return true; + } + INullable nullable = (value as INullable); + return ((nullable is not null) && nullable.IsNull); + } + + internal static Exception ExceptionWithStackTrace(Exception e) + { + try + { + throw e; + } + catch (Exception caught) + { + return caught; + } + } + +#region COM+ exceptions + internal static ArgumentException Argument(string error) + { + ArgumentException e = new(error); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static ArgumentException Argument(string error, Exception inner) + { + ArgumentException e = new(error, inner); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static ArgumentException Argument(string error, string parameter) + { + ArgumentException e = new(error, parameter); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static ArgumentNullException ArgumentNull(string parameter) + { + ArgumentNullException e = new(parameter); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static ArgumentNullException ArgumentNull(string parameter, string error) + { + ArgumentNullException e = new(parameter, error); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static ArgumentOutOfRangeException ArgumentOutOfRange(string parameterName) + { + ArgumentOutOfRangeException e = new(parameterName); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static ArgumentOutOfRangeException ArgumentOutOfRange(string message, string parameterName) + { + ArgumentOutOfRangeException e = new(parameterName, message); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static IndexOutOfRangeException IndexOutOfRange(string error) + { + IndexOutOfRangeException e = new(error); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static IndexOutOfRangeException IndexOutOfRange(int value) + { + IndexOutOfRangeException e = new(value.ToString(CultureInfo.InvariantCulture)); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static IndexOutOfRangeException IndexOutOfRange() + { + IndexOutOfRangeException e = new(); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static InvalidOperationException InvalidOperation(string error, Exception inner) + { + InvalidOperationException e = new(error, inner); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static OverflowException Overflow(string error) => Overflow(error, null); + + internal static OverflowException Overflow(string error, Exception inner) + { + OverflowException e = new(error, inner); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static TimeoutException TimeoutException(string error, Exception inner = null) + { + TimeoutException e = new(error, inner); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static TypeLoadException TypeLoad(string error) + { + TypeLoadException e = new(error); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static InvalidCastException InvalidCast() + { + InvalidCastException e = new(); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static InvalidCastException InvalidCast(string error) + { + return InvalidCast(error, null); + } + + internal static InvalidCastException InvalidCast(string error, Exception inner) + { + InvalidCastException e = new(error, inner); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static InvalidOperationException InvalidOperation(string error) + { + InvalidOperationException e = new(error); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static IOException IO(string error) + { + IOException e = new(error); + TraceExceptionAsReturnValue(e); + return e; + } + internal static IOException IO(string error, Exception inner) + { + IOException e = new(error, inner); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static NotSupportedException NotSupported() + { + NotSupportedException e = new(); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static NotSupportedException NotSupported(string error) + { + NotSupportedException e = new(error); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static InvalidOperationException DataAdapter(string error) => InvalidOperation(error); + + private static InvalidOperationException Provider(string error) => InvalidOperation(error); + + internal static ArgumentException InvalidMultipartName(string property, string value) + { + ArgumentException e = new(StringsHelper.GetString(Strings.ADP_InvalidMultipartName, StringsHelper.GetString(property), value)); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static ArgumentException InvalidMultipartNameIncorrectUsageOfQuotes(string property, string value) + { + ArgumentException e = new(StringsHelper.GetString(Strings.ADP_InvalidMultipartNameQuoteUsage, StringsHelper.GetString(property), value)); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static ArgumentException InvalidMultipartNameToManyParts(string property, string value, int limit) + { + ArgumentException e = new(StringsHelper.GetString(Strings.ADP_InvalidMultipartNameToManyParts, StringsHelper.GetString(property), value, limit)); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static ObjectDisposedException ObjectDisposed(object instance) + { + ObjectDisposedException e = new(instance.GetType().Name); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static InvalidOperationException MethodCalledTwice(string method) + { + InvalidOperationException e = new(StringsHelper.GetString(Strings.ADP_CalledTwice, method)); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static ArgumentOutOfRangeException ArgumentOutOfRange(string message, string parameterName, object value) + { + ArgumentOutOfRangeException e = new(parameterName, value, message); + TraceExceptionAsReturnValue(e); + return e; + } +#endregion + +#region Helper Functions + internal static ArgumentOutOfRangeException NotSupportedEnumerationValue(Type type, string value, string method) + => ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_NotSupportedEnumerationValue, type.Name, value, method), type.Name); + + internal static void CheckArgumentNull(object value, string parameterName) + { + if (value is null) + { + throw ArgumentNull(parameterName); + } + } + + internal static bool IsCatchableExceptionType(Exception e) + { + // only StackOverflowException & ThreadAbortException are sealed classes + // a 'catchable' exception is defined by what it is not. + Debug.Assert(e != null, "Unexpected null exception!"); + Type type = e.GetType(); + + return ((type != typeof(StackOverflowException)) && + (type != typeof(OutOfMemoryException)) && + (type != typeof(ThreadAbortException)) && + (type != typeof(NullReferenceException)) && + (type != typeof(AccessViolationException)) && + !typeof(SecurityException).IsAssignableFrom(type)); + } + + internal static bool IsCatchableOrSecurityExceptionType(Exception e) + { + // a 'catchable' exception is defined by what it is not. + // since IsCatchableExceptionType defined SecurityException as not 'catchable' + // this method will return true for SecurityException has being catchable. + + // the other way to write this method is, but then SecurityException is checked twice + // return ((e is SecurityException) || IsCatchableExceptionType(e)); + + // only StackOverflowException & ThreadAbortException are sealed classes + Debug.Assert(e != null, "Unexpected null exception!"); + Type type = e.GetType(); + + return ((type != typeof(StackOverflowException)) && + (type != typeof(OutOfMemoryException)) && + (type != typeof(ThreadAbortException)) && + (type != typeof(NullReferenceException)) && + (type != typeof(AccessViolationException))); + } + + // Invalid Enumeration + internal static ArgumentOutOfRangeException InvalidEnumerationValue(Type type, int value) + => ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_InvalidEnumerationValue, type.Name, value.ToString(CultureInfo.InvariantCulture)), type.Name); + + internal static ArgumentOutOfRangeException InvalidCommandBehavior(CommandBehavior value) + { + Debug.Assert((0 > (int)value) || ((int)value > 0x3F), "valid CommandType " + value.ToString()); + + return InvalidEnumerationValue(typeof(CommandBehavior), (int)value); + } + + internal static void ValidateCommandBehavior(CommandBehavior value) + { + if (((int)value < 0) || (0x3F < (int)value)) + { + throw InvalidCommandBehavior(value); + } + } + + internal static ArgumentOutOfRangeException InvalidUserDefinedTypeSerializationFormat(Format value) + { +#if DEBUG + switch (value) + { + case Format.Unknown: + case Format.Native: + case Format.UserDefined: + Debug.Assert(false, "valid UserDefinedTypeSerializationFormat " + value.ToString()); + break; + } +#endif + return InvalidEnumerationValue(typeof(Format), (int)value); + } + + internal static ArgumentOutOfRangeException NotSupportedUserDefinedTypeSerializationFormat(Format value, string method) + => NotSupportedEnumerationValue(typeof(Format), value.ToString(), method); + + internal static ArgumentException InvalidArgumentLength(string argumentName, int limit) + => Argument(StringsHelper.GetString(Strings.ADP_InvalidArgumentLength, argumentName, limit)); + + internal static ArgumentException MustBeReadOnly(string argumentName) => Argument(StringsHelper.GetString(Strings.ADP_MustBeReadOnly, argumentName)); + + internal static Exception CreateSqlException(MsalException msalException, SqlConnectionString connectionOptions, SqlInternalConnectionTds sender, string username) + { + // Error[0] + SqlErrorCollection sqlErs = new(); + + sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, + connectionOptions.DataSource, + StringsHelper.GetString(Strings.SQL_MSALFailure, username, connectionOptions.Authentication.ToString("G")), + ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0)); + + // Error[1] + string errorMessage1 = StringsHelper.GetString(Strings.SQL_MSALInnerException, msalException.ErrorCode); + sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, + connectionOptions.DataSource, errorMessage1, + ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0)); + + // Error[2] + if (!string.IsNullOrEmpty(msalException.Message)) + { + sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, + connectionOptions.DataSource, msalException.Message, + ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0)); + } + return SqlException.CreateException(sqlErs, "", sender); + } + +#endregion + +#region CommandBuilder, Command, BulkCopy + /// + /// This allows the caller to determine if it is an error or not for the quotedString to not be quoted + /// + /// The return value is true if the string was quoted and false if it was not + internal static bool RemoveStringQuotes(string quotePrefix, string quoteSuffix, string quotedString, out string unquotedString) + { + int prefixLength = quotePrefix is null ? 0 : quotePrefix.Length; + int suffixLength = quoteSuffix is null ? 0 : quoteSuffix.Length; + + if ((suffixLength + prefixLength) == 0) + { + unquotedString = quotedString; + return true; + } + + if (quotedString is null) + { + unquotedString = quotedString; + return false; + } + + int quotedStringLength = quotedString.Length; + + // is the source string too short to be quoted + if (quotedStringLength < prefixLength + suffixLength) + { + unquotedString = quotedString; + return false; + } + + // is the prefix present? + if (prefixLength > 0) + { + if (!quotedString.StartsWith(quotePrefix, StringComparison.Ordinal)) + { + unquotedString = quotedString; + return false; + } + } + + // is the suffix present? + if (suffixLength > 0) + { + if (!quotedString.EndsWith(quoteSuffix, StringComparison.Ordinal)) + { + unquotedString = quotedString; + return false; + } + unquotedString = quotedString.Substring(prefixLength, quotedStringLength - (prefixLength + suffixLength)) + .Replace(quoteSuffix + quoteSuffix, quoteSuffix); + } + else + { + unquotedString = quotedString.Substring(prefixLength, quotedStringLength - prefixLength); + } + return true; + } + + internal static string BuildQuotedString(string quotePrefix, string quoteSuffix, string unQuotedString) + { + var resultString = new StringBuilder(unQuotedString.Length + quoteSuffix.Length + quoteSuffix.Length); + AppendQuotedString(resultString, quotePrefix, quoteSuffix, unQuotedString); + return resultString.ToString(); + } + + internal static string AppendQuotedString(StringBuilder buffer, string quotePrefix, string quoteSuffix, string unQuotedString) + { + Debug.Assert(buffer is not null, "buffer parameter must be initialized!"); + + if (!string.IsNullOrEmpty(quotePrefix)) + { + buffer.Append(quotePrefix); + } + + // Assuming that the suffix is escaped by doubling it. i.e. foo"bar becomes "foo""bar". + if (!string.IsNullOrEmpty(quoteSuffix)) + { + int start = buffer.Length; + buffer.Append(unQuotedString); + buffer.Replace(quoteSuffix, quoteSuffix + quoteSuffix, start, unQuotedString.Length); + buffer.Append(quoteSuffix); + } + else + { + buffer.Append(unQuotedString); + } + + return buffer.ToString(); + } + + internal static string BuildMultiPartName(string[] strings) + { + StringBuilder bld = new(); + // Assume we want to build a full multi-part name with all parts except trimming separators for + // leading empty names (null or empty strings, but not whitespace). Separators in the middle + // should be added, even if the name part is null/empty, to maintain proper location of the parts. + for (int i = 0; i < strings.Length; i++) + { + if (0 < bld.Length) + { + bld.Append('.'); + } + if (strings[i] is not null && 0 != strings[i].Length) + { + bld.Append(BuildQuotedString("[", "]", strings[i])); + } + } + return bld.ToString(); + } + + // global constant strings + internal const string ColumnEncryptionSystemProviderNamePrefix = "MSSQL_"; + internal const string Command = "Command"; + internal const string Connection = "Connection"; + internal const string Parameter = "Parameter"; + internal const string ParameterName = "ParameterName"; + internal const string ParameterSetPosition = "set_Position"; + + internal const int DefaultCommandTimeout = 30; + internal const float FailoverTimeoutStep = 0.08F; // fraction of timeout to use for fast failover connections + + internal const int CharSize = UnicodeEncoding.CharSize; + + internal static Delegate FindBuilder(MulticastDelegate mcd) + { + foreach (Delegate del in mcd?.GetInvocationList()) + { + if (del.Target is DbCommandBuilder) + return del; + } + + return null; + } + + internal static long TimerCurrent() => DateTime.UtcNow.ToFileTimeUtc(); + + internal static long TimerFromSeconds(int seconds) + { + long result = checked((long)seconds * TimeSpan.TicksPerSecond); + return result; + } + + internal static long TimerFromMilliseconds(long milliseconds) + { + long result = checked(milliseconds * TimeSpan.TicksPerMillisecond); + return result; + } + + internal static bool TimerHasExpired(long timerExpire) + { + bool result = TimerCurrent() > timerExpire; + return result; + } + + internal static long TimerRemaining(long timerExpire) + { + long timerNow = TimerCurrent(); + long result = checked(timerExpire - timerNow); + return result; + } + + internal static long TimerRemainingMilliseconds(long timerExpire) + { + long result = TimerToMilliseconds(TimerRemaining(timerExpire)); + return result; + } + + internal static long TimerRemainingSeconds(long timerExpire) + { + long result = TimerToSeconds(TimerRemaining(timerExpire)); + return result; + } + + internal static long TimerToMilliseconds(long timerValue) + { + long result = timerValue / TimeSpan.TicksPerMillisecond; + return result; + } + + private static long TimerToSeconds(long timerValue) + { + long result = timerValue / TimeSpan.TicksPerSecond; + return result; + } + + /// + /// Note: In Longhorn you'll be able to rename a machine without + /// rebooting. Therefore, don't cache this machine name. + /// + [EnvironmentPermission(SecurityAction.Assert, Read = "COMPUTERNAME")] + internal static string MachineName() => Environment.MachineName; + + internal static Transaction GetCurrentTransaction() + { + Transaction transaction = Transaction.Current; + return transaction; + } + + internal static bool IsDirection(DbParameter value, ParameterDirection condition) + { +#if DEBUG + switch (condition) + { // @perfnote: Enum.IsDefined + case ParameterDirection.Input: + case ParameterDirection.Output: + case ParameterDirection.InputOutput: + case ParameterDirection.ReturnValue: + break; + default: + throw ADP.InvalidParameterDirection(condition); + } +#endif + return (condition == (condition & value.Direction)); + } + + internal static void IsNullOrSqlType(object value, out bool isNull, out bool isSqlType) + { + if ((value is null) || (value == DBNull.Value)) + { + isNull = true; + isSqlType = false; + } + else + { + if (value is INullable nullable) + { + isNull = nullable.IsNull; + // Duplicated from DataStorage.cs + // For back-compat, SqlXml is not in this list + isSqlType = ((value is SqlBinary) || + (value is SqlBoolean) || + (value is SqlByte) || + (value is SqlBytes) || + (value is SqlChars) || + (value is SqlDateTime) || + (value is SqlDecimal) || + (value is SqlDouble) || + (value is SqlGuid) || + (value is SqlInt16) || + (value is SqlInt32) || + (value is SqlInt64) || + (value is SqlMoney) || + (value is SqlSingle) || + (value is SqlString)); + } + else + { + isNull = false; + isSqlType = false; + } + } + } + + private static Version s_systemDataVersion; + + internal static Version GetAssemblyVersion() + { + // NOTE: Using lazy thread-safety since we don't care if two threads both happen to update the value at the same time + if (s_systemDataVersion is null) + { + s_systemDataVersion = new Version(ThisAssembly.InformationalVersion); + } + + return s_systemDataVersion; + } + + + private const string ONDEMAND_PREFIX = "-ondemand"; + private const string AZURE_SYNAPSE = "-ondemand.sql.azuresynapse."; + + internal static bool IsAzureSynapseOnDemandEndpoint(string dataSource) + { + return IsEndpoint(dataSource, ONDEMAND_PREFIX) || dataSource.Contains(AZURE_SYNAPSE); + } + + internal static readonly string[] s_azureSqlServerEndpoints = { StringsHelper.GetString(Strings.AZURESQL_GenericEndpoint), + StringsHelper.GetString(Strings.AZURESQL_GermanEndpoint), + StringsHelper.GetString(Strings.AZURESQL_UsGovEndpoint), + StringsHelper.GetString(Strings.AZURESQL_ChinaEndpoint)}; + + internal static bool IsAzureSqlServerEndpoint(string dataSource) + { + return IsEndpoint(dataSource, null); + } + + // This method assumes dataSource parameter is in TCP connection string format. + private static bool IsEndpoint(string dataSource, string prefix) + { + int length = dataSource.Length; + // remove server port + int foundIndex = dataSource.LastIndexOf(','); + if (foundIndex >= 0) + { + length = foundIndex; + } + + // check for the instance name + foundIndex = dataSource.LastIndexOf('\\', length - 1, length - 1); + if (foundIndex > 0) + { + length = foundIndex; + } + + // trim trailing whitespace + while (length > 0 && char.IsWhiteSpace(dataSource[length - 1])) + { + length -= 1; + } + + // check if servername ends with any endpoints + for (int index = 0; index < s_azureSqlServerEndpoints.Length; index++) + { + string endpoint = string.IsNullOrEmpty(prefix) ? s_azureSqlServerEndpoints[index] : prefix + s_azureSqlServerEndpoints[index]; + if (length > endpoint.Length) + { + if (string.Compare(dataSource, length - endpoint.Length, endpoint, 0, endpoint.Length, StringComparison.OrdinalIgnoreCase) == 0) + { + return true; + } + } + } + + return false; + } + + internal static ArgumentException SingleValuedProperty(string propertyName, string value) + { + ArgumentException e = new(StringsHelper.GetString(Strings.ADP_SingleValuedProperty, propertyName, value)); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static ArgumentException DoubleValuedProperty(string propertyName, string value1, string value2) + { + ArgumentException e = new(StringsHelper.GetString(Strings.ADP_DoubleValuedProperty, propertyName, value1, value2)); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static ArgumentException InvalidPrefixSuffix() + { + ArgumentException e = new(StringsHelper.GetString(Strings.ADP_InvalidPrefixSuffix)); + TraceExceptionAsReturnValue(e); + return e; + } +#endregion + +#region DbConnectionOptions, DataAccess + internal static ArgumentException ConnectionStringSyntax(int index) => Argument(StringsHelper.GetString(Strings.ADP_ConnectionStringSyntax, index)); + + internal static ArgumentException KeywordNotSupported(string keyword) => Argument(StringsHelper.GetString(Strings.ADP_KeywordNotSupported, keyword)); + + internal static Exception InvalidConnectionOptionValue(string key) => InvalidConnectionOptionValue(key, null); + + internal static Exception InvalidConnectionOptionValue(string key, Exception inner) + => Argument(StringsHelper.GetString(Strings.ADP_InvalidConnectionOptionValue, key), inner); + + internal static Exception InvalidConnectionOptionValueLength(string key, int limit) + => Argument(StringsHelper.GetString(Strings.ADP_InvalidConnectionOptionValueLength, key, limit)); + + internal static Exception MissingConnectionOptionValue(string key, string requiredAdditionalKey) + => Argument(StringsHelper.GetString(Strings.ADP_MissingConnectionOptionValue, key, requiredAdditionalKey)); + + internal static InvalidOperationException InvalidDataDirectory() => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidDataDirectory)); + + internal static ArgumentException CollectionRemoveInvalidObject(Type itemType, ICollection collection) + => Argument(StringsHelper.GetString(Strings.ADP_CollectionRemoveInvalidObject, itemType.Name, collection.GetType().Name)); // MDAC 68201 + + internal static ArgumentNullException CollectionNullValue(string parameter, Type collection, Type itemType) + => ArgumentNull(parameter, StringsHelper.GetString(Strings.ADP_CollectionNullValue, collection.Name, itemType.Name)); + + internal static IndexOutOfRangeException CollectionIndexInt32(int index, Type collection, int count) + => IndexOutOfRange(StringsHelper.GetString(Strings.ADP_CollectionIndexInt32, index.ToString(CultureInfo.InvariantCulture), collection.Name, count.ToString(CultureInfo.InvariantCulture))); + + internal static IndexOutOfRangeException CollectionIndexString(Type itemType, string propertyName, string propertyValue, Type collection) + => IndexOutOfRange(StringsHelper.GetString(Strings.ADP_CollectionIndexString, itemType.Name, propertyName, propertyValue, collection.Name)); + + internal static InvalidCastException CollectionInvalidType(Type collection, Type itemType, object invalidValue) + => InvalidCast(StringsHelper.GetString(Strings.ADP_CollectionInvalidType, collection.Name, itemType.FullName, invalidValue.GetType().FullName)); + + internal static ArgumentException ConvertFailed(Type fromType, Type toType, Exception innerException) + => ADP.Argument(StringsHelper.GetString(Strings.SqlConvert_ConvertFailed, fromType.FullName, toType.FullName), innerException); + + internal static ArgumentException InvalidMinMaxPoolSizeValues() + => ADP.Argument(StringsHelper.GetString(Strings.ADP_InvalidMinMaxPoolSizeValues)); +#endregion + +#region DbConnection + private static string ConnectionStateMsg(ConnectionState state) + { // MDAC 82165, if the ConnectionState enum to msg the localization looks weird + return state switch + { + (ConnectionState.Closed) or (ConnectionState.Connecting | ConnectionState.Broken) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Closed), + (ConnectionState.Connecting) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Connecting), + (ConnectionState.Open) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Open), + (ConnectionState.Open | ConnectionState.Executing) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_OpenExecuting), + (ConnectionState.Open | ConnectionState.Fetching) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_OpenFetching), + _ => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg, state.ToString()), + }; + } + + internal static InvalidOperationException NoConnectionString() + => InvalidOperation(StringsHelper.GetString(Strings.ADP_NoConnectionString)); + + internal static NotImplementedException MethodNotImplemented([CallerMemberName] string methodName = "") + { + NotImplementedException e = new(methodName); + TraceExceptionAsReturnValue(e); + return e; + } +#endregion + +#region Stream + internal static Exception StreamClosed([CallerMemberName] string method = "") => InvalidOperation(StringsHelper.GetString(Strings.ADP_StreamClosed, method)); + + static internal Exception InvalidSeekOrigin(string parameterName) => ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_InvalidSeekOrigin), parameterName); + + internal static IOException ErrorReadingFromStream(Exception internalException) => IO(StringsHelper.GetString(Strings.SqlMisc_StreamErrorMessage), internalException); +#endregion + +#region Generic Data Provider Collection + internal static ArgumentException ParametersIsNotParent(Type parameterType, ICollection collection) + => Argument(StringsHelper.GetString(Strings.ADP_CollectionIsNotParent, parameterType.Name, collection.GetType().Name)); + + internal static ArgumentException ParametersIsParent(Type parameterType, ICollection collection) + => Argument(StringsHelper.GetString(Strings.ADP_CollectionIsNotParent, parameterType.Name, collection.GetType().Name)); +#endregion + +#region ConnectionUtil + internal enum InternalErrorCode + { + UnpooledObjectHasOwner = 0, + UnpooledObjectHasWrongOwner = 1, + PushingObjectSecondTime = 2, + PooledObjectHasOwner = 3, + PooledObjectInPoolMoreThanOnce = 4, + CreateObjectReturnedNull = 5, + NewObjectCannotBePooled = 6, + NonPooledObjectUsedMoreThanOnce = 7, + AttemptingToPoolOnRestrictedToken = 8, + // ConnectionOptionsInUse = 9, + ConvertSidToStringSidWReturnedNull = 10, + // UnexpectedTransactedObject = 11, + AttemptingToConstructReferenceCollectionOnStaticObject = 12, + AttemptingToEnlistTwice = 13, + CreateReferenceCollectionReturnedNull = 14, + PooledObjectWithoutPool = 15, + UnexpectedWaitAnyResult = 16, + SynchronousConnectReturnedPending = 17, + CompletedConnectReturnedPending = 18, + + NameValuePairNext = 20, + InvalidParserState1 = 21, + InvalidParserState2 = 22, + InvalidParserState3 = 23, + + InvalidBuffer = 30, + + UnimplementedSMIMethod = 40, + InvalidSmiCall = 41, + + SqlDependencyObtainProcessDispatcherFailureObjectHandle = 50, + SqlDependencyProcessDispatcherFailureCreateInstance = 51, + SqlDependencyProcessDispatcherFailureAppDomain = 52, + SqlDependencyCommandHashIsNotAssociatedWithNotification = 53, + + UnknownTransactionFailure = 60, + } + + internal static Exception InternalError(InternalErrorCode internalError) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_InternalProviderError, (int)internalError)); + + internal static Exception ClosedConnectionError() => InvalidOperation(StringsHelper.GetString(Strings.ADP_ClosedConnectionError)); + internal static Exception ConnectionAlreadyOpen(ConnectionState state) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_ConnectionAlreadyOpen, ADP.ConnectionStateMsg(state))); + + internal static Exception TransactionPresent() => InvalidOperation(StringsHelper.GetString(Strings.ADP_TransactionPresent)); + + internal static Exception LocalTransactionPresent() => InvalidOperation(StringsHelper.GetString(Strings.ADP_LocalTransactionPresent)); + + internal static Exception OpenConnectionPropertySet(string property, ConnectionState state) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_OpenConnectionPropertySet, property, ADP.ConnectionStateMsg(state))); + + internal static Exception EmptyDatabaseName() => Argument(StringsHelper.GetString(Strings.ADP_EmptyDatabaseName)); + + internal enum ConnectionError + { + BeginGetConnectionReturnsNull, + GetConnectionReturnsNull, + ConnectionOptionsMissing, + CouldNotSwitchToClosedPreviouslyOpenedState, + } + + internal static Exception InternalConnectionError(ConnectionError internalError) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_InternalConnectionError, (int)internalError)); + + internal static Exception InvalidConnectRetryCountValue() => Argument(StringsHelper.GetString(Strings.SQLCR_InvalidConnectRetryCountValue)); + + internal static Exception InvalidConnectRetryIntervalValue() => Argument(StringsHelper.GetString(Strings.SQLCR_InvalidConnectRetryIntervalValue)); +#endregion + +#region DbDataReader + internal static Exception DataReaderClosed([CallerMemberName] string method = "") + => InvalidOperation(StringsHelper.GetString(Strings.ADP_DataReaderClosed, method)); + + internal static ArgumentOutOfRangeException InvalidSourceBufferIndex(int maxLen, long srcOffset, string parameterName) + => ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_InvalidSourceBufferIndex, + maxLen.ToString(CultureInfo.InvariantCulture), + srcOffset.ToString(CultureInfo.InvariantCulture)), parameterName); + + internal static ArgumentOutOfRangeException InvalidDestinationBufferIndex(int maxLen, int dstOffset, string parameterName) + => ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_InvalidDestinationBufferIndex, + maxLen.ToString(CultureInfo.InvariantCulture), + dstOffset.ToString(CultureInfo.InvariantCulture)), parameterName); + + internal static IndexOutOfRangeException InvalidBufferSizeOrIndex(int numBytes, int bufferIndex) + => IndexOutOfRange(StringsHelper.GetString(Strings.SQL_InvalidBufferSizeOrIndex, + numBytes.ToString(CultureInfo.InvariantCulture), + bufferIndex.ToString(CultureInfo.InvariantCulture))); + + internal static Exception InvalidDataLength(long length) + => IndexOutOfRange(StringsHelper.GetString(Strings.SQL_InvalidDataLength, length.ToString(CultureInfo.InvariantCulture))); + + internal static bool CompareInsensitiveInvariant(string strvalue, string strconst) + => 0 == CultureInfo.InvariantCulture.CompareInfo.Compare(strvalue, strconst, CompareOptions.IgnoreCase); + + internal static int DstCompare(string strA, string strB) // this is null safe + => CultureInfo.CurrentCulture.CompareInfo.Compare(strA, strB, ADP.DefaultCompareOptions); + + internal static void SetCurrentTransaction(Transaction transaction) => Transaction.Current = transaction; + + internal static Exception NonSeqByteAccess(long badIndex, long currIndex, string method) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_NonSeqByteAccess, + badIndex.ToString(CultureInfo.InvariantCulture), + currIndex.ToString(CultureInfo.InvariantCulture), + method)); + + internal static Exception NegativeParameter(string parameterName) => InvalidOperation(StringsHelper.GetString(Strings.ADP_NegativeParameter, parameterName)); + + internal static Exception InvalidXmlMissingColumn(string collectionName, string columnName) + => Argument(StringsHelper.GetString(Strings.MDF_InvalidXmlMissingColumn, collectionName, columnName)); + + internal static InvalidOperationException AsyncOperationPending() => InvalidOperation(StringsHelper.GetString(Strings.ADP_PendingAsyncOperation)); +#endregion + +#region IDbCommand + // IDbCommand.CommandType + static internal ArgumentOutOfRangeException InvalidCommandType(CommandType value) + { +#if DEBUG + switch (value) + { + case CommandType.Text: + case CommandType.StoredProcedure: + case CommandType.TableDirect: + Debug.Assert(false, "valid CommandType " + value.ToString()); + break; + } +#endif + return InvalidEnumerationValue(typeof(CommandType), (int)value); + } + + internal static Exception TooManyRestrictions(string collectionName) + => Argument(StringsHelper.GetString(Strings.MDF_TooManyRestrictions, collectionName)); + + internal static Exception CommandTextRequired(string method) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_CommandTextRequired, method)); + + internal static Exception UninitializedParameterSize(int index, Type dataType) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_UninitializedParameterSize, index.ToString(CultureInfo.InvariantCulture), dataType.Name)); + + internal static Exception PrepareParameterType(DbCommand cmd) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_PrepareParameterType, cmd.GetType().Name)); + + internal static Exception PrepareParameterSize(DbCommand cmd) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_PrepareParameterSize, cmd.GetType().Name)); + + internal static Exception PrepareParameterScale(DbCommand cmd, string type) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_PrepareParameterScale, cmd.GetType().Name, type)); + + internal static Exception MismatchedAsyncResult(string expectedMethod, string gotMethod) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_MismatchedAsyncResult, expectedMethod, gotMethod)); + + // IDataParameter.SourceVersion + internal static ArgumentOutOfRangeException InvalidDataRowVersion(DataRowVersion value) + { +#if DEBUG + switch (value) + { + case DataRowVersion.Default: + case DataRowVersion.Current: + case DataRowVersion.Original: + case DataRowVersion.Proposed: + Debug.Fail($"Invalid DataRowVersion {value}"); + break; + } +#endif + return InvalidEnumerationValue(typeof(DataRowVersion), (int)value); + } + + internal static ArgumentOutOfRangeException NotSupportedCommandBehavior(CommandBehavior value, string method) + => NotSupportedEnumerationValue(typeof(CommandBehavior), value.ToString(), method); + + internal static ArgumentException BadParameterName(string parameterName) + { + ArgumentException e = new(StringsHelper.GetString(Strings.ADP_BadParameterName, parameterName)); + TraceExceptionAsReturnValue(e); + return e; + } + + internal static Exception DeriveParametersNotSupported(IDbCommand value) + => DataAdapter(StringsHelper.GetString(Strings.ADP_DeriveParametersNotSupported, value.GetType().Name, value.CommandType.ToString())); + + internal static Exception NoStoredProcedureExists(string sproc) => InvalidOperation(StringsHelper.GetString(Strings.ADP_NoStoredProcedureExists, sproc)); +#endregion + +#region DbMetaDataFactory + internal static Exception DataTableDoesNotExist(string collectionName) + => Argument(StringsHelper.GetString(Strings.MDF_DataTableDoesNotExist, collectionName)); + + // IDbCommand.UpdateRowSource + internal static ArgumentOutOfRangeException InvalidUpdateRowSource(UpdateRowSource value) + { +#if DEBUG + switch (value) + { + case UpdateRowSource.None: + case UpdateRowSource.OutputParameters: + case UpdateRowSource.FirstReturnedRecord: + case UpdateRowSource.Both: + Debug.Fail("valid UpdateRowSource " + value.ToString()); + break; + } +#endif + return InvalidEnumerationValue(typeof(UpdateRowSource), (int)value); + } + + internal static Exception QueryFailed(string collectionName, Exception e) + => InvalidOperation(StringsHelper.GetString(Strings.MDF_QueryFailed, collectionName), e); + + internal static Exception NoColumns() => Argument(StringsHelper.GetString(Strings.MDF_NoColumns)); + + internal static InvalidOperationException ConnectionRequired(string method) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_ConnectionRequired, method)); + + internal static InvalidOperationException OpenConnectionRequired(string method, ConnectionState state) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_OpenConnectionRequired, method, ADP.ConnectionStateMsg(state))); + + internal static Exception OpenReaderExists(bool marsOn) => OpenReaderExists(null, marsOn); + + internal static Exception OpenReaderExists(Exception e, bool marsOn) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_OpenReaderExists, marsOn ? ADP.Command : ADP.Connection), e); + + internal static Exception InvalidXml() => Argument(StringsHelper.GetString(Strings.MDF_InvalidXml)); + + internal static Exception InvalidXmlInvalidValue(string collectionName, string columnName) + => Argument(StringsHelper.GetString(Strings.MDF_InvalidXmlInvalidValue, collectionName, columnName)); + + internal static Exception CollectionNameIsNotUnique(string collectionName) + => Argument(StringsHelper.GetString(Strings.MDF_CollectionNameISNotUnique, collectionName)); + + internal static Exception UnableToBuildCollection(string collectionName) + => Argument(StringsHelper.GetString(Strings.MDF_UnableToBuildCollection, collectionName)); + + internal static Exception UndefinedCollection(string collectionName) + => Argument(StringsHelper.GetString(Strings.MDF_UndefinedCollection, collectionName)); + + internal static Exception UnsupportedVersion(string collectionName) => Argument(StringsHelper.GetString(Strings.MDF_UnsupportedVersion, collectionName)); + + internal static Exception AmbiguousCollectionName(string collectionName) + => Argument(StringsHelper.GetString(Strings.MDF_AmbiguousCollectionName, collectionName)); + + internal static Exception MissingDataSourceInformationColumn() => Argument(StringsHelper.GetString(Strings.MDF_MissingDataSourceInformationColumn)); + + internal static Exception IncorrectNumberOfDataSourceInformationRows() + => Argument(StringsHelper.GetString(Strings.MDF_IncorrectNumberOfDataSourceInformationRows)); + + internal static Exception MissingRestrictionColumn() => Argument(StringsHelper.GetString(Strings.MDF_MissingRestrictionColumn)); + + internal static Exception MissingRestrictionRow() => Argument(StringsHelper.GetString(Strings.MDF_MissingRestrictionRow)); + + internal static Exception UndefinedPopulationMechanism(string populationMechanism) +#if NETFRAMEWORK + => Argument(StringsHelper.GetString(Strings.MDF_UndefinedPopulationMechanism, populationMechanism)); +#else + => throw new NotImplementedException(); +#endif +#endregion + +#region DbConnectionPool and related + internal static Exception PooledOpenTimeout() + => ADP.InvalidOperation(StringsHelper.GetString(Strings.ADP_PooledOpenTimeout)); + + internal static Exception NonPooledOpenTimeout() + => ADP.TimeoutException(StringsHelper.GetString(Strings.ADP_NonPooledOpenTimeout)); +#endregion + +#region DbProviderException + internal static InvalidOperationException TransactionConnectionMismatch() + => Provider(StringsHelper.GetString(Strings.ADP_TransactionConnectionMismatch)); + + internal static InvalidOperationException TransactionRequired(string method) + => Provider(StringsHelper.GetString(Strings.ADP_TransactionRequired, method)); + + internal static InvalidOperationException TransactionCompletedButNotDisposed() => Provider(StringsHelper.GetString(Strings.ADP_TransactionCompletedButNotDisposed)); + +#endregion + +#region SqlMetaData, SqlTypes + internal static Exception InvalidMetaDataValue() => ADP.Argument(StringsHelper.GetString(Strings.ADP_InvalidMetaDataValue)); + + internal static InvalidOperationException NonSequentialColumnAccess(int badCol, int currCol) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_NonSequentialColumnAccess, + badCol.ToString(CultureInfo.InvariantCulture), + currCol.ToString(CultureInfo.InvariantCulture))); +#endregion + +#region IDataParameter + internal static ArgumentException InvalidDataType(TypeCode typecode) => Argument(StringsHelper.GetString(Strings.ADP_InvalidDataType, typecode.ToString())); + + internal static ArgumentException UnknownDataType(Type dataType) => Argument(StringsHelper.GetString(Strings.ADP_UnknownDataType, dataType.FullName)); + + internal static ArgumentException DbTypeNotSupported(DbType type, Type enumtype) + => Argument(StringsHelper.GetString(Strings.ADP_DbTypeNotSupported, type.ToString(), enumtype.Name)); + + internal static ArgumentException UnknownDataTypeCode(Type dataType, TypeCode typeCode) + => Argument(StringsHelper.GetString(Strings.ADP_UnknownDataTypeCode, ((int)typeCode).ToString(CultureInfo.InvariantCulture), dataType.FullName)); + + internal static ArgumentException InvalidOffsetValue(int value) + => Argument(StringsHelper.GetString(Strings.ADP_InvalidOffsetValue, value.ToString(CultureInfo.InvariantCulture))); + + internal static ArgumentException InvalidSizeValue(int value) + => Argument(StringsHelper.GetString(Strings.ADP_InvalidSizeValue, value.ToString(CultureInfo.InvariantCulture))); + + internal static ArgumentException ParameterValueOutOfRange(decimal value) + => ADP.Argument(StringsHelper.GetString(Strings.ADP_ParameterValueOutOfRange, value.ToString((IFormatProvider)null))); + + internal static ArgumentException ParameterValueOutOfRange(SqlDecimal value) => ADP.Argument(StringsHelper.GetString(Strings.ADP_ParameterValueOutOfRange, value.ToString())); + + internal static ArgumentException ParameterValueOutOfRange(string value) => ADP.Argument(StringsHelper.GetString(Strings.ADP_ParameterValueOutOfRange, value)); + + internal static ArgumentException VersionDoesNotSupportDataType(string typeName) => Argument(StringsHelper.GetString(Strings.ADP_VersionDoesNotSupportDataType, typeName)); + + internal static Exception ParameterConversionFailed(object value, Type destType, Exception inner) + { + Debug.Assert(null != value, "null value on conversion failure"); + Debug.Assert(null != inner, "null inner on conversion failure"); + + Exception e; + string message = StringsHelper.GetString(Strings.ADP_ParameterConversionFailed, value.GetType().Name, destType.Name); + if (inner is ArgumentException) + { + e = new ArgumentException(message, inner); + } + else if (inner is FormatException) + { + e = new FormatException(message, inner); + } + else if (inner is InvalidCastException) + { + e = new InvalidCastException(message, inner); + } + else if (inner is OverflowException) + { + e = new OverflowException(message, inner); + } + else + { + e = inner; + } + TraceExceptionAsReturnValue(e); + return e; + } +#endregion + +#region IDataParameterCollection + internal static Exception ParametersMappingIndex(int index, DbParameterCollection collection) => CollectionIndexInt32(index, collection.GetType(), collection.Count); + + internal static Exception ParametersSourceIndex(string parameterName, DbParameterCollection collection, Type parameterType) + => CollectionIndexString(parameterType, ADP.ParameterName, parameterName, collection.GetType()); + + internal static Exception ParameterNull(string parameter, DbParameterCollection collection, Type parameterType) + => CollectionNullValue(parameter, collection.GetType(), parameterType); + + internal static Exception InvalidParameterType(DbParameterCollection collection, Type parameterType, object invalidValue) + => CollectionInvalidType(collection.GetType(), parameterType, invalidValue); +#endregion + +#region IDbTransaction + internal static Exception ParallelTransactionsNotSupported(DbConnection obj) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_ParallelTransactionsNotSupported, obj.GetType().Name)); + + internal static Exception TransactionZombied(DbTransaction obj) => InvalidOperation(StringsHelper.GetString(Strings.ADP_TransactionZombied, obj.GetType().Name)); +#endregion + +#region DbProviderConfigurationHandler + internal static InvalidOperationException InvalidMixedUsageOfSecureAndClearCredential() + => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfSecureAndClearCredential)); + + internal static ArgumentException InvalidMixedArgumentOfSecureAndClearCredential() + => Argument(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfSecureAndClearCredential)); + + internal static InvalidOperationException InvalidMixedUsageOfSecureCredentialAndIntegratedSecurity() + => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfSecureCredentialAndIntegratedSecurity)); + + internal static ArgumentException InvalidMixedArgumentOfSecureCredentialAndIntegratedSecurity() + => Argument(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfSecureCredentialAndIntegratedSecurity)); + + internal static InvalidOperationException InvalidMixedUsageOfAccessTokenAndIntegratedSecurity() + => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenAndIntegratedSecurity)); + + static internal InvalidOperationException InvalidMixedUsageOfAccessTokenAndUserIDPassword() + => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenAndUserIDPassword)); + + static internal InvalidOperationException InvalidMixedUsageOfAccessTokenAndAuthentication() + => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenAndAuthentication)); + + static internal Exception InvalidMixedUsageOfCredentialAndAccessToken() + => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfCredentialAndAccessToken)); +#endregion + + internal static bool IsEmpty(string str) => string.IsNullOrEmpty(str); + internal static readonly IntPtr s_ptrZero = IntPtr.Zero; +#if NETFRAMEWORK +#region netfx project only + internal static Task CreatedTaskWithException(Exception ex) + { + TaskCompletionSource completion = new(); + completion.SetException(ex); + return completion.Task; + } + + internal static Task CreatedTaskWithCancellation() + { + TaskCompletionSource completion = new(); + completion.SetCanceled(); + return completion.Task; + } + + internal static void TraceExceptionForCapture(Exception e) + { + Debug.Assert(ADP.IsCatchableExceptionType(e), "Invalid exception type, should have been re-thrown!"); + TraceException(" '{0}'", e); + } + + // + // Helper Functions + // + internal static void CheckArgumentLength(string value, string parameterName) + { + CheckArgumentNull(value, parameterName); + if (0 == value.Length) + { + throw Argument(StringsHelper.GetString(Strings.ADP_EmptyString, parameterName)); // MDAC 94859 + } + } + + // IDbConnection.BeginTransaction, OleDbTransaction.Begin + internal static ArgumentOutOfRangeException InvalidIsolationLevel(IsolationLevel value) + { +#if DEBUG + switch (value) + { + case IsolationLevel.Unspecified: + case IsolationLevel.Chaos: + case IsolationLevel.ReadUncommitted: + case IsolationLevel.ReadCommitted: + case IsolationLevel.RepeatableRead: + case IsolationLevel.Serializable: + case IsolationLevel.Snapshot: + Debug.Assert(false, "valid IsolationLevel " + value.ToString()); + break; + } +#endif + return InvalidEnumerationValue(typeof(IsolationLevel), (int)value); + } + + // DBDataPermissionAttribute.KeyRestrictionBehavior + internal static ArgumentOutOfRangeException InvalidKeyRestrictionBehavior(KeyRestrictionBehavior value) + { +#if DEBUG + switch (value) + { + case KeyRestrictionBehavior.PreventUsage: + case KeyRestrictionBehavior.AllowOnly: + Debug.Assert(false, "valid KeyRestrictionBehavior " + value.ToString()); + break; + } +#endif + return InvalidEnumerationValue(typeof(KeyRestrictionBehavior), (int)value); + } + + // IDataParameter.Direction + internal static ArgumentOutOfRangeException InvalidParameterDirection(ParameterDirection value) + { +#if DEBUG + switch (value) + { + case ParameterDirection.Input: + case ParameterDirection.Output: + case ParameterDirection.InputOutput: + case ParameterDirection.ReturnValue: + Debug.Assert(false, "valid ParameterDirection " + value.ToString()); + break; + } +#endif + return InvalidEnumerationValue(typeof(ParameterDirection), (int)value); + } + + // + // DbConnectionOptions, DataAccess + // + internal static ArgumentException InvalidKeyname(string parameterName) + { + return Argument(StringsHelper.GetString(Strings.ADP_InvalidKey), parameterName); + } + internal static ArgumentException InvalidValue(string parameterName) + { + return Argument(StringsHelper.GetString(Strings.ADP_InvalidValue), parameterName); + } + internal static ArgumentException InvalidMixedArgumentOfSecureCredentialAndContextConnection() + { + return ADP.Argument(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfSecureCredentialAndContextConnection)); + } + internal static InvalidOperationException InvalidMixedUsageOfAccessTokenAndContextConnection() + { + return ADP.InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenAndContextConnection)); + } + internal static Exception InvalidMixedUsageOfAccessTokenAndCredential() + { + return ADP.InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenAndCredential)); + } + + // + // DBDataPermission, DataAccess, Odbc + // + internal static Exception InvalidXMLBadVersion() + { + return Argument(StringsHelper.GetString(Strings.ADP_InvalidXMLBadVersion)); + } + internal static Exception NotAPermissionElement() + { + return Argument(StringsHelper.GetString(Strings.ADP_NotAPermissionElement)); + } + internal static Exception PermissionTypeMismatch() + { + return Argument(StringsHelper.GetString(Strings.ADP_PermissionTypeMismatch)); + } + + // + // DbDataReader + // + internal static Exception NumericToDecimalOverflow() + { + return InvalidCast(StringsHelper.GetString(Strings.ADP_NumericToDecimalOverflow)); + } + + // + // : IDbCommand + // + internal static Exception InvalidCommandTimeout(int value, string name) + { + return Argument(StringsHelper.GetString(Strings.ADP_InvalidCommandTimeout, value.ToString(CultureInfo.InvariantCulture)), name); + } + + // + // : DbDataAdapter + // + internal static InvalidOperationException ComputerNameEx(int lastError) + { + return InvalidOperation(StringsHelper.GetString(Strings.ADP_ComputerNameEx, lastError)); + } + + // global constant strings + internal const float FailoverTimeoutStepForTnir = 0.125F; // Fraction of timeout to use in case of Transparent Network IP resolution. + internal const int MinimumTimeoutForTnirMs = 500; // The first login attempt in Transparent network IP Resolution + + internal static readonly int s_ptrSize = IntPtr.Size; + internal static readonly IntPtr s_invalidPtr = new(-1); // use for INVALID_HANDLE + + internal static readonly bool s_isWindowsNT = (PlatformID.Win32NT == Environment.OSVersion.Platform); + internal static readonly bool s_isPlatformNT5 = (ADP.s_isWindowsNT && (Environment.OSVersion.Version.Major >= 5)); + + [FileIOPermission(SecurityAction.Assert, AllFiles = FileIOPermissionAccess.PathDiscovery)] + [ResourceExposure(ResourceScope.Machine)] + [ResourceConsumption(ResourceScope.Machine)] + internal static string GetFullPath(string filename) + { // MDAC 77686 + return Path.GetFullPath(filename); + } + + // TODO: cache machine name and listen to longhorn event to reset it + internal static string GetComputerNameDnsFullyQualified() + { + const int ComputerNameDnsFullyQualified = 3; // winbase.h, enum COMPUTER_NAME_FORMAT + const int ERROR_MORE_DATA = 234; // winerror.h + + string value; + if (s_isPlatformNT5) + { + int length = 0; // length parameter must be zero if buffer is null + // query for the required length + // VSTFDEVDIV 479551 - ensure that GetComputerNameEx does not fail with unexpected values and that the length is positive + int getComputerNameExError = 0; + if (0 == SafeNativeMethods.GetComputerNameEx(ComputerNameDnsFullyQualified, null, ref length)) + { + getComputerNameExError = Marshal.GetLastWin32Error(); + } + if ((getComputerNameExError != 0 && getComputerNameExError != ERROR_MORE_DATA) || length <= 0) + { + throw ADP.ComputerNameEx(getComputerNameExError); + } + + StringBuilder buffer = new(length); + length = buffer.Capacity; + if (0 == SafeNativeMethods.GetComputerNameEx(ComputerNameDnsFullyQualified, buffer, ref length)) + { + throw ADP.ComputerNameEx(Marshal.GetLastWin32Error()); + } + + // Note: In Longhorn you'll be able to rename a machine without + // rebooting. Therefore, don't cache this machine name. + value = buffer.ToString(); + } + else + { + value = ADP.MachineName(); + } + return value; + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static IntPtr IntPtrOffset(IntPtr pbase, int offset) + { + if (4 == ADP.s_ptrSize) + { + return (IntPtr)checked(pbase.ToInt32() + offset); + } + Debug.Assert(8 == ADP.s_ptrSize, "8 != IntPtr.Size"); // MDAC 73747 + return (IntPtr)checked(pbase.ToInt64() + offset); + } + +#endregion +#else +#region netcore project only + internal static Timer UnsafeCreateTimer(TimerCallback callback, object state, int dueTime, int period) + { + // Don't capture the current ExecutionContext and its AsyncLocals onto + // a global timer causing them to live forever + bool restoreFlow = false; + try + { + if (!ExecutionContext.IsFlowSuppressed()) + { + ExecutionContext.SuppressFlow(); + restoreFlow = true; + } + + return new Timer(callback, state, dueTime, period); + } + finally + { + // Restore the current ExecutionContext + if (restoreFlow) + ExecutionContext.RestoreFlow(); + } + } + + // + // COM+ exceptions + // + internal static PlatformNotSupportedException DbTypeNotSupported(string dbType) => new(StringsHelper.GetString(Strings.SQL_DbTypeNotSupportedOnThisPlatform, dbType)); + + // IDbConnection.BeginTransaction, OleDbTransaction.Begin + internal static ArgumentOutOfRangeException InvalidIsolationLevel(IsolationLevel value) + { +#if DEBUG + switch (value) + { + case IsolationLevel.Unspecified: + case IsolationLevel.Chaos: + case IsolationLevel.ReadUncommitted: + case IsolationLevel.ReadCommitted: + case IsolationLevel.RepeatableRead: + case IsolationLevel.Serializable: + case IsolationLevel.Snapshot: + Debug.Fail("valid IsolationLevel " + value.ToString()); + break; + } +#endif + return InvalidEnumerationValue(typeof(IsolationLevel), (int)value); + } + + // ConnectionUtil + internal static Exception IncorrectPhysicalConnectionType() => new ArgumentException(StringsHelper.GetString(StringsHelper.SNI_IncorrectPhysicalConnectionType)); + + // IDataParameter.Direction + internal static ArgumentOutOfRangeException InvalidParameterDirection(ParameterDirection value) + { +#if DEBUG + switch (value) + { + case ParameterDirection.Input: + case ParameterDirection.Output: + case ParameterDirection.InputOutput: + case ParameterDirection.ReturnValue: + Debug.Fail("valid ParameterDirection " + value.ToString()); + break; + } +#endif + return InvalidEnumerationValue(typeof(ParameterDirection), (int)value); + } + + // + // : IDbCommand + // + internal static Exception InvalidCommandTimeout(int value, [CallerMemberName] string property = "") + => Argument(StringsHelper.GetString(Strings.ADP_InvalidCommandTimeout, value.ToString(CultureInfo.InvariantCulture)), property); +#endregion +#endif + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionOptions.Common.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionOptions.Common.cs new file mode 100644 index 0000000000..65e425590e --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionOptions.Common.cs @@ -0,0 +1,770 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Text; +using System.Text.RegularExpressions; +using Microsoft.Data.SqlClient; + +namespace Microsoft.Data.Common +{ + partial class DbConnectionOptions + { + // instances of this class are intended to be immutable, i.e readonly + // used by pooling classes so it is much easier to verify correctness + // when not worried about the class being modified during execution + + // connection string common keywords + private static class KEY + { + internal const string Integrated_Security = DbConnectionStringKeywords.IntegratedSecurity; + internal const string Password = DbConnectionStringKeywords.Password; + internal const string Persist_Security_Info = DbConnectionStringKeywords.PersistSecurityInfo; + internal const string User_ID = DbConnectionStringKeywords.UserID; + internal const string Encrypt = DbConnectionStringKeywords.Encrypt; + } + + // known connection string common synonyms + private static class SYNONYM + { + internal const string Pwd = DbConnectionStringSynonyms.Pwd; + internal const string UID = DbConnectionStringSynonyms.UID; + } + +#if DEBUG + /*private const string ConnectionStringPatternV1 = + "[\\s;]*" + +"(?([^=\\s]|\\s+[^=\\s]|\\s+==|==)+)" + + "\\s*=(?!=)\\s*" + +"(?(" + + "(" + "\"" + "([^\"]|\"\")*" + "\"" + ")" + + "|" + + "(" + "'" + "([^']|'')*" + "'" + ")" + + "|" + + "(" + "(?![\"'])" + "([^\\s;]|\\s+[^\\s;])*" + "(?([^=\\s\\p{Cc}]|\\s+[^=\\s\\p{Cc}]|\\s+==|==)+)" // allow any visible character for keyname except '=' which must quoted as '==' + + "\\s*=(?!=)\\s*" // the equal sign divides the key and value parts + + "(?" + + "(\"([^\"\u0000]|\"\")*\")" // double quoted string, " must be quoted as "" + + "|" + + "('([^'\u0000]|'')*')" // single quoted string, ' must be quoted as '' + + "|" + + "((?![\"'\\s])" // unquoted value must not start with " or ' or space, would also like = but too late to change + + "([^;\\s\\p{Cc}]|\\s+[^;\\s\\p{Cc}])*" // control characters must be quoted + + "(?([^=\\s\\p{Cc}]|\\s+[^=\\s\\p{Cc}])+)" // allow any visible character for keyname except '=' + + "\\s*=\\s*" // the equal sign divides the key and value parts + + "(?" + + "(\\{([^\\}\u0000]|\\}\\})*\\})" // quoted string, starts with { and ends with } + + "|" + + "((?![\\{\\s])" // unquoted value must not start with { or space, would also like = but too late to change + + "([^;\\s\\p{Cc}]|\\s+[^;\\s\\p{Cc}])*" // control characters must be quoted + + + ")" // although the spec does not allow {} + // embedded within a value, the retail code does. + + ")(\\s*)(;|[\u0000\\s]*$)" // whitespace after value up to semicolon or end-of-line + + ")*" // repeat the key-value pair + + "[\\s;]*[\u0000\\s]*" // trailing whitespace/semicolons (DataSourceLocator), embedded nulls are allowed only in the end + ; + + private static readonly Regex s_connectionStringRegex = new Regex(ConnectionStringPattern, RegexOptions.ExplicitCapture | RegexOptions.Compiled); + private static readonly Regex s_connectionStringRegexOdbc = new Regex(ConnectionStringPatternOdbc, RegexOptions.ExplicitCapture | RegexOptions.Compiled); +#endif + private const string ConnectionStringValidKeyPattern = "^(?![;\\s])[^\\p{Cc}]+(? _parsetable; + + internal Dictionary Parsetable => _parsetable; + public bool IsEmpty => _keyChain == null; + + public DbConnectionOptions(string connectionString, Dictionary synonyms) + { + _parsetable = new Dictionary(StringComparer.InvariantCultureIgnoreCase); + _usersConnectionString = ((null != connectionString) ? connectionString : ""); + + // first pass on parsing, initial syntax check + if (0 < _usersConnectionString.Length) + { + _keyChain = ParseInternal(_parsetable, _usersConnectionString, true, synonyms, false); + _hasPasswordKeyword = (_parsetable.ContainsKey(KEY.Password) || _parsetable.ContainsKey(SYNONYM.Pwd)); + _hasUserIdKeyword = (_parsetable.ContainsKey(KEY.User_ID) || _parsetable.ContainsKey(SYNONYM.UID)); + } + } + + protected DbConnectionOptions(DbConnectionOptions connectionOptions) + { // Clone used by SqlConnectionString + _usersConnectionString = connectionOptions._usersConnectionString; + _parsetable = connectionOptions._parsetable; + _keyChain = connectionOptions._keyChain; + _hasPasswordKeyword = connectionOptions._hasPasswordKeyword; + _hasUserIdKeyword = connectionOptions._hasUserIdKeyword; + } + + internal bool TryGetParsetableValue(string key, out string value) => _parsetable.TryGetValue(key, out value); + + // same as Boolean, but with SSPI thrown in as valid yes + public bool ConvertValueToIntegratedSecurity() + { + return _parsetable.TryGetValue(KEY.Integrated_Security, out string value) && value != null ? + ConvertValueToIntegratedSecurityInternal(value) : + false; + } + + internal bool ConvertValueToIntegratedSecurityInternal(string stringValue) + { + if (CompareInsensitiveInvariant(stringValue, "sspi") || CompareInsensitiveInvariant(stringValue, "true") || CompareInsensitiveInvariant(stringValue, "yes")) + return true; + else if (CompareInsensitiveInvariant(stringValue, "false") || CompareInsensitiveInvariant(stringValue, "no")) + return false; + else + { + string tmp = stringValue.Trim(); // Remove leading & trailing whitespace. + if (CompareInsensitiveInvariant(tmp, "sspi") || CompareInsensitiveInvariant(tmp, "true") || CompareInsensitiveInvariant(tmp, "yes")) + return true; + else if (CompareInsensitiveInvariant(tmp, "false") || CompareInsensitiveInvariant(tmp, "no")) + return false; + else + { + throw ADP.InvalidConnectionOptionValue(KEY.Integrated_Security); + } + } + } + + public int ConvertValueToInt32(string keyName, int defaultValue) + { + return _parsetable.TryGetValue(keyName, out string value) && value != null ? + ConvertToInt32Internal(keyName, value) : + defaultValue; + } + + internal static int ConvertToInt32Internal(string keyname, string stringValue) + { + try + { + return int.Parse(stringValue, System.Globalization.NumberStyles.Integer, CultureInfo.InvariantCulture); + } + catch (FormatException e) + { + throw ADP.InvalidConnectionOptionValue(keyname, e); + } + catch (OverflowException e) + { + throw ADP.InvalidConnectionOptionValue(keyname, e); + } + } + + public string ConvertValueToString(string keyName, string defaultValue) + => _parsetable.TryGetValue(keyName, out string value) && value != null ? value : defaultValue; + + public bool ContainsKey(string keyword) => _parsetable.ContainsKey(keyword); + + protected internal virtual string Expand() => _usersConnectionString; + + public string UsersConnectionString(bool hidePassword) => UsersConnectionString(hidePassword, false); + + internal string UsersConnectionStringForTrace() => UsersConnectionString(true, true); + + private string UsersConnectionString(bool hidePassword, bool forceHidePassword) + { + string connectionString = _usersConnectionString; + if (_hasPasswordKeyword && (forceHidePassword || (hidePassword && !HasPersistablePassword))) + { + ReplacePasswordPwd(out connectionString, false); + } + return connectionString ?? string.Empty; + } + + internal bool HasPersistablePassword => _hasPasswordKeyword ? + ConvertValueToBoolean(KEY.Persist_Security_Info, DbConnectionStringDefaults.PersistSecurityInfo) : + true; // no password means persistable password so we don't have to munge + + public bool ConvertValueToBoolean(string keyName, bool defaultValue) + { + string value; + return _parsetable.TryGetValue(keyName, out value) ? + ConvertValueToBooleanInternal(keyName, value) : + defaultValue; + } + + internal static bool ConvertValueToBooleanInternal(string keyName, string stringValue) + { + if (CompareInsensitiveInvariant(stringValue, "true") || CompareInsensitiveInvariant(stringValue, "yes")) + return true; + else if (CompareInsensitiveInvariant(stringValue, "false") || CompareInsensitiveInvariant(stringValue, "no")) + return false; + else + { + string tmp = stringValue.Trim(); // Remove leading & trailing whitespace. + if (CompareInsensitiveInvariant(tmp, "true") || CompareInsensitiveInvariant(tmp, "yes")) + return true; + else if (CompareInsensitiveInvariant(tmp, "false") || CompareInsensitiveInvariant(tmp, "no")) + return false; + else + { + throw ADP.InvalidConnectionOptionValue(keyName); + } + } + } + + private static bool CompareInsensitiveInvariant(string strvalue, string strconst) + => (0 == StringComparer.OrdinalIgnoreCase.Compare(strvalue, strconst)); + + [System.Diagnostics.Conditional("DEBUG")] + private static void DebugTraceKeyValuePair(string keyname, string keyvalue, Dictionary synonyms) + { + if (SqlClientEventSource.Log.IsAdvancedTraceOn()) + { + Debug.Assert(string.Equals(keyname, keyname?.ToLower(), StringComparison.InvariantCulture), "missing ToLower"); + string realkeyname = ((null != synonyms) ? synonyms[keyname] : keyname); + + if (!string.Equals(KEY.Password, realkeyname, StringComparison.InvariantCultureIgnoreCase) && + !string.Equals(SYNONYM.Pwd, realkeyname, StringComparison.InvariantCultureIgnoreCase)) + { + // don't trace passwords ever! + if (null != keyvalue) + { + SqlClientEventSource.Log.AdvancedTraceEvent(" KeyName='{0}', KeyValue='{1}'", keyname, keyvalue); + } + else + { + SqlClientEventSource.Log.AdvancedTraceEvent(" KeyName='{0}'", keyname); + } + } + } + } + + private static string GetKeyName(StringBuilder buffer) + { + int count = buffer.Length; + while ((0 < count) && char.IsWhiteSpace(buffer[count - 1])) + { + count--; // trailing whitespace + } + return buffer.ToString(0, count).ToLower(CultureInfo.InvariantCulture); + } + + private static string GetKeyValue(StringBuilder buffer, bool trimWhitespace) + { + int count = buffer.Length; + int index = 0; + if (trimWhitespace) + { + while ((index < count) && char.IsWhiteSpace(buffer[index])) + { + index++; // leading whitespace + } + while ((0 < count) && char.IsWhiteSpace(buffer[count - 1])) + { + count--; // trailing whitespace + } + } + return buffer.ToString(index, count - index); + } + + // transition states used for parsing + private enum ParserState + { + NothingYet = 1, //start point + Key, + KeyEqual, + KeyEnd, + UnquotedValue, + DoubleQuoteValue, + DoubleQuoteValueQuote, + SingleQuoteValue, + SingleQuoteValueQuote, + BraceQuoteValue, + BraceQuoteValueQuote, + QuotedValueEnd, + NullTermination, + }; + + internal static int GetKeyValuePair(string connectionString, int currentPosition, StringBuilder buffer, bool useOdbcRules, out string keyname, out string keyvalue) + { + int startposition = currentPosition; + + buffer.Length = 0; + keyname = null; + keyvalue = null; + + char currentChar = '\0'; + + ParserState parserState = ParserState.NothingYet; + int length = connectionString.Length; + for (; currentPosition < length; ++currentPosition) + { + currentChar = connectionString[currentPosition]; + + switch (parserState) + { + case ParserState.NothingYet: // [\\s;]* + if ((';' == currentChar) || char.IsWhiteSpace(currentChar)) + { + continue; + } + if ('\0' == currentChar) + { parserState = ParserState.NullTermination; continue; } + if (char.IsControl(currentChar)) + { throw ADP.ConnectionStringSyntax(startposition); } + startposition = currentPosition; + if ('=' != currentChar) + { + parserState = ParserState.Key; + break; + } + else + { + parserState = ParserState.KeyEqual; + continue; + } + + case ParserState.Key: // (?([^=\\s\\p{Cc}]|\\s+[^=\\s\\p{Cc}]|\\s+==|==)+) + if ('=' == currentChar) + { parserState = ParserState.KeyEqual; continue; } + if (char.IsWhiteSpace(currentChar)) + { break; } + if (char.IsControl(currentChar)) + { throw ADP.ConnectionStringSyntax(startposition); } + break; + + case ParserState.KeyEqual: // \\s*=(?!=)\\s* + if (!useOdbcRules && '=' == currentChar) + { parserState = ParserState.Key; break; } + keyname = GetKeyName(buffer); + if (string.IsNullOrEmpty(keyname)) + { throw ADP.ConnectionStringSyntax(startposition); } + buffer.Length = 0; + parserState = ParserState.KeyEnd; + goto case ParserState.KeyEnd; + + case ParserState.KeyEnd: + if (char.IsWhiteSpace(currentChar)) + { continue; } + if (useOdbcRules) + { + if ('{' == currentChar) + { parserState = ParserState.BraceQuoteValue; break; } + } + else + { + if ('\'' == currentChar) + { parserState = ParserState.SingleQuoteValue; continue; } + if ('"' == currentChar) + { parserState = ParserState.DoubleQuoteValue; continue; } + } + if (';' == currentChar) + { goto ParserExit; } + if ('\0' == currentChar) + { goto ParserExit; } + if (char.IsControl(currentChar)) + { throw ADP.ConnectionStringSyntax(startposition); } + parserState = ParserState.UnquotedValue; + break; + + case ParserState.UnquotedValue: // "((?![\"'\\s])" + "([^;\\s\\p{Cc}]|\\s+[^;\\s\\p{Cc}])*" + "(? SplitConnectionString(string connectionString, Dictionary synonyms, bool firstKey) + { + var parsetable = new Dictionary(); + Regex parser = (firstKey ? s_connectionStringRegexOdbc : s_connectionStringRegex); + + const int KeyIndex = 1, ValueIndex = 2; + Debug.Assert(KeyIndex == parser.GroupNumberFromName("key"), "wrong key index"); + Debug.Assert(ValueIndex == parser.GroupNumberFromName("value"), "wrong value index"); + + if (null != connectionString) + { + Match match = parser.Match(connectionString); + if (!match.Success || (match.Length != connectionString.Length)) + { + throw ADP.ConnectionStringSyntax(match.Length); + } + int indexValue = 0; + CaptureCollection keyvalues = match.Groups[ValueIndex].Captures; + foreach (Capture keypair in match.Groups[KeyIndex].Captures) + { + string keyname = (firstKey ? keypair.Value : keypair.Value.Replace("==", "=")).ToLower(CultureInfo.InvariantCulture); + string keyvalue = keyvalues[indexValue++].Value; + if (0 < keyvalue.Length) + { + if (!firstKey) + { + switch (keyvalue[0]) + { + case '\"': + keyvalue = keyvalue.Substring(1, keyvalue.Length - 2).Replace("\"\"", "\""); + break; + case '\'': + keyvalue = keyvalue.Substring(1, keyvalue.Length - 2).Replace("\'\'", "\'"); + break; + default: + break; + } + } + } + else + { + keyvalue = null; + } + DebugTraceKeyValuePair(keyname, keyvalue, synonyms); + string synonym; + string realkeyname = null != synonyms ? + (synonyms.TryGetValue(keyname, out synonym) ? synonym : null) : keyname; + + if (!IsKeyNameValid(realkeyname)) + { + throw ADP.KeywordNotSupported(keyname); + } + if (!firstKey || !parsetable.ContainsKey(realkeyname)) + { + parsetable[realkeyname] = keyvalue; // last key-value pair wins (or first) + } + } + } + return parsetable; + } + + private static void ParseComparison(Dictionary parsetable, string connectionString, Dictionary synonyms, bool firstKey, Exception e) + { + try + { + var parsedvalues = SplitConnectionString(connectionString, synonyms, firstKey); + foreach (var entry in parsedvalues) + { + string keyname = entry.Key; + string value1 = entry.Value; + string value2; + bool parsetableContainsKey = parsetable.TryGetValue(keyname, out value2); + Debug.Assert(parsetableContainsKey, $"{nameof(ParseInternal)} code vs. regex mismatch keyname <{keyname}>"); + Debug.Assert(value1 == value2, $"{nameof(ParseInternal)} code vs. regex mismatch keyvalue <{value1}> <{value2}>"); + } + } + catch (ArgumentException f) + { + if (null != e) + { + string msg1 = e.Message; + string msg2 = f.Message; + + const string KeywordNotSupportedMessagePrefix = "Keyword not supported:"; + const string WrongFormatMessagePrefix = "Format of the initialization string"; + bool isEquivalent = (msg1 == msg2); + if (!isEquivalent) + { + // We also accept cases were Regex parser (debug only) reports "wrong format" and + // retail parsing code reports format exception in different location or "keyword not supported" + if (msg2.StartsWith(WrongFormatMessagePrefix, StringComparison.Ordinal)) + { + if (msg1.StartsWith(KeywordNotSupportedMessagePrefix, StringComparison.Ordinal) || msg1.StartsWith(WrongFormatMessagePrefix, StringComparison.Ordinal)) + { + isEquivalent = true; + } + } + } + Debug.Assert(isEquivalent, "ParseInternal code vs regex message mismatch: <" + msg1 + "> <" + msg2 + ">"); + } + else + { + Debug.Fail("ParseInternal code vs regex throw mismatch " + f.Message); + } + e = null; + } + if (null != e) + { + Debug.Fail("ParseInternal code threw exception vs regex mismatch"); + } + } +#endif + + private static NameValuePair ParseInternal(Dictionary parsetable, string connectionString, bool buildChain, Dictionary synonyms, bool firstKey) + { + Debug.Assert(null != connectionString, "null connectionstring"); + StringBuilder buffer = new StringBuilder(); + NameValuePair localKeychain = null, keychain = null; +#if DEBUG + try + { +#endif + int nextStartPosition = 0; + int endPosition = connectionString.Length; + while (nextStartPosition < endPosition) + { + int startPosition = nextStartPosition; + + string keyname, keyvalue; + nextStartPosition = GetKeyValuePair(connectionString, startPosition, buffer, firstKey, out keyname, out keyvalue); + if (string.IsNullOrEmpty(keyname)) + { + // if (nextStartPosition != endPosition) { throw; } + break; + } +#if DEBUG + DebugTraceKeyValuePair(keyname, keyvalue, synonyms); +#endif + Debug.Assert(IsKeyNameValid(keyname), "ParseFailure, invalid keyname"); + Debug.Assert(IsValueValidInternal(keyvalue), "parse failure, invalid keyvalue"); + + string realkeyname = (synonyms is not null) ? + (synonyms.TryGetValue(keyname, out string synonym) ? synonym : null) : + keyname; + + if (!IsKeyNameValid(realkeyname)) + { + throw ADP.KeywordNotSupported(keyname); + } + if (!firstKey || !parsetable.ContainsKey(realkeyname)) + { + parsetable[realkeyname] = keyvalue; // last key-value pair wins (or first) + } + + if (null != localKeychain) + { + localKeychain = localKeychain.Next = new NameValuePair(realkeyname, keyvalue, nextStartPosition - startPosition); + } + else if (buildChain) + { // first time only - don't contain modified chain from UDL file + keychain = localKeychain = new NameValuePair(realkeyname, keyvalue, nextStartPosition - startPosition); + } + } +#if DEBUG + } + catch (ArgumentException e) + { + ParseComparison(parsetable, connectionString, synonyms, firstKey, e); + throw; + } + ParseComparison(parsetable, connectionString, synonyms, firstKey, null); +#endif + return keychain; + } + + internal NameValuePair ReplacePasswordPwd(out string constr, bool fakePassword) + { + bool expanded = false; + int copyPosition = 0; + NameValuePair head = null, tail = null, next = null; + StringBuilder builder = new StringBuilder(_usersConnectionString.Length); + for (NameValuePair current = _keyChain; null != current; current = current.Next) + { + if (!string.Equals(KEY.Password, current.Name, StringComparison.InvariantCultureIgnoreCase) && + !string.Equals(SYNONYM.Pwd, current.Name, StringComparison.InvariantCultureIgnoreCase)) + { + builder.Append(_usersConnectionString, copyPosition, current.Length); + if (fakePassword) + { + next = new NameValuePair(current.Name, current.Value, current.Length); + } + } + else if (fakePassword) + { + // replace user password/pwd value with * + const string equalstar = "=*;"; + builder.Append(current.Name).Append(equalstar); + next = new NameValuePair(current.Name, "*", current.Name.Length + equalstar.Length); + expanded = true; + } + else + { + // drop the password/pwd completely in returning for user + expanded = true; + } + + if (fakePassword) + { + if (null != tail) + { + tail = tail.Next = next; + } + else + { + tail = head = next; + } + } + copyPosition += current.Length; + } + Debug.Assert(expanded, "password/pwd was not removed"); + constr = builder.ToString(); + return head; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionPoolKey.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionPoolKey.cs new file mode 100644 index 0000000000..7d2799289f --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionPoolKey.cs @@ -0,0 +1,58 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.Data.Common +{ + // DbConnectionPoolKey: Base class implementation of a key to connection pool groups + // Only connection string is used as a key + internal class DbConnectionPoolKey : ICloneable + { + private string _connectionString; + + internal DbConnectionPoolKey(string connectionString) + { + _connectionString = connectionString; + } + + protected DbConnectionPoolKey(DbConnectionPoolKey key) + { + _connectionString = key.ConnectionString; + } + + public virtual object Clone() + { + return new DbConnectionPoolKey(this); + } + + internal virtual string ConnectionString + { + get + { + return _connectionString; + } + + set + { + _connectionString = value; + } + } + + public override bool Equals(object obj) + { + if (obj == null) + { + return false; + } + + return (obj is DbConnectionPoolKey key && _connectionString == key._connectionString); + } + + public override int GetHashCode() + { + return _connectionString == null ? 0 : _connectionString.GetHashCode(); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionStringCommon.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionStringCommon.cs new file mode 100644 index 0000000000..ee6fafa0f0 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/DbConnectionStringCommon.cs @@ -0,0 +1,1153 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using Microsoft.Data.SqlClient; + +namespace Microsoft.Data.Common +{ + internal static class DbConnectionStringBuilderUtil + { + internal static bool ConvertToBoolean(object value) + { + Debug.Assert(null != value, "ConvertToBoolean(null)"); + if (value is string svalue) + { + if (StringComparer.OrdinalIgnoreCase.Equals(svalue, "true") || StringComparer.OrdinalIgnoreCase.Equals(svalue, "yes")) + return true; + else if (StringComparer.OrdinalIgnoreCase.Equals(svalue, "false") || StringComparer.OrdinalIgnoreCase.Equals(svalue, "no")) + return false; + else + { + string tmp = svalue.Trim(); // Remove leading & trailing white space. + if (StringComparer.OrdinalIgnoreCase.Equals(tmp, "true") || StringComparer.OrdinalIgnoreCase.Equals(tmp, "yes")) + return true; + else if (StringComparer.OrdinalIgnoreCase.Equals(tmp, "false") || StringComparer.OrdinalIgnoreCase.Equals(tmp, "no")) + return false; + } + return bool.Parse(svalue); + } + try + { + return Convert.ToBoolean(value, CultureInfo.InvariantCulture); + } + catch (InvalidCastException e) + { + throw ADP.ConvertFailed(value.GetType(), typeof(bool), e); + } + } + + internal static bool ConvertToIntegratedSecurity(object value) + { + Debug.Assert(null != value, "ConvertToIntegratedSecurity(null)"); + if (value is string svalue) + { + if (StringComparer.OrdinalIgnoreCase.Equals(svalue, "sspi") || StringComparer.OrdinalIgnoreCase.Equals(svalue, "true") || StringComparer.OrdinalIgnoreCase.Equals(svalue, "yes")) + return true; + else if (StringComparer.OrdinalIgnoreCase.Equals(svalue, "false") || StringComparer.OrdinalIgnoreCase.Equals(svalue, "no")) + return false; + else + { + string tmp = svalue.Trim(); // Remove leading & trailing white space. + if (StringComparer.OrdinalIgnoreCase.Equals(tmp, "sspi") || StringComparer.OrdinalIgnoreCase.Equals(tmp, "true") || StringComparer.OrdinalIgnoreCase.Equals(tmp, "yes")) + return true; + else if (StringComparer.OrdinalIgnoreCase.Equals(tmp, "false") || StringComparer.OrdinalIgnoreCase.Equals(tmp, "no")) + return false; + } + return bool.Parse(svalue); + } + try + { + return Convert.ToBoolean(value, CultureInfo.InvariantCulture); + } + catch (InvalidCastException e) + { + throw ADP.ConvertFailed(value.GetType(), typeof(bool), e); + } + } + + internal static int ConvertToInt32(object value) + { + try + { + return Convert.ToInt32(value, CultureInfo.InvariantCulture); + } + catch (InvalidCastException e) + { + throw ADP.ConvertFailed(value.GetType(), typeof(int), e); + } + } + + internal static string ConvertToString(object value) + { + try + { + return Convert.ToString(value, CultureInfo.InvariantCulture); + } + catch (InvalidCastException e) + { + throw ADP.ConvertFailed(value.GetType(), typeof(string), e); + } + } + + #region <> + internal static bool TryConvertToPoolBlockingPeriod(string value, out PoolBlockingPeriod result) + { + Debug.Assert(Enum.GetNames(typeof(PoolBlockingPeriod)).Length == 3, "PoolBlockingPeriod enum has changed, update needed"); + Debug.Assert(null != value, "TryConvertToPoolBlockingPeriod(null,...)"); + + if (StringComparer.OrdinalIgnoreCase.Equals(value, nameof(PoolBlockingPeriod.Auto))) + { + result = PoolBlockingPeriod.Auto; + return true; + } + else if (StringComparer.OrdinalIgnoreCase.Equals(value, nameof(PoolBlockingPeriod.AlwaysBlock))) + { + result = PoolBlockingPeriod.AlwaysBlock; + return true; + } + else if (StringComparer.OrdinalIgnoreCase.Equals(value, nameof(PoolBlockingPeriod.NeverBlock))) + { + result = PoolBlockingPeriod.NeverBlock; + return true; + } + else + { + result = DbConnectionStringDefaults.PoolBlockingPeriod; + return false; + } + } + + internal static bool IsValidPoolBlockingPeriodValue(PoolBlockingPeriod value) + { + Debug.Assert(Enum.GetNames(typeof(PoolBlockingPeriod)).Length == 3, "PoolBlockingPeriod enum has changed, update needed"); + return value == PoolBlockingPeriod.Auto || value == PoolBlockingPeriod.AlwaysBlock || value == PoolBlockingPeriod.NeverBlock; + } + + internal static string PoolBlockingPeriodToString(PoolBlockingPeriod value) + { + Debug.Assert(IsValidPoolBlockingPeriodValue(value)); + + return value switch + { + PoolBlockingPeriod.AlwaysBlock => nameof(PoolBlockingPeriod.AlwaysBlock), + PoolBlockingPeriod.NeverBlock => nameof(PoolBlockingPeriod.NeverBlock), + _ => nameof(PoolBlockingPeriod.Auto), + }; + } + + /// + /// This method attempts to convert the given value to a PoolBlockingPeriod enum. The algorithm is: + /// * if the value is from type string, it will be matched against PoolBlockingPeriod enum names only, using ordinal, case-insensitive comparer + /// * if the value is from type PoolBlockingPeriod, it will be used as is + /// * if the value is from integral type (SByte, Int16, Int32, Int64, Byte, UInt16, UInt32, or UInt64), it will be converted to enum + /// * if the value is another enum or any other type, it will be blocked with an appropriate ArgumentException + /// + /// in any case above, if the converted value is out of valid range, the method raises ArgumentOutOfRangeException. + /// + /// PoolBlockingPeriod value in the valid range + internal static PoolBlockingPeriod ConvertToPoolBlockingPeriod(string keyword, object value) + { + Debug.Assert(null != value, "ConvertToPoolBlockingPeriod(null)"); + if (value is string sValue) + { + // We could use Enum.TryParse here, but it accepts value combinations like + // "ReadOnly, ReadWrite" which are unwelcome here + // Also, Enum.TryParse is 100x slower than plain StringComparer.OrdinalIgnoreCase.Equals method. + + if (TryConvertToPoolBlockingPeriod(sValue, out PoolBlockingPeriod result)) + { + return result; + } + + // try again after remove leading & trailing whitespaces. + sValue = sValue.Trim(); + if (TryConvertToPoolBlockingPeriod(sValue, out result)) + { + return result; + } + + // string values must be valid + throw ADP.InvalidConnectionOptionValue(keyword); + } + else + { + // the value is not string, try other options + PoolBlockingPeriod eValue; + + if (value is PoolBlockingPeriod period) + { + // quick path for the most common case + eValue = period; + } + else if (value.GetType().IsEnum) + { + // explicitly block scenarios in which user tries to use wrong enum types, like: + // builder["PoolBlockingPeriod"] = EnvironmentVariableTarget.Process; + // workaround: explicitly cast non-PoolBlockingPeriod enums to int + throw ADP.ConvertFailed(value.GetType(), typeof(PoolBlockingPeriod), null); + } + else + { + try + { + // Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest + eValue = (PoolBlockingPeriod)Enum.ToObject(typeof(PoolBlockingPeriod), value); + } + catch (ArgumentException e) + { + // to be consistent with the messages we send in case of wrong type usage, replace + // the error with our exception, and keep the original one as inner one for troubleshooting + throw ADP.ConvertFailed(value.GetType(), typeof(PoolBlockingPeriod), e); + } + } + + // ensure value is in valid range + if (IsValidPoolBlockingPeriodValue(eValue)) + { + return eValue; + } + else + { + throw ADP.InvalidEnumerationValue(typeof(ApplicationIntent), (int)eValue); + } + } + } + #endregion + + internal static bool TryConvertToApplicationIntent(string value, out ApplicationIntent result) + { + Debug.Assert(Enum.GetNames(typeof(ApplicationIntent)).Length == 2, "ApplicationIntent enum has changed, update needed"); + Debug.Assert(null != value, "TryConvertToApplicationIntent(null,...)"); + + if (StringComparer.OrdinalIgnoreCase.Equals(value, nameof(ApplicationIntent.ReadOnly))) + { + result = ApplicationIntent.ReadOnly; + return true; + } + else if (StringComparer.OrdinalIgnoreCase.Equals(value, nameof(ApplicationIntent.ReadWrite))) + { + result = ApplicationIntent.ReadWrite; + return true; + } + else + { + result = DbConnectionStringDefaults.ApplicationIntent; + return false; + } + } + + internal static bool IsValidApplicationIntentValue(ApplicationIntent value) + { + Debug.Assert(Enum.GetNames(typeof(ApplicationIntent)).Length == 2, "ApplicationIntent enum has changed, update needed"); + return value == ApplicationIntent.ReadOnly || value == ApplicationIntent.ReadWrite; + } + + internal static string ApplicationIntentToString(ApplicationIntent value) + { + Debug.Assert(IsValidApplicationIntentValue(value)); + if (value == ApplicationIntent.ReadOnly) + { + return nameof(ApplicationIntent.ReadOnly); + } + else + { + return nameof(ApplicationIntent.ReadWrite); + } + } + + /// + /// This method attempts to convert the given value tp ApplicationIntent enum. The algorithm is: + /// * if the value is from type string, it will be matched against ApplicationIntent enum names only, using ordinal, case-insensitive comparer + /// * if the value is from type ApplicationIntent, it will be used as is + /// * if the value is from integral type (SByte, Int16, Int32, Int64, Byte, UInt16, UInt32, or UInt64), it will be converted to enum + /// * if the value is another enum or any other type, it will be blocked with an appropriate ArgumentException + /// + /// in any case above, if the converted value is out of valid range, the method raises ArgumentOutOfRangeException. + /// + /// application intent value in the valid range + internal static ApplicationIntent ConvertToApplicationIntent(string keyword, object value) + { + Debug.Assert(null != value, "ConvertToApplicationIntent(null)"); + if (value is string sValue) + { + // We could use Enum.TryParse here, but it accepts value combinations like + // "ReadOnly, ReadWrite" which are unwelcome here + // Also, Enum.TryParse is 100x slower than plain StringComparer.OrdinalIgnoreCase.Equals method. + + if (TryConvertToApplicationIntent(sValue, out ApplicationIntent result)) + { + return result; + } + + // try again after remove leading & trailing whitespaces. + sValue = sValue.Trim(); + if (TryConvertToApplicationIntent(sValue, out result)) + { + return result; + } + + // string values must be valid + throw ADP.InvalidConnectionOptionValue(keyword); + } + else + { + // the value is not string, try other options + ApplicationIntent eValue; + + if (value is ApplicationIntent intent) + { + // quick path for the most common case + eValue = intent; + } + else if (value.GetType().IsEnum) + { + // explicitly block scenarios in which user tries to use wrong enum types, like: + // builder["ApplicationIntent"] = EnvironmentVariableTarget.Process; + // workaround: explicitly cast non-ApplicationIntent enums to int + throw ADP.ConvertFailed(value.GetType(), typeof(ApplicationIntent), null); + } + else + { + try + { + // Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest + eValue = (ApplicationIntent)Enum.ToObject(typeof(ApplicationIntent), value); + } + catch (ArgumentException e) + { + // to be consistent with the messages we send in case of wrong type usage, replace + // the error with our exception, and keep the original one as inner one for troubleshooting + throw ADP.ConvertFailed(value.GetType(), typeof(ApplicationIntent), e); + } + } + + // ensure value is in valid range + if (IsValidApplicationIntentValue(eValue)) + { + return eValue; + } + else + { + throw ADP.InvalidEnumerationValue(typeof(ApplicationIntent), (int)eValue); + } + } + } + + const string SqlPasswordString = "Sql Password"; + const string ActiveDirectoryPasswordString = "Active Directory Password"; + const string ActiveDirectoryIntegratedString = "Active Directory Integrated"; + const string ActiveDirectoryInteractiveString = "Active Directory Interactive"; + const string ActiveDirectoryServicePrincipalString = "Active Directory Service Principal"; + const string ActiveDirectoryDeviceCodeFlowString = "Active Directory Device Code Flow"; + internal const string ActiveDirectoryManagedIdentityString = "Active Directory Managed Identity"; + internal const string ActiveDirectoryMSIString = "Active Directory MSI"; + internal const string ActiveDirectoryDefaultString = "Active Directory Default"; + const string SqlCertificateString = "Sql Certificate"; + +#if DEBUG + private static readonly string[] s_supportedAuthenticationModes = + { + "NotSpecified", + "SqlPassword", + "ActiveDirectoryPassword", + "ActiveDirectoryIntegrated", + "ActiveDirectoryInteractive", + "ActiveDirectoryServicePrincipal", + "ActiveDirectoryDeviceCodeFlow", + "ActiveDirectoryManagedIdentity", + "ActiveDirectoryMSI", + "ActiveDirectoryDefault" + }; + + private static bool IsValidAuthenticationMethodEnum() + { + string[] names = Enum.GetNames(typeof(SqlAuthenticationMethod)); + int l = s_supportedAuthenticationModes.Length; + bool listValid; + if (listValid = names.Length == l) + { + for (int i = 0; i < l; i++) + { + if (s_supportedAuthenticationModes[i].CompareTo(names[i]) != 0) + { + listValid = false; + } + } + } + return listValid; + } +#endif + + internal static bool TryConvertToAuthenticationType(string value, out SqlAuthenticationMethod result) + { +#if DEBUG + Debug.Assert(IsValidAuthenticationMethodEnum(), "SqlAuthenticationMethod enum has changed, update needed"); +#endif + bool isSuccess = false; + + if (StringComparer.InvariantCultureIgnoreCase.Equals(value, SqlPasswordString) + || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.SqlPassword, CultureInfo.InvariantCulture))) + { + result = SqlAuthenticationMethod.SqlPassword; + isSuccess = true; + } + else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryPasswordString) + || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryPassword, CultureInfo.InvariantCulture))) + { + result = SqlAuthenticationMethod.ActiveDirectoryPassword; + isSuccess = true; + } + else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryIntegratedString) + || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryIntegrated, CultureInfo.InvariantCulture))) + { + result = SqlAuthenticationMethod.ActiveDirectoryIntegrated; + isSuccess = true; + } + else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryInteractiveString) + || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryInteractive, CultureInfo.InvariantCulture))) + { + result = SqlAuthenticationMethod.ActiveDirectoryInteractive; + isSuccess = true; + } + else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryServicePrincipalString) + || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryServicePrincipal, CultureInfo.InvariantCulture))) + { + result = SqlAuthenticationMethod.ActiveDirectoryServicePrincipal; + isSuccess = true; + } + else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryDeviceCodeFlowString) + || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow, CultureInfo.InvariantCulture))) + { + result = SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow; + isSuccess = true; + } + else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryManagedIdentityString) + || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryManagedIdentity, CultureInfo.InvariantCulture))) + { + result = SqlAuthenticationMethod.ActiveDirectoryManagedIdentity; + isSuccess = true; + } + else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryMSIString) + || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryMSI, CultureInfo.InvariantCulture))) + { + result = SqlAuthenticationMethod.ActiveDirectoryMSI; + isSuccess = true; + } + else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryDefaultString) + || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryDefault, CultureInfo.InvariantCulture))) + { + result = SqlAuthenticationMethod.ActiveDirectoryDefault; + isSuccess = true; + } +#if ADONET_CERT_AUTH && NETFRAMEWORK + else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, SqlCertificateString) + || StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.SqlCertificate, CultureInfo.InvariantCulture))) { + result = SqlAuthenticationMethod.SqlCertificate; + isSuccess = true; + } +#endif + else + { + result = DbConnectionStringDefaults.Authentication; + } + return isSuccess; + } + + /// + /// Convert a string value to the corresponding SqlConnectionColumnEncryptionSetting. + /// + /// + /// + /// + internal static bool TryConvertToColumnEncryptionSetting(string value, out SqlConnectionColumnEncryptionSetting result) + { + bool isSuccess = false; + + if (StringComparer.InvariantCultureIgnoreCase.Equals(value, nameof(SqlConnectionColumnEncryptionSetting.Enabled))) + { + result = SqlConnectionColumnEncryptionSetting.Enabled; + isSuccess = true; + } + else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, nameof(SqlConnectionColumnEncryptionSetting.Disabled))) + { + result = SqlConnectionColumnEncryptionSetting.Disabled; + isSuccess = true; + } + else + { + result = DbConnectionStringDefaults.ColumnEncryptionSetting; + } + + return isSuccess; + } + + /// + /// Is it a valid connection level column encryption setting ? + /// + /// + /// + internal static bool IsValidColumnEncryptionSetting(SqlConnectionColumnEncryptionSetting value) + { + Debug.Assert(Enum.GetNames(typeof(SqlConnectionColumnEncryptionSetting)).Length == 2, "SqlConnectionColumnEncryptionSetting enum has changed, update needed"); + return value == SqlConnectionColumnEncryptionSetting.Enabled || value == SqlConnectionColumnEncryptionSetting.Disabled; + } + + /// + /// Convert connection level column encryption setting value to string. + /// + /// + /// + internal static string ColumnEncryptionSettingToString(SqlConnectionColumnEncryptionSetting value) + { + Debug.Assert(IsValidColumnEncryptionSetting(value), "value is not a valid connection level column encryption setting."); + + return value switch + { + SqlConnectionColumnEncryptionSetting.Enabled => nameof(SqlConnectionColumnEncryptionSetting.Enabled), + SqlConnectionColumnEncryptionSetting.Disabled => nameof(SqlConnectionColumnEncryptionSetting.Disabled), + _ => null, + }; + } + + internal static bool IsValidAuthenticationTypeValue(SqlAuthenticationMethod value) + { + Debug.Assert(Enum.GetNames(typeof(SqlAuthenticationMethod)).Length == 10, "SqlAuthenticationMethod enum has changed, update needed"); + return value == SqlAuthenticationMethod.SqlPassword + || value == SqlAuthenticationMethod.ActiveDirectoryPassword + || value == SqlAuthenticationMethod.ActiveDirectoryIntegrated + || value == SqlAuthenticationMethod.ActiveDirectoryInteractive + || value == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal + || value == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow + || value == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity + || value == SqlAuthenticationMethod.ActiveDirectoryMSI + || value == SqlAuthenticationMethod.ActiveDirectoryDefault +#if ADONET_CERT_AUTH && NETFRAMEWORK + || value == SqlAuthenticationMethod.SqlCertificate +#endif + || value == SqlAuthenticationMethod.NotSpecified; + } + + internal static string AuthenticationTypeToString(SqlAuthenticationMethod value) + { + Debug.Assert(IsValidAuthenticationTypeValue(value)); + + return value switch + { + SqlAuthenticationMethod.SqlPassword => SqlPasswordString, + SqlAuthenticationMethod.ActiveDirectoryPassword => ActiveDirectoryPasswordString, + SqlAuthenticationMethod.ActiveDirectoryIntegrated => ActiveDirectoryIntegratedString, + SqlAuthenticationMethod.ActiveDirectoryInteractive => ActiveDirectoryInteractiveString, + SqlAuthenticationMethod.ActiveDirectoryServicePrincipal => ActiveDirectoryServicePrincipalString, + SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow => ActiveDirectoryDeviceCodeFlowString, + SqlAuthenticationMethod.ActiveDirectoryManagedIdentity => ActiveDirectoryManagedIdentityString, + SqlAuthenticationMethod.ActiveDirectoryMSI => ActiveDirectoryMSIString, + SqlAuthenticationMethod.ActiveDirectoryDefault => ActiveDirectoryDefaultString, +#if ADONET_CERT_AUTH && NETFRAMEWORK + SqlAuthenticationMethod.SqlCertificate => SqlCertificateString, +#endif + _ => null + }; + } + + internal static SqlAuthenticationMethod ConvertToAuthenticationType(string keyword, object value) + { + if (null == value) + { + return DbConnectionStringDefaults.Authentication; + } + + if (value is string sValue) + { + if (TryConvertToAuthenticationType(sValue, out SqlAuthenticationMethod result)) + { + return result; + } + + // try again after remove leading & trailing whitespaces. + sValue = sValue.Trim(); + if (TryConvertToAuthenticationType(sValue, out result)) + { + return result; + } + + // string values must be valid + throw ADP.InvalidConnectionOptionValue(keyword); + } + else + { + // the value is not string, try other options + SqlAuthenticationMethod eValue; + + if (value is SqlAuthenticationMethod method) + { + // quick path for the most common case + eValue = method; + } + else if (value.GetType().IsEnum) + { + // explicitly block scenarios in which user tries to use wrong enum types, like: + // builder["ApplicationIntent"] = EnvironmentVariableTarget.Process; + // workaround: explicitly cast non-ApplicationIntent enums to int + throw ADP.ConvertFailed(value.GetType(), typeof(SqlAuthenticationMethod), null); + } + else + { + try + { + // Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest + eValue = (SqlAuthenticationMethod)Enum.ToObject(typeof(SqlAuthenticationMethod), value); + } + catch (ArgumentException e) + { + // to be consistent with the messages we send in case of wrong type usage, replace + // the error with our exception, and keep the original one as inner one for troubleshooting + throw ADP.ConvertFailed(value.GetType(), typeof(SqlAuthenticationMethod), e); + } + } + + // ensure value is in valid range + if (IsValidAuthenticationTypeValue(eValue)) + { + return eValue; + } + else + { + throw ADP.InvalidEnumerationValue(typeof(SqlAuthenticationMethod), (int)eValue); + } + } + } + + /// + /// Convert the provided value to a SqlConnectionColumnEncryptionSetting. + /// + /// + /// + /// + internal static SqlConnectionColumnEncryptionSetting ConvertToColumnEncryptionSetting(string keyword, object value) + { + if (null == value) + { + return DbConnectionStringDefaults.ColumnEncryptionSetting; + } + + if (value is string sValue) + { + if (TryConvertToColumnEncryptionSetting(sValue, out SqlConnectionColumnEncryptionSetting result)) + { + return result; + } + + // try again after remove leading & trailing whitespaces. + sValue = sValue.Trim(); + if (TryConvertToColumnEncryptionSetting(sValue, out result)) + { + return result; + } + + // string values must be valid + throw ADP.InvalidConnectionOptionValue(keyword); + } + else + { + // the value is not string, try other options + SqlConnectionColumnEncryptionSetting eValue; + + if (value is SqlConnectionColumnEncryptionSetting setting) + { + // quick path for the most common case + eValue = setting; + } + else if (value.GetType().IsEnum) + { + // explicitly block scenarios in which user tries to use wrong enum types, like: + // builder["SqlConnectionColumnEncryptionSetting"] = EnvironmentVariableTarget.Process; + // workaround: explicitly cast non-SqlConnectionColumnEncryptionSetting enums to int + throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionColumnEncryptionSetting), null); + } + else + { + try + { + // Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest + eValue = (SqlConnectionColumnEncryptionSetting)Enum.ToObject(typeof(SqlConnectionColumnEncryptionSetting), value); + } + catch (ArgumentException e) + { + // to be consistent with the messages we send in case of wrong type usage, replace + // the error with our exception, and keep the original one as inner one for troubleshooting + throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionColumnEncryptionSetting), e); + } + } + + // ensure value is in valid range + if (IsValidColumnEncryptionSetting(eValue)) + { + return eValue; + } + else + { + throw ADP.InvalidEnumerationValue(typeof(SqlConnectionColumnEncryptionSetting), (int)eValue); + } + } + } + + #region <> + /// + /// Convert a string value to the corresponding SqlConnectionAttestationProtocol + /// + /// + /// + /// + internal static bool TryConvertToAttestationProtocol(string value, out SqlConnectionAttestationProtocol result) + { + if (StringComparer.InvariantCultureIgnoreCase.Equals(value, nameof(SqlConnectionAttestationProtocol.HGS))) + { + result = SqlConnectionAttestationProtocol.HGS; + return true; + } + else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, nameof(SqlConnectionAttestationProtocol.AAS))) + { + result = SqlConnectionAttestationProtocol.AAS; + return true; + } + else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, nameof(SqlConnectionAttestationProtocol.None))) + { + result = SqlConnectionAttestationProtocol.None; + return true; + } + else + { + result = DbConnectionStringDefaults.AttestationProtocol; + return false; + } + } + + internal static bool IsValidAttestationProtocol(SqlConnectionAttestationProtocol value) + { + Debug.Assert(Enum.GetNames(typeof(SqlConnectionAttestationProtocol)).Length == 4, "SqlConnectionAttestationProtocol enum has changed, update needed"); + return value == SqlConnectionAttestationProtocol.NotSpecified + || value == SqlConnectionAttestationProtocol.HGS + || value == SqlConnectionAttestationProtocol.AAS + || value == SqlConnectionAttestationProtocol.None; + } + + internal static string AttestationProtocolToString(SqlConnectionAttestationProtocol value) + { + Debug.Assert(IsValidAttestationProtocol(value), "value is not a valid attestation protocol"); + + return value switch + { + SqlConnectionAttestationProtocol.AAS => nameof(SqlConnectionAttestationProtocol.AAS), + SqlConnectionAttestationProtocol.HGS => nameof(SqlConnectionAttestationProtocol.HGS), + SqlConnectionAttestationProtocol.None => nameof(SqlConnectionAttestationProtocol.None), + _ => null + }; + } + + internal static SqlConnectionAttestationProtocol ConvertToAttestationProtocol(string keyword, object value) + { + if (null == value) + { + return DbConnectionStringDefaults.AttestationProtocol; + } + + if (value is string sValue) + { + // try again after remove leading & trailing whitespaces. + sValue = sValue.Trim(); + if (TryConvertToAttestationProtocol(sValue, out SqlConnectionAttestationProtocol result)) + { + return result; + } + + // string values must be valid + throw ADP.InvalidConnectionOptionValue(keyword); + } + else + { + // the value is not string, try other options + SqlConnectionAttestationProtocol eValue; + + if (value is SqlConnectionAttestationProtocol protocol) + { + eValue = protocol; + } + else if (value.GetType().IsEnum) + { + // explicitly block scenarios in which user tries to use wrong enum types, like: + // builder["SqlConnectionAttestationProtocol"] = EnvironmentVariableTarget.Process; + // workaround: explicitly cast non-SqlConnectionAttestationProtocol enums to int + throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionAttestationProtocol), null); + } + else + { + try + { + // Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest + eValue = (SqlConnectionAttestationProtocol)Enum.ToObject(typeof(SqlConnectionAttestationProtocol), value); + } + catch (ArgumentException e) + { + // to be consistent with the messages we send in case of wrong type usage, replace + // the error with our exception, and keep the original one as inner one for troubleshooting + throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionAttestationProtocol), e); + } + } + + if (IsValidAttestationProtocol(eValue)) + { + return eValue; + } + else + { + throw ADP.InvalidEnumerationValue(typeof(SqlConnectionAttestationProtocol), (int)eValue); + } + } + } + + internal static SqlConnectionEncryptOption ConvertToSqlConnectionEncryptOption(string keyword, object value) + { + if (value is null) + { + return DbConnectionStringDefaults.Encrypt; + } + else if (value is string sValue) + { + return SqlConnectionEncryptOption.Parse(sValue); + } + + throw ADP.InvalidConnectionOptionValue(keyword); + } + + #endregion + + #region <> + /// + /// IP Address Preference. + /// + private readonly static Dictionary s_preferenceNames = new(StringComparer.InvariantCultureIgnoreCase); + + static DbConnectionStringBuilderUtil() + { + foreach (SqlConnectionIPAddressPreference item in Enum.GetValues(typeof(SqlConnectionIPAddressPreference))) + { + s_preferenceNames.Add(item.ToString(), item); + } + } + + /// + /// Convert a string value to the corresponding IPAddressPreference. + /// + /// The string representation of the enumeration name to convert. + /// When this method returns, `result` contains an object of type `SqlConnectionIPAddressPreference` whose value is represented by `value` if the operation succeeds. + /// If the parse operation fails, `result` contains the default value of the `SqlConnectionIPAddressPreference` type. + /// `true` if the value parameter was converted successfully; otherwise, `false`. + internal static bool TryConvertToIPAddressPreference(string value, out SqlConnectionIPAddressPreference result) + { + if (!s_preferenceNames.TryGetValue(value, out result)) + { + result = DbConnectionStringDefaults.IPAddressPreference; + return false; + } + return true; + } + + /// + /// Verifies if the `value` is defined in the expected Enum. + /// + internal static bool IsValidIPAddressPreference(SqlConnectionIPAddressPreference value) + => value == SqlConnectionIPAddressPreference.IPv4First + || value == SqlConnectionIPAddressPreference.IPv6First + || value == SqlConnectionIPAddressPreference.UsePlatformDefault; + + internal static string IPAddressPreferenceToString(SqlConnectionIPAddressPreference value) + => Enum.GetName(typeof(SqlConnectionIPAddressPreference), value); + + internal static SqlConnectionIPAddressPreference ConvertToIPAddressPreference(string keyword, object value) + { + if (value is null) + { + return DbConnectionStringDefaults.IPAddressPreference; // IPv4First + } + + if (value is string sValue) + { + // try again after remove leading & trailing whitespaces. + sValue = sValue.Trim(); + if (TryConvertToIPAddressPreference(sValue, out SqlConnectionIPAddressPreference result)) + { + return result; + } + + // string values must be valid + throw ADP.InvalidConnectionOptionValue(keyword); + } + else + { + // the value is not string, try other options + SqlConnectionIPAddressPreference eValue; + + if (value is SqlConnectionIPAddressPreference preference) + { + eValue = preference; + } + else if (value.GetType().IsEnum) + { + // explicitly block scenarios in which user tries to use wrong enum types, like: + // builder["SqlConnectionIPAddressPreference"] = EnvironmentVariableTarget.Process; + // workaround: explicitly cast non-SqlConnectionIPAddressPreference enums to int + throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionIPAddressPreference), null); + } + else + { + try + { + // Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest + eValue = (SqlConnectionIPAddressPreference)Enum.ToObject(typeof(SqlConnectionIPAddressPreference), value); + } + catch (ArgumentException e) + { + // to be consistent with the messages we send in case of wrong type usage, replace + // the error with our exception, and keep the original one as inner one for troubleshooting + throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionIPAddressPreference), e); + } + } + + if (IsValidIPAddressPreference(eValue)) + { + return eValue; + } + else + { + throw ADP.InvalidEnumerationValue(typeof(SqlConnectionIPAddressPreference), (int)eValue); + } + } + } + #endregion + +#if ADONET_CERT_AUTH && NETFRAMEWORK + internal static bool IsValidCertificateValue(string value) => string.IsNullOrEmpty(value) + || value.StartsWith("subject:", StringComparison.OrdinalIgnoreCase) + || value.StartsWith("sha1:", StringComparison.OrdinalIgnoreCase); +#endif + } + + internal static class DbConnectionStringDefaults + { + internal const ApplicationIntent ApplicationIntent = Microsoft.Data.SqlClient.ApplicationIntent.ReadWrite; + internal const string ApplicationName = +#if NETFRAMEWORK + "Framework Microsoft SqlClient Data Provider"; +#else + "Core Microsoft SqlClient Data Provider"; +#endif + internal const string AttachDBFilename = ""; + internal const int CommandTimeout = 30; + internal const int ConnectTimeout = 15; + +#if NETFRAMEWORK + internal const bool ConnectionReset = true; + internal const bool ContextConnection = false; + internal static readonly bool TransparentNetworkIPResolution = !LocalAppContextSwitches.DisableTNIRByDefault; + internal const string NetworkLibrary = ""; +#if ADONET_CERT_AUTH + internal const string Certificate = ""; +#endif +#endif + internal const string CurrentLanguage = ""; + internal const string DataSource = ""; + internal static readonly SqlConnectionEncryptOption Encrypt = SqlConnectionEncryptOption.Mandatory; + internal const string HostNameInCertificate = ""; + internal const bool Enlist = true; + internal const string FailoverPartner = ""; + internal const string InitialCatalog = ""; + internal const bool IntegratedSecurity = false; + internal const int LoadBalanceTimeout = 0; // default of 0 means don't use + internal const bool MultipleActiveResultSets = false; + internal const bool MultiSubnetFailover = false; + internal const int MaxPoolSize = 100; + internal const int MinPoolSize = 0; + internal const int PacketSize = 8000; + internal const string Password = ""; + internal const bool PersistSecurityInfo = false; + internal const bool Pooling = true; + internal const bool TrustServerCertificate = false; + internal const string TypeSystemVersion = "Latest"; + internal const string UserID = ""; + internal const bool UserInstance = false; + internal const bool Replication = false; + internal const string WorkstationID = ""; + internal const string TransactionBinding = "Implicit Unbind"; + internal const int ConnectRetryCount = 1; + internal const int ConnectRetryInterval = 10; + internal static readonly SqlAuthenticationMethod Authentication = SqlAuthenticationMethod.NotSpecified; + internal const SqlConnectionColumnEncryptionSetting ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Disabled; + internal const string EnclaveAttestationUrl = ""; + internal const SqlConnectionAttestationProtocol AttestationProtocol = SqlConnectionAttestationProtocol.NotSpecified; + internal const SqlConnectionIPAddressPreference IPAddressPreference = SqlConnectionIPAddressPreference.IPv4First; + internal const PoolBlockingPeriod PoolBlockingPeriod = SqlClient.PoolBlockingPeriod.Auto; + internal const string ServerSPN = ""; + internal const string FailoverPartnerSPN = ""; + } + + internal static class DbConnectionStringKeywords + { +#if NETFRAMEWORK + // Odbc + internal const string Driver = "Driver"; + internal const string Dsn = "Dsn"; + internal const string FileDsn = "FileDsn"; + internal const string SaveFile = "SaveFile"; + + // OleDb + internal const string FileName = "File Name"; + internal const string OleDbServices = "OLE DB Services"; + internal const string Provider = "Provider"; + + // OracleClient + internal const string Unicode = "Unicode"; + internal const string OmitOracleConnectionName = "Omit Oracle Connection Name"; + + // SqlClient + internal const string TransparentNetworkIPResolution = "Transparent Network IP Resolution"; + internal const string Certificate = "Certificate"; +#endif + // SqlClient + internal const string ApplicationIntent = "Application Intent"; + internal const string ApplicationName = "Application Name"; + internal const string AttachDBFilename = "AttachDbFilename"; + internal const string ConnectTimeout = "Connect Timeout"; + internal const string CommandTimeout = "Command Timeout"; + internal const string ConnectionReset = "Connection Reset"; + internal const string ContextConnection = "Context Connection"; + internal const string CurrentLanguage = "Current Language"; + internal const string Encrypt = "Encrypt"; + internal const string HostNameInCertificate = "Host Name In Certificate"; + internal const string FailoverPartner = "Failover Partner"; + internal const string InitialCatalog = "Initial Catalog"; + internal const string MultipleActiveResultSets = "Multiple Active Result Sets"; + internal const string MultiSubnetFailover = "Multi Subnet Failover"; + internal const string NetworkLibrary = "Network Library"; + internal const string PacketSize = "Packet Size"; + internal const string Replication = "Replication"; + internal const string TransactionBinding = "Transaction Binding"; + internal const string TrustServerCertificate = "Trust Server Certificate"; + internal const string TypeSystemVersion = "Type System Version"; + internal const string UserInstance = "User Instance"; + internal const string WorkstationID = "Workstation ID"; + internal const string ConnectRetryCount = "Connect Retry Count"; + internal const string ConnectRetryInterval = "Connect Retry Interval"; + internal const string Authentication = "Authentication"; + internal const string ColumnEncryptionSetting = "Column Encryption Setting"; + internal const string EnclaveAttestationUrl = "Enclave Attestation Url"; + internal const string AttestationProtocol = "Attestation Protocol"; + internal const string IPAddressPreference = "IP Address Preference"; + internal const string ServerSPN = "Server SPN"; + internal const string FailoverPartnerSPN = "Failover Partner SPN"; + + // common keywords (OleDb, OracleClient, SqlClient) + internal const string DataSource = "Data Source"; + internal const string IntegratedSecurity = "Integrated Security"; + internal const string Password = "Password"; + internal const string PersistSecurityInfo = "Persist Security Info"; + internal const string UserID = "User ID"; + + // managed pooling (OracleClient, SqlClient) + internal const string Enlist = "Enlist"; + internal const string LoadBalanceTimeout = "Load Balance Timeout"; + internal const string MaxPoolSize = "Max Pool Size"; + internal const string Pooling = "Pooling"; + internal const string MinPoolSize = "Min Pool Size"; + internal const string PoolBlockingPeriod = "Pool Blocking Period"; + } + + internal static class DbConnectionStringSynonyms + { +#if NETFRAMEWORK + //internal const string TransparentNetworkIPResolution = TRANSPARENTNETWORKIPRESOLUTION; + internal const string TRANSPARENTNETWORKIPRESOLUTION = "transparentnetworkipresolution"; +#endif + //internal const string ApplicationName = APP; + internal const string APP = "app"; + + // internal const string IPAddressPreference = IPADDRESSPREFERENCE; + internal const string IPADDRESSPREFERENCE = "ipaddresspreference"; + + //internal const string ApplicationIntent = APPLICATIONINTENT; + internal const string APPLICATIONINTENT = "applicationintent"; + + //internal const string AttachDBFilename = EXTENDEDPROPERTIES+","+INITIALFILENAME; + internal const string EXTENDEDPROPERTIES = "extended properties"; + internal const string INITIALFILENAME = "initial file name"; + + // internal const string HostNameInCertificate = HOSTNAMEINCERTIFICATE; + internal const string HOSTNAMEINCERTIFICATE = "hostnameincertificate"; + + //internal const string ConnectTimeout = CONNECTIONTIMEOUT+","+TIMEOUT; + internal const string CONNECTIONTIMEOUT = "connection timeout"; + internal const string TIMEOUT = "timeout"; + + //internal const string ConnectRetryCount = CONNECTRETRYCOUNT; + internal const string CONNECTRETRYCOUNT = "connectretrycount"; + + //internal const string ConnectRetryInterval = CONNECTRETRYINTERVAL; + internal const string CONNECTRETRYINTERVAL = "connectretryinterval"; + + //internal const string CurrentLanguage = LANGUAGE; + internal const string LANGUAGE = "language"; + + //internal const string OraDataSource = SERVER; + //internal const string SqlDataSource = ADDR+","+ADDRESS+","+SERVER+","+NETWORKADDRESS; + internal const string ADDR = "addr"; + internal const string ADDRESS = "address"; + internal const string SERVER = "server"; + internal const string NETWORKADDRESS = "network address"; + + //internal const string InitialCatalog = DATABASE; + internal const string DATABASE = "database"; + + //internal const string IntegratedSecurity = TRUSTEDCONNECTION; + internal const string TRUSTEDCONNECTION = "trusted_connection"; // underscore introduced in everett + + //internal const string LoadBalanceTimeout = ConnectionLifetime; + internal const string ConnectionLifetime = "connection lifetime"; + + //internal const string MultipleActiveResultSets = MULTIPLEACTIVERESULTSETS; + internal const string MULTIPLEACTIVERESULTSETS = "multipleactiveresultsets"; + + //internal const string MultiSubnetFailover = MULTISUBNETFAILOVER; + internal const string MULTISUBNETFAILOVER = "multisubnetfailover"; + + //internal const string NetworkLibrary = NET+","+NETWORK; + internal const string NET = "net"; + internal const string NETWORK = "network"; + + //internal const string PoolBlockingPeriod = POOLBLOCKINGPERIOD; + internal const string POOLBLOCKINGPERIOD = "poolblockingperiod"; + + //internal const string Password = Pwd; + internal const string Pwd = "pwd"; + + //internal const string PersistSecurityInfo = PERSISTSECURITYINFO; + internal const string PERSISTSECURITYINFO = "persistsecurityinfo"; + + //internal const string TrustServerCertificate = TRUSTSERVERCERTIFICATE; + internal const string TRUSTSERVERCERTIFICATE = "trustservercertificate"; + + //internal const string UserID = UID+","+User; + internal const string UID = "uid"; + internal const string User = "user"; + + //internal const string WorkstationID = WSID; + internal const string WSID = "wsid"; + + //internal const string server SPNs + internal const string ServerSPN = "ServerSPN"; + internal const string FailoverPartnerSPN = "FailoverPartnerSPN"; + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/MultipartIdentifier.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/MultipartIdentifier.cs new file mode 100644 index 0000000000..a30b462092 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/MultipartIdentifier.cs @@ -0,0 +1,291 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Text; + +namespace Microsoft.Data.Common +{ + internal class MultipartIdentifier + { + private const int MaxParts = 4; + internal const int ServerIndex = 0; + internal const int CatalogIndex = 1; + internal const int SchemaIndex = 2; + internal const int TableIndex = 3; + + /* + Left quote strings need to correspond 1 to 1 with the right quote strings + example: "ab" "cd", passed in for the left and the right quote + would set a or b as a starting quote character. + If a is the starting quote char then c would be the ending quote char + otherwise if b is the starting quote char then d would be the ending quote character. + */ + internal static string[] ParseMultipartIdentifier(string name, string leftQuote, string rightQuote, string property, bool ThrowOnEmptyMultipartName) + { + return ParseMultipartIdentifier(name, leftQuote, rightQuote, '.', MaxParts, true, property, ThrowOnEmptyMultipartName); + } + + private enum MPIState + { + MPI_Value, + MPI_ParseNonQuote, + MPI_LookForSeparator, + MPI_LookForNextCharOrSeparator, + MPI_ParseQuote, + MPI_RightQuote, + } + + /* Core function for parsing the multipart identifier string. + * parameters: name - string to parse + * leftquote: set of characters which are valid quoting characters to initiate a quote + * rightquote: set of characters which are valid to stop a quote, array index's correspond to the leftquote array. + * separator: separator to use + * limit: number of names to parse out + * removequote:to remove the quotes on the returned string + */ + private static void IncrementStringCount(string name, string[] ary, ref int position, string property) + { + ++position; + int limit = ary.Length; + if (position >= limit) + { + throw ADP.InvalidMultipartNameToManyParts(property, name, limit); + } + ary[position] = string.Empty; + } + + private static bool IsWhitespace(char ch) + { + return char.IsWhiteSpace(ch); + } + + internal static string[] ParseMultipartIdentifier(string name, string leftQuote, string rightQuote, char separator, int limit, bool removequotes, string property, bool ThrowOnEmptyMultipartName) + { + if (limit <= 0) + { + throw ADP.InvalidMultipartNameToManyParts(property, name, limit); + } + + if (-1 != leftQuote.IndexOf(separator) || -1 != rightQuote.IndexOf(separator) || leftQuote.Length != rightQuote.Length) + { + throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name); + } + + string[] parsedNames = new string[limit]; // return string array + int stringCount = 0; // index of current string in the buffer + MPIState state = MPIState.MPI_Value; // Initialize the starting state + + StringBuilder sb = new StringBuilder(name.Length); // String buffer to hold the string being currently built, init the string builder so it will never be resized + StringBuilder whitespaceSB = null; // String buffer to hold whitespace used when parsing nonquoted strings 'a b . c d' = 'a b' and 'c d' + char rightQuoteChar = ' '; // Right quote character to use given the left quote character found. + for (int index = 0; index < name.Length; ++index) + { + char testchar = name[index]; + switch (state) + { + case MPIState.MPI_Value: + { + int quoteIndex; + if (IsWhitespace(testchar)) + { // Is White Space then skip the whitespace + continue; + } + else + if (testchar == separator) + { // If we found a separator, no string was found, initialize the string we are parsing to Empty and the next one to Empty. + // This is NOT a redundant setting of string.Empty it solves the case where we are parsing ".foo" and we should be returning null, null, empty, foo + parsedNames[stringCount] = string.Empty; + IncrementStringCount(name, parsedNames, ref stringCount, property); + } + else + if (-1 != (quoteIndex = leftQuote.IndexOf(testchar))) + { // If we are a left quote + rightQuoteChar = rightQuote[quoteIndex]; // record the corresponding right quote for the left quote + sb.Length = 0; + if (!removequotes) + { + sb.Append(testchar); + } + state = MPIState.MPI_ParseQuote; + } + else + if (-1 != rightQuote.IndexOf(testchar)) + { // If we shouldn't see a right quote + throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name); + } + else + { + sb.Length = 0; + sb.Append(testchar); + state = MPIState.MPI_ParseNonQuote; + } + break; + } + + case MPIState.MPI_ParseNonQuote: + { + if (testchar == separator) + { + parsedNames[stringCount] = sb.ToString(); // set the currently parsed string + IncrementStringCount(name, parsedNames, ref stringCount, property); + state = MPIState.MPI_Value; + } + else // Quotes are not valid inside a non-quoted name + if (-1 != rightQuote.IndexOf(testchar)) + { + throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name); + } + else + if (-1 != leftQuote.IndexOf(testchar)) + { + throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name); + } + else + if (IsWhitespace(testchar)) + { // If it is Whitespace + parsedNames[stringCount] = sb.ToString(); // Set the currently parsed string + if (null == whitespaceSB) + { + whitespaceSB = new StringBuilder(); + } + whitespaceSB.Length = 0; + whitespaceSB.Append(testchar); // start to record the whitespace, if we are parsing a name like "foo bar" we should return "foo bar" + state = MPIState.MPI_LookForNextCharOrSeparator; + } + else + { + sb.Append(testchar); + } + break; + } + + case MPIState.MPI_LookForNextCharOrSeparator: + { + if (!IsWhitespace(testchar)) + { // If it is not whitespace + if (testchar == separator) + { + IncrementStringCount(name, parsedNames, ref stringCount, property); + state = MPIState.MPI_Value; + } + else + { // If its not a separator and not whitespace + sb.Append(whitespaceSB); + sb.Append(testchar); + parsedNames[stringCount] = sb.ToString(); // Need to set the name here in case the string ends here. + state = MPIState.MPI_ParseNonQuote; + } + } + else + { + whitespaceSB.Append(testchar); + } + break; + } + + case MPIState.MPI_ParseQuote: + { + if (testchar == rightQuoteChar) + { // if se are on a right quote see if we are escaping the right quote or ending the quoted string + if (!removequotes) + { + sb.Append(testchar); + } + state = MPIState.MPI_RightQuote; + } + else + { + sb.Append(testchar); // Append what we are currently parsing + } + break; + } + + case MPIState.MPI_RightQuote: + { + if (testchar == rightQuoteChar) + { // If the next char is another right quote then we were escaping the right quote + sb.Append(testchar); + state = MPIState.MPI_ParseQuote; + } + else + if (testchar == separator) + { // If its a separator then record what we've parsed + parsedNames[stringCount] = sb.ToString(); + IncrementStringCount(name, parsedNames, ref stringCount, property); + state = MPIState.MPI_Value; + } + else + if (!IsWhitespace(testchar)) + { // If it is not whitespace we got problems + throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name); + } + else + { // It is a whitespace character so the following char should be whitespace, separator, or end of string anything else is bad + parsedNames[stringCount] = sb.ToString(); + state = MPIState.MPI_LookForSeparator; + } + break; + } + + case MPIState.MPI_LookForSeparator: + { + if (!IsWhitespace(testchar)) + { // If it is not whitespace + if (testchar == separator) + { // If it is a separator + IncrementStringCount(name, parsedNames, ref stringCount, property); + state = MPIState.MPI_Value; + } + else + { // Otherwise not a separator + throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name); + } + } + break; + } + } + } + + // Resolve final states after parsing the string + switch (state) + { + case MPIState.MPI_Value: // These states require no extra action + case MPIState.MPI_LookForSeparator: + case MPIState.MPI_LookForNextCharOrSeparator: + break; + + case MPIState.MPI_ParseNonQuote: // Dump what ever was parsed + case MPIState.MPI_RightQuote: + parsedNames[stringCount] = sb.ToString(); + break; + + case MPIState.MPI_ParseQuote: // Invalid Ending States + default: + throw ADP.InvalidMultipartNameIncorrectUsageOfQuotes(property, name); + } + + if (parsedNames[0] == null) + { + if (ThrowOnEmptyMultipartName) + { + throw ADP.InvalidMultipartName(property, name); // Name is entirely made up of whitespace + } + } + else + { + // Shuffle the parsed name, from left justification to right justification, i.e. [a][b][null][null] goes to [null][null][a][b] + int offset = limit - stringCount - 1; + if (offset > 0) + { + for (int x = limit - 1; x >= offset; --x) + { + parsedNames[x] = parsedNames[x - offset]; + parsedNames[x - offset] = null; + } + } + } + return parsedNames; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/NameValuePair.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/NameValuePair.cs new file mode 100644 index 0000000000..f0cfc71d53 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Common/NameValuePair.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.Runtime.Serialization; + +namespace Microsoft.Data.Common +{ + [Serializable] + internal sealed class NameValuePair + { + readonly private string _name; + readonly private string _value; + [OptionalField(VersionAdded = 2)] + readonly private int _length; + private NameValuePair _next; + + internal NameValuePair(string name, string value, int length) + { + Debug.Assert(!string.IsNullOrEmpty(name), "empty keyname"); + _name = name; + _value = value; + _length = length; + } + + internal int Length + { + get + { + Debug.Assert(0 < _length, "NameValuePair zero Length usage"); + return _length; + } + } + + internal string Name => _name; + internal string Value => _value; + + internal NameValuePair Next + { + get => _next; + set + { + if ((null != _next) || (null == value)) + { + throw ADP.InternalError(ADP.InternalErrorCode.NameValuePairNext); + } + _next = value; + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/DataException.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/DataException.cs new file mode 100644 index 0000000000..e8a49ffa33 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/DataException.cs @@ -0,0 +1,56 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using Microsoft.Data.SqlClient; + +namespace Microsoft.Data +{ + internal static class ExceptionBuilder + { + // The class defines the exceptions that are specific to the DataSet. + // The class contains functions that take the proper informational variables and then construct + // the appropriate exception with an error string obtained from the resource Data.txt. + // The exception is then returned to the caller, so that the caller may then throw from its + // location so that the catcher of the exception will have the appropriate call stack. + // This class is used so that there will be compile time checking of error messages. + // The resource Data.txt will ensure proper string text based on the appropriate + // locale. + + private static void TraceException(string trace, Exception e) + { + Debug.Assert(null != e, "TraceException: null Exception"); + if (null != e) + { + SqlClientEventSource.Log.TryAdvancedTraceEvent(trace, e.Message); + try + { + SqlClientEventSource.Log.TryAdvancedTraceEvent(" Environment StackTrace = '{0}'", Environment.StackTrace); + } + catch (System.Security.SecurityException) + { + // if you don't have permission - you don't get the stack trace + } + } + } + + internal static void TraceExceptionAsReturnValue(Exception e) + { + TraceException(" Message='{0}'", e); + } + + internal static ArgumentException _Argument(string error) + { + ArgumentException e = new ArgumentException(error); + ExceptionBuilder.TraceExceptionAsReturnValue(e); + return e; + } + + internal static Exception InvalidOffsetLength() + { + return _Argument(StringsHelper.GetString(Strings.Data_InvalidOffsetLength)); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/OperationAbortedException.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/OperationAbortedException.cs new file mode 100644 index 0000000000..537a4ac0a1 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/OperationAbortedException.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.Serialization; +using Microsoft.Data.Common; + +namespace Microsoft.Data +{ + /// + [Serializable] + [System.Runtime.CompilerServices.TypeForwardedFrom("System.Data, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")] + public sealed class OperationAbortedException : SystemException + { + private OperationAbortedException(string message, Exception innerException) : base(message, innerException) + { + HResult = unchecked((int)0x80131936); + } + + private OperationAbortedException(SerializationInfo info, StreamingContext context) : base(info, context) + { + } + + internal static OperationAbortedException Aborted(Exception inner) + { + OperationAbortedException e; + if (inner == null) + { + e = new OperationAbortedException(Strings.ADP_OperationAborted, null); + } + else + { + e = new OperationAbortedException(Strings.ADP_OperationAbortedExceptionMessage, inner); + } + ADP.TraceExceptionAsReturnValue(e); + return e; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContext.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContext.cs new file mode 100644 index 0000000000..ec6b695429 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContext.cs @@ -0,0 +1,112 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.Runtime.ConstrainedExecution; +using System.Threading; + +namespace Microsoft.Data.ProviderBase +{ + /// + /// Represents the context of an authentication attempt when using the new active directory based authentication mechanisms. + /// All data members, except_isUpdateInProgressCounter, should be immutable. + /// + sealed internal class DbConnectionPoolAuthenticationContext + { + /// + /// The value expected in _isUpdateInProgress if a thread has taken a lock on this context, + /// to perform the update on the context. + /// + private const int STATUS_LOCKED = 1; + + /// + /// The value expected in _isUpdateInProgress if no thread has taken a lock on this context. + /// + private const int STATUS_UNLOCKED = 0; + + /// + /// Access Token, which is obtained from Active Directory Authentication Library for SQL Server, and needs to be sent to SQL Server + /// as part of TDS Token type Federated Authentication Token. + /// + private readonly byte[] _accessToken; + + /// + /// Expiration time of the above access token. + /// + private readonly DateTime _expirationTime; + + /// + /// A member which is used to achieve a lock to control refresh attempt on this context. + /// + private int _isUpdateInProgress; + + /// + /// Constructor. + /// + /// Access Token that will be used to connect to SQL Server. Carries identity information about a user. + /// The expiration time in UTC for the above accessToken. + internal DbConnectionPoolAuthenticationContext(byte[] accessToken, DateTime expirationTime) + { + + Debug.Assert(accessToken != null && accessToken.Length > 0); + Debug.Assert(expirationTime > DateTime.MinValue && expirationTime < DateTime.MaxValue); + + _accessToken = accessToken; + _expirationTime = expirationTime; + _isUpdateInProgress = STATUS_UNLOCKED; + } + + /// + /// Static Method. + /// Given two contexts, choose one to update in the cache. Chooses based on expiration time. + /// + /// Context1 + /// Context2 + internal static DbConnectionPoolAuthenticationContext ChooseAuthenticationContextToUpdate(DbConnectionPoolAuthenticationContext context1, DbConnectionPoolAuthenticationContext context2) + { + + Debug.Assert(context1 != null, "context1 should not be null."); + Debug.Assert(context2 != null, "context2 should not be null."); + + return context1.ExpirationTime > context2.ExpirationTime ? context1 : context2; + } + + internal byte[] AccessToken + { + get + { + return _accessToken; + } + } + + internal DateTime ExpirationTime + { + get + { + return _expirationTime; + } + } + + /// + /// Try locking the variable _isUpdateInProgressCounter and return if this thread got the lock to update. + /// Whichever thread got the chance to update this variable to 1 wins the lock. + /// + internal bool LockToUpdate() + { + int oldValue = Interlocked.CompareExchange(ref _isUpdateInProgress, STATUS_LOCKED, STATUS_UNLOCKED); + return (oldValue == STATUS_UNLOCKED); + } + + /// + /// Release the lock which was obtained through LockToUpdate. + /// + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal void ReleaseLockToUpdate() + { + int oldValue = Interlocked.CompareExchange(ref _isUpdateInProgress, STATUS_UNLOCKED, STATUS_LOCKED); + Debug.Assert(oldValue == STATUS_LOCKED); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContextKey.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContextKey.cs new file mode 100644 index 0000000000..a6f15ca999 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolAuthenticationContextKey.cs @@ -0,0 +1,112 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; + +namespace Microsoft.Data.ProviderBase +{ + /// + /// Represents the key of dbConnectionPoolAuthenticationContext. + /// All data members should be immutable and so, hashCode is pre-computed. + /// + sealed internal class DbConnectionPoolAuthenticationContextKey + { + /// + /// Security Token Service Authority. + /// + private readonly string _stsAuthority; + + /// + /// Service Principal Name. + /// + private readonly string _servicePrincipalName; + + /// + /// Pre-Computed Hash Code. + /// + private readonly int _hashCode; + + internal string StsAuthority + { + get + { + return _stsAuthority; + } + } + + internal string ServicePrincipalName + { + get + { + return _servicePrincipalName; + } + } + + /// + /// Constructor for the type. + /// + /// Token Endpoint URL + /// SPN representing the SQL service in an active directory. + internal DbConnectionPoolAuthenticationContextKey(string stsAuthority, string servicePrincipalName) + { + Debug.Assert(!string.IsNullOrWhiteSpace(stsAuthority)); + Debug.Assert(!string.IsNullOrWhiteSpace(servicePrincipalName)); + + _stsAuthority = stsAuthority; + _servicePrincipalName = servicePrincipalName; + + // Pre-compute hash since data members are not going to change. + _hashCode = ComputeHashCode(); + } + + /// + /// Override the default Equals implementation. + /// + /// + /// + public override bool Equals(object obj) + { + if (obj == null) + { + return false; + } + + DbConnectionPoolAuthenticationContextKey otherKey = obj as DbConnectionPoolAuthenticationContextKey; + if (otherKey == null) + { + return false; + } + + return (String.Equals(StsAuthority, otherKey.StsAuthority, StringComparison.InvariantCultureIgnoreCase) + && String.Equals(ServicePrincipalName, otherKey.ServicePrincipalName, StringComparison.InvariantCultureIgnoreCase)); + } + + /// + /// Override the default GetHashCode implementation. + /// + /// + public override int GetHashCode() + { + return _hashCode; + } + + /// + /// Compute the hash code for this object. + /// + /// + private int ComputeHashCode() + { + int hashCode = 33; + + unchecked + { + hashCode = (hashCode * 17) + StsAuthority.GetHashCode(); + hashCode = (hashCode * 17) + ServicePrincipalName.GetHashCode(); + } + + return hashCode; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroup.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroup.cs new file mode 100644 index 0000000000..7568340594 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroup.cs @@ -0,0 +1,312 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + + +using Microsoft.Data.Common; +using Microsoft.Data.SqlClient; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; + +namespace Microsoft.Data.ProviderBase +{ + // set_ConnectionString calls DbConnectionFactory.GetConnectionPoolGroup + // when not found a new pool entry is created and potentially added + // DbConnectionPoolGroup starts in the Active state + + // Open calls DbConnectionFactory.GetConnectionPool + // if the existing pool entry is Disabled, GetConnectionPoolGroup is called for a new entry + // DbConnectionFactory.GetConnectionPool calls DbConnectionPoolGroup.GetConnectionPool + + // DbConnectionPoolGroup.GetConnectionPool will return pool for the current identity + // or null if identity is restricted or pooling is disabled or state is disabled at time of add + // state changes are Active->Active, Idle->Active + + // DbConnectionFactory.PruneConnectionPoolGroups calls Prune + // which will QueuePoolForRelease on all empty pools + // and once no pools remain, change state from Active->Idle->Disabled + // Once Disabled, factory can remove its reference to the pool entry + + sealed internal class DbConnectionPoolGroup + { + private readonly DbConnectionOptions _connectionOptions; + private readonly DbConnectionPoolKey _poolKey; + private readonly DbConnectionPoolGroupOptions _poolGroupOptions; + private ConcurrentDictionary _poolCollection; + + private int _state; // see PoolGroupState* below + + private DbConnectionPoolGroupProviderInfo _providerInfo; + private DbMetaDataFactory _metaDataFactory; + + private static int s_objectTypeCount; // EventSource counter + + // always lock this before changing _state, we don't want to move out of the 'Disabled' state + // PoolGroupStateUninitialized = 0; + private const int PoolGroupStateActive = 1; // initial state, GetPoolGroup from cache, connection Open + private const int PoolGroupStateIdle = 2; // all pools are pruned via Clear + private const int PoolGroupStateDisabled = 4; // factory pool entry pruning method + + internal DbConnectionPoolGroup(DbConnectionOptions connectionOptions, DbConnectionPoolKey key, DbConnectionPoolGroupOptions poolGroupOptions) + { + Debug.Assert(null != connectionOptions, "null connection options"); +#if NETFRAMEWORK + Debug.Assert(null == poolGroupOptions || ADP.s_isWindowsNT, "should not have pooling options on Win9x"); +#endif + + _connectionOptions = connectionOptions; + _poolKey = key; + _poolGroupOptions = poolGroupOptions; + + // always lock this object before changing state + // HybridDictionary does not create any sub-objects until add + // so it is safe to use for non-pooled connection as long as + // we check _poolGroupOptions first + _poolCollection = new ConcurrentDictionary(); + _state = PoolGroupStateActive; + } + + internal DbConnectionOptions ConnectionOptions => _connectionOptions; + + internal DbConnectionPoolKey PoolKey => _poolKey; + + internal DbConnectionPoolGroupProviderInfo ProviderInfo + { + get + { + return _providerInfo; + } + set + { + _providerInfo = value; + if (null != value) + { + _providerInfo.PoolGroup = this; + } + } + } + + internal bool IsDisabled => (PoolGroupStateDisabled == _state); + + internal int ObjectID { get; } = Interlocked.Increment(ref s_objectTypeCount); + + internal DbConnectionPoolGroupOptions PoolGroupOptions => _poolGroupOptions; + + internal DbMetaDataFactory MetaDataFactory + { + get + { + return _metaDataFactory; + } + + set + { + _metaDataFactory = value; + } + } + + internal int Clear() + { + // must be multi-thread safe with competing calls by Clear and Prune via background thread + // will return the number of connections in the group after clearing has finished + + // First, note the old collection and create a new collection to be used + ConcurrentDictionary oldPoolCollection = null; + lock (this) + { + if (_poolCollection.Count > 0) + { + oldPoolCollection = _poolCollection; + _poolCollection = new ConcurrentDictionary(); + } + } + + // Then, if a new collection was created, release the pools from the old collection + if (oldPoolCollection != null) + { + foreach (KeyValuePair entry in oldPoolCollection) + { + DbConnectionPool pool = entry.Value; + if (pool != null) + { + DbConnectionFactory connectionFactory = pool.ConnectionFactory; +#if NETFRAMEWORK + connectionFactory.PerformanceCounters.NumberOfActiveConnectionPools.Decrement(); +#endif + connectionFactory.QueuePoolForRelease(pool, true); + } + } + } + + // Finally, return the pool collection count - this may be non-zero if something was added while we were clearing + return _poolCollection.Count; + } + + internal DbConnectionPool GetConnectionPool(DbConnectionFactory connectionFactory) + { + // When this method returns null it indicates that the connection + // factory should not use pooling. + + // We don't support connection pooling on Win9x; + // PoolGroupOptions will only be null when we're not supposed to pool + // connections. + DbConnectionPool pool = null; + if (null != _poolGroupOptions) + { +#if NETFRAMEWORK + Debug.Assert(ADP.s_isWindowsNT, "should not be pooling on Win9x"); +#endif + + DbConnectionPoolIdentity currentIdentity = DbConnectionPoolIdentity.NoIdentity; + + if (_poolGroupOptions.PoolByIdentity) + { + // if we're pooling by identity (because integrated security is + // being used for these connections) then we need to go out and + // search for the connectionPool that matches the current identity. + + currentIdentity = DbConnectionPoolIdentity.GetCurrent(); + + // If the current token is restricted in some way, then we must + // not attempt to pool these connections. + if (currentIdentity.IsRestricted) + { + currentIdentity = null; + } + } + + if (null != currentIdentity) + { + if (!_poolCollection.TryGetValue(currentIdentity, out pool)) // find the pool + { + lock (this) + { + // Did someone already add it to the list? + if (!_poolCollection.TryGetValue(currentIdentity, out pool)) + { + DbConnectionPoolProviderInfo connectionPoolProviderInfo = connectionFactory.CreateConnectionPoolProviderInfo(ConnectionOptions); + DbConnectionPool newPool = new(connectionFactory, this, currentIdentity, connectionPoolProviderInfo); + + if (MarkPoolGroupAsActive()) + { + // If we get here, we know for certain that we there isn't + // a pool that matches the current identity, so we have to + // add the optimistically created one + newPool.Startup(); // must start pool before usage + bool addResult = _poolCollection.TryAdd(currentIdentity, newPool); + Debug.Assert(addResult, "No other pool with current identity should exist at this point"); + SqlClientEventSource.Log.EnterActiveConnectionPool(); +#if NETFRAMEWORK + connectionFactory.PerformanceCounters.NumberOfActiveConnectionPools.Increment(); +#endif + pool = newPool; + } + else + { + // else pool entry has been disabled so don't create new pools + Debug.Assert(PoolGroupStateDisabled == _state, "state should be disabled"); + + // don't need to call connectionFactory.QueuePoolForRelease(newPool) because + // pool callbacks were delayed and no risk of connections being created + newPool.Shutdown(); + } + } + else + { + // else found an existing pool to use instead + Debug.Assert(PoolGroupStateActive == _state, "state should be active since a pool exists and lock holds"); + } + } + } + // the found pool could be in any state + } + } + + if (null == pool) + { + lock (this) + { + // keep the pool entry state active when not pooling + MarkPoolGroupAsActive(); + } + } + return pool; + } + + private bool MarkPoolGroupAsActive() + { + // when getting a connection, make the entry active if it was idle (but not disabled) + // must always lock this before calling + + if (PoolGroupStateIdle == _state) + { + _state = PoolGroupStateActive; + SqlClientEventSource.Log.TryTraceEvent(" {0}, Active", ObjectID); + } + return (PoolGroupStateActive == _state); + } + + internal bool Prune() + { + // must only call from DbConnectionFactory.PruneConnectionPoolGroups on background timer thread + // must lock(DbConnectionFactory._connectionPoolGroups.SyncRoot) before calling ReadyToRemove + // to avoid conflict with DbConnectionFactory.CreateConnectionPoolGroup replacing pool entry + lock (this) + { + if (_poolCollection.Count > 0) + { + var newPoolCollection = new ConcurrentDictionary(); + + foreach (KeyValuePair entry in _poolCollection) + { + DbConnectionPool pool = entry.Value; + if (pool != null) + { + // Actually prune the pool if there are no connections in the pool and no errors occurred. + // Empty pool during pruning indicates zero or low activity, but + // an error state indicates the pool needs to stay around to + // throttle new connection attempts. + if ((!pool.ErrorOccurred) && (0 == pool.Count)) + { + // Order is important here. First we remove the pool + // from the collection of pools so no one will try + // to use it while we're processing and finally we put the + // pool into a list of pools to be released when they + // are completely empty. + DbConnectionFactory connectionFactory = pool.ConnectionFactory; +#if NETFRAMEWORK + connectionFactory.PerformanceCounters.NumberOfActiveConnectionPools.Decrement(); +#endif + connectionFactory.QueuePoolForRelease(pool, false); + } + else + { + newPoolCollection.TryAdd(entry.Key, entry.Value); + } + } + } + _poolCollection = newPoolCollection; + } + + // must be pruning thread to change state and no connections + // otherwise pruning thread risks making entry disabled soon after user calls ClearPool + if (0 == _poolCollection.Count) + { + if (PoolGroupStateActive == _state) + { + _state = PoolGroupStateIdle; + SqlClientEventSource.Log.TryTraceEvent(" {0}, Idle", ObjectID); + } + else if (PoolGroupStateIdle == _state) + { + _state = PoolGroupStateDisabled; + SqlClientEventSource.Log.TryTraceEvent(" {0}, Disabled", ObjectID); + } + } + return (PoolGroupStateDisabled == _state); + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroupProviderInfo.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroupProviderInfo.cs new file mode 100644 index 0000000000..3eceb6d3e3 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolGroupProviderInfo.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Data.ProviderBase +{ + internal class DbConnectionPoolGroupProviderInfo + { + private DbConnectionPoolGroup _poolGroup; + + internal DbConnectionPoolGroup PoolGroup + { + get + { + return _poolGroup; + } + set + { + _poolGroup = value; + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolOptions.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolOptions.cs new file mode 100644 index 0000000000..866453432c --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolOptions.cs @@ -0,0 +1,73 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.Data.ProviderBase +{ + internal sealed class DbConnectionPoolGroupOptions + { + private readonly bool _poolByIdentity; + private readonly int _minPoolSize; + private readonly int _maxPoolSize; + private readonly int _creationTimeout; + private readonly TimeSpan _loadBalanceTimeout; + private readonly bool _hasTransactionAffinity; + private readonly bool _useLoadBalancing; + + public DbConnectionPoolGroupOptions( + bool poolByIdentity, + int minPoolSize, + int maxPoolSize, + int creationTimeout, + int loadBalanceTimeout, + bool hasTransactionAffinity + ) + { + _poolByIdentity = poolByIdentity; + _minPoolSize = minPoolSize; + _maxPoolSize = maxPoolSize; + _creationTimeout = creationTimeout; + + if (0 != loadBalanceTimeout) + { + _loadBalanceTimeout = new TimeSpan(0, 0, loadBalanceTimeout); + _useLoadBalancing = true; + } + + _hasTransactionAffinity = hasTransactionAffinity; + } + + public int CreationTimeout + { + get { return _creationTimeout; } + } + public bool HasTransactionAffinity + { + get { return _hasTransactionAffinity; } + } + public TimeSpan LoadBalanceTimeout + { + get { return _loadBalanceTimeout; } + } + public int MaxPoolSize + { + get { return _maxPoolSize; } + } + public int MinPoolSize + { + get { return _minPoolSize; } + } + public bool PoolByIdentity + { + get { return _poolByIdentity; } + } + public bool UseLoadBalancing + { + get { return _useLoadBalancing; } + } + } +} + + diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolProviderInfo.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolProviderInfo.cs new file mode 100644 index 0000000000..5392795dff --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbConnectionPoolProviderInfo.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Data.ProviderBase +{ + internal class DbConnectionPoolProviderInfo + { + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbMetaDataFactory.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbMetaDataFactory.cs new file mode 100644 index 0000000000..6e907d26e1 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/DbMetaDataFactory.cs @@ -0,0 +1,558 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Data.Common; +using System; +using System.Data; +using System.Data.Common; +using System.Diagnostics; +using System.Globalization; +using System.IO; + +namespace Microsoft.Data.ProviderBase +{ + internal class DbMetaDataFactory + { + + private DataSet _metaDataCollectionsDataSet; + private string _normalizedServerVersion; + private string _serverVersionString; + // well known column names + private const string CollectionNameKey = "CollectionName"; + private const string PopulationMechanismKey = "PopulationMechanism"; + private const string PopulationStringKey = "PopulationString"; + private const string MaximumVersionKey = "MaximumVersion"; + private const string MinimumVersionKey = "MinimumVersion"; + private const string DataSourceProductVersionNormalizedKey = "DataSourceProductVersionNormalized"; + private const string DataSourceProductVersionKey = "DataSourceProductVersion"; + private const string RestrictionNumberKey = "RestrictionNumber"; + private const string NumberOfRestrictionsKey = "NumberOfRestrictions"; + private const string RestrictionNameKey = "RestrictionName"; + private const string ParameterNameKey = "ParameterName"; + + // population mechanisms + private const string DataTableKey = "DataTable"; + private const string SqlCommandKey = "SQLCommand"; + private const string PrepareCollectionKey = "PrepareCollection"; + + public DbMetaDataFactory(Stream xmlStream, string serverVersion, string normalizedServerVersion) + { + ADP.CheckArgumentNull(xmlStream, nameof(xmlStream)); + ADP.CheckArgumentNull(serverVersion, nameof(serverVersion)); + ADP.CheckArgumentNull(normalizedServerVersion, nameof(normalizedServerVersion)); + + LoadDataSetFromXml(xmlStream); + + _serverVersionString = serverVersion; + _normalizedServerVersion = normalizedServerVersion; + } + + protected DataSet CollectionDataSet => _metaDataCollectionsDataSet; + + protected string ServerVersion => _serverVersionString; + + protected string ServerVersionNormalized => _normalizedServerVersion; + + protected DataTable CloneAndFilterCollection(string collectionName, string[] hiddenColumnNames) + { + DataTable destinationTable; + DataColumn[] filteredSourceColumns; + DataColumnCollection destinationColumns; + DataRow newRow; + + DataTable sourceTable = _metaDataCollectionsDataSet.Tables[collectionName]; + + if ((sourceTable == null) || (collectionName != sourceTable.TableName)) + { + throw ADP.DataTableDoesNotExist(collectionName); + } + + destinationTable = new DataTable(collectionName) + { + Locale = CultureInfo.InvariantCulture + }; + destinationColumns = destinationTable.Columns; + + filteredSourceColumns = FilterColumns(sourceTable, hiddenColumnNames, destinationColumns); + + foreach (DataRow row in sourceTable.Rows) + { + if (SupportedByCurrentVersion(row)) + { + newRow = destinationTable.NewRow(); + for (int i = 0; i < destinationColumns.Count; i++) + { + newRow[destinationColumns[i]] = row[filteredSourceColumns[i], DataRowVersion.Current]; + } + destinationTable.Rows.Add(newRow); + newRow.AcceptChanges(); + } + } + + return destinationTable; + } + + public void Dispose() => Dispose(true); + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + _normalizedServerVersion = null; + _serverVersionString = null; + _metaDataCollectionsDataSet.Dispose(); + } + } + + private DataTable ExecuteCommand(DataRow requestedCollectionRow, string[] restrictions, DbConnection connection) + { + DataTable metaDataCollectionsTable = _metaDataCollectionsDataSet.Tables[DbMetaDataCollectionNames.MetaDataCollections]; + DataColumn populationStringColumn = metaDataCollectionsTable.Columns[PopulationStringKey]; + DataColumn numberOfRestrictionsColumn = metaDataCollectionsTable.Columns[NumberOfRestrictionsKey]; + DataColumn collectionNameColumn = metaDataCollectionsTable.Columns[CollectionNameKey]; + + DataTable resultTable = null; + + Debug.Assert(requestedCollectionRow != null); + string sqlCommand = requestedCollectionRow[populationStringColumn, DataRowVersion.Current] as string; + int numberOfRestrictions = (int)requestedCollectionRow[numberOfRestrictionsColumn, DataRowVersion.Current]; + string collectionName = requestedCollectionRow[collectionNameColumn, DataRowVersion.Current] as string; + + if ((restrictions != null) && (restrictions.Length > numberOfRestrictions)) + { + throw ADP.TooManyRestrictions(collectionName); + } + + DbCommand command = connection.CreateCommand(); + command.CommandText = sqlCommand; + command.CommandTimeout = Math.Max(command.CommandTimeout, 180); + + for (int i = 0; i < numberOfRestrictions; i++) + { + + DbParameter restrictionParameter = command.CreateParameter(); + + if ((restrictions != null) && (restrictions.Length > i) && (restrictions[i] != null)) + { + restrictionParameter.Value = restrictions[i]; + } + else + { + // This is where we have to assign null to the value of the parameter. + restrictionParameter.Value = DBNull.Value; + } + + restrictionParameter.ParameterName = GetParameterName(collectionName, i + 1); + restrictionParameter.Direction = ParameterDirection.Input; + command.Parameters.Add(restrictionParameter); + } + + DbDataReader reader = null; + try + { + try + { + reader = command.ExecuteReader(); + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + throw ADP.QueryFailed(collectionName, e); + } + + // Build a DataTable from the reader + resultTable = new DataTable(collectionName) + { + Locale = CultureInfo.InvariantCulture + }; + + DataTable schemaTable = reader.GetSchemaTable(); + foreach (DataRow row in schemaTable.Rows) + { + resultTable.Columns.Add(row["ColumnName"] as string, (Type)row["DataType"] as Type); + } + object[] values = new object[resultTable.Columns.Count]; + while (reader.Read()) + { + reader.GetValues(values); + resultTable.Rows.Add(values); + } + } + finally + { + reader?.Dispose(); + } + return resultTable; + } + + private DataColumn[] FilterColumns(DataTable sourceTable, string[] hiddenColumnNames, DataColumnCollection destinationColumns) + { + int columnCount = 0; + foreach (DataColumn sourceColumn in sourceTable.Columns) + { + if (IncludeThisColumn(sourceColumn, hiddenColumnNames)) + { + columnCount++; + } + } + + if (columnCount == 0) + { + throw ADP.NoColumns(); + } + + int currentColumn = 0; + DataColumn[] filteredSourceColumns = new DataColumn[columnCount]; + + foreach (DataColumn sourceColumn in sourceTable.Columns) + { + if (IncludeThisColumn(sourceColumn, hiddenColumnNames)) + { + DataColumn newDestinationColumn = new(sourceColumn.ColumnName, sourceColumn.DataType); + destinationColumns.Add(newDestinationColumn); + filteredSourceColumns[currentColumn] = sourceColumn; + currentColumn++; + } + } + return filteredSourceColumns; + } + + internal DataRow FindMetaDataCollectionRow(string collectionName) + { + bool versionFailure; + bool haveExactMatch; + bool haveMultipleInexactMatches; + string candidateCollectionName; + + DataTable metaDataCollectionsTable = _metaDataCollectionsDataSet.Tables[DbMetaDataCollectionNames.MetaDataCollections]; + if (metaDataCollectionsTable == null) + { + throw ADP.InvalidXml(); + } + + DataColumn collectionNameColumn = metaDataCollectionsTable.Columns[DbMetaDataColumnNames.CollectionName]; + + if ((null == collectionNameColumn) || (typeof(string) != collectionNameColumn.DataType)) + { + throw ADP.InvalidXmlMissingColumn(DbMetaDataCollectionNames.MetaDataCollections, DbMetaDataColumnNames.CollectionName); + } + + DataRow requestedCollectionRow = null; + string exactCollectionName = null; + + // find the requested collection + versionFailure = false; + haveExactMatch = false; + haveMultipleInexactMatches = false; + + foreach (DataRow row in metaDataCollectionsTable.Rows) + { + + candidateCollectionName = row[collectionNameColumn, DataRowVersion.Current] as string; + if (string.IsNullOrEmpty(candidateCollectionName)) + { + throw ADP.InvalidXmlInvalidValue(DbMetaDataCollectionNames.MetaDataCollections, DbMetaDataColumnNames.CollectionName); + } + + if (ADP.CompareInsensitiveInvariant(candidateCollectionName, collectionName)) + { + if (!SupportedByCurrentVersion(row)) + { + versionFailure = true; + } + else + { + if (collectionName == candidateCollectionName) + { + if (haveExactMatch) + { + throw ADP.CollectionNameIsNotUnique(collectionName); + } + requestedCollectionRow = row; + exactCollectionName = candidateCollectionName; + haveExactMatch = true; + } + else if (!haveExactMatch) + { + // have an inexact match - ok only if it is the only one + if (exactCollectionName != null) + { + // can't fail here becasue we may still find an exact match + haveMultipleInexactMatches = true; + } + requestedCollectionRow = row; + exactCollectionName = candidateCollectionName; + } + } + } + } + + if (requestedCollectionRow == null) + { + if (!versionFailure) + { + throw ADP.UndefinedCollection(collectionName); + } + else + { + throw ADP.UnsupportedVersion(collectionName); + } + } + + if (!haveExactMatch && haveMultipleInexactMatches) + { + throw ADP.AmbiguousCollectionName(collectionName); + } + + return requestedCollectionRow; + + } + + private void FixUpVersion(DataTable dataSourceInfoTable) + { + Debug.Assert(dataSourceInfoTable.TableName == DbMetaDataCollectionNames.DataSourceInformation); + DataColumn versionColumn = dataSourceInfoTable.Columns[DataSourceProductVersionKey]; + DataColumn normalizedVersionColumn = dataSourceInfoTable.Columns[DataSourceProductVersionNormalizedKey]; + + if ((versionColumn == null) || (normalizedVersionColumn == null)) + { + throw ADP.MissingDataSourceInformationColumn(); + } + + if (dataSourceInfoTable.Rows.Count != 1) + { + throw ADP.IncorrectNumberOfDataSourceInformationRows(); + } + + DataRow dataSourceInfoRow = dataSourceInfoTable.Rows[0]; + + dataSourceInfoRow[versionColumn] = _serverVersionString; + dataSourceInfoRow[normalizedVersionColumn] = _normalizedServerVersion; + dataSourceInfoRow.AcceptChanges(); + } + + + private string GetParameterName(string neededCollectionName, int neededRestrictionNumber) + { + DataColumn collectionName = null; + DataColumn parameterName = null; + DataColumn restrictionName = null; + DataColumn restrictionNumber = null; + + string result = null; + + DataTable restrictionsTable = _metaDataCollectionsDataSet.Tables[DbMetaDataCollectionNames.Restrictions]; + if (restrictionsTable != null) + { + DataColumnCollection restrictionColumns = restrictionsTable.Columns; + if (restrictionColumns != null) + { + collectionName = restrictionColumns[DbMetaDataFactory.CollectionNameKey]; + parameterName = restrictionColumns[ParameterNameKey]; + restrictionName = restrictionColumns[RestrictionNameKey]; + restrictionNumber = restrictionColumns[RestrictionNumberKey]; + } + } + + if ((parameterName == null) || (collectionName == null) || (restrictionName == null) || (restrictionNumber == null)) + { + throw ADP.MissingRestrictionColumn(); + } + + foreach (DataRow restriction in restrictionsTable.Rows) + { + + if (((string)restriction[collectionName] == neededCollectionName) && + ((int)restriction[restrictionNumber] == neededRestrictionNumber) && + (SupportedByCurrentVersion(restriction))) + { + + result = (string)restriction[parameterName]; + break; + } + } + + if (result == null) + { + throw ADP.MissingRestrictionRow(); + } + + return result; + } + + public virtual DataTable GetSchema(DbConnection connection, string collectionName, string[] restrictions) + { + Debug.Assert(_metaDataCollectionsDataSet != null); + + DataTable metaDataCollectionsTable = _metaDataCollectionsDataSet.Tables[DbMetaDataCollectionNames.MetaDataCollections]; + DataColumn populationMechanismColumn = metaDataCollectionsTable.Columns[PopulationMechanismKey]; + DataColumn collectionNameColumn = metaDataCollectionsTable.Columns[DbMetaDataColumnNames.CollectionName]; + + string[] hiddenColumns; + + DataRow requestedCollectionRow = FindMetaDataCollectionRow(collectionName); + string exactCollectionName = requestedCollectionRow[collectionNameColumn, DataRowVersion.Current] as string; + + if (!ADP.IsEmptyArray(restrictions)) + { + + for (int i = 0; i < restrictions.Length; i++) + { + if ((restrictions[i] != null) && (restrictions[i].Length > 4096)) + { + // use a non-specific error because no new beta 2 error messages are allowed + // TODO: will add a more descriptive error in RTM + throw ADP.NotSupported(); + } + } + } + + string populationMechanism = requestedCollectionRow[populationMechanismColumn, DataRowVersion.Current] as string; + + DataTable requestedSchema; + switch (populationMechanism) + { + + case DataTableKey: + if (exactCollectionName == DbMetaDataCollectionNames.MetaDataCollections) + { + hiddenColumns = new string[2]; + hiddenColumns[0] = PopulationMechanismKey; + hiddenColumns[1] = PopulationStringKey; + } + else + { + hiddenColumns = null; + } + // none of the datatable collections support restrictions + if (!ADP.IsEmptyArray(restrictions)) + { + throw ADP.TooManyRestrictions(exactCollectionName); + } + + + requestedSchema = CloneAndFilterCollection(exactCollectionName, hiddenColumns); + + // TODO: Consider an alternate method that doesn't involve special casing -- perhaps _prepareCollection + + // for the data source information table we need to fix up the version columns at run time + // since the version is determined at run time + if (exactCollectionName == DbMetaDataCollectionNames.DataSourceInformation) + { + FixUpVersion(requestedSchema); + } + break; + + case SqlCommandKey: + requestedSchema = ExecuteCommand(requestedCollectionRow, restrictions, connection); + break; + + case PrepareCollectionKey: + requestedSchema = PrepareCollection(exactCollectionName, restrictions, connection); + break; + + default: + throw ADP.UndefinedPopulationMechanism(populationMechanism); + } + + return requestedSchema; + } + + private bool IncludeThisColumn(DataColumn sourceColumn, string[] hiddenColumnNames) + { + + bool result = true; + string sourceColumnName = sourceColumn.ColumnName; + + switch (sourceColumnName) + { + + case MinimumVersionKey: + case MaximumVersionKey: + result = false; + break; + + default: + if (hiddenColumnNames == null) + { + break; + } + for (int i = 0; i < hiddenColumnNames.Length; i++) + { + if (hiddenColumnNames[i] == sourceColumnName) + { + result = false; + break; + } + } + break; + } + + return result; + } + + private void LoadDataSetFromXml(Stream XmlStream) + { + _metaDataCollectionsDataSet = new DataSet + { + Locale = System.Globalization.CultureInfo.InvariantCulture + }; + _metaDataCollectionsDataSet.ReadXml(XmlStream); + } + + protected virtual DataTable PrepareCollection(string collectionName, string[] restrictions, DbConnection connection) + { + throw ADP.NotSupported(); + } + + private bool SupportedByCurrentVersion(DataRow requestedCollectionRow) + { + bool result = true; + DataColumnCollection tableColumns = requestedCollectionRow.Table.Columns; + DataColumn versionColumn; + object version; + + // check the minimum version first + versionColumn = tableColumns[MinimumVersionKey]; + if (versionColumn != null) + { + version = requestedCollectionRow[versionColumn]; + if (version != null) + { + if (version != DBNull.Value) + { + if (0 > string.Compare(_normalizedServerVersion, (string)version, StringComparison.OrdinalIgnoreCase)) + { + result = false; + } + } + } + } + + // if the minimum version was ok what about the maximum version + if (result) + { + versionColumn = tableColumns[MaximumVersionKey]; + if (versionColumn != null) + { + version = requestedCollectionRow[versionColumn]; + if (version != null) + { + if (version != DBNull.Value) + { + if (0 < string.Compare(_normalizedServerVersion, (string)version, StringComparison.OrdinalIgnoreCase)) + { + result = false; + } + } + } + } + } + return result; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/FieldNameLookup.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/FieldNameLookup.cs new file mode 100644 index 0000000000..41f67f9403 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/FieldNameLookup.cs @@ -0,0 +1,117 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Data; +using System.Globalization; +using Microsoft.Data.Common; + +namespace Microsoft.Data.ProviderBase +{ + internal sealed class FieldNameLookup + { + private readonly string[] _fieldNames; + private readonly int _defaultLocaleID; + + private Dictionary _fieldNameLookup; + private CompareInfo _compareInfo; + + public FieldNameLookup(string[] fieldNames, int defaultLocaleID) + { + _defaultLocaleID = defaultLocaleID; + if (fieldNames == null) + { + throw ADP.ArgumentNull(nameof(fieldNames)); + } + _fieldNames = fieldNames; + } + + public FieldNameLookup(IDataReader reader, int defaultLocaleID) + { + _defaultLocaleID = defaultLocaleID; + string[] fieldNames = new string[reader.FieldCount]; + for (int i = 0; i < fieldNames.Length; ++i) + { + fieldNames[i] = reader.GetName(i); + } + _fieldNames = fieldNames; + } + + public int GetOrdinal(string fieldName) + { + if (fieldName == null) + { + throw ADP.ArgumentNull(nameof(fieldName)); + } + int index = IndexOf(fieldName); + if (index == -1) + { + throw ADP.IndexOutOfRange(fieldName); + } + return index; + } + + private int IndexOf(string fieldName) + { + if (_fieldNameLookup == null) + { + GenerateLookup(); + } + if (!_fieldNameLookup.TryGetValue(fieldName, out int index)) + { + index = LinearIndexOf(fieldName, CompareOptions.IgnoreCase); + if (index == -1) + { + // do the slow search now (kana, width insensitive comparison) + index = LinearIndexOf(fieldName, ADP.DefaultCompareOptions); + } + } + + return index; + } + + private CompareInfo GetCompareInfo() + { + if (_defaultLocaleID != -1) + { + return CompareInfo.GetCompareInfo(_defaultLocaleID); + } + return CultureInfo.InvariantCulture.CompareInfo; + } + + private int LinearIndexOf(string fieldName, CompareOptions compareOptions) + { + if (_compareInfo == null) + { + _compareInfo = GetCompareInfo(); + } + + for (int index = 0; index < _fieldNames.Length; index++) + { + if (_compareInfo.Compare(fieldName, _fieldNames[index], compareOptions) == 0) + { + _fieldNameLookup[fieldName] = index; + return index; + } + } + return -1; + } + + private void GenerateLookup() + { + int length = _fieldNames.Length; + Dictionary lookup = new Dictionary(length); + + // walk the field names from the end to the beginning so that if a name exists + // multiple times the first (from beginning to end) index of it is stored + // in the hash table + for (int index = length - 1; 0 <= index; --index) + { + string fieldName = _fieldNames[index]; + lookup[fieldName] = index; + } + _fieldNameLookup = lookup; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/TimeoutTimer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/TimeoutTimer.cs new file mode 100644 index 0000000000..9948b223d1 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/ProviderBase/TimeoutTimer.cs @@ -0,0 +1,185 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Data.Common; +using System; +using System.Diagnostics; + +namespace Microsoft.Data.ProviderBase +{ + // Purpose: + // Manages determining and tracking timeouts + // + // Intended use: + // Call StartXXXXTimeout() to get a timer with the given expiration point + // Get remaining time in appropriate format to pass to subsystem timeouts + // Check for timeout via IsExpired for checks in managed code. + // Simply abandon to GC when done. + internal class TimeoutTimer + { + //------------------- + // Fields + //------------------- + private long _timerExpire; + private bool _isInfiniteTimeout; + private long _originalTimerTicks; + + //------------------- + // Timeout-setting methods + //------------------- + + // Get a new timer that will expire in the given number of seconds + // For input, a value of zero seconds indicates infinite timeout + internal static TimeoutTimer StartSecondsTimeout(int seconds) + { + //-------------------- + // Preconditions: None (seconds must conform to SetTimeoutSeconds requirements) + + //-------------------- + // Method body + var timeout = new TimeoutTimer(); + timeout.SetTimeoutSeconds(seconds); + + //--------------------- + // Postconditions + Debug.Assert(timeout != null); // Need a valid timeouttimer if no error + + return timeout; + } + + // Get a new timer that will expire in the given number of milliseconds + // No current need to support infinite milliseconds timeout + internal static TimeoutTimer StartMillisecondsTimeout(long milliseconds) + { + //-------------------- + // Preconditions + Debug.Assert(0 <= milliseconds); + + //-------------------- + // Method body + var timeout = new TimeoutTimer(); + timeout._originalTimerTicks = milliseconds * TimeSpan.TicksPerMillisecond; + timeout._timerExpire = checked(ADP.TimerCurrent() + timeout._originalTimerTicks); + timeout._isInfiniteTimeout = false; + + //--------------------- + // Postconditions + Debug.Assert(timeout != null); // Need a valid timeouttimer if no error + + return timeout; + } + + //------------------- + // Methods for changing timeout + //------------------- + + internal void SetTimeoutSeconds(int seconds) + { + //-------------------- + // Preconditions + Debug.Assert(0 <= seconds || InfiniteTimeout == seconds); // no need to support negative seconds at present + + //-------------------- + // Method body + if (InfiniteTimeout == seconds) + { + _isInfiniteTimeout = true; + } + else + { + // Stash current time + timeout + _originalTimerTicks = ADP.TimerFromSeconds(seconds); + _timerExpire = checked(ADP.TimerCurrent() + _originalTimerTicks); + _isInfiniteTimeout = false; + } + + //--------------------- + // Postconditions:None + } + + // Reset timer to original duration. + internal void Reset() + { + if (InfiniteTimeout == _originalTimerTicks) + { + _isInfiniteTimeout = true; + } + else + { + _timerExpire = checked(ADP.TimerCurrent() + _originalTimerTicks); + _isInfiniteTimeout = false; + } + } + + //------------------- + // Timeout info properties + //------------------- + + // Indicator for infinite timeout when starting a timer + internal static readonly long InfiniteTimeout = 0; + + // Is this timer in an expired state? + internal bool IsExpired + { + get + { + return !IsInfinite && ADP.TimerHasExpired(_timerExpire); + } + } + + // is this an infinite-timeout timer? + internal bool IsInfinite + { + get + { + return _isInfiniteTimeout; + } + } + + // Special accessor for TimerExpire for use when thunking to legacy timeout methods. + internal long LegacyTimerExpire + { + get + { + return (_isInfiniteTimeout) ? long.MaxValue : _timerExpire; + } + } + + // Returns milliseconds remaining trimmed to zero for none remaining + // and long.MaxValue for infinite + // This method should be preferred for internal calculations that are not + // yet common enough to code into the TimeoutTimer class itself. + internal long MillisecondsRemaining + { + get + { + //------------------- + // Preconditions: None + + //------------------- + // Method Body + long milliseconds; + if (_isInfiniteTimeout) + { + milliseconds = long.MaxValue; + } + else + { + milliseconds = ADP.TimerRemainingMilliseconds(_timerExpire); + if (0 > milliseconds) + { + milliseconds = 0; + } + } + + //-------------------- + // Postconditions + Debug.Assert(0 <= milliseconds); // This property guarantees no negative return values + + return milliseconds; + } + } + } +} + diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.Windows.cs new file mode 100644 index 0000000000..83ce5085e7 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.Windows.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. +using System.Data; +using System.Data.Common; +using Microsoft.Data.SqlClient.Server; + +namespace Microsoft.Data.Sql +{ + /// + public sealed partial class SqlDataSourceEnumerator : DbDataSourceEnumerator + { + private partial DataTable GetDataSourcesInternal() + { +#if NETFRAMEWORK + return SqlDataSourceEnumeratorNativeHelper.GetDataSources(); +#else + return SqlClient.TdsParserStateObjectFactory.UseManagedSNI ? SqlDataSourceEnumeratorManagedHelper.GetDataSources() : SqlDataSourceEnumeratorNativeHelper.GetDataSources(); +#endif + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.cs new file mode 100644 index 0000000000..e8f7aac29c --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumerator.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. +using System; +using System.Data; +using System.Data.Common; + +namespace Microsoft.Data.Sql +{ + /// + public sealed partial class SqlDataSourceEnumerator : DbDataSourceEnumerator + { + private static readonly Lazy s_singletonInstance = new(() => new SqlDataSourceEnumerator()); + + private SqlDataSourceEnumerator() : base(){} + + /// + public static SqlDataSourceEnumerator Instance => s_singletonInstance.Value; + + /// + override public DataTable GetDataSources() => GetDataSourcesInternal(); + + private partial DataTable GetDataSourcesInternal(); + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorManagedHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorManagedHelper.cs new file mode 100644 index 0000000000..43be666e0d --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorManagedHelper.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. +using System.Collections.Generic; +using System.Data; +using Microsoft.Data.Sql; + +namespace Microsoft.Data.SqlClient.Server +{ + /// + /// Provides a mechanism for enumerating all available instances of SQL Server within the local network + /// + internal static class SqlDataSourceEnumeratorManagedHelper + { + /// + /// Provides a mechanism for enumerating all available instances of SQL Server within the local network. + /// + /// DataTable with ServerName,InstanceName,IsClustered and Version + internal static DataTable GetDataSources() + { + // TODO: Implement multicast request besides the implemented broadcast request. + throw new System.NotImplementedException(StringsHelper.net_MethodNotImplementedException); + } + + private static DataTable ParseServerEnumString(string serverInstances) + { + DataTable dataTable = SqlDataSourceEnumeratorUtil.PrepareDataTable(); + DataRow dataRow; + + if (serverInstances.Length == 0) + { + return dataTable; + } + + string[] numOfServerInstances = serverInstances.Split(SqlDataSourceEnumeratorUtil.s_endOfServerInstanceDelimiter_Managed, System.StringSplitOptions.None); + SqlClientEventSource.Log.TryTraceEvent(" Number of recieved server instances are {2}", + nameof(SqlDataSourceEnumeratorManagedHelper), nameof(ParseServerEnumString), numOfServerInstances.Length); + + foreach (string currentServerInstance in numOfServerInstances) + { + Dictionary InstanceDetails = new(); + string[] delimitedKeyValues = currentServerInstance.Split(SqlDataSourceEnumeratorUtil.InstanceKeysDelimiter); + string currentKey = string.Empty; + + for (int keyvalue = 0; keyvalue < delimitedKeyValues.Length; keyvalue++) + { + if (keyvalue % 2 == 0) + { + currentKey = delimitedKeyValues[keyvalue]; + } + else if (currentKey != string.Empty) + { + InstanceDetails.Add(currentKey, delimitedKeyValues[keyvalue]); + } + } + + if (InstanceDetails.Count > 0) + { + dataRow = dataTable.NewRow(); + dataRow[0] = InstanceDetails.ContainsKey(SqlDataSourceEnumeratorUtil.ServerNameCol) == true ? + InstanceDetails[SqlDataSourceEnumeratorUtil.ServerNameCol] : string.Empty; + dataRow[1] = InstanceDetails.ContainsKey(SqlDataSourceEnumeratorUtil.InstanceNameCol) == true ? + InstanceDetails[SqlDataSourceEnumeratorUtil.InstanceNameCol] : string.Empty; + dataRow[2] = InstanceDetails.ContainsKey(SqlDataSourceEnumeratorUtil.IsClusteredCol) == true ? + InstanceDetails[SqlDataSourceEnumeratorUtil.IsClusteredCol] : string.Empty; + dataRow[3] = InstanceDetails.ContainsKey(SqlDataSourceEnumeratorUtil.VersionNameCol) == true ? + InstanceDetails[SqlDataSourceEnumeratorUtil.VersionNameCol] : string.Empty; + + dataTable.Rows.Add(dataRow); + } + } + return dataTable.SetColumnsReadOnly(); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorNativeHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorNativeHelper.cs new file mode 100644 index 0000000000..f6ebfc4b8f --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorNativeHelper.cs @@ -0,0 +1,179 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Data; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Security; +using System.Text; +using Microsoft.Data.Common; +using Microsoft.Data.SqlClient; +using static Microsoft.Data.Sql.SqlDataSourceEnumeratorUtil; + +namespace Microsoft.Data.Sql +{ + /// + /// Provides a mechanism for enumerating all available instances of SQL Server within the local network + /// + internal static class SqlDataSourceEnumeratorNativeHelper + { + /// + /// Retrieves a DataTable containing information about all visible SQL Server instances + /// + /// + internal static DataTable GetDataSources() + { + (new NamedPermissionSet("FullTrust")).Demand(); // SQLBUDT 244304 + char[] buffer = null; + StringBuilder strbldr = new(); + + int bufferSize = 1024; + int readLength = 0; + buffer = new char[bufferSize]; + bool more = true; + bool failure = false; + IntPtr handle = ADP.s_ptrZero; + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + long s_timeoutTime = TdsParserStaticMethods.GetTimeoutSeconds(ADP.DefaultCommandTimeout); + RuntimeHelpers.PrepareConstrainedRegions(); + try + { } + finally + { + handle = SNINativeMethodWrapper.SNIServerEnumOpen(); + SqlClientEventSource.Log.TryTraceEvent(" {2} returned handle = {3}.", + nameof(SqlDataSourceEnumeratorNativeHelper), + nameof(GetDataSources), + nameof(SNINativeMethodWrapper.SNIServerEnumOpen), handle); + } + + if (handle != ADP.s_ptrZero) + { + while (more && !TdsParserStaticMethods.TimeoutHasExpired(s_timeoutTime)) + { + readLength = SNINativeMethodWrapper.SNIServerEnumRead(handle, buffer, bufferSize, out more); + + SqlClientEventSource.Log.TryTraceEvent(" {2} returned 'readlength':{3}, and 'more':{4} with 'bufferSize' of {5}", + nameof(SqlDataSourceEnumeratorNativeHelper), + nameof(GetDataSources), + nameof(SNINativeMethodWrapper.SNIServerEnumRead), + readLength, more, bufferSize); + if (readLength > bufferSize) + { + failure = true; + more = false; + } + else if (readLength > 0) + { + strbldr.Append(buffer, 0, readLength); + } + } + } + } + finally + { + if (handle != ADP.s_ptrZero) + { + SNINativeMethodWrapper.SNIServerEnumClose(handle); + SqlClientEventSource.Log.TryTraceEvent(" {2} called.", + nameof(SqlDataSourceEnumeratorNativeHelper), + nameof(GetDataSources), + nameof(SNINativeMethodWrapper.SNIServerEnumClose)); + } + } + + if (failure) + { + Debug.Assert(false, $"{nameof(GetDataSources)}:{nameof(SNINativeMethodWrapper.SNIServerEnumRead)} returned bad length"); + SqlClientEventSource.Log.TryTraceEvent(" {2} returned bad length, requested buffer {3}, received {4}", + nameof(SqlDataSourceEnumeratorNativeHelper), + nameof(GetDataSources), + nameof(SNINativeMethodWrapper.SNIServerEnumRead), + bufferSize, readLength); + + throw ADP.ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_ParameterValueOutOfRange, readLength), nameof(readLength)); + } + return ParseServerEnumString(strbldr.ToString()); + } + + private static DataTable ParseServerEnumString(string serverInstances) + { + DataTable dataTable = PrepareDataTable(); + string serverName = null; + string instanceName = null; + string isClustered = null; + string version = null; + string[] serverinstanceslist = serverInstances.Split(EndOfServerInstanceDelimiter_Native); + SqlClientEventSource.Log.TryTraceEvent(" Number of recieved server instances are {2}", + nameof(SqlDataSourceEnumeratorNativeHelper), nameof(ParseServerEnumString), serverinstanceslist.Length); + + // Every row comes in the format "serverName\instanceName;Clustered:[Yes|No];Version:.." + // Every row is terminated by a null character. + // Process one row at a time + foreach (string instance in serverinstanceslist) + { + string value = instance.Trim(EndOfServerInstanceDelimiter_Native); // MDAC 91934 + if (value.Length == 0) + { + continue; + } + foreach (string instance2 in value.Split(InstanceKeysDelimiter)) + { + if (serverName == null) + { + foreach (string instance3 in instance2.Split(ServerNamesAndInstanceDelimiter)) + { + if (serverName == null) + { + serverName = instance3; + continue; + } + Debug.Assert(instanceName == null, $"{nameof(instanceName)}({instanceName}) is not null."); + instanceName = instance3; + } + continue; + } + if (isClustered == null) + { + Debug.Assert(string.Compare(Clustered, 0, instance2, 0, s_clusteredLength, StringComparison.OrdinalIgnoreCase) == 0, + $"{nameof(Clustered)} ({Clustered}) doesn't equal {nameof(instance2)} ({instance2})"); + isClustered = instance2.Substring(s_clusteredLength); + continue; + } + Debug.Assert(version == null, $"{nameof(version)}({version}) is not null."); + Debug.Assert(string.Compare(SqlDataSourceEnumeratorUtil.Version, 0, instance2, 0, s_versionLength, StringComparison.OrdinalIgnoreCase) == 0, + $"{nameof(SqlDataSourceEnumeratorUtil.Version)} ({SqlDataSourceEnumeratorUtil.Version}) doesn't equal {nameof(instance2)} ({instance2})"); + version = instance2.Substring(s_versionLength); + } + + string query = "ServerName='" + serverName + "'"; + + if (!ADP.IsEmpty(instanceName)) + { // SQL BU DT 20006584: only append instanceName if present. + query += " AND InstanceName='" + instanceName + "'"; + } + + // SNI returns dupes - do not add them. SQL BU DT 290323 + if (dataTable.Select(query).Length == 0) + { + DataRow dataRow = dataTable.NewRow(); + dataRow[0] = serverName; + dataRow[1] = instanceName; + dataRow[2] = isClustered; + dataRow[3] = version; + dataTable.Rows.Add(dataRow); + } + serverName = null; + instanceName = null; + isClustered = null; + version = null; + } + return dataTable.SetColumnsReadOnly(); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorUtil.cs new file mode 100644 index 0000000000..fb6972d8cf --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlDataSourceEnumeratorUtil.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Data; +using System.Globalization; + +namespace Microsoft.Data.Sql +{ + /// + /// const values for SqlDataSourceEnumerator + /// + internal static class SqlDataSourceEnumeratorUtil + { + internal const string ServerNameCol = "ServerName"; + internal const string InstanceNameCol = "InstanceName"; + internal const string IsClusteredCol = "IsClustered"; + internal const string VersionNameCol = "Version"; + + internal const string Version = "Version:"; + internal const string Clustered = "Clustered:"; + internal static readonly int s_versionLength = Version.Length; + internal static readonly int s_clusteredLength = Clustered.Length; + + internal static readonly string[] s_endOfServerInstanceDelimiter_Managed = new[] { ";;" }; + internal const char EndOfServerInstanceDelimiter_Native = '\0'; + internal const char InstanceKeysDelimiter = ';'; + internal const char ServerNamesAndInstanceDelimiter = '\\'; + + internal static DataTable PrepareDataTable() + { + DataTable dataTable = new("SqlDataSources"); + dataTable.Locale = CultureInfo.InvariantCulture; + dataTable.Columns.Add(ServerNameCol, typeof(string)); + dataTable.Columns.Add(InstanceNameCol, typeof(string)); + dataTable.Columns.Add(IsClusteredCol, typeof(string)); + dataTable.Columns.Add(VersionNameCol, typeof(string)); + + return dataTable; + } + + /// + /// Sets all columns read-only. + /// + internal static DataTable SetColumnsReadOnly(this DataTable dataTable) + { + foreach (DataColumn column in dataTable.Columns) + { + column.ReadOnly = true; + } + return dataTable; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlNotificationRequest.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlNotificationRequest.cs new file mode 100644 index 0000000000..ccbff8fc0f --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/Sql/SqlNotificationRequest.cs @@ -0,0 +1,80 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Data.Common; +using Microsoft.Data.SqlClient; + +namespace Microsoft.Data.Sql +{ + /// + public sealed class SqlNotificationRequest + { + private string _userData; + private string _options; + private int _timeout; + + /// + public SqlNotificationRequest() + : this(null, null, SQL.SqlDependencyTimeoutDefault) { } + + /// + public SqlNotificationRequest(string userData, string options, int timeout) + { + UserData = userData; + Timeout = timeout; + Options = options; + } + + /// + public string Options + { + get + { + return _options; + } + set + { + if ((null != value) && (ushort.MaxValue < value.Length)) + { + throw ADP.ArgumentOutOfRange(string.Empty, nameof(Options)); + } + _options = value; + } + } + + /// + public int Timeout + { + get + { + return _timeout; + } + set + { + if (0 > value) + { + throw ADP.ArgumentOutOfRange(string.Empty, nameof(Timeout)); + } + _timeout = value; + } + } + + /// + public string UserData + { + get + { + return _userData; + } + set + { + if ((null != value) && (ushort.MaxValue < value.Length)) + { + throw ADP.ArgumentOutOfRange(string.Empty, nameof(UserData)); + } + _userData = value; + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs new file mode 100644 index 0000000000..a8fdf219d3 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -0,0 +1,516 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Concurrent; +using System.Security; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Identity; +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.Extensibility; + +namespace Microsoft.Data.SqlClient +{ + /// + public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider + { + /// + /// This is a static cache instance meant to hold instances of "PublicClientApplication" mapping to information available in PublicClientAppKey. + /// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode + /// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache. + /// + private static ConcurrentDictionary s_pcaMap + = new ConcurrentDictionary(); + private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient"; + private static readonly string s_defaultScopeSuffix = "/.default"; + private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name; + private readonly SqlClientLogger _logger = new SqlClientLogger(); + private Func _deviceCodeFlowCallback; + private ICustomWebUi _customWebUI = null; + private readonly string _applicationClientId = ActiveDirectoryAuthentication.AdoClientId; + + /// + public ActiveDirectoryAuthenticationProvider() + : this(DefaultDeviceFlowCallback) + { + } + + /// + public ActiveDirectoryAuthenticationProvider(string applicationClientId) + : this(DefaultDeviceFlowCallback, applicationClientId) + { + } + + /// + public ActiveDirectoryAuthenticationProvider(Func deviceCodeFlowCallbackMethod, string applicationClientId = null) + { + if (applicationClientId != null) + { + _applicationClientId = applicationClientId; + } + SetDeviceCodeFlowCallback(deviceCodeFlowCallbackMethod); + } + + /// + public static void ClearUserTokenCache() + { + if (!s_pcaMap.IsEmpty) + { + s_pcaMap.Clear(); + } + } + + /// + public void SetDeviceCodeFlowCallback(Func deviceCodeFlowCallbackMethod) => _deviceCodeFlowCallback = deviceCodeFlowCallbackMethod; + + /// + public void SetAcquireAuthorizationCodeAsyncCallback(Func> acquireAuthorizationCodeAsyncCallback) => _customWebUI = new CustomWebUi(acquireAuthorizationCodeAsyncCallback); + + /// + public override bool IsSupported(SqlAuthenticationMethod authentication) + { + return authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated + || authentication == SqlAuthenticationMethod.ActiveDirectoryPassword + || authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive + || authentication == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal + || authentication == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow + || authentication == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity + || authentication == SqlAuthenticationMethod.ActiveDirectoryMSI + || authentication == SqlAuthenticationMethod.ActiveDirectoryDefault; + } + + /// + public override void BeforeLoad(SqlAuthenticationMethod authentication) + { + _logger.LogInfo(_type, "BeforeLoad", $"being loaded into SqlAuthProviders for {authentication}."); + } + + /// + public override void BeforeUnload(SqlAuthenticationMethod authentication) + { + _logger.LogInfo(_type, "BeforeUnload", $"being unloaded from SqlAuthProviders for {authentication}."); + } + +#if NETSTANDARD + private Func _parentActivityOrWindowFunc = null; + + /// + public void SetParentActivityOrWindowFunc(Func parentActivityOrWindowFunc) => this._parentActivityOrWindowFunc = parentActivityOrWindowFunc; +#endif + +#if NETFRAMEWORK + private Func _iWin32WindowFunc = null; + + /// + public void SetIWin32WindowFunc(Func iWin32WindowFunc) => this._iWin32WindowFunc = iWin32WindowFunc; +#endif + + /// + public override async Task AcquireTokenAsync(SqlAuthenticationParameters parameters) + { + CancellationTokenSource cts = new CancellationTokenSource(); + + // Use Connection timeout value to cancel token acquire request after certain period of time. + cts.CancelAfter(parameters.ConnectionTimeout * 1000); // Convert to milliseconds + + string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix; + string[] scopes = new string[] { scope }; + TokenRequestContext tokenRequestContext = new(scopes); + + /* We split audience from Authority URL here. Audience can be one of the following: + * The Azure AD authority audience enumeration + * The tenant ID, which can be: + * - A GUID (the ID of your Azure AD instance), for single-tenant applications + * - A domain name associated with your Azure AD instance (also for single-tenant applications) + * One of these placeholders as a tenant ID in place of the Azure AD authority audience enumeration: + * - `organizations` for a multitenant application + * - `consumers` to sign in users only with their personal accounts + * - `common` to sign in users with their work and school accounts or their personal Microsoft accounts + * + * MSAL will throw a meaningful exception if you specify both the Azure AD authority audience and the tenant ID. + * If you don't specify an audience, your app will target Azure AD and personal Microsoft accounts as an audience. (That is, it will behave as though `common` were specified.) + * More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration + **/ + + int seperatorIndex = parameters.Authority.LastIndexOf('/'); + string authority = parameters.Authority.Remove(seperatorIndex + 1); + string audience = parameters.Authority.Substring(seperatorIndex + 1); + string clientId = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId; + + if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDefault) + { + DefaultAzureCredentialOptions defaultAzureCredentialOptions = new() + { + AuthorityHost = new Uri(authority), + SharedTokenCacheTenantId = audience, + VisualStudioCodeTenantId = audience, + VisualStudioTenantId = audience, + ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications. + }; + + // Optionally set clientId when available + if (clientId is not null) + { + defaultAzureCredentialOptions.ManagedIdentityClientId = clientId; + defaultAzureCredentialOptions.SharedTokenCacheUsername = clientId; + } + AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } + + TokenCredentialOptions tokenCredentialOptions = new TokenCredentialOptions() { AuthorityHost = new Uri(authority) }; + + if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI) + { + AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } + + AuthenticationResult result; + if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal) + { + AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } + + /* + * Today, MSAL.NET uses another redirect URI by default in desktop applications that run on Windows + * (urn:ietf:wg:oauth:2.0:oob). In the future, we'll want to change this default, so we recommend + * that you use https://login.microsoftonline.com/common/oauth2/nativeclient. + * + * https://docs.microsoft.com/en-us/azure/active-directory/develop/scenario-desktop-app-registration#redirect-uris + */ + string redirectUri = s_nativeClientRedirectUri; + +#if NETCOREAPP + if (parameters.AuthenticationMethod != SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) + { + redirectUri = "http://localhost"; + } +#endif + PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId +#if NETFRAMEWORK + , _iWin32WindowFunc +#endif +#if NETSTANDARD + , _parentActivityOrWindowFunc +#endif + ); + + IPublicClientApplication app = GetPublicClientAppInstance(pcaKey); + + if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) + { + if (!string.IsNullOrEmpty(parameters.UserId)) + { + result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) + .WithCorrelationId(parameters.ConnectionId) + .WithUsername(parameters.UserId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + } + else + { + result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + } + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn); + } + else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword) + { + SecureString password = new SecureString(); + foreach (char c in parameters.Password) + password.AppendChar(c); + password.MakeReadOnly(); + + result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); + } + else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || + parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) + { + // Fetch available accounts from 'app' instance + System.Collections.Generic.IEnumerator accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator(); + + IAccount account = default; + if (accounts.MoveNext()) + { + if (!string.IsNullOrEmpty(parameters.UserId)) + { + do + { + IAccount currentVal = accounts.Current; + if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0) + { + account = currentVal; + break; + } + } + while (accounts.MoveNext()); + } + else + { + account = accounts.Current; + } + } + + if (null != account) + { + try + { + // If 'account' is available in 'app', we use the same to acquire token silently. + // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent + result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + catch (MsalUiRequiredException) + { + // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, + // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), + // or the user needs to perform two factor authentication. + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + } + else + { + // If no existing 'account' is found, we request user to sign in interactively. + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + } + else + { + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | {0} authentication mode not supported by ActiveDirectoryAuthenticationProvider class.", parameters.AuthenticationMethod); + throw SQL.UnsupportedAuthenticationSpecified(parameters.AuthenticationMethod); + } + + return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); + } + + private async Task AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, + SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts) + { + try + { + if (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) + { + CancellationTokenSource ctsInteractive = new CancellationTokenSource(); +#if NETCOREAPP + /* + * On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser, + * but once the user finishes authentication, the web page is redirected in such a way that MSAL can intercept the Uri. + * MSAL cannot detect if the user navigates away or simply closes the browser. Apps using this technique are encouraged + * to define a timeout (via CancellationToken). We recommend a timeout of at least a few minutes, to take into account + * cases where the user is prompted to change password or perform 2FA. + * + * https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/wiki/System-Browser-on-.Net-Core#system-browser-experience + */ + ctsInteractive.CancelAfter(180000); +#endif + if (_customWebUI != null) + { + return await app.AcquireTokenInteractive(scopes) + .WithCorrelationId(connectionId) + .WithCustomWebUi(_customWebUI) + .WithLoginHint(userId) + .ExecuteAsync(ctsInteractive.Token) + .ConfigureAwait(false); + } + else + { + /* + * We will use the MSAL Embedded or System web browser which changes by Default in MSAL according to this table: + * + * Framework Embedded System Default + * ------------------------------------------- + * .NET Classic Yes Yes^ Embedded + * .NET Core No Yes^ System + * .NET Standard No No NONE + * UWP Yes No Embedded + * Xamarin.Android Yes Yes System + * Xamarin.iOS Yes Yes System + * Xamarin.Mac Yes No Embedded + * + * ^ Requires "http://localhost" redirect URI + * + * https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/wiki/MSAL.NET-uses-web-browser#at-a-glance + */ + return await app.AcquireTokenInteractive(scopes) + .WithCorrelationId(connectionId) + .WithLoginHint(userId) + .ExecuteAsync(ctsInteractive.Token) + .ConfigureAwait(false); + } + } + else + { + AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes, + deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult)) + .WithCorrelationId(connectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + return result; + } + } + catch (OperationCanceledException) + { + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenInteractiveDeviceFlowAsync | Operation timed out while acquiring access token."); + throw (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) ? + SQL.ActiveDirectoryInteractiveTimeout() : + SQL.ActiveDirectoryDeviceFlowTimeout(); + } + } + + private static Task DefaultDeviceFlowCallback(DeviceCodeResult result) + { + // This will print the message on the console which tells the user where to go sign-in using + // a separate browser and the code to enter once they sign in. + // The AcquireTokenWithDeviceCode() method will poll the server after firing this + // device code callback to look for the successful login of the user via that browser. + // This background polling (whose interval and timeout data is also provided as fields in the + // deviceCodeCallback class) will occur until: + // * The user has successfully logged in via browser and entered the proper code + // * The timeout specified by the server for the lifetime of this code (typically ~15 minutes) has been reached + // * The developing application calls the Cancel() method on a CancellationToken sent into the method. + // If this occurs, an OperationCanceledException will be thrown (see catch below for more details). + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenInteractiveDeviceFlowAsync | Callback triggered with Device Code Result: {0}", result.Message); + Console.WriteLine(result.Message); + return Task.FromResult(0); + } + + private class CustomWebUi : ICustomWebUi + { + private readonly Func> _acquireAuthorizationCodeAsyncCallback; + + internal CustomWebUi(Func> acquireAuthorizationCodeAsyncCallback) => _acquireAuthorizationCodeAsyncCallback = acquireAuthorizationCodeAsyncCallback; + + public Task AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken) + => _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken); + } + + private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey) + { + if (!s_pcaMap.TryGetValue(publicClientAppKey, out IPublicClientApplication clientApplicationInstance)) + { + clientApplicationInstance = CreateClientAppInstance(publicClientAppKey); + s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance); + } + return clientApplicationInstance; + } + + private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey) + { + IPublicClientApplication publicClientApplication; + +#if NETSTANDARD + if (_parentActivityOrWindowFunc != null) + { + publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) + .WithAuthority(publicClientAppKey._authority) + .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) + .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) + .WithRedirectUri(publicClientAppKey._redirectUri) + .WithParentActivityOrWindow(_parentActivityOrWindowFunc) + .Build(); + } +#endif +#if NETFRAMEWORK + if (_iWin32WindowFunc != null) + { + publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) + .WithAuthority(publicClientAppKey._authority) + .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) + .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) + .WithRedirectUri(publicClientAppKey._redirectUri) + .WithParentActivityOrWindow(_iWin32WindowFunc) + .Build(); + } +#endif +#if !NETCOREAPP + else +#endif + { + publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) + .WithAuthority(publicClientAppKey._authority) + .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) + .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) + .WithRedirectUri(publicClientAppKey._redirectUri) + .Build(); + } + + return publicClientApplication; + } + + internal class PublicClientAppKey + { + public readonly string _authority; + public readonly string _redirectUri; + public readonly string _applicationClientId; +#if NETFRAMEWORK + public readonly Func _iWin32WindowFunc; +#endif +#if NETSTANDARD + public readonly Func _parentActivityOrWindowFunc; +#endif + + public PublicClientAppKey(string authority, string redirectUri, string applicationClientId +#if NETFRAMEWORK + , Func iWin32WindowFunc +#endif +#if NETSTANDARD + , Func parentActivityOrWindowFunc +#endif + ) + { + _authority = authority; + _redirectUri = redirectUri; + _applicationClientId = applicationClientId; +#if NETFRAMEWORK + _iWin32WindowFunc = iWin32WindowFunc; +#endif +#if NETSTANDARD + _parentActivityOrWindowFunc = parentActivityOrWindowFunc; +#endif + } + + public override bool Equals(object obj) + { + if (obj != null && obj is PublicClientAppKey pcaKey) + { + return (string.CompareOrdinal(_authority, pcaKey._authority) == 0 + && string.CompareOrdinal(_redirectUri, pcaKey._redirectUri) == 0 + && string.CompareOrdinal(_applicationClientId, pcaKey._applicationClientId) == 0 +#if NETFRAMEWORK + && pcaKey._iWin32WindowFunc == _iWin32WindowFunc +#endif +#if NETSTANDARD + && pcaKey._parentActivityOrWindowFunc == _parentActivityOrWindowFunc +#endif + ); + } + return false; + } + + public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _applicationClientId +#if NETFRAMEWORK + , _iWin32WindowFunc +#endif +#if NETSTANDARD + , _parentActivityOrWindowFunc +#endif + ).GetHashCode(); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ActiveDirectoryAuthenticationTimeoutRetryHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ActiveDirectoryAuthenticationTimeoutRetryHelper.cs new file mode 100644 index 0000000000..473b3b638a --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ActiveDirectoryAuthenticationTimeoutRetryHelper.cs @@ -0,0 +1,139 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.ComponentModel; +using System.Security.Cryptography; +using System.Text; + +namespace Microsoft.Data.SqlClient +{ + /// + /// AD auth retry states. + /// + internal enum ActiveDirectoryAuthenticationTimeoutRetryState + { + NotStarted = 0, + Retrying, + HasLoggedIn, + } + + /// + /// AD auth retry helper. + /// + internal class ActiveDirectoryAuthenticationTimeoutRetryHelper + { + private ActiveDirectoryAuthenticationTimeoutRetryState _state = ActiveDirectoryAuthenticationTimeoutRetryState.NotStarted; + private SqlFedAuthToken _token; + private readonly string _typeName; + private readonly SqlClientLogger _sqlAuthLogger = new SqlClientLogger(); + + /// + /// Constructor. + /// + public ActiveDirectoryAuthenticationTimeoutRetryHelper() + { + _typeName = GetType().Name; + } + + /// + /// Retry state. + /// + public ActiveDirectoryAuthenticationTimeoutRetryState State + { + get { return _state; } + set + { + switch (_state) + { + case ActiveDirectoryAuthenticationTimeoutRetryState.NotStarted: + if (value != ActiveDirectoryAuthenticationTimeoutRetryState.Retrying + && value != ActiveDirectoryAuthenticationTimeoutRetryState.HasLoggedIn) + { + throw new InvalidOperationException($"Cannot transit from {_state} to {value}."); + } + break; + case ActiveDirectoryAuthenticationTimeoutRetryState.Retrying: + if (value != ActiveDirectoryAuthenticationTimeoutRetryState.HasLoggedIn) + { + throw new InvalidOperationException($"Cannot transit from {_state} to {value}."); + } + break; + case ActiveDirectoryAuthenticationTimeoutRetryState.HasLoggedIn: + throw new InvalidOperationException($"Cannot transit from {_state} to {value}."); + default: + throw new InvalidOperationException($"Unsupported state: {value}."); + } + if (_sqlAuthLogger.IsLoggingEnabled) + { + _sqlAuthLogger.LogInfo(_typeName, "SetState", $"State changed from {_state} to {value}."); + } + _state = value; + } + } + + /// + /// Cached token. + /// + public SqlFedAuthToken CachedToken + { + get + { + if (_sqlAuthLogger.IsLoggingEnabled) + { + _sqlAuthLogger.LogInfo(_typeName, "GetCachedToken", $"Retrieved cached token {GetTokenHash(_token)}."); + } + return _token; + } + set + { + if (_sqlAuthLogger.IsLoggingEnabled) + { + _sqlAuthLogger.LogInfo(_typeName, "SetCachedToken", $"CachedToken changed from {GetTokenHash(_token)} to {GetTokenHash(value)}."); + } + _token = value; + } + } + + /// + /// Whether login can be retried after a client/server connection timeout due to a long-time token acquisition. + /// + public bool CanRetryWithSqlException(SqlException sqlex) + { + var methodName = "CheckCanRetry"; + if (_sqlAuthLogger.LogAssert(_state == ActiveDirectoryAuthenticationTimeoutRetryState.NotStarted, _typeName, methodName, $"Cannot retry due to state == {_state}.") + && _sqlAuthLogger.LogAssert(CachedToken != null, _typeName, methodName, $"Cannot retry when cached token is null.") + && _sqlAuthLogger.LogAssert(IsConnectTimeoutError(sqlex), _typeName, methodName, $"Cannot retry when exception is not timeout.")) + { + _sqlAuthLogger.LogInfo(_typeName, methodName, "All checks passed."); + return true; + } + return false; + } + + private static bool IsConnectTimeoutError(SqlException sqlex) + { + var innerException = sqlex.InnerException as Win32Exception; + if (innerException == null) + return false; + return innerException.NativeErrorCode == 10054 // Server timeout + || innerException.NativeErrorCode == 258; // Client timeout + } + + private static string GetTokenHash(SqlFedAuthToken token) + { + if (token == null) + return "null"; + + // Here we mimic how ADAL calculates hash for token. They use UTF8 instead of Unicode. + var originalTokenString = SqlAuthenticationToken.AccessTokenStringFromBytes(token.accessToken); + var bytesInUtf8 = Encoding.UTF8.GetBytes(originalTokenString); + using (var sha256 = SHA256.Create()) + { + var hash = sha256.ComputeHash(bytesInUtf8); + return Convert.ToBase64String(hash); + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/AlwaysEncryptedEnclaveProviderUtils.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/AlwaysEncryptedEnclaveProviderUtils.cs new file mode 100644 index 0000000000..3263031dde --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/AlwaysEncryptedEnclaveProviderUtils.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.Data.SqlClient +{ + internal class EnclavePublicKey + { + public byte[] PublicKey { get; set; } + + public EnclavePublicKey(byte[] payload) + { + PublicKey = payload; + } + } + + internal class EnclaveDiffieHellmanInfo + { + public int Size { get; private set; } + + public byte[] PublicKey { get; private set; } + + public byte[] PublicKeySignature { get; private set; } + + public EnclaveDiffieHellmanInfo(byte[] payload) + { + Size = payload.Length; + + int publicKeySize = BitConverter.ToInt32(payload, 0); + int publicKeySignatureSize = BitConverter.ToInt32(payload, 4); + + PublicKey = new byte[publicKeySize]; + PublicKeySignature = new byte[publicKeySignatureSize]; + Buffer.BlockCopy(payload, 8, PublicKey, 0, publicKeySize); + Buffer.BlockCopy(payload, 8 + publicKeySize, PublicKeySignature, 0, publicKeySignatureSize); + } + } + + internal enum EnclaveType + { + None = 0, + /// + /// Virtualization Based Security + /// + Vbs = 1, + /// + /// Intel SGX based security + /// + Sgx = 2 + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ApplicationIntent.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ApplicationIntent.cs new file mode 100644 index 0000000000..4e8bf1afa5 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ApplicationIntent.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Data.SqlClient +{ + /// +#if NETFRAMEWORK + [System.Serializable] +#endif + public enum ApplicationIntent + { + /// + ReadWrite = 0, + + /// + ReadOnly = 1, + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/AssemblyRef.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/AssemblyRef.cs new file mode 100644 index 0000000000..b48251e9cf --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/AssemblyRef.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +//------------------------------------------------------------------------------ +// This code was auto-generated by msbuild target. +// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated. +//------------------------------------------------------------------------------ + +namespace Microsoft.Data.SqlClient +{ + internal static class AssemblyRef + { + // NOTE: The current Microsoft.VSDesigner editor attributes are implemented for System.Data.SqlClient, and are not publicly available. + // New attributes that are designed to work with Microsoft.Data.SqlClient and are publicly documented should be included in future. + //internal const string MicrosoftVSDesigner = "Microsoft.VSDesigner, Version=10.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a"; + internal const string SystemDrawing = "System.Drawing, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a"; + internal const string EcmaPublicKey = "b77a5c561934e089"; + internal const string EcmaPublicKeyFull = "00000000000000000400000000000000"; + internal const string SystemDesign = "System.Design, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a"; + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/AzureAttestationBasedEnclaveProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/AzureAttestationBasedEnclaveProvider.cs new file mode 100644 index 0000000000..433347ef10 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/AzureAttestationBasedEnclaveProvider.cs @@ -0,0 +1,544 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.IdentityModel.Tokens.Jwt; +using System.Linq; +using System.Runtime.Caching; +using System.Security.Claims; +using System.Security.Cryptography; +using System.Text; +using System.Threading; +using Microsoft.IdentityModel.JsonWebTokens; +using Microsoft.IdentityModel.Logging; +using Microsoft.IdentityModel.Protocols; +using Microsoft.IdentityModel.Protocols.OpenIdConnect; +using Microsoft.IdentityModel.Tokens; + +// Azure Attestation Protocol Flow +// To start the attestation process, Sql Client sends the Protocol Id (i.e. 1), Nonce, Attestation Url and ECDH Public Key +// Sql Server uses attestation Url to attest the enclave and send the JWT to Sql client. +// Along with JWT, Sql server also sends enclave RSA public key, enclave Type, Enclave ECDH Public key. + +// To verify the chain of trust here is how it works +// JWT is signed by well-known signing keys which Sql client can download over https (via OpenIdConnect protocol). +// JWT contains the Enclave public key to safeguard against spoofing enclave RSA public key. +// Enclave ECDH public key signed by enclave RSA key + +// JWT validation +// To get the signing key for the JWT, we use OpenIdConnect API's. It download the signing keys from the well-known endpoint. +// We validate that JWT is signed, valid (i.e. not expired) and check the Issuer. + +// Claim validation: +// Validate the RSA public key send by Sql server matches the value specified in JWT. + +// Enclave Specific checks +// VSM +// Validate the nonce send by Sql client during start of attestation is same as that of specified in the JWT + +// SGX +// JWT for SGX enclave does not contain nonce claim. To workaround this limitation Sql Server sends the RSA public key XOR with the Nonce. +// In Sql server tempered with the nonce value then both Sql Server and client will not able to compute the same shared secret. + +namespace Microsoft.Data.SqlClient +{ + // Implementation of an Enclave provider (both for Sgx and Vsm) with Azure Attestation + internal class AzureAttestationEnclaveProvider : EnclaveProviderBase + { + #region Constants + private const int DiffieHellmanKeySize = 384; + private const int AzureBasedAttestationProtocolId = (int)SqlConnectionAttestationProtocol.AAS; + private const int SigningKeyRetryInSec = 3; + #endregion + + #region Members + // this is meta data endpoint for AAS provided by Windows team + // i.e. https:///.well-known/openid-configuration + // such as https://sql.azure.attest.com/.well-known/openid-configuration + private const string AttestationUrlSuffix = @"/.well-known/openid-configuration"; + + private static readonly MemoryCache OpenIdConnectConfigurationCache = new MemoryCache("OpenIdConnectConfigurationCache"); + #endregion + + #region Internal methods + // When overridden in a derived class, looks up an existing enclave session information in the enclave session cache. + // If the enclave provider doesn't implement enclave session caching, this method is expected to return null in the sqlEnclaveSession parameter. + internal override void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength) + { + GetEnclaveSessionHelper(enclaveSessionParameters, generateCustomData, out sqlEnclaveSession, out counter, out customData, out customDataLength); + } + + // Gets the information that SqlClient subsequently uses to initiate the process of attesting the enclave and to establish a secure session with the enclave. + internal override SqlEnclaveAttestationParameters GetAttestationParameters(string attestationUrl, byte[] customData, int customDataLength) + { + ECDiffieHellman clientDHKey = KeyConverter.CreateECDiffieHellman(DiffieHellmanKeySize); + byte[] attestationParam = PrepareAttestationParameters(attestationUrl, customData, customDataLength); + return new SqlEnclaveAttestationParameters(AzureBasedAttestationProtocolId, attestationParam, clientDHKey); + } + + // When overridden in a derived class, performs enclave attestation, generates a symmetric key for the session, creates a an enclave session and stores the session information in the cache. + internal override void CreateEnclaveSession(byte[] attestationInfo, ECDiffieHellman clientDHKey, EnclaveSessionParameters enclaveSessionParameters, byte[] customData, int customDataLength, out SqlEnclaveSession sqlEnclaveSession, out long counter) + { + sqlEnclaveSession = null; + counter = 0; + try + { + ThreadRetryCache.Remove(Thread.CurrentThread.ManagedThreadId.ToString()); + sqlEnclaveSession = GetEnclaveSessionFromCache(enclaveSessionParameters, out counter); + if (sqlEnclaveSession == null) + { + if (!string.IsNullOrEmpty(enclaveSessionParameters.AttestationUrl) && customData != null && customDataLength > 0) + { + byte[] nonce = customData; + + IdentityModelEventSource.ShowPII = true; + + // Deserialize the payload + AzureAttestationInfo attestInfo = new AzureAttestationInfo(attestationInfo); + + // Validate the attestation info + VerifyAzureAttestationInfo(enclaveSessionParameters.AttestationUrl, attestInfo.EnclaveType, attestInfo.AttestationToken.AttestationToken, attestInfo.Identity, nonce); + + // Set up shared secret and validate signature + byte[] sharedSecret = GetSharedSecret(attestInfo.Identity, nonce, attestInfo.EnclaveType, attestInfo.EnclaveDHInfo, clientDHKey); + + // add session to cache + sqlEnclaveSession = AddEnclaveSessionToCache(enclaveSessionParameters, sharedSecret, attestInfo.SessionId, out counter); + } + else + { + throw SQL.AttestationFailed(Strings.FailToCreateEnclaveSession); + } + } + } + finally + { + // As per current design, we want to minimize the number of create session calls. To achieve this we block all the GetEnclaveSession calls until the first call to + // GetEnclaveSession -> GetAttestationParameters -> CreateEnclaveSession completes or the event timeout happen. + // Case 1: When the first request successfully creates the session, then all outstanding GetEnclaveSession will use the current session. + // Case 2: When the first request unable to create the enclave session (may be due to some error or the first request doesn't require enclave computation) then in those case we set the event timeout to 0. + UpdateEnclaveSessionLockStatus(sqlEnclaveSession); + } + } + + // When overridden in a derived class, looks up and evicts an enclave session from the enclave session cache, if the provider implements session caching. + internal override void InvalidateEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, SqlEnclaveSession enclaveSessionToInvalidate) + { + InvalidateEnclaveSessionHelper(enclaveSessionParameters, enclaveSessionToInvalidate); + } + #endregion + + #region Internal Class + + // A model class representing the deserialization of the byte payload the client + // receives from SQL Server while setting up a session. + // Protocol format: + // 1. Total Size of the attestation blob as UINT + // 2. Size of Enclave RSA public key as UINT + // 3. Size of Attestation token as UINT + // 4. Enclave Type as UINT + // 5. Enclave RSA public key (raw key, of length #2) + // 6. Attestation token (of length #3) + // 7. Size of Session Id was UINT + // 8. Session id value + // 9. Size of enclave ECDH public key + // 10. Enclave ECDH public key (of length #9) + internal class AzureAttestationInfo + { + public uint TotalSize { get; set; } + + // The enclave's RSA Public Key. + // Needed to establish trust of the enclave. + // Used to verify the enclave's DiffieHellman info. + public EnclavePublicKey Identity { get; set; } + + // The enclave report from the SQL Server host's enclave. + public AzureAttestationToken AttestationToken { get; set; } + + // The id of the current session. + // Needed to set up a secure session between the client and enclave. + public long SessionId { get; set; } + + public EnclaveType EnclaveType { get; set; } + + // The DiffieHellman public key and signature of SQL Server host's enclave. + // Needed to set up a secure session between the client and enclave. + public EnclaveDiffieHellmanInfo EnclaveDHInfo { get; set; } + + public AzureAttestationInfo(byte[] attestationInfo) + { + try + { + int offset = 0; + + // Total size of the attestation info buffer + TotalSize = BitConverter.ToUInt32(attestationInfo, offset); + offset += sizeof(uint); + + // Size of the Enclave public key + int identitySize = BitConverter.ToInt32(attestationInfo, offset); + offset += sizeof(uint); + + // Size of the Azure attestation token + int attestationTokenSize = BitConverter.ToInt32(attestationInfo, offset); + offset += sizeof(uint); + + // Enclave type + int enclaveType = BitConverter.ToInt32(attestationInfo, offset); + EnclaveType = (EnclaveType)enclaveType; + offset += sizeof(uint); + + // Get the enclave public key + byte[] identityBuffer = attestationInfo.Skip(offset).Take(identitySize).ToArray(); + Identity = new EnclavePublicKey(identityBuffer); + offset += identitySize; + + // Get Azure attestation token + byte[] attestationTokenBuffer = attestationInfo.Skip(offset).Take(attestationTokenSize).ToArray(); + AttestationToken = new AzureAttestationToken(attestationTokenBuffer); + offset += attestationTokenSize; + + uint secureSessionInfoResponseSize = BitConverter.ToUInt32(attestationInfo, offset); + offset += sizeof(uint); + + SessionId = BitConverter.ToInt64(attestationInfo, offset); + offset += sizeof(long); + + int secureSessionBufferSize = Convert.ToInt32(secureSessionInfoResponseSize) - sizeof(uint); + byte[] secureSessionBuffer = attestationInfo.Skip(offset).Take(secureSessionBufferSize).ToArray(); + EnclaveDHInfo = new EnclaveDiffieHellmanInfo(secureSessionBuffer); + offset += Convert.ToInt32(EnclaveDHInfo.Size); + } + catch (Exception exception) + { + throw SQL.AttestationFailed(string.Format(Strings.FailToParseAttestationInfo, exception.Message)); + } + } + } + + // A managed model representing the output of EnclaveGetAttestationReport + // https://msdn.microsoft.com/en-us/library/windows/desktop/mt844233(v=vs.85).aspx + internal class AzureAttestationToken + { + public string AttestationToken { get; set; } + + public AzureAttestationToken(byte[] payload) + { + string jwt = System.Text.Encoding.Default.GetString(payload); + AttestationToken = jwt.Trim().Trim('"'); + } + } + #endregion Internal Class + + #region Private helpers + // Prepare the attestation data in following format + // Attestation Url length + // Attestation Url + // Size of nonce + // Nonce value + internal byte[] PrepareAttestationParameters(string attestationUrl, byte[] attestNonce, int attestNonceLength) + { + if (!string.IsNullOrEmpty(attestationUrl) && attestNonce != null && attestNonceLength > 0) + { + // In c# strings are not null terminated, so adding the null termination before serializing it + string attestationUrlLocal = attestationUrl + char.MinValue; + byte[] serializedAttestationUrl = Encoding.Unicode.GetBytes(attestationUrlLocal); + byte[] serializedAttestationUrlLength = BitConverter.GetBytes(serializedAttestationUrl.Length); + + // serializing nonce + byte[] serializedNonce = attestNonce; + byte[] serializedNonceLength = BitConverter.GetBytes(attestNonceLength); + + // Computing the total length of the data + int totalDataSize = serializedAttestationUrl.Length + serializedAttestationUrlLength.Length + serializedNonce.Length + serializedNonceLength.Length; + + int dataCopied = 0; + byte[] attestationParam = new byte[totalDataSize]; + + // copy the attestation url and url length + Buffer.BlockCopy(serializedAttestationUrlLength, 0, attestationParam, dataCopied, serializedAttestationUrlLength.Length); + dataCopied += serializedAttestationUrlLength.Length; + + Buffer.BlockCopy(serializedAttestationUrl, 0, attestationParam, dataCopied, serializedAttestationUrl.Length); + dataCopied += serializedAttestationUrl.Length; + + // copy the nonce and nonce length + Buffer.BlockCopy(serializedNonceLength, 0, attestationParam, dataCopied, serializedNonceLength.Length); + dataCopied += serializedNonceLength.Length; + + Buffer.BlockCopy(serializedNonce, 0, attestationParam, dataCopied, serializedNonce.Length); + dataCopied += serializedNonce.Length; + + return attestationParam; + } + else + { + throw SQL.AttestationFailed(Strings.FailToCreateEnclaveSession); + } + } + + // Performs Attestation per the protocol used by Azure Attestation Service + private void VerifyAzureAttestationInfo(string attestationUrl, EnclaveType enclaveType, string attestationToken, EnclavePublicKey enclavePublicKey, byte[] nonce) + { + bool shouldForceUpdateSigningKeys = false; + string attestationInstanceUrl = GetAttestationInstanceUrl(attestationUrl); + + bool shouldRetryValidation; + bool isSignatureValid; + string exceptionMessage = string.Empty; + do + { + shouldRetryValidation = false; + + // Get the OpenId config object for the signing keys + OpenIdConnectConfiguration openIdConfig = GetOpenIdConfigForSigningKeys(attestationInstanceUrl, shouldForceUpdateSigningKeys); + + // Verify the token signature against the signing keys downloaded from meta data end point + bool isKeySigningExpired; + isSignatureValid = VerifyTokenSignature(attestationToken, attestationInstanceUrl, openIdConfig.SigningKeys, out isKeySigningExpired, out exceptionMessage); + + // In cases if we fail to validate the token, since we are using the old signing keys + // let's re-download the signing keys again and re-validate the token signature + if (!isSignatureValid && isKeySigningExpired && !shouldForceUpdateSigningKeys) + { + shouldForceUpdateSigningKeys = true; + shouldRetryValidation = true; + } + } + while (shouldRetryValidation); + + if (!isSignatureValid) + { + throw SQL.AttestationFailed(string.Format(Strings.AttestationTokenSignatureValidationFailed, exceptionMessage)); + } + + // Validate claims in the token + ValidateAttestationClaims(enclaveType, attestationToken, enclavePublicKey, nonce); + } + + // Returns the innermost exception value + private static string GetInnerMostExceptionMessage(Exception exception) + { + Exception exLocal = exception; + while (exLocal.InnerException != null) + { + exLocal = exLocal.InnerException; + } + + return exLocal.Message; + } + + // For the given attestation url it downloads the token signing keys from the well-known openid configuration end point. + // It also caches that information for 1 day to avoid DDOS attacks. + private OpenIdConnectConfiguration GetOpenIdConfigForSigningKeys(string url, bool forceUpdate) + { + OpenIdConnectConfiguration openIdConnectConfig = OpenIdConnectConfigurationCache[url] as OpenIdConnectConfiguration; + if (forceUpdate || openIdConnectConfig == null) + { + // Compute the meta data endpoint + string openIdMetadataEndpoint = url + AttestationUrlSuffix; + + try + { + IConfigurationManager configurationManager = new ConfigurationManager(openIdMetadataEndpoint, new OpenIdConnectConfigurationRetriever()); + openIdConnectConfig = configurationManager.GetConfigurationAsync(CancellationToken.None).Result; + } + catch (Exception exception) + { + throw SQL.AttestationFailed(string.Format(Strings.GetAttestationTokenSigningKeysFailed, GetInnerMostExceptionMessage(exception)), exception); + } + + OpenIdConnectConfigurationCache.Add(url, openIdConnectConfig, DateTime.UtcNow.AddDays(1)); + } + + return openIdConnectConfig; + } + + // Return the attestation instance url for given attestation url + // such as for https://sql.azure.attest.com/attest/SgxEnclave?api-version=2017-11-01 + // It will return https://sql.azure.attest.com + private string GetAttestationInstanceUrl(string attestationUrl) + { + Uri attestationUri = new Uri(attestationUrl); + return attestationUri.GetLeftPart(UriPartial.Authority); + } + + // Generate the list of valid issuer Url's (in case if tokenIssuerUrl is using default port) + private static ICollection GenerateListOfIssuers(string tokenIssuerUrl) + { + List issuerUrls = new List(); + + Uri tokenIssuerUri = new Uri(tokenIssuerUrl); + int port = tokenIssuerUri.Port; + bool isDefaultPort = tokenIssuerUri.IsDefaultPort; + + string issuerUrl = tokenIssuerUri.GetLeftPart(UriPartial.Authority); + issuerUrls.Add(issuerUrl); + + if (isDefaultPort) + { + issuerUrls.Add(string.Concat(issuerUrl, ":", port.ToString())); + } + + return issuerUrls; + } + + // Verifies the attestation token is signed by correct signing keys. + private bool VerifyTokenSignature(string attestationToken, string tokenIssuerUrl, ICollection issuerSigningKeys, out bool isKeySigningExpired, out string exceptionMessage) + { + exceptionMessage = string.Empty; + bool isSignatureValid = false; + isKeySigningExpired = false; + + // Configure the TokenValidationParameters + TokenValidationParameters validationParameters = + new TokenValidationParameters + { + RequireExpirationTime = true, + ValidateLifetime = true, + ValidateIssuer = true, + ValidateAudience = false, + RequireSignedTokens = true, + ValidIssuers = GenerateListOfIssuers(tokenIssuerUrl), + IssuerSigningKeys = issuerSigningKeys + }; + + try + { + SecurityToken validatedToken; + JwtSecurityTokenHandler handler = new JwtSecurityTokenHandler(); + var token = handler.ValidateToken(attestationToken, validationParameters, out validatedToken); + isSignatureValid = true; + } + catch (SecurityTokenExpiredException securityException) + { + throw SQL.AttestationFailed(Strings.ExpiredAttestationToken, securityException); + } + catch (SecurityTokenValidationException securityTokenException) + { + isKeySigningExpired = true; + + // Sleep for SigningKeyRetryInSec sec before retrying to download the signing keys again. + Thread.Sleep(SigningKeyRetryInSec * 1000); + exceptionMessage = GetInnerMostExceptionMessage(securityTokenException); + } + catch (Exception exception) + { + throw SQL.AttestationFailed(string.Format(Strings.InvalidAttestationToken, GetInnerMostExceptionMessage(exception))); + } + + return isSignatureValid; + } + + // Computes the SHA256 hash of the byte array + private byte[] ComputeSHA256(byte[] data) + { + byte[] result = null; + try + { + using (SHA256 sha256 = SHA256.Create()) + { + result = sha256.ComputeHash(data); + } + } + catch (Exception argumentException) + { + throw SQL.AttestationFailed(Strings.InvalidArgumentToSHA256, argumentException); + } + return result; + } + + // Validate the claims in the attestation token + private void ValidateAttestationClaims(EnclaveType enclaveType, string attestationToken, EnclavePublicKey enclavePublicKey, byte[] nonce) + { + // Read the json token + JsonWebToken token; + try + { + JsonWebTokenHandler tokenHandler = new JsonWebTokenHandler(); + token = tokenHandler.ReadJsonWebToken(attestationToken); + } + catch (ArgumentException argumentException) + { + throw SQL.AttestationFailed(string.Format(Strings.FailToParseAttestationToken, argumentException.Message)); + } + + // Get all the claims from the token + Dictionary claims = new Dictionary(); + foreach (Claim claim in token.Claims.ToList()) + { + claims.Add(claim.Type, claim.Value); + } + + // Get Enclave held data claim and validate it with the Base64UrlEncode(enclave public key) + ValidateClaim(claims, "aas-ehd", enclavePublicKey.PublicKey); + + if (enclaveType == EnclaveType.Vbs) + { + // Get rp_data claim and validate it with the Base64UrlEncode(nonce) + ValidateClaim(claims, "rp_data", nonce); + } + } + + // Validate the claim value against the actual data + private void ValidateClaim(Dictionary claims, string claimName, byte[] actualData) + { + // Get required claim data + string claimData; + bool hasClaim = claims.TryGetValue(claimName, out claimData); + if (!hasClaim) + { + throw SQL.AttestationFailed(string.Format(Strings.MissingClaimInAttestationToken, claimName)); + } + + // Get the Base64Url of the actual data and compare it with claim + string encodedActualData = string.Empty; + try + { + encodedActualData = Base64UrlEncoder.Encode(actualData); + } + catch (Exception) + { + throw SQL.AttestationFailed(Strings.InvalidArgumentToBase64UrlDecoder); + } + + bool hasValidClaim = string.Equals(encodedActualData, claimData, StringComparison.Ordinal); + if (!hasValidClaim) + { + throw SQL.AttestationFailed(string.Format(Strings.InvalidClaimInAttestationToken, claimName, claimData)); + } + } + + private byte[] GetSharedSecret(EnclavePublicKey enclavePublicKey, byte[] nonce, EnclaveType enclaveType, EnclaveDiffieHellmanInfo enclaveDHInfo, ECDiffieHellman clientDHKey) + { + byte[] enclaveRsaPublicKey = enclavePublicKey.PublicKey; + + // For SGX enclave we Sql server sends the enclave public key XOR'ed with Nonce. + // In case if Sql server replayed old JWT then shared secret will not match and hence client will not able to determine the updated enclave keys. + if (enclaveType == EnclaveType.Sgx) + { + for (int iterator = 0; iterator < enclaveRsaPublicKey.Length; iterator++) + { + enclaveRsaPublicKey[iterator] = (byte)(enclaveRsaPublicKey[iterator] ^ nonce[iterator % nonce.Length]); + } + } + + // Perform signature verification. The enclave's DiffieHellman public key was signed by the enclave's RSA public key. + using (RSA rsa = KeyConverter.CreateRSAFromPublicKeyBlob(enclaveRsaPublicKey)) + { + if (!rsa.VerifyData(enclaveDHInfo.PublicKey, enclaveDHInfo.PublicKeySignature, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1)) + { + throw new ArgumentException(Strings.GetSharedSecretFailed); + } + } + + using (ECDiffieHellman enclaveDHKey = KeyConverter.CreateECDiffieHellmanFromPublicKeyBlob(enclaveDHInfo.PublicKey)) + { + return KeyConverter.DeriveKey(clientDHKey, enclaveDHKey.PublicKey); + } + } + #endregion + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ColumnEncryptionKeyInfo.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ColumnEncryptionKeyInfo.cs new file mode 100644 index 0000000000..1b1dbceba0 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/ColumnEncryptionKeyInfo.cs @@ -0,0 +1,125 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.Data.SqlClient +{ + /// + /// Class encapsulating Column encryption key info + /// + internal class ColumnEncryptionKeyInfo + { + internal readonly int KeyId; + internal readonly int DatabaseId; + internal readonly byte[] DecryptedKeyBytes; + internal readonly byte[] KeyIdBytes; + internal readonly byte[] DatabaseIdBytes; + internal readonly byte[] KeyMetadataVersionBytes; + + private static readonly string _decryptedKeyName = "DecryptedKey"; + private static readonly string _keyMetadataVersionName = "KeyMetadataVersion"; + private static readonly string _className = "ColumnEncryptionKeyInfo"; + private static readonly string _bytePackageName = "BytePackage"; + private static readonly string _serializeToBufferMethodName = "SerializeToBuffer"; + private static readonly string _startOffsetName = "StartOffset"; + + /// + /// Constructor + /// + /// Decrypted key bytes + /// database id for this column encryption key + /// key metadata version for this column encryption key + /// key id for this column encryption key + internal ColumnEncryptionKeyInfo(byte[] decryptedKey, int databaseId, byte[] keyMetadataVersion, int keyid) + { + + if (null == decryptedKey) + throw SQL.NullArgumentInConstructorInternal(_decryptedKeyName, _className); + if (0 == decryptedKey.Length) + throw SQL.EmptyArgumentInConstructorInternal(_decryptedKeyName, _className); + if (null == keyMetadataVersion) + throw SQL.NullArgumentInConstructorInternal(_keyMetadataVersionName, _className); + if (0 == keyMetadataVersion.Length) + throw SQL.EmptyArgumentInConstructorInternal(_keyMetadataVersionName, _className); + + KeyId = keyid; + DatabaseId = databaseId; + DecryptedKeyBytes = decryptedKey; + KeyMetadataVersionBytes = keyMetadataVersion; + + //Covert keyId to Bytes + ushort keyIdUShort; + + try + { + keyIdUShort = (ushort)keyid; + } + catch (Exception e) + { + throw SQL.InvalidKeyIdUnableToCastToUnsignedShort(keyid, e); + } + + KeyIdBytes = BitConverter.GetBytes(keyIdUShort); + + //Covert databaseId to Bytes + uint databaseIdUInt; + + try + { + databaseIdUInt = (uint)databaseId; + } + catch (Exception e) + { + throw SQL.InvalidDatabaseIdUnableToCastToUnsignedInt(databaseId, e); + } + + DatabaseIdBytes = BitConverter.GetBytes(databaseIdUInt); + } + + /// + /// Calculates number of bytes required to serialize this object + /// + /// Number of bytes required for serialization + internal int GetLengthForSerialization() + { + int lengthForSerialization = 0; + lengthForSerialization += DecryptedKeyBytes.Length; + lengthForSerialization += KeyIdBytes.Length; + lengthForSerialization += DatabaseIdBytes.Length; + lengthForSerialization += KeyMetadataVersionBytes.Length; + return lengthForSerialization; + } + + /// + /// Serialize this object in a given byte[] starting at a given offset + /// + /// byte array for serialization + /// start offset in byte array + /// next available offset + internal int SerializeToBuffer(byte[] bytePackage, int startOffset) + { + + if (null == bytePackage) + throw SQL.NullArgumentInternal(_bytePackageName, _className, _serializeToBufferMethodName); + if (0 == bytePackage.Length) + throw SQL.EmptyArgumentInternal(_bytePackageName, _className, _serializeToBufferMethodName); + if (!(startOffset < bytePackage.Length)) + throw SQL.OffsetOutOfBounds(_startOffsetName, _className, _serializeToBufferMethodName); + if ((bytePackage.Length - startOffset) < GetLengthForSerialization()) + throw SQL.InsufficientBuffer(_bytePackageName, _className, _serializeToBufferMethodName); + + Buffer.BlockCopy(DatabaseIdBytes, 0, bytePackage, startOffset, DatabaseIdBytes.Length); + startOffset += DatabaseIdBytes.Length; + Buffer.BlockCopy(KeyMetadataVersionBytes, 0, bytePackage, startOffset, KeyMetadataVersionBytes.Length); + startOffset += KeyMetadataVersionBytes.Length; + Buffer.BlockCopy(KeyIdBytes, 0, bytePackage, startOffset, KeyIdBytes.Length); + startOffset += KeyIdBytes.Length; + Buffer.BlockCopy(DecryptedKeyBytes, 0, bytePackage, startOffset, DecryptedKeyBytes.Length); + startOffset += DecryptedKeyBytes.Length; + + return startOffset; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/DataClassification/SensitivityClassification.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/DataClassification/SensitivityClassification.cs new file mode 100644 index 0000000000..27acd8508b --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Data/SqlClient/DataClassification/SensitivityClassification.cs @@ -0,0 +1,119 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Collections.ObjectModel; + +namespace Microsoft.Data.SqlClient.DataClassification +{ + /// + public sealed class Label + { + /// + public string Name { get; private set; } + + /// + public string Id { get; private set; } + + /// + public Label(string name, string id) + { + Name = name; + Id = id; + } + } + + /// + public sealed class InformationType + { + /// + public string Name { get; private set; } + + /// + public string Id { get; private set; } + + /// + public InformationType(string name, string id) + { + Name = name; + Id = id; + } + } + + /// + public enum SensitivityRank + { + /// + NOT_DEFINED = -1, + /// + NONE = 0, + /// + LOW = 10, + /// + MEDIUM = 20, + /// + HIGH = 30, + /// + CRITICAL = 40 + } + + /// + public sealed class SensitivityProperty + { + /// + public Label Label { get; private set; } + + /// + public InformationType InformationType { get; private set; } + + /// + public SensitivityRank SensitivityRank { get; private set; } + + /// + public SensitivityProperty(Label label, InformationType informationType, SensitivityRank sensitivityRank = SensitivityRank.NOT_DEFINED) // Default to NOT_DEFINED for backwards compatibility + { + Label = label; + InformationType = informationType; + SensitivityRank = sensitivityRank; + } + } + + /// + public sealed class ColumnSensitivity + { + /// + public ReadOnlyCollection SensitivityProperties { get; private set; } + + /// + public ColumnSensitivity(IList sensitivityProperties) + { + SensitivityProperties = new ReadOnlyCollection(sensitivityProperties); + } + } + + /// + public sealed class SensitivityClassification + { + /// + public ReadOnlyCollection