diff --git a/pkg/Microsoft.Private.PackageBaseline/packageIndex.json b/pkg/Microsoft.Private.PackageBaseline/packageIndex.json index 6723f074967c..f9bd6629cf89 100644 --- a/pkg/Microsoft.Private.PackageBaseline/packageIndex.json +++ b/pkg/Microsoft.Private.PackageBaseline/packageIndex.json @@ -1143,7 +1143,7 @@ "4.1.0.0": "4.1.0", "4.1.1.0": "4.3.0", "4.2.0.0": "4.4.0", - "4.2.1.0": "4.5.0" + "4.3.0.0": "4.5.0" } }, "System.Data.SqlXml": { diff --git a/src/System.Data.SqlClient/dir.props b/src/System.Data.SqlClient/dir.props index 4888b29a2299..a30234936837 100644 --- a/src/System.Data.SqlClient/dir.props +++ b/src/System.Data.SqlClient/dir.props @@ -2,7 +2,7 @@ - 4.2.1.0 + 4.3.0.0 MSFT \ No newline at end of file diff --git a/src/System.Data.SqlClient/ref/System.Data.SqlClient.cs b/src/System.Data.SqlClient/ref/System.Data.SqlClient.cs index 3071b6f05229..0f636b5134a8 100644 --- a/src/System.Data.SqlClient/ref/System.Data.SqlClient.cs +++ b/src/System.Data.SqlClient/ref/System.Data.SqlClient.cs @@ -184,6 +184,17 @@ public sealed partial class SqlMetaData public static Microsoft.SqlServer.Server.SqlMetaData InferFromValue(object value, string name) { throw null; } } } +namespace System.Data.Sql +{ + public sealed partial class SqlNotificationRequest + { + public SqlNotificationRequest() { } + public SqlNotificationRequest(string userData, string options, int timeout) { } + public string Options { get { throw null; } set { } } + public int Timeout { get { throw null; } set { } } + public string UserData { get { throw null; } set { } } + } +} namespace System.Data.SqlClient { public enum ApplicationIntent @@ -331,6 +342,7 @@ public sealed partial class SqlCommand : System.Data.Common.DbCommand, System.IC public System.Threading.Tasks.Task ExecuteXmlReaderAsync() { throw null; } public System.Threading.Tasks.Task ExecuteXmlReaderAsync(System.Threading.CancellationToken cancellationToken) { throw null; } public override void Prepare() { } + public System.Data.Sql.SqlNotificationRequest Notification { get { throw null; } set { } } } public sealed partial class SqlConnection : System.Data.Common.DbConnection, System.ICloneable { @@ -426,6 +438,70 @@ public sealed partial class SqlDataAdapter : System.Data.Common.DbDataAdapter, S protected override void OnRowUpdating(System.Data.Common.RowUpdatingEventArgs value) { } object System.ICloneable.Clone() { throw null; } } + public sealed partial class SqlDependency + { + public SqlDependency() { } + public SqlDependency(SqlCommand command) { } + public SqlDependency(SqlCommand command, string options, int timeout) { } + public bool HasChanges { get { throw null; } } + public string Id { get { throw null; } } + public event OnChangeEventHandler OnChange { add { } remove { } } + public void AddCommandDependency(SqlCommand command) { } + public static bool Start(string connectionString) { throw null; } + public static bool Start(string connectionString, string queue) { throw null; } + public static bool Stop(string connectionString) { throw null; } + public static bool Stop(string connectionString, string queue) { throw null; } + } + public delegate void OnChangeEventHandler(object sender, SqlNotificationEventArgs e); + public partial class SqlNotificationEventArgs : System.EventArgs + { + public SqlNotificationEventArgs(SqlNotificationType type, SqlNotificationInfo info, SqlNotificationSource source) { } + public SqlNotificationType Type { get { throw null; } } + public SqlNotificationInfo Info { get { throw null; } } + public SqlNotificationSource Source { get { throw null; } } + } + public enum SqlNotificationInfo + { + Truncate = 0, + Insert = 1, + Update = 2, + Delete = 3, + Drop = 4, + Alter = 5, + Restart = 6, + Error = 7, + Query = 8, + Invalid = 9, + Options = 10, + Isolation = 11, + Expired = 12, + Resource = 13, + PreviousFire = 14, + TemplateLimit = 15, + Merge = 16, + Unknown = -1, + AlreadyChanged = -2 + } + public enum SqlNotificationSource + { + Data = 0, + Timeout = 1, + Object = 2, + Database = 3, + System = 4, + Statement = 5, + Environment = 6, + Execution = 7, + Owner = 8, + Unknown = -1, + Client = -2 + } + public enum SqlNotificationType + { + Change = 0, + Subscribe = 1, + Unknown = -1 + } public sealed partial class SqlRowUpdatedEventArgs : System.Data.Common.RowUpdatedEventArgs { public SqlRowUpdatedEventArgs(DataRow row, IDbCommand command, StatementType statementType, System.Data.Common.DataTableMapping tableMapping) diff --git a/src/System.Data.SqlClient/src/Resources/Strings.resx b/src/System.Data.SqlClient/src/Resources/Strings.resx index bb9f94ce37f3..dc59e7026d80 100644 --- a/src/System.Data.SqlClient/src/Resources/Strings.resx +++ b/src/System.Data.SqlClient/src/Resources/Strings.resx @@ -667,6 +667,33 @@ Number of fields in record '{0}' does not match the number in the original record. + + This SqlCommand object is already associated with another SqlDependency object. + + + The SQL Server Service Broker for the current database is not enabled, and as a result query notifications are not supported. Please enable the Service Broker for this database if you wish to use notifications. + + + When using SqlDependency without providing an options value, SqlDependency.Start() must be called prior to execution of a command added to the SqlDependency instance. + + + When using SqlDependency without providing an options value, SqlDependency.Start() must be called for each server that is being executed against. + + + SqlDependency.Start has been called for the server the command is executing against more than once, but there is no matching server/user/database Start() call for current command. + + + SqlDependency.OnChange does not support multiple event registrations for the same delegate. + + + No SqlDependency exists for the key. + + + Timeout specified is invalid. Timeout cannot be < 0. + + + SqlDependency does not support calling Start() with different connection strings having the same server, user, and database in the same app domain. + The dbType {0} is invalid for this constructor. diff --git a/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj b/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj index a529906a63be..749e243a2f51 100644 --- a/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj +++ b/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj @@ -86,6 +86,7 @@ + @@ -113,6 +114,9 @@ + + + @@ -123,6 +127,11 @@ + + + + + diff --git a/src/System.Data.SqlClient/src/System/Data/Common/AdapterUtil.SqlClient.cs b/src/System.Data.SqlClient/src/System/Data/Common/AdapterUtil.SqlClient.cs index eb411a5a9246..05614321ba6f 100644 --- a/src/System.Data.SqlClient/src/System/Data/Common/AdapterUtil.SqlClient.cs +++ b/src/System.Data.SqlClient/src/System/Data/Common/AdapterUtil.SqlClient.cs @@ -47,6 +47,12 @@ internal static Exception ExceptionWithStackTrace(Exception e) } } + internal static void TraceExceptionWithoutRethrow(Exception e) + { + Debug.Assert(ADP.IsCatchableExceptionType(e), "Invalid exception type, should have been re-thrown!"); + TraceException(" '%ls'\n", e); + } + // // COM+ exceptions // diff --git a/src/System.Data.SqlClient/src/System/Data/Sql/SqlNotificationRequest.cs b/src/System.Data.SqlClient/src/System/Data/Sql/SqlNotificationRequest.cs new file mode 100644 index 000000000000..88c2dc82226c --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/Sql/SqlNotificationRequest.cs @@ -0,0 +1,74 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Data.Common; +using System.Data.SqlClient; + +namespace System.Data.Sql +{ + public sealed class SqlNotificationRequest + { + private string _userData; + private string _options; + private int _timeout; + + public SqlNotificationRequest() + : this(null, null, SQL.SqlDependencyTimeoutDefault) { } + + public SqlNotificationRequest(string userData, string options, int timeout) + { + UserData = userData; + Timeout = timeout; + Options = options; + } + + public string Options + { + get + { + return _options; + } + set + { + if ((null != value) && (ushort.MaxValue < value.Length)) + { + throw ADP.ArgumentOutOfRange(string.Empty, nameof(Options)); + } + _options = value; + } + } + + public int Timeout + { + get + { + return _timeout; + } + set + { + if (0 > value) + { + throw ADP.ArgumentOutOfRange(string.Empty, nameof(Timeout)); + } + _timeout = value; + } + } + + public string UserData + { + get + { + return _userData; + } + set + { + if ((null != value) && (ushort.MaxValue < value.Length)) + { + throw ADP.ArgumentOutOfRange(string.Empty, nameof(UserData)); + } + _userData = value; + } + } + } +} \ No newline at end of file diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/OnChangedEventHandler.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/OnChangedEventHandler.cs new file mode 100644 index 000000000000..ce8b2cb8efec --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/OnChangedEventHandler.cs @@ -0,0 +1,8 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Data.SqlClient +{ + public delegate void OnChangeEventHandler(object sender, SqlNotificationEventArgs e); +} \ No newline at end of file diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlBulkCopy.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlBulkCopy.cs index 75cc51100ce0..7f5d1c4807ea 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlBulkCopy.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlBulkCopy.cs @@ -486,7 +486,7 @@ private Task CreateAndExecuteInitialQueryAsync(out Bulk { string TDSCommand = CreateInitialQuery(); - Task executeTask = _parser.TdsExecuteSQLBatch(TDSCommand, this.BulkCopyTimeout, _stateObj, sync: !_isAsyncBulkCopy, callerHasConnectionLock: true); + Task executeTask = _parser.TdsExecuteSQLBatch(TDSCommand, this.BulkCopyTimeout, null, _stateObj, sync: !_isAsyncBulkCopy, callerHasConnectionLock: true); if (executeTask == null) { @@ -743,7 +743,7 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i private Task SubmitUpdateBulkCommand(string TDSCommand) { - Task executeTask = _parser.TdsExecuteSQLBatch(TDSCommand, this.BulkCopyTimeout, _stateObj, sync: !_isAsyncBulkCopy, callerHasConnectionLock: true); + Task executeTask = _parser.TdsExecuteSQLBatch(TDSCommand, this.BulkCopyTimeout, null, _stateObj, sync: !_isAsyncBulkCopy, callerHasConnectionLock: true); if (executeTask == null) { diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlCommand.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlCommand.cs index 1f936a210f93..84cdab38de85 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlCommand.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlCommand.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System.Data.Common; +using System.Data.Sql; using System.Data.SqlTypes; using System.Diagnostics; using System.Runtime.CompilerServices; @@ -23,6 +24,8 @@ public sealed class SqlCommand : DbCommand, ICloneable private UpdateRowSource _updatedRowSource = UpdateRowSource.Both; private bool _designTimeInvisible; + internal SqlDependency _sqlDep; + private static readonly DiagnosticListener _diagnosticListener = new DiagnosticListener(SqlClientDiagnosticListenerExtensions.DiagnosticListenerName); private bool _parentOperationStarted = false; @@ -182,6 +185,7 @@ private CachedAsyncState cachedAsyncState // _rowsAffected is cumulative for ExecuteNonQuery across all rpc batches internal int _rowsAffected = -1; // rows affected by the command + private SqlNotificationRequest _notification; // transaction support private SqlTransaction _transaction; @@ -304,6 +308,27 @@ override protected DbConnection DbConnection } } + private SqlInternalConnectionTds InternalTdsConnection + { + get + { + return (SqlInternalConnectionTds)_activeConnection.InnerConnection; + } + } + + public SqlNotificationRequest Notification + { + get + { + return _notification; + } + set + { + _sqlDep = null; + _notification = value; + } + } + internal SqlStatistics Statistics { get @@ -1104,6 +1129,8 @@ private Task InternalExecuteNonQuery(TaskCompletionSource completion, bo // returns false for empty command text ValidateCommand(async, methodName); + CheckNotificationStateAndAutoEnlist(); // Only call after validate - requires non null connection! + Task task = null; // only send over SQL Batch command if we are not a stored proc and have no parameters and not in batch RPC mode @@ -1398,7 +1425,7 @@ override protected DbDataReader ExecuteDbDataReader(CommandBehavior behavior) } - private SqlDataReader EndExecuteReader(IAsyncResult asyncResult) + internal SqlDataReader EndExecuteReader(IAsyncResult asyncResult) { Exception asyncException = ((Task)asyncResult).Exception; if (asyncException != null) @@ -1448,7 +1475,7 @@ private SqlDataReader EndExecuteReaderInternal(IAsyncResult asyncResult) } } - private IAsyncResult BeginExecuteReader(CommandBehavior behavior, AsyncCallback callback, object stateObject) + internal IAsyncResult BeginExecuteReader(CommandBehavior behavior, AsyncCallback callback, object stateObject) { // Reset _pendingCancel upon entry into any Execute - used to synchronize state // between entry into Execute* API and the thread obtaining the stateObject. @@ -1869,6 +1896,58 @@ internal _SqlMetaDataSet MetaData } } + // Check to see if notificactions auto enlistment is turned on. Enlist if so. + private void CheckNotificationStateAndAutoEnlist() + { + // Auto-enlist not supported in Core + + // If we have a notification with a dependency, setup the notification options at this time. + + // If user passes options, then we will always have option data at the time the SqlDependency + // ctor is called. But, if we are using default queue, then we do not have this data until + // Start(). Due to this, we always delay setting options until execute. + + // There is a variance in order between Start(), SqlDependency(), and Execute. This is the + // best way to solve that problem. + if (null != Notification) + { + if (_sqlDep != null) + { + if (null == _sqlDep.Options) + { + // If null, SqlDependency was not created with options, so we need to obtain default options now. + // GetDefaultOptions can and will throw under certain conditions. + + // In order to match to the appropriate start - we need 3 pieces of info: + // 1) server 2) user identity (SQL Auth or Int Sec) 3) database + + SqlDependency.IdentityUserNamePair identityUserName = null; + + // Obtain identity from connection. + SqlInternalConnectionTds internalConnection = _activeConnection.InnerConnection as SqlInternalConnectionTds; + if (internalConnection.Identity != null) + { + identityUserName = new SqlDependency.IdentityUserNamePair(internalConnection.Identity, null); + } + else + { + identityUserName = new SqlDependency.IdentityUserNamePair(null, internalConnection.ConnectionOptions.UserID); + } + + Notification.Options = SqlDependency.GetDefaultComposedOptions(_activeConnection.DataSource, + InternalTdsConnection.ServerProvidedFailOverPartner, + identityUserName, _activeConnection.Database); + } + + // Set UserData on notifications, as well as adding to the appdomain dispatcher. The value is + // computed by an algorithm on the dependency - fixed and will always produce the same value + // given identical commandtext + parameter values. + Notification.UserData = _sqlDep.ComputeHashAndAddToDispatcher(this); + // Maintain server list for SqlDependency. + _sqlDep.AddToServerList(_activeConnection.DataSource); + } + } + } // Tds-specific logic for ExecuteNonQuery run handling private Task RunExecuteNonQueryTds(string methodName, bool async, int timeout, bool asyncWrite) @@ -1928,9 +2007,10 @@ private Task RunExecuteNonQueryTds(string methodName, bool async, int timeout, b // no parameters are sent over // no data reader is returned // use this overload for "batch SQL" tds token type - Task executeTask = _stateObj.Parser.TdsExecuteSQLBatch(this.CommandText, timeout, _stateObj, sync: true); + Task executeTask = _stateObj.Parser.TdsExecuteSQLBatch(this.CommandText, timeout, this.Notification, _stateObj, sync: true); Debug.Assert(executeTask == null, "Shouldn't get a task when doing sync writes"); + NotifyDependency(); if (async) { _activeConnection.GetOpenTdsConnection(methodName).IncrementAsyncCount(); @@ -1986,6 +2066,9 @@ internal SqlDataReader RunExecuteReader(CommandBehavior cmdBehavior, RunBehavior // this function may throw for an invalid connection // returns false for empty command text ValidateCommand(async, method); + + CheckNotificationStateAndAutoEnlist(); // Only call after validate - requires non null connection! + SqlStatistics statistics = Statistics; if (null != statistics) { @@ -2094,7 +2177,7 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi // Send over SQL Batch command if we are not a stored proc and have no parameters Debug.Assert(!IsUserPrepared, "CommandType.Text with no params should not be prepared!"); string text = GetCommandText(cmdBehavior) + GetResetOptionsString(cmdBehavior); - writeTask = _stateObj.Parser.TdsExecuteSQLBatch(text, timeout, _stateObj, sync: !asyncWrite); + writeTask = _stateObj.Parser.TdsExecuteSQLBatch(text, timeout, this.Notification, _stateObj, sync: !asyncWrite); } else if (System.Data.CommandType.Text == this.CommandType) { @@ -2138,7 +2221,7 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi rpc.options = TdsEnums.RPC_NOMETADATA; Debug.Assert(_rpcArrayOf1[0] == rpc); - writeTask = _stateObj.Parser.TdsExecuteRPC(_rpcArrayOf1, timeout, inSchema, _stateObj, CommandType.StoredProcedure == CommandType, sync: !asyncWrite); + writeTask = _stateObj.Parser.TdsExecuteRPC(_rpcArrayOf1, timeout, inSchema, this.Notification, _stateObj, CommandType.StoredProcedure == CommandType, sync: !asyncWrite); } else { @@ -2153,7 +2236,7 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi // turn set options ON if (null != optionSettings) { - Task executeTask = _stateObj.Parser.TdsExecuteSQLBatch(optionSettings, timeout, _stateObj, sync: true); + Task executeTask = _stateObj.Parser.TdsExecuteSQLBatch(optionSettings, timeout, this.Notification, _stateObj, sync: true); Debug.Assert(executeTask == null, "Shouldn't get a task when doing sync writes"); bool dataReady; Debug.Assert(_stateObj._syncOverAsync, "Should not attempt pends in a synchronous call"); @@ -2166,7 +2249,7 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi // execute sp Debug.Assert(_rpcArrayOf1[0] == rpc); - writeTask = _stateObj.Parser.TdsExecuteRPC(_rpcArrayOf1, timeout, inSchema, _stateObj, CommandType.StoredProcedure == CommandType, sync: !asyncWrite); + writeTask = _stateObj.Parser.TdsExecuteRPC(_rpcArrayOf1, timeout, inSchema, this.Notification, _stateObj, CommandType.StoredProcedure == CommandType, sync: !asyncWrite); } Debug.Assert(writeTask == null || async, "Returned task in sync mode"); @@ -2253,6 +2336,8 @@ private void FinishExecuteReader(SqlDataReader ds, RunBehavior runBehavior, stri { // always wrap with a try { FinishExecuteReader(...) } finally { PutStateObject(); } + NotifyDependency(); + if (runBehavior == RunBehavior.UntilDone) { try @@ -3261,6 +3346,14 @@ internal void CancelIgnoreFailure() } } + private void NotifyDependency() + { + if (_sqlDep != null) + { + _sqlDep.StartTimer(Notification); + } + } + object ICloneable.Clone() => Clone(); public SqlCommand Clone() => new SqlCommand(this); diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDataReader.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDataReader.cs index 66a350f9cfef..ba831dbfaa85 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDataReader.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDataReader.cs @@ -3574,7 +3574,7 @@ private void RestoreServerSettings(TdsParser parser, TdsParserStateObject stateO // broken connection, so check state first. if (parser.State == TdsParserState.OpenLoggedIn) { - Task executeTask = parser.TdsExecuteSQLBatch(_resetOptionsString, (_command != null) ? _command.CommandTimeout : 0, stateObj, sync: true); + Task executeTask = parser.TdsExecuteSQLBatch(_resetOptionsString, (_command != null) ? _command.CommandTimeout : 0, null, stateObj, sync: true); Debug.Assert(executeTask == null, "Shouldn't get a task when doing sync writes"); // must execute this one synchronously as we can't retry diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDependency.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDependency.cs new file mode 100644 index 000000000000..d28c1d492e95 --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDependency.cs @@ -0,0 +1,975 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Data.Common; +using System.Data.ProviderBase; +using System.Diagnostics; +using System.Globalization; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Xml; +using System.Data.Sql; + +namespace System.Data.SqlClient +{ + public sealed class SqlDependency + { + // Private class encapsulating the user/identity information - either SQL Auth username or Windows identity. + internal class IdentityUserNamePair + { + private DbConnectionPoolIdentity _identity; + private string _userName; + + internal IdentityUserNamePair(DbConnectionPoolIdentity identity, string userName) + { + Debug.Assert((identity == null && userName != null) || + (identity != null && userName == null), "Unexpected arguments!"); + _identity = identity; + _userName = userName; + } + + internal DbConnectionPoolIdentity Identity => _identity; + + internal string UserName => _userName; + + public override bool Equals(object value) + { + IdentityUserNamePair temp = (IdentityUserNamePair)value; + + bool result = false; + + if (null == temp) + { // If passed value null - false. + result = false; + } + else if (this == temp) + { // If instances equal - true. + result = true; + } + else + { + if (_identity != null) + { + if (_identity.Equals(temp._identity)) + { + result = true; + } + } + else if (_userName == temp._userName) + { + result = true; + } + } + + return result; + } + + public override int GetHashCode() + { + int hashValue = 0; + + if (null != _identity) + { + hashValue = _identity.GetHashCode(); + } + else + { + hashValue = _userName.GetHashCode(); + } + + return hashValue; + } + } + + // Private class encapsulating the database, service info and hash logic. + private class DatabaseServicePair + { + private string _database; + private string _service; // Store the value, but don't use for equality or hashcode! + + internal DatabaseServicePair(string database, string service) + { + Debug.Assert(database != null, "Unexpected argument!"); + _database = database; + _service = service; + } + + internal string Database => _database; + + internal string Service => _service; + + public override bool Equals(object value) + { + DatabaseServicePair temp = (DatabaseServicePair)value; + + bool result = false; + + if (null == temp) + { // If passed value null - false. + result = false; + } + else if (this == temp) + { // If instances equal - true. + result = true; + } + else if (_database == temp._database) + { + result = true; + } + + return result; + } + + public override int GetHashCode() + { + return _database.GetHashCode(); + } + } + + // Private class encapsulating the event and it's registered execution context. + internal class EventContextPair + { + private OnChangeEventHandler _eventHandler; + private ExecutionContext _context; + private SqlDependency _dependency; + private SqlNotificationEventArgs _args; + + private static ContextCallback s_contextCallback = new ContextCallback(InvokeCallback); + + internal EventContextPair(OnChangeEventHandler eventHandler, SqlDependency dependency) + { + Debug.Assert(eventHandler != null && dependency != null, "Unexpected arguments!"); + _eventHandler = eventHandler; + _context = ExecutionContext.Capture(); + _dependency = dependency; + } + + public override bool Equals(object value) + { + EventContextPair temp = (EventContextPair)value; + + bool result = false; + + if (null == temp) + { // If passed value null - false. + result = false; + } + else if (this == temp) + { // If instances equal - true. + result = true; + } + else + { + if (_eventHandler == temp._eventHandler) + { // Handler for same delegates are reference equivalent. + result = true; + } + } + + return result; + } + + public override int GetHashCode() + { + return _eventHandler.GetHashCode(); + } + + internal void Invoke(SqlNotificationEventArgs args) + { + _args = args; + ExecutionContext.Run(_context, s_contextCallback, this); + } + + private static void InvokeCallback(object eventContextPair) + { + EventContextPair pair = (EventContextPair)eventContextPair; + pair._eventHandler(pair._dependency, (SqlNotificationEventArgs)pair._args); + } + } + + // Instance members + + // SqlNotificationRequest required state members + + // Only used for SqlDependency.Id. + private readonly string _id = Guid.NewGuid().ToString() + ";" + s_appDomainKey; + private string _options; // Concat of service & db, in the form "service=x;local database=y". + private int _timeout; + + // Various SqlDependency required members + private bool _dependencyFired = false; + // We are required to implement our own event collection to preserve ExecutionContext on callback. + private List _eventList = new List(); + private object _eventHandlerLock = new object(); // Lock for event serialization. + // Track the time that this dependency should time out. If the server didn't send a change + // notification or a time-out before this point then the client will perform a client-side + // timeout. + private DateTime _expirationTime = DateTime.MaxValue; + // Used for invalidation of dependencies based on which servers they rely upon. + // It's possible we will over invalidate if unexpected server failure occurs (but not server down). + private List _serverList = new List(); + + // Static members + + private static object s_startStopLock = new object(); + private static readonly string s_appDomainKey = Guid.NewGuid().ToString(); + // Hashtable containing all information to match from a server, user, database triple to the service started for that + // triple. For each server, there can be N users. For each user, there can be N databases. For each server, user, + // database, there can only be one service. + private static Dictionary>> s_serverUserHash = + new Dictionary>>(StringComparer.OrdinalIgnoreCase); + private static SqlDependencyProcessDispatcher s_processDispatcher = null; + // The following two strings are used for AppDomain.CreateInstance. + private static readonly string s_assemblyName = (typeof(SqlDependencyProcessDispatcher)).Assembly.FullName; + private static readonly string s_typeName = (typeof(SqlDependencyProcessDispatcher)).FullName; + + // Constructors + + public SqlDependency() : this(null, null, SQL.SqlDependencyTimeoutDefault) + { + } + + public SqlDependency(SqlCommand command) : this(command, null, SQL.SqlDependencyTimeoutDefault) + { + } + + public SqlDependency(SqlCommand command, string options, int timeout) + { + if (timeout < 0) + { + throw SQL.InvalidSqlDependencyTimeout(nameof(timeout)); + } + _timeout = timeout; + + if (null != options) + { // Ignore null value - will force to default. + _options = options; + } + + AddCommandInternal(command); + SqlDependencyPerAppDomainDispatcher.SingletonInstance.AddDependencyEntry(this); // Add dep to hashtable with Id. + } + + // Public Properties + + public bool HasChanges => _dependencyFired; + + public string Id => _id; + + // Internal Properties + + internal static string AppDomainKey => s_appDomainKey; + + internal DateTime ExpirationTime => _expirationTime; + + internal string Options => _options; + + internal static SqlDependencyProcessDispatcher ProcessDispatcher => s_processDispatcher; + + internal int Timeout => _timeout; + + // Events + + public event OnChangeEventHandler OnChange + { + // EventHandlers to be fired when dependency is notified. + add + { + if (null != value) + { + SqlNotificationEventArgs sqlNotificationEvent = null; + + lock (_eventHandlerLock) + { + if (_dependencyFired) + { // If fired, fire the new event immediately. + sqlNotificationEvent = new SqlNotificationEventArgs(SqlNotificationType.Subscribe, SqlNotificationInfo.AlreadyChanged, SqlNotificationSource.Client); + } + else + { + EventContextPair pair = new EventContextPair(value, this); + if (!_eventList.Contains(pair)) + { + _eventList.Add(pair); + } + else + { + throw SQL.SqlDependencyEventNoDuplicate(); + } + } + } + + if (null != sqlNotificationEvent) + { // Delay firing the event until outside of lock. + value(this, sqlNotificationEvent); + } + } + } + remove + { + if (null != value) + { + EventContextPair pair = new EventContextPair(value, this); + lock (_eventHandlerLock) + { + int index = _eventList.IndexOf(pair); + if (0 <= index) + { + _eventList.RemoveAt(index); + } + } + } + } + } + + // Public Methods + + public void AddCommandDependency(SqlCommand command) + { + // Adds command to dependency collection so we automatically create the SqlNotificationsRequest object + // and listen for a notification for the added commands. + if (command == null) + { + throw ADP.ArgumentNull(nameof(command)); + } + + AddCommandInternal(command); + } + + // Static Methods - public & internal + + // Static Start/Stop methods + + public static bool Start(string connectionString) + { + return Start(connectionString, null, true); + } + + public static bool Start(string connectionString, string queue) + { + return Start(connectionString, queue, false); + } + + internal static bool Start(string connectionString, string queue, bool useDefaults) + { + if (string.IsNullOrEmpty(connectionString)) + { + if (null == connectionString) + { + throw ADP.ArgumentNull(nameof(connectionString)); + } + else + { + throw ADP.Argument(nameof(connectionString)); + } + } + + if (!useDefaults && string.IsNullOrEmpty(queue)) + { // If specified but null or empty, use defaults. + useDefaults = true; + queue = null; // Force to null - for proper hashtable comparison for default case. + } + + // End duplicate Start/Stop logic. + + bool errorOccurred = false; + bool result = false; + + lock (s_startStopLock) + { + try + { + if (null == s_processDispatcher) + { // Ensure _processDispatcher reference is present - inside lock. + s_processDispatcher = SqlDependencyProcessDispatcher.SingletonProcessDispatcher; + } + + if (useDefaults) + { // Default listener. + string server = null; + DbConnectionPoolIdentity identity = null; + string user = null; + string database = null; + string service = null; + bool appDomainStart = false; + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { // CER to ensure that if Start succeeds we add to hash completing setup. + // Start using process wide default service/queue & database from connection string. + result = s_processDispatcher.StartWithDefault( + connectionString, + out server, + out identity, + out user, + out database, + ref service, + s_appDomainKey, + SqlDependencyPerAppDomainDispatcher.SingletonInstance, + out errorOccurred, + out appDomainStart); + } + finally + { + if (appDomainStart && !errorOccurred) + { // If success, add to hashtable. + IdentityUserNamePair identityUser = new IdentityUserNamePair(identity, user); + DatabaseServicePair databaseService = new DatabaseServicePair(database, service); + if (!AddToServerUserHash(server, identityUser, databaseService)) + { + try + { + Stop(connectionString, queue, useDefaults, true); + } + catch (Exception e) + { // Discard stop failure! + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + + ADP.TraceExceptionWithoutRethrow(e); // Discard failure, but trace for now. + } + throw SQL.SqlDependencyDuplicateStart(); + } + } + } + } + else + { // Start with specified service/queue & database. + result = s_processDispatcher.Start( + connectionString, + queue, + s_appDomainKey, + SqlDependencyPerAppDomainDispatcher.SingletonInstance); + // No need to call AddToServerDatabaseHash since if not using default queue user is required + // to provide options themselves. + } + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + + ADP.TraceExceptionWithoutRethrow(e); // Discard failure, but trace for now. + + throw; + } + } + + return result; + } + + public static bool Stop(string connectionString) + { + return Stop(connectionString, null, true, false); + } + + public static bool Stop(string connectionString, string queue) + { + return Stop(connectionString, queue, false, false); + } + + internal static bool Stop(string connectionString, string queue, bool useDefaults, bool startFailed) + { + if (string.IsNullOrEmpty(connectionString)) + { + if (null == connectionString) + { + throw ADP.ArgumentNull(nameof(connectionString)); + } + else + { + throw ADP.Argument(nameof(connectionString)); + } + } + + if (!useDefaults && string.IsNullOrEmpty(queue)) + { // If specified but null or empty, use defaults. + useDefaults = true; + queue = null; // Force to null - for proper hashtable comparison for default case. + } + + // End duplicate Start/Stop logic. + + bool result = false; + + lock (s_startStopLock) + { + if (null != s_processDispatcher) + { // If _processDispatcher null, no Start has been called. + try + { + string server = null; + DbConnectionPoolIdentity identity = null; + string user = null; + string database = null; + string service = null; + + if (useDefaults) + { + bool appDomainStop = false; + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { // CER to ensure that if Stop succeeds we remove from hash completing teardown. + // Start using process wide default service/queue & database from connection string. + result = s_processDispatcher.Stop( + connectionString, + out server, + out identity, + out user, + out database, + ref service, + s_appDomainKey, + out appDomainStop); + } + finally + { + if (appDomainStop && !startFailed) + { // If success, remove from hashtable. + Debug.Assert(!string.IsNullOrEmpty(server) && !string.IsNullOrEmpty(database), "Server or Database null/Empty upon successfull Stop()!"); + IdentityUserNamePair identityUser = new IdentityUserNamePair(identity, user); + DatabaseServicePair databaseService = new DatabaseServicePair(database, service); + RemoveFromServerUserHash(server, identityUser, databaseService); + } + } + } + else + { + result = s_processDispatcher.Stop( + connectionString, + out server, + out identity, + out user, + out database, + ref queue, + s_appDomainKey, + out bool ignored); + // No need to call RemoveFromServerDatabaseHash since if not using default queue user is required + // to provide options themselves. + } + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + + ADP.TraceExceptionWithoutRethrow(e); // Discard failure, but trace for now. + } + } + } + return result; + } + + // General static utility functions + + private static bool AddToServerUserHash(string server, IdentityUserNamePair identityUser, DatabaseServicePair databaseService) + { + bool result = false; + + lock (s_serverUserHash) + { + Dictionary> identityDatabaseHash; + + if (!s_serverUserHash.ContainsKey(server)) + { + identityDatabaseHash = new Dictionary>(); + s_serverUserHash.Add(server, identityDatabaseHash); + } + else + { + identityDatabaseHash = s_serverUserHash[server]; + } + + List databaseServiceList; + + if (!identityDatabaseHash.ContainsKey(identityUser)) + { + databaseServiceList = new List(); + identityDatabaseHash.Add(identityUser, databaseServiceList); + } + else + { + databaseServiceList = identityDatabaseHash[identityUser]; + } + + if (!databaseServiceList.Contains(databaseService)) + { + databaseServiceList.Add(databaseService); + result = true; + } + } + + return result; + } + + private static void RemoveFromServerUserHash(string server, IdentityUserNamePair identityUser, DatabaseServicePair databaseService) + { + lock (s_serverUserHash) + { + Dictionary> identityDatabaseHash; + + if (s_serverUserHash.ContainsKey(server)) + { + identityDatabaseHash = s_serverUserHash[server]; + + List databaseServiceList; + + if (identityDatabaseHash.ContainsKey(identityUser)) + { + databaseServiceList = identityDatabaseHash[identityUser]; + + int index = databaseServiceList.IndexOf(databaseService); + if (index >= 0) + { + databaseServiceList.RemoveAt(index); + + if (databaseServiceList.Count == 0) + { + identityDatabaseHash.Remove(identityUser); + + if (identityDatabaseHash.Count == 0) + { + s_serverUserHash.Remove(server); + } + } + } + else + { + Debug.Fail("Unexpected state - hash did not contain database!"); + } + } + else + { + Debug.Fail("Unexpected state - hash did not contain user!"); + } + } + else + { + Debug.Fail("Unexpected state - hash did not contain server!"); + } + } + } + + internal static string GetDefaultComposedOptions(string server, string failoverServer, IdentityUserNamePair identityUser, string database) + { + // Server must be an exact match, but user and database only needs to match exactly if there is more than one + // for the given user or database passed. That is ambiguious and we must fail. + string result; + + lock (s_serverUserHash) + { + if (!s_serverUserHash.ContainsKey(server)) + { + if (0 == s_serverUserHash.Count) + { // Special error for no calls to start. + throw SQL.SqlDepDefaultOptionsButNoStart(); + } + else if (!string.IsNullOrEmpty(failoverServer) && s_serverUserHash.ContainsKey(failoverServer)) + { + server = failoverServer; + } + else + { + throw SQL.SqlDependencyNoMatchingServerStart(); + } + } + + Dictionary> identityDatabaseHash = s_serverUserHash[server]; + + List databaseList = null; + + if (!identityDatabaseHash.ContainsKey(identityUser)) + { + if (identityDatabaseHash.Count > 1) + { + throw SQL.SqlDependencyNoMatchingServerStart(); + } + else + { + // Since only one user, - use that. + // Foreach - but only one value present. + foreach (KeyValuePair> entry in identityDatabaseHash) + { + databaseList = entry.Value; + break; // Only iterate once. + } + } + } + else + { + databaseList = identityDatabaseHash[identityUser]; + } + + DatabaseServicePair pair = new DatabaseServicePair(database, null); + DatabaseServicePair resultingPair = null; + int index = databaseList.IndexOf(pair); + if (index != -1) + { // Exact match found, use it. + resultingPair = databaseList[index]; + } + + if (null != resultingPair) + { // Exact database match. + database = FixupServiceOrDatabaseName(resultingPair.Database); // Fixup in place. + string quotedService = FixupServiceOrDatabaseName(resultingPair.Service); + result = "Service=" + quotedService + ";Local Database=" + database; + } + else + { // No exact database match found. + if (databaseList.Count == 1) + { // If only one database for this server/user, use it. + object[] temp = databaseList.ToArray(); // Must copy, no other choice but foreach. + resultingPair = (DatabaseServicePair)temp[0]; + Debug.Assert(temp.Length == 1, "If databaseList.Count==1, why does copied array have length other than 1?"); + string quotedDatabase = FixupServiceOrDatabaseName(resultingPair.Database); + string quotedService = FixupServiceOrDatabaseName(resultingPair.Service); + result = "Service=" + quotedService + ";Local Database=" + quotedDatabase; + } + else + { // More than one database for given server, ambiguous - fail the default case! + throw SQL.SqlDependencyNoMatchingServerDatabaseStart(); + } + } + } + + Debug.Assert(!string.IsNullOrEmpty(result), "GetDefaultComposedOptions should never return null or empty string!"); + return result; + } + + // Internal Methods + + // Called by SqlCommand upon execution of a SqlNotificationRequest class created by this dependency. We + // use this list for a reverse lookup based on server. + internal void AddToServerList(string server) + { + lock (_serverList) + { + int index = _serverList.BinarySearch(server, StringComparer.OrdinalIgnoreCase); + if (0 > index) + { // If less than 0, item was not found in list. + index = ~index; // BinarySearch returns the 2's compliment of where the item should be inserted to preserver a sorted list after insertion. + _serverList.Insert(index, server); + + } + } + } + + internal bool ContainsServer(string server) + { + lock (_serverList) + { + return _serverList.Contains(server); + } + } + + internal string ComputeHashAndAddToDispatcher(SqlCommand command) + { + // Create a string representing the concatenation of the connection string, command text and .ToString on all parameter values. + // This string will then be mapped to unique notification ID (new GUID). We add the guid and the hash to the app domain + // dispatcher to be able to map back to the dependency that needs to be fired for a notification of this + // command. + + // Add Connection string to prevent redundant notifications when same command is running against different databases or SQL servers + string commandHash = ComputeCommandHash(command.Connection.ConnectionString, command); // calculate the string representation of command + + string idString = SqlDependencyPerAppDomainDispatcher.SingletonInstance.AddCommandEntry(commandHash, this); // Add to map. + return idString; + } + + internal void Invalidate(SqlNotificationType type, SqlNotificationInfo info, SqlNotificationSource source) + { + List eventList = null; + + lock (_eventHandlerLock) + { + if (_dependencyFired && + SqlNotificationInfo.AlreadyChanged != info && + SqlNotificationSource.Client != source) + { + + if (ExpirationTime >= DateTime.UtcNow) + { + Debug.Fail("Received notification twice - we should never enter this state!"); + } + } + else + { + // It is the invalidators responsibility to remove this dependency from the app domain static hash. + _dependencyFired = true; + eventList = _eventList; + _eventList = new List(); // Since we are firing the events, null so we do not fire again. + } + } + + if (eventList != null) + { + foreach (EventContextPair pair in eventList) + { + pair.Invoke(new SqlNotificationEventArgs(type, info, source)); + } + } + } + + // This method is used by SqlCommand. + internal void StartTimer(SqlNotificationRequest notificationRequest) + { + if (_expirationTime == DateTime.MaxValue) + { + int seconds = SQL.SqlDependencyServerTimeout; + if (0 != _timeout) + { + seconds = _timeout; + } + if (notificationRequest != null && notificationRequest.Timeout < seconds && notificationRequest.Timeout != 0) + { + seconds = notificationRequest.Timeout; + } + + // We use UTC to check if SqlDependency is expired, need to use it here as well. + _expirationTime = DateTime.UtcNow.AddSeconds(seconds); + SqlDependencyPerAppDomainDispatcher.SingletonInstance.StartTimer(this); + } + } + + // Private Methods + + private void AddCommandInternal(SqlCommand cmd) + { + if (cmd != null) + { + SqlConnection connection = cmd.Connection; + + if (cmd.Notification != null) + { + // Fail if cmd has notification that is not already associated with this dependency. + if (cmd._sqlDep == null || cmd._sqlDep != this) + { + throw SQL.SqlCommandHasExistingSqlNotificationRequest(); + } + } + else + { + bool needToInvalidate = false; + + lock (_eventHandlerLock) + { + if (!_dependencyFired) + { + cmd.Notification = new SqlNotificationRequest() + { + Timeout = _timeout + }; + + // Add the command - A dependancy should always map to a set of commands which haven't fired. + if (null != _options) + { // Assign options if user provided. + cmd.Notification.Options = _options; + } + + cmd._sqlDep = this; + } + else + { + // We should never be able to enter this state, since if we've fired our event list is cleared + // and the event method will immediately fire if a new event is added. So, we should never have + // an event to fire in the event list once we've fired. + Debug.Assert(0 == _eventList.Count, "How can we have an event at this point?"); + if (0 == _eventList.Count) + { // Keep logic just in case. + needToInvalidate = true; // Delay invalidation until outside of lock. + } + } + } + + if (needToInvalidate) + { + Invalidate(SqlNotificationType.Subscribe, SqlNotificationInfo.AlreadyChanged, SqlNotificationSource.Client); + } + } + } + } + + private string ComputeCommandHash(string connectionString, SqlCommand command) + { + // Create a string representing the concatenation of the connection string, the command text and .ToString on all its parameter values. + // This string will then be mapped to the notification ID. + + // All types should properly support a .ToString for the values except + // byte[], char[], and XmlReader. + + StringBuilder builder = new StringBuilder(); + + // add the Connection string and the Command text + builder.AppendFormat("{0};{1}", connectionString, command.CommandText); + + // append params + for (int i = 0; i < command.Parameters.Count; i++) + { + object value = command.Parameters[i].Value; + + if (value == null || value == DBNull.Value) + { + builder.Append("; NULL"); + } + else + { + Type type = value.GetType(); + + if (type == typeof(byte[])) + { + builder.Append(";"); + byte[] temp = (byte[])value; + for (int j = 0; j < temp.Length; j++) + { + builder.Append(temp[j].ToString("x2", CultureInfo.InvariantCulture)); + } + } + else if (type == typeof(char[])) + { + builder.Append((char[])value); + } + else if (type == typeof(XmlReader)) + { + builder.Append(";"); + // Cannot .ToString XmlReader - just allocate GUID. + // This means if XmlReader is used, we will not reuse IDs. + builder.Append(Guid.NewGuid().ToString()); + } + else + { + builder.Append(";"); + builder.Append(value.ToString()); + } + } + } + + string result = builder.ToString(); + + return result; + } + + // Basic copy of function in SqlConnection.cs for ChangeDatabase and similar functionality. Since this will + // only be used for default service and database provided by server, we do not need to worry about an already + // quoted value. + internal static string FixupServiceOrDatabaseName(string name) + { + if (!string.IsNullOrEmpty(name)) + { + return "\"" + name.Replace("\"", "\"\"") + "\""; + } + else + { + return name; + } + } + } +} \ No newline at end of file diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDependencyListener.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDependencyListener.cs new file mode 100644 index 000000000000..ce30233356f3 --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDependencyListener.cs @@ -0,0 +1,1473 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Data.ProviderBase; +using System.Data.SqlClient; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.Threading; +using System.Xml; +using System.Runtime.Versioning; +using System.Diagnostics.CodeAnalysis; + +// This class is the process wide dependency dispatcher. It contains all connection listeners for the entire process and +// receives notifications on those connections to dispatch to the corresponding AppDomain dispatcher to notify the +// appropriate dependencies. + +internal class SqlDependencyProcessDispatcher : MarshalByRefObject +{ + // Class to contain/store all relevant information about a connection that waits on the SSB queue. + private class SqlConnectionContainer + { + private SqlConnection _con; + private SqlCommand _com; + private SqlParameter _conversationGuidParam; + private SqlParameter _timeoutParam; + private SqlConnectionContainerHashHelper _hashHelper; + private string _queue; + private string _receiveQuery; + private string _beginConversationQuery; + private string _endConversationQuery; + private string _concatQuery; + private readonly int _defaultWaitforTimeout = 60000; // Waitfor(Receive) timeout (milleseconds) + private string _escapedQueueName; + private string _sprocName; + private string _dialogHandle; + private string _cachedServer; + private string _cachedDatabase; + private volatile bool _errorState = false; + private volatile bool _stop = false; // Can probably simplify this slightly - one bool instead of two. + private volatile bool _stopped = false; + private volatile bool _serviceQueueCreated = false; + private int _startCount = 0; // Each container class is called once per Start() - we refCount + // to track when we can dispose. + private Timer _retryTimer = null; + private Dictionary _appDomainKeyHash = null; // AppDomainKey->Open RefCount + + // Constructor + + internal SqlConnectionContainer(SqlConnectionContainerHashHelper hashHelper, string appDomainKey, bool useDefaults) + { + bool setupCompleted = false; + try + { + _hashHelper = hashHelper; + string guid = null; + + // If default, queue name is not present on hashHelper at this point - so we need to + // generate one and complete initialization. + if (useDefaults) + { + guid = Guid.NewGuid().ToString(); + _queue = SQL.SqlNotificationServiceDefault + "-" + guid; + _hashHelper.ConnectionStringBuilder.ApplicationName = _queue; // Used by cleanup sproc. + } + else + { + _queue = _hashHelper.Queue; + } + + // Always use ConnectionStringBuilder since in default case it is different from the + // connection string used in the hashHelper. + _con = new SqlConnection(_hashHelper.ConnectionStringBuilder.ConnectionString); // Create connection and open. + + // Assert permission for this particular connection string since it differs from the user passed string + // which we have already demanded upon. + SqlConnectionString connStringObj = (SqlConnectionString)_con.ConnectionOptions; + + _con.Open(); + + _cachedServer = _con.DataSource; + + _escapedQueueName = SqlConnection.FixupDatabaseTransactionName(_queue); // Properly escape to prevent SQL Injection. + _appDomainKeyHash = new Dictionary(); // Dictionary stores the Start/Stop refcount per AppDomain for this container. + _com = new SqlCommand() + { + Connection = _con, + // Determine if broker is enabled on current database. + CommandText = "select is_broker_enabled from sys.databases where database_id=db_id()" + }; + + if (!(bool)_com.ExecuteScalar()) + { + throw SQL.SqlDependencyDatabaseBrokerDisabled(); + } + + _conversationGuidParam = new SqlParameter("@p1", SqlDbType.UniqueIdentifier); + _timeoutParam = new SqlParameter("@p2", SqlDbType.Int) + { + Value = 0 // Timeout set to 0 for initial sync query. + }; + _com.Parameters.Add(_timeoutParam); + + setupCompleted = true; + // connection with the server has been setup - from this point use TearDownAndDispose() in case of error + + // Create standard query. + _receiveQuery = "WAITFOR(RECEIVE TOP (1) message_type_name, conversation_handle, cast(message_body AS XML) as message_body from " + _escapedQueueName + "), TIMEOUT @p2;"; + + // Create queue, service, sync query, and async query on user thread to ensure proper + // init prior to return. + + if (useDefaults) + { // Only create if user did not specify service & database. + _sprocName = SqlConnection.FixupDatabaseTransactionName(SQL.SqlNotificationStoredProcedureDefault + "-" + guid); + CreateQueueAndService(false); // Fail if we cannot create service, queue, etc. + } + else + { + // Continue query setup. + _com.CommandText = _receiveQuery; + _endConversationQuery = "END CONVERSATION @p1; "; + _concatQuery = _endConversationQuery + _receiveQuery; + } + + IncrementStartCount(appDomainKey, out bool ignored); + // Query synchronously once to ensure everything is working correctly. + // We want the exception to occur on start to immediately inform caller. + SynchronouslyQueryServiceBrokerQueue(); + _timeoutParam.Value = _defaultWaitforTimeout; // Sync successful, extend timeout to 60 seconds. + AsynchronouslyQueryServiceBrokerQueue(); + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + + ADP.TraceExceptionWithoutRethrow(e); // Discard failure, but trace for now. + if (setupCompleted) + { + // Be sure to drop service & queue. This may fail if create service & queue failed. + // This method will not drop unless we created or service & queue ref-count is 0. + TearDownAndDispose(); + } + else + { + // connection has not been fully setup yet - cannot use TearDownAndDispose(); + // we have to dispose the command and the connection to avoid connection leaks (until GC collects them). + if (_com != null) + { + _com.Dispose(); + _com = null; + } + if (_con != null) + { + _con.Dispose(); + _con = null; + } + + } + throw; + } + } + + // Properties + + internal string Database + { + get + { + if (_cachedDatabase == null) + { + _cachedDatabase = _con.Database; + } + return _cachedDatabase; + } + } + + internal SqlConnectionContainerHashHelper HashHelper => _hashHelper; + + internal bool InErrorState => _errorState; + + internal string Queue => _queue; + + internal string Server => _cachedServer; + + // Methods + + // This function is called by a ThreadPool thread as a result of an AppDomain calling + // SqlDependencyProcessDispatcher.QueueAppDomainUnload on AppDomain.Unload. + internal bool AppDomainUnload(string appDomainKey) + { + Debug.Assert(!string.IsNullOrEmpty(appDomainKey), "Unexpected empty appDomainKey!"); + + // Dictionary used to track how many times start has been called per app domain. + // For each decrement, subtract from count, and delete if we reach 0. + lock (_appDomainKeyHash) + { + if (_appDomainKeyHash.ContainsKey(appDomainKey)) + { // Do nothing if AppDomain did not call Start! + int value = _appDomainKeyHash[appDomainKey]; + Debug.Assert(value > 0, "Why is value 0 or less?"); + + bool ignored = false; + while (value > 0) + { + Debug.Assert(!_stopped, "We should not yet be stopped!"); + Stop(appDomainKey, out ignored); // Stop will decrement value and remove if necessary from _appDomainKeyHash. + value--; + } + + // Stop will remove key when decremented to 0 for this AppDomain, which should now be the case. + Debug.Assert(0 == value, "We did not reach 0 at end of loop in AppDomainUnload!"); + Debug.Assert(!_appDomainKeyHash.ContainsKey(appDomainKey), "Key not removed after AppDomainUnload!"); + } + } + + return _stopped; + } + + private void AsynchronouslyQueryServiceBrokerQueue() + { + AsyncCallback callback = new AsyncCallback(AsyncResultCallback); + _com.BeginExecuteReader(CommandBehavior.Default, callback, null); // NO LOCK NEEDED + } + + private void AsyncResultCallback(IAsyncResult asyncResult) + { + try + { + using (SqlDataReader reader = _com.EndExecuteReader(asyncResult)) + { + ProcessNotificationResults(reader); + } + + // Successfull completion of query - no errors. + if (!_stop) + { + AsynchronouslyQueryServiceBrokerQueue(); // Requeue... + } + else + { + TearDownAndDispose(); + } + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + // Let the waiting thread detect the error and exit (otherwise, the Stop call loops forever) + _errorState = true; + throw; + } + + if (!_stop) + { // Only assert if not in cancel path. + ADP.TraceExceptionWithoutRethrow(e); // Discard failure, but trace for now. + } + + // Failure - likely due to cancelled command. Check _stop state. + if (_stop) + { + TearDownAndDispose(); + } + else + { + _errorState = true; + Restart(null); // Error code path. Will Invalidate based on server if 1st retry fails. + } + } + } + + private void CreateQueueAndService(bool restart) + { + SqlCommand com = new SqlCommand() + { + Connection = _con + }; + SqlTransaction trans = null; + + try + { + trans = _con.BeginTransaction(); // Since we cannot batch proc creation, start transaction. + com.Transaction = trans; + + string nameLiteral = SqlServerEscapeHelper.MakeStringLiteral(_queue); + + com.CommandText = + "CREATE PROCEDURE " + _sprocName + " AS" + + " BEGIN" + + " BEGIN TRANSACTION;" + + " RECEIVE TOP(0) conversation_handle FROM " + _escapedQueueName + ";" + + " IF (SELECT COUNT(*) FROM " + _escapedQueueName + " WHERE message_type_name = 'http://schemas.microsoft.com/SQL/ServiceBroker/DialogTimer') > 0" + + " BEGIN" + + " if ((SELECT COUNT(*) FROM sys.services WHERE name = " + nameLiteral + ") > 0)" + + " DROP SERVICE " + _escapedQueueName + ";" + + " if (OBJECT_ID(" + nameLiteral + ", 'SQ') IS NOT NULL)" + + " DROP QUEUE " + _escapedQueueName + ";" + + " DROP PROCEDURE " + _sprocName + ";" // Don't need conditional because this is self + + " END" + + " COMMIT TRANSACTION;" + + " END"; + + if (!restart) + { + com.ExecuteNonQuery(); + } + else + { // Upon restart, be resilient to the user dropping queue, service, or procedure. + try + { + com.ExecuteNonQuery(); // Cannot add 'IF OBJECT_ID' to create procedure query - wrap and discard failure. + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); + + try + { // Since the failure will result in a rollback, rollback our object. + if (null != trans) + { + trans.Rollback(); + trans = null; + } + } + catch (Exception f) + { + if (!ADP.IsCatchableExceptionType(f)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(f); // Discard failure, but trace for now. + } + } + + if (null == trans) + { // Create a new transaction for next operations. + trans = _con.BeginTransaction(); + com.Transaction = trans; + } + } + + + com.CommandText = + "IF OBJECT_ID(" + nameLiteral + ", 'SQ') IS NULL" + + " BEGIN" + + " CREATE QUEUE " + _escapedQueueName + " WITH ACTIVATION (PROCEDURE_NAME=" + _sprocName + ", MAX_QUEUE_READERS=1, EXECUTE AS OWNER);" + + " END;" + + " IF (SELECT COUNT(*) FROM sys.services WHERE NAME=" + nameLiteral + ") = 0" + + " BEGIN" + + " CREATE SERVICE " + _escapedQueueName + " ON QUEUE " + _escapedQueueName + " ([http://schemas.microsoft.com/SQL/Notifications/PostQueryNotification]);" + + " IF (SELECT COUNT(*) FROM sys.database_principals WHERE name='sql_dependency_subscriber' AND type='R') <> 0" + + " BEGIN" + + " GRANT SEND ON SERVICE::" + _escapedQueueName + " TO sql_dependency_subscriber;" + + " END; " + + " END;" + + " BEGIN DIALOG @dialog_handle FROM SERVICE " + _escapedQueueName + " TO SERVICE " + nameLiteral; + + SqlParameter param = new SqlParameter() + { + ParameterName = "@dialog_handle", + DbType = DbType.Guid, + Direction = ParameterDirection.Output + }; + com.Parameters.Add(param); + com.ExecuteNonQuery(); + + // Finish setting up queries and state. For re-start, we need to ensure we begin a new dialog above and reset + // our queries to use the new dialogHandle. + _dialogHandle = ((Guid)param.Value).ToString(); + _beginConversationQuery = "BEGIN CONVERSATION TIMER ('" + _dialogHandle + "') TIMEOUT = 120; " + _receiveQuery; + _com.CommandText = _beginConversationQuery; + _endConversationQuery = "END CONVERSATION @p1; "; + _concatQuery = _endConversationQuery + _com.CommandText; + + trans.Commit(); + trans = null; + _serviceQueueCreated = true; + } + finally + { + if (null != trans) + { + try + { + trans.Rollback(); + trans = null; + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); // Discard failure, but trace for now. + } + } + } + } + + internal void IncrementStartCount(string appDomainKey, out bool appDomainStart) + { + appDomainStart = false; // Reset out param. + int result = Interlocked.Increment(ref _startCount); // Add to refCount. + + // Dictionary used to track how many times start has been called per app domain. + // For each increment, add to count, and create entry if not present. + lock (_appDomainKeyHash) + { + if (_appDomainKeyHash.ContainsKey(appDomainKey)) + { + _appDomainKeyHash[appDomainKey] = _appDomainKeyHash[appDomainKey] + 1; + } + else + { + _appDomainKeyHash[appDomainKey] = 1; + appDomainStart = true; + } + } + } + + private void ProcessNotificationResults(SqlDataReader reader) + { + Guid handle = Guid.Empty; // Conversation_handle. Always close this! + try + { + if (!_stop) + { + while (reader.Read()) + { + string msgType = reader.GetString(0); + handle = reader.GetGuid(1); + + // Only process QueryNotification messages. + if (0 == string.Compare(msgType, "http://schemas.microsoft.com/SQL/Notifications/QueryNotification", StringComparison.OrdinalIgnoreCase)) + { + SqlXml payload = reader.GetSqlXml(2); + if (null != payload) + { + SqlNotification notification = SqlNotificationParser.ProcessMessage(payload); + if (null != notification) + { + string key = notification.Key; + int index = key.IndexOf(';'); // Our format is simple: "AppDomainKey;commandHash" + + if (index >= 0) + { // Ensure ';' present. + string appDomainKey = key.Substring(0, index); + SqlDependencyPerAppDomainDispatcher dispatcher; + lock (s_staticInstance._sqlDependencyPerAppDomainDispatchers) + { + dispatcher = s_staticInstance._sqlDependencyPerAppDomainDispatchers[appDomainKey]; + } + if (null != dispatcher) + { + try + { + dispatcher.InvalidateCommandID(notification); // CROSS APP-DOMAIN CALL! + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); // Discard failure. User event could throw exception. + } + } + else + { + Debug.Fail("Received notification but do not have an associated PerAppDomainDispatcher!"); + } + } + else + { + Debug.Fail("Unexpected ID format received!"); + } + } + else + { + Debug.Fail("Null notification returned from ProcessMessage!"); + } + } + else + { + Debug.Fail("Null payload for QN notification type!"); + } + } + else + { + handle = Guid.Empty; + } + } + } + } + finally + { + // Since we do not want to make a separate round trip just for the end conversation call, we need to + // batch it with the next command. + if (handle == Guid.Empty) + { // This should only happen if failure occurred, or if non-QN format received. + _com.CommandText = _beginConversationQuery ?? _receiveQuery; // If we're doing the initial query, we won't have a conversation Guid to begin yet. + if (_com.Parameters.Count > 1) + { // Remove conversation param since next execute is only query. + _com.Parameters.Remove(_conversationGuidParam); + } + Debug.Assert(_com.Parameters.Count == 1, "Unexpected number of parameters!"); + } + else + { + _com.CommandText = _concatQuery; // END query + WAITFOR RECEIVE query. + _conversationGuidParam.Value = handle; // Set value for conversation handle. + if (_com.Parameters.Count == 1) + { // Add parameter if previous execute was only query. + _com.Parameters.Add(_conversationGuidParam); + } + Debug.Assert(_com.Parameters.Count == 2, "Unexpected number of parameters!"); + } + } + } + + // SxS: this method uses WindowsIdentity.Impersonate to impersonate the current thread with the + // credentials used to create this SqlConnectionContainer. + [ResourceExposure(ResourceScope.None)] + [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] + private void Restart(object unused) + { // Unused arg required by TimerCallback. + try + { + lock (this) + { + if (!_stop) + { // Only execute if we are still in running state. + try + { + _con.Close(); + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); // Discard close failure, if it occurs. Only trace it. + } + } + } + + // Rather than one long lock - take lock 3 times for shorter periods. + + lock (this) + { + if (!_stop) + { + _con.Open(); + } + } + + lock (this) + { + if (!_stop) + { + if (_serviceQueueCreated) + { + bool failure = false; + + try + { + CreateQueueAndService(true); // Ensure service, queue, etc is present, if we created it. + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); // Discard failure, but trace for now. + failure = true; + } + + if (failure) + { + // If we failed to re-created queue, service, sproc - invalidate! + s_staticInstance.Invalidate(Server, + new SqlNotification(SqlNotificationInfo.Error, + SqlNotificationSource.Client, + SqlNotificationType.Change, + null)); + + } + } + } + } + + lock (this) + { + if (!_stop) + { + _timeoutParam.Value = 0; // Reset timeout to zero - we do not want to block. + SynchronouslyQueryServiceBrokerQueue(); + // If the above succeeds, we are back in success case - requeue for async call. + _timeoutParam.Value = _defaultWaitforTimeout; // If success, reset to default for re-queue. + AsynchronouslyQueryServiceBrokerQueue(); + _errorState = false; + _retryTimer = null; + } + } + + if (_stop) + { + TearDownAndDispose(); // Function will lock(this). + } + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); + + try + { + // If unexpected query or connection failure, invalidate all dependencies against this server. + // We may over-notify if only some of the connections to a particular database were affected, + // but this should not be frequent enough to be a concern. + // NOTE - we invalidate after failure first occurs and then retry fails. We will then continue + // to invalidate every time the retry fails. + s_staticInstance.Invalidate(Server, + new SqlNotification(SqlNotificationInfo.Error, + SqlNotificationSource.Client, + SqlNotificationType.Change, + null)); + } + catch (Exception f) + { + if (!ADP.IsCatchableExceptionType(f)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(f); // Discard exception from Invalidate. User events can throw. + } + + try + { + _con.Close(); + } + catch (Exception f) + { + if (!ADP.IsCatchableExceptionType(f)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(f); // Discard close failure, if it occurs. Only trace it. + } + + if (!_stop) + { + // Create a timer to callback in one minute, retrying the call to Restart(). + _retryTimer = new Timer(new TimerCallback(Restart), null, _defaultWaitforTimeout, Timeout.Infinite); + // We will retry this indefinitely, until success - or Stop(); + } + } + } + + internal bool Stop(string appDomainKey, out bool appDomainStop) + { + appDomainStop = false; + + // Dictionary used to track how many times start has been called per app domain. + // For each decrement, subtract from count, and delete if we reach 0. + + if (null != appDomainKey) + { + // If null, then this was called from SqlDependencyProcessDispatcher, we ignore appDomainKeyHash. + lock (_appDomainKeyHash) + { + if (_appDomainKeyHash.ContainsKey(appDomainKey)) + { // Do nothing if AppDomain did not call Start! + int value = _appDomainKeyHash[appDomainKey]; + + Debug.Assert(value > 0, "Unexpected count for appDomainKey"); + + if (value > 0) + { + _appDomainKeyHash[appDomainKey] = value - 1; + } + else + { + Debug.Fail("Unexpected AppDomainKey count in Stop()"); + } + + if (1 == value) + { // Remove from dictionary if pre-decrement count was one. + _appDomainKeyHash.Remove(appDomainKey); + appDomainStop = true; + } + } + else + { + Debug.Fail("Unexpected state on Stop() - no AppDomainKey entry in hashtable!"); + } + } + } + + Debug.Assert(_startCount > 0, "About to decrement _startCount less than 0!"); + int result = Interlocked.Decrement(ref _startCount); + + if (0 == result) + { // If we've reached refCount 0, destroy. + // Lock to ensure Cancel() complete prior to other thread calling TearDown. + lock (this) + { + try + { + // Race condition with executing thread - will throw if connection is closed due to failure. + // Rather than fighting the race condition, just call it and discard any potential failure. + _com.Cancel(); // Cancel the pending command. No-op if connection closed. + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); // Discard failure, if it should occur. + } + _stop = true; + } + + // Wait until stopped and service & queue are dropped. + Stopwatch retryStopwatch = Stopwatch.StartNew(); + while (true) + { + lock (this) + { + if (_stopped) + { + break; + } + + // If we are in error state (_errorState is true), force a tear down. + // Likewise, if we have exceeded the maximum retry period (30 seconds) waiting for cleanup, force a tear down. + // In rare cases during app domain unload, the async cleanup performed by AsyncResultCallback + // may fail to execute TearDownAndDispose, leaving this method in an infinite loop. + // To avoid the infinite loop, we force the cleanup here after 30 seconds. Since we have reached + // refcount of 0, either this method call or the thread running AsyncResultCallback is responsible for calling + // TearDownAndDispose when transitioning to the _stopped state. Failing to call TearDownAndDispose means we leak + // the service broker objects created by this SqlDependency instance, so we make a best effort here to call + // TearDownAndDispose in the maximum retry period case as well as in the _errorState case. + if (_errorState || retryStopwatch.Elapsed.Seconds >= 30) + { + Timer retryTimer = _retryTimer; + _retryTimer = null; + if (retryTimer != null) + { + retryTimer.Dispose(); // Dispose timer - stop retry loop! + } + TearDownAndDispose(); // Will not hit server unless connection open! + break; + } + } + + // Yield the thread since the stop has not yet completed. + // To avoid CPU spikes while waiting, yield and wait for at least one millisecond + Thread.Sleep(1); + } + } + + Debug.Assert(0 <= _startCount, "Invalid start count state"); + + return _stopped; + } + + private void SynchronouslyQueryServiceBrokerQueue() + { + using (SqlDataReader reader = _com.ExecuteReader()) + { + ProcessNotificationResults(reader); + } + } + + [SuppressMessage("Microsoft.Security", "CA2100:ReviewSqlQueriesForSecurityVulnerabilities")] + private void TearDownAndDispose() + { + lock (this) + { // Lock to ensure Stop() (with Cancel()) complete prior to TearDown. + try + { + // Only execute if connection is still up and open. + if (ConnectionState.Closed != _con.State && ConnectionState.Broken != _con.State) + { + if (_com.Parameters.Count > 1) + { // Need to close dialog before completing. + // In the normal case, the "End Conversation" query is executed before a + // receive query and upon return we will clear the state. However, unless + // a non notification query result is returned, we will not clear it. That + // means a query is generally always executing with an "end conversation" on + // the wire. Rather than synchronize for success of the other "end conversation", + // simply re-execute. + try + { + _com.CommandText = _endConversationQuery; + _com.Parameters.Remove(_timeoutParam); + _com.ExecuteNonQuery(); + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); // Discard failure. + } + } + + if (_serviceQueueCreated && !_errorState) + { + /* + BEGIN TRANSACTION; + DROP SERVICE "+_escapedQueueName+"; + DROP QUEUE "+_escapedQueueName+"; + DROP PROCEDURE "+_sprocName+"; + COMMIT TRANSACTION; + */ + _com.CommandText = "BEGIN TRANSACTION; DROP SERVICE " + _escapedQueueName + "; DROP QUEUE " + _escapedQueueName + "; DROP PROCEDURE " + _sprocName + "; COMMIT TRANSACTION;"; + try + { + _com.ExecuteNonQuery(); + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); // Discard failure. + } + } + } + } + finally + { + _stopped = true; + _con.Dispose(); // Close and dispose connection. + } + } + } + } + + // Private class encapsulating the notification payload parsing logic. + + private class SqlNotificationParser + { + [Flags] + private enum MessageAttributes + { + None = 0, + Type = 1, + Source = 2, + Info = 4, + All = Type + Source + Info, + } + + // node names in the payload + private const string RootNode = "QueryNotification"; + private const string MessageNode = "Message"; + + // attribute names (on the QueryNotification element) + private const string InfoAttribute = "info"; + private const string SourceAttribute = "source"; + private const string TypeAttribute = "type"; + + internal static SqlNotification ProcessMessage(SqlXml xmlMessage) + { + using (XmlReader xmlReader = xmlMessage.CreateReader()) + { + string keyvalue = string.Empty; + + MessageAttributes messageAttributes = MessageAttributes.None; + + SqlNotificationType type = SqlNotificationType.Unknown; + SqlNotificationInfo info = SqlNotificationInfo.Unknown; + SqlNotificationSource source = SqlNotificationSource.Unknown; + + string key = string.Empty; + + // Move to main node, expecting "QueryNotification". + xmlReader.Read(); + if ((XmlNodeType.Element == xmlReader.NodeType) && + (RootNode == xmlReader.LocalName) && + (3 <= xmlReader.AttributeCount)) + { + // Loop until we've processed all the attributes. + while ((MessageAttributes.All != messageAttributes) && (xmlReader.MoveToNextAttribute())) + { + try + { + switch (xmlReader.LocalName) + { + case TypeAttribute: + try + { + SqlNotificationType temp = (SqlNotificationType)Enum.Parse(typeof(SqlNotificationType), xmlReader.Value, true); + if (Enum.IsDefined(typeof(SqlNotificationType), temp)) + { + type = temp; + } + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); // Discard failure, if it should occur. + } + messageAttributes |= MessageAttributes.Type; + break; + case SourceAttribute: + try + { + SqlNotificationSource temp = (SqlNotificationSource)Enum.Parse(typeof(SqlNotificationSource), xmlReader.Value, true); + if (Enum.IsDefined(typeof(SqlNotificationSource), temp)) + { + source = temp; + } + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); // Discard failure, if it should occur. + } + messageAttributes |= MessageAttributes.Source; + break; + case InfoAttribute: + try + { + string value = xmlReader.Value; + // 3 of the server info values do not match client values - map. + switch (value) + { + case "set options": + info = SqlNotificationInfo.Options; + break; + case "previous invalid": + info = SqlNotificationInfo.PreviousFire; + break; + case "query template limit": + info = SqlNotificationInfo.TemplateLimit; + break; + default: + SqlNotificationInfo temp = (SqlNotificationInfo)Enum.Parse(typeof(SqlNotificationInfo), value, true); + if (Enum.IsDefined(typeof(SqlNotificationInfo), temp)) + { + info = temp; + } + break; + } + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); // Discard failure, if it should occur. + } + messageAttributes |= MessageAttributes.Info; + break; + default: + break; + } + } + catch (ArgumentException e) + { + ADP.TraceExceptionWithoutRethrow(e); // Discard failure, but trace. + return null; + } + } + + if (MessageAttributes.All != messageAttributes) + { + return null; + } + + // Proceed to the "Message" node. + if (!xmlReader.Read()) + { + return null; + } + + // Verify state after Read(). + if ((XmlNodeType.Element != xmlReader.NodeType) || (0 != string.Compare(xmlReader.LocalName, MessageNode, StringComparison.OrdinalIgnoreCase))) + { + return null; + } + + // Proceed to the Text Node. + if (!xmlReader.Read()) + { + return null; + } + + // Verify state after Read(). + if (xmlReader.NodeType != XmlNodeType.Text) + { + return null; + } + + // Create a new XmlTextReader on the Message node value. + using (XmlTextReader xmlMessageReader = new XmlTextReader(xmlReader.Value, XmlNodeType.Element, null)) + { + // Proceed to the Text Node. + if (!xmlMessageReader.Read()) + { + return null; + } + + if (xmlMessageReader.NodeType == XmlNodeType.Text) + { + key = xmlMessageReader.Value; + xmlMessageReader.Close(); + } + else + { + return null; + } + } + + return new SqlNotification(info, source, type, key); + } + else + { + return null; // failure + } + } + } + } + + // Private class encapsulating the SqlConnectionContainer hash logic. + + private class SqlConnectionContainerHashHelper + { + // For default, queue is computed in SqlConnectionContainer constructor, so queue will be empty and + // connection string will not include app name based on queue. As a result, the connection string + // builder will always contain up to date info, but _connectionString and _queue will not. + + // As a result, we will not use _connectionStringBuilder as part of Equals or GetHashCode. + + private DbConnectionPoolIdentity _identity; + private string _connectionString; + private string _queue; + private SqlConnectionStringBuilder _connectionStringBuilder; // Not to be used for comparison! + + internal SqlConnectionContainerHashHelper(DbConnectionPoolIdentity identity, string connectionString, + string queue, SqlConnectionStringBuilder connectionStringBuilder) + { + _identity = identity; + _connectionString = connectionString; + _queue = queue; + _connectionStringBuilder = connectionStringBuilder; + } + + // Not to be used for comparison! + internal SqlConnectionStringBuilder ConnectionStringBuilder => _connectionStringBuilder; + + internal DbConnectionPoolIdentity Identity => _identity; + + internal string Queue => _queue; + + public override bool Equals(object value) + { + SqlConnectionContainerHashHelper temp = (SqlConnectionContainerHashHelper)value; + + bool result = false; + + // Ignore SqlConnectionStringBuilder, since it is present largely for debug purposes. + + if (null == temp) + { // If passed value null - false. + result = false; + } + else if (this == temp) + { // If instances equal - true. + result = true; + } + else + { + if ((_identity != null && temp._identity == null) || // If XOR of null identities false - false. + (_identity == null && temp._identity != null)) + { + result = false; + } + else if (_identity == null && temp._identity == null) + { + if (temp._connectionString == _connectionString && + string.Equals(temp._queue, _queue, StringComparison.OrdinalIgnoreCase)) + { + result = true; + } + else + { + result = false; + } + } + else + { + if (temp._identity.Equals(_identity) && + temp._connectionString == _connectionString && + string.Equals(temp._queue, _queue, StringComparison.OrdinalIgnoreCase)) + { + result = true; + } + else + { + result = false; + } + } + } + + return result; + } + + public override int GetHashCode() + { + int hashValue = 0; + + if (null != _identity) + { + hashValue = _identity.GetHashCode(); + } + + if (null != _queue) + { + hashValue = unchecked(_connectionString.GetHashCode() + _queue.GetHashCode() + hashValue); + } + else + { + hashValue = unchecked(_connectionString.GetHashCode() + hashValue); + } + + return hashValue; + } + } + + // SqlDependencyProcessDispatcher static members + + private static SqlDependencyProcessDispatcher s_staticInstance = new SqlDependencyProcessDispatcher(null); + + // Dictionaries used as maps. + private Dictionary _connectionContainers; // NT_ID+ConStr+Service->Container + private Dictionary _sqlDependencyPerAppDomainDispatchers; // AppDomainKey->Callback + + // Constructors + + // Private constructor - only called by public constructor for static initialization. + private SqlDependencyProcessDispatcher(object dummyVariable) + { + Debug.Assert(null == s_staticInstance, "Real constructor called with static instance already created!"); + + _connectionContainers = new Dictionary(); + _sqlDependencyPerAppDomainDispatchers = new Dictionary(); + } + + // Constructor is only called by remoting. + // Required to be public, even on internal class, for Remoting infrastructure. + public SqlDependencyProcessDispatcher() + { + // Empty constructor and object - dummy to obtain singleton. + } + + // Properties + + internal static SqlDependencyProcessDispatcher SingletonProcessDispatcher => s_staticInstance; + + // Various private methods + + private static SqlConnectionContainerHashHelper GetHashHelper( + string connectionString, + out SqlConnectionStringBuilder connectionStringBuilder, + out DbConnectionPoolIdentity identity, + out string user, + string queue) + { + // Force certain connection string properties to be used by SqlDependencyProcessDispatcher. + // This logic is done here to enable us to have the complete connection string now to be used + // for tracing as we flow through the logic. + connectionStringBuilder = new SqlConnectionStringBuilder(connectionString) + { + Pooling = false, + Enlist = false, + ConnectRetryCount = 0 + }; + if (null != queue) + { // User provided! + connectionStringBuilder.ApplicationName = queue; // ApplicationName will be set to queue name. + } + + if (connectionStringBuilder.IntegratedSecurity) + { + // Use existing identity infrastructure for error cases and proper hash value. + identity = DbConnectionPoolIdentity.GetCurrent(); + user = null; + } + else + { + identity = null; + user = connectionStringBuilder.UserID; + } + + return new SqlConnectionContainerHashHelper(identity, connectionStringBuilder.ConnectionString, + queue, connectionStringBuilder); + } + + // Needed for remoting to prevent lifetime issues and default GC cleanup. + public override object InitializeLifetimeService() + { + return null; + } + + private void Invalidate(string server, SqlNotification sqlNotification) + { + Debug.Assert(this == s_staticInstance, "Instance method called on non _staticInstance instance!"); + lock (_sqlDependencyPerAppDomainDispatchers) + { + + foreach (KeyValuePair entry in _sqlDependencyPerAppDomainDispatchers) + { + SqlDependencyPerAppDomainDispatcher perAppDomainDispatcher = entry.Value; + try + { + perAppDomainDispatcher.InvalidateServer(server, sqlNotification); + } + catch (Exception f) + { + // Since we are looping over dependency dispatchers, do not allow one Invalidate + // that results in a throw prevent us from invalidating all dependencies + // related to this server. + // NOTE - SqlDependencyPerAppDomainDispatcher already wraps individual dependency invalidates + // with try/catch, but we should be careful and do the same here. + if (!ADP.IsCatchableExceptionType(f)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(f); // Discard failure, but trace. + } + } + } + } + + // Clean-up method initiated by other AppDomain.Unloads + + // Individual AppDomains upon AppDomain.UnloadEvent will call this method. + internal void QueueAppDomainUnloading(string appDomainKey) + { + ThreadPool.QueueUserWorkItem(new WaitCallback(AppDomainUnloading), appDomainKey); + } + + // This method is only called by queued work-items from the method above. + private void AppDomainUnloading(object state) + { + string appDomainKey = (string)state; + + Debug.Assert(this == s_staticInstance, "Instance method called on non _staticInstance instance!"); + lock (_connectionContainers) + { + List containersToRemove = new List(); + + foreach (KeyValuePair entry in _connectionContainers) + { + SqlConnectionContainer container = entry.Value; + if (container.AppDomainUnload(appDomainKey)) + { // Perhaps wrap in try catch. + containersToRemove.Add(container.HashHelper); + } + } + + foreach (SqlConnectionContainerHashHelper hashHelper in containersToRemove) + { + _connectionContainers.Remove(hashHelper); + } + } + + lock (_sqlDependencyPerAppDomainDispatchers) + { // Remove from global Dictionary. + _sqlDependencyPerAppDomainDispatchers.Remove(appDomainKey); + } + } + + // ------------- + // Start methods + // ------------- + + internal bool StartWithDefault( + string connectionString, + out string server, + out DbConnectionPoolIdentity identity, + out string user, + out string database, + ref string service, + string appDomainKey, + SqlDependencyPerAppDomainDispatcher dispatcher, + out bool errorOccurred, + out bool appDomainStart) + { + Debug.Assert(this == s_staticInstance, "Instance method called on non _staticInstance instance!"); + return Start( + connectionString, + out server, + out identity, + out user, + out database, + ref service, + appDomainKey, + dispatcher, + out errorOccurred, + out appDomainStart, + true); + } + + internal bool Start( + string connectionString, + string queue, + string appDomainKey, + SqlDependencyPerAppDomainDispatcher dispatcher) + { + Debug.Assert(this == s_staticInstance, "Instance method called on non _staticInstance instance!"); + return Start( + connectionString, + out string dummyValue1, + out DbConnectionPoolIdentity dummyValue3, + out dummyValue1, + out dummyValue1, + ref queue, + appDomainKey, + dispatcher, + out bool dummyValue2, + out dummyValue2, + false); + } + + private bool Start( + string connectionString, + out string server, + out DbConnectionPoolIdentity identity, + out string user, + out string database, + ref string queueService, + string appDomainKey, + SqlDependencyPerAppDomainDispatcher dispatcher, + out bool errorOccurred, + out bool appDomainStart, + bool useDefaults) + { + Debug.Assert(this == s_staticInstance, "Instance method called on non _staticInstance instance!"); + server = null; // Reset out params. + identity = null; + user = null; + database = null; + errorOccurred = false; + appDomainStart = false; + + lock (_sqlDependencyPerAppDomainDispatchers) + { + if (!_sqlDependencyPerAppDomainDispatchers.ContainsKey(appDomainKey)) + { + _sqlDependencyPerAppDomainDispatchers[appDomainKey] = dispatcher; + } + } + + SqlConnectionContainerHashHelper hashHelper = GetHashHelper(connectionString, + out SqlConnectionStringBuilder connectionStringBuilder, + out identity, + out user, + queueService); + + bool started = false; + + SqlConnectionContainer container = null; + lock (_connectionContainers) + { + if (!_connectionContainers.ContainsKey(hashHelper)) + { + container = new SqlConnectionContainer(hashHelper, appDomainKey, useDefaults); + _connectionContainers.Add(hashHelper, container); + started = true; + appDomainStart = true; + } + else + { + container = _connectionContainers[hashHelper]; + if (container.InErrorState) + { + errorOccurred = true; // Set outparam errorOccurred true so we invalidate on Start(). + } + else + { + container.IncrementStartCount(appDomainKey, out appDomainStart); + } + } + } + + if (useDefaults && !errorOccurred) + { // Return server, database, and queue for use by SqlDependency. + server = container.Server; + database = container.Database; + queueService = container.Queue; + } + + return started; + } + + // Stop methods + + internal bool Stop( + string connectionString, + out string server, + out DbConnectionPoolIdentity identity, + out string user, + out string database, + ref string queueService, + string appDomainKey, + out bool appDomainStop) + { + Debug.Assert(this == s_staticInstance, "Instance method called on non _staticInstance instance!"); + server = null; // Reset out param. + identity = null; + user = null; + database = null; + appDomainStop = false; + + SqlConnectionContainerHashHelper hashHelper = GetHashHelper(connectionString, + out SqlConnectionStringBuilder connectionStringBuilder, + out identity, + out user, + queueService); + + bool stopped = false; + + lock (_connectionContainers) + { + if (_connectionContainers.ContainsKey(hashHelper)) + { + SqlConnectionContainer container = _connectionContainers[hashHelper]; + server = container.Server; // Return server, database, and queue info for use by calling SqlDependency. + database = container.Database; + queueService = container.Queue; + + if (container.Stop(appDomainKey, out appDomainStop)) + { // Stop can be blocking if refCount == 0 on container. + stopped = true; + _connectionContainers.Remove(hashHelper); // Remove from collection. + } + } + } + + return stopped; + } +} \ No newline at end of file diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDependencyUtils.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDependencyUtils.cs new file mode 100644 index 000000000000..ba360d3b4af3 --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDependencyUtils.cs @@ -0,0 +1,510 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Data.Common; +using System.Diagnostics; +using System.Threading; + +namespace System.Data.SqlClient +{ + // This is a singleton instance per AppDomain that acts as the notification dispatcher for + // that AppDomain. It receives calls from the SqlDependencyProcessDispatcher with an ID or a server name + // to invalidate matching dependencies in the given AppDomain. + + internal class SqlDependencyPerAppDomainDispatcher : MarshalByRefObject + { + // Instance members + + internal static readonly SqlDependencyPerAppDomainDispatcher + SingletonInstance = new SqlDependencyPerAppDomainDispatcher(); // singleton object + + internal object _instanceLock = new object(); + + // Dependency ID -> Dependency hashtable. 1 -> 1 mapping. + // 1) Used for ASP.Net to map from ID to dependency. + // 2) Used to enumerate dependencies to invalidate based on server. + private Dictionary _dependencyIdToDependencyHash; + + // holds dependencies list per notification and the command hash from which this notification was generated + // command hash is needed to remove its entry from _commandHashToNotificationId when the notification is removed + private sealed class DependencyList : List + { + public readonly string CommandHash; + + internal DependencyList(string commandHash) + { + CommandHash = commandHash; + } + } + + // notificationId -> Dependencies hashtable: 1 -> N mapping. notificationId == appDomainKey + commandHash. + // More than one dependency can be using the same command hash values resulting in a hash to the same value. + // We use this to cache mapping between command to dependencies such that we may reduce the notification + // resource effect on SQL Server. The Guid identifier is sent to the server during notification enlistment, + // and returned during the notification event. Dependencies look up existing Guids, if one exists, to ensure + // they are re-using notification ids. + private Dictionary _notificationIdToDependenciesHash; + + // CommandHash value -> notificationId associated with it: 1->1 mapping. This map is used to quickly find if we need to create + // new notification or hookup into existing one. + // CommandHash is built from connection string, command text and parameters + private Dictionary _commandHashToNotificationId; + + // TIMEOUT LOGIC DESCRIPTION + // + // Every time we add a dependency we compute the next, earlier timeout. + // + // We setup a timer to get a callback every 15 seconds. In the call back: + // - If there are no active dependencies, we just return. + // - If there are dependencies but none of them timed-out (compared to the "next timeout"), + // we just return. + // - Otherwise we Invalidate() those that timed-out. + // + // So the client-generated timeouts have a granularity of 15 seconds. This allows + // for a simple and low-resource-consumption implementation. + // + // LOCKS: don't update _nextTimeout outside of the _dependencyHash.SyncRoot lock. + + private bool _sqlDependencyTimeOutTimerStarted = false; + // Next timeout for any of the dependencies in the dependency table. + private DateTime _nextTimeout; + // Timer to periodically check the dependencies in the table and see if anyone needs + // a timeout. We'll enable this only on demand. + private Timer _timeoutTimer; + + private SqlDependencyPerAppDomainDispatcher() + { + _dependencyIdToDependencyHash = new Dictionary(); + _notificationIdToDependenciesHash = new Dictionary(); + _commandHashToNotificationId = new Dictionary(); + + _timeoutTimer = new Timer(new TimerCallback(TimeoutTimerCallback), null, Timeout.Infinite, Timeout.Infinite); + + // If rude abort - we'll leak. This is acceptable for now. + AppDomain.CurrentDomain.DomainUnload += new EventHandler(UnloadEventHandler); + } + + // When remoted across appdomains, MarshalByRefObject links by default time out if there is no activity + // within a few minutes. Add this override to prevent marshaled links from timing out. + public override object InitializeLifetimeService() + { + return null; + } + + // Events + + private void UnloadEventHandler(object sender, EventArgs e) + { + // Make non-blocking call to ProcessDispatcher to ThreadPool.QueueUserWorkItem to complete + // stopping of all start calls in this AppDomain. For containers shared among various AppDomains, + // this will just be a ref-count subtract. For non-shared containers, we will close the container + // and clean-up. + SqlDependencyProcessDispatcher dispatcher = SqlDependency.ProcessDispatcher; + if (null != dispatcher) + { + dispatcher.QueueAppDomainUnloading(SqlDependency.AppDomainKey); + } + } + + // Methods for dependency hash manipulation and firing. + + // This method is called upon SqlDependency constructor. + internal void AddDependencyEntry(SqlDependency dep) + { + lock (_instanceLock) + { + _dependencyIdToDependencyHash.Add(dep.Id, dep); + } + } + + // This method is called upon Execute of a command associated with a SqlDependency object. + internal string AddCommandEntry(string commandHash, SqlDependency dep) + { + string notificationId = string.Empty; + lock (_instanceLock) + { + if (_dependencyIdToDependencyHash.ContainsKey(dep.Id)) + { + // check if we already have notification associated with given command hash + if (_commandHashToNotificationId.TryGetValue(commandHash, out notificationId)) + { + // we have one or more SqlDependency instances with same command hash + + DependencyList dependencyList = null; + if (!_notificationIdToDependenciesHash.TryGetValue(notificationId, out dependencyList)) + { + // this should not happen since _commandHashToNotificationId and _notificationIdToDependenciesHash are always + // updated together + Debug.Fail("_commandHashToNotificationId has entries that were removed from _notificationIdToDependenciesHash. Remember to keep them in sync"); + throw ADP.InternalError(ADP.InternalErrorCode.SqlDependencyCommandHashIsNotAssociatedWithNotification); + } + + // join the new dependency to the list + if (!dependencyList.Contains(dep)) + { + dependencyList.Add(dep); + } + } + else + { + // we did not find notification ID with the same app domain and command hash, create a new one + // use unique guid to avoid duplicate IDs + // prepend app domain ID to the key - SqlConnectionContainer::ProcessNotificationResults (SqlDependencyListener.cs) + // uses this app domain ID to route the message back to the app domain in which this SqlDependency was created + notificationId = string.Format(System.Globalization.CultureInfo.InvariantCulture, + "{0};{1}", + SqlDependency.AppDomainKey, // must be first + Guid.NewGuid().ToString("D", System.Globalization.CultureInfo.InvariantCulture) + ); + + DependencyList dependencyList = new DependencyList(commandHash); + dependencyList.Add(dep); + + // map command hash to notification we just created to reuse it for the next client + _commandHashToNotificationId.Add(commandHash, notificationId); + _notificationIdToDependenciesHash.Add(notificationId, dependencyList); + } + + Debug.Assert(_notificationIdToDependenciesHash.Count == _commandHashToNotificationId.Count, "always keep these maps in sync!"); + } + } + + return notificationId; + } + + // This method is called by the ProcessDispatcher upon a notification for this AppDomain. + internal void InvalidateCommandID(SqlNotification sqlNotification) + { + List dependencyList = null; + + lock (_instanceLock) + { + dependencyList = LookupCommandEntryWithRemove(sqlNotification.Key); + + if (null != dependencyList) + { + foreach (SqlDependency dependency in dependencyList) + { + // Ensure we remove from process static app domain hash for dependency initiated invalidates. + LookupDependencyEntryWithRemove(dependency.Id); + + // Completely remove Dependency from commandToDependenciesHash. + RemoveDependencyFromCommandToDependenciesHash(dependency); + } + } + } + + if (null != dependencyList) + { + // After removal from hashtables, invalidate. + foreach (SqlDependency dependency in dependencyList) + { + try + { + dependency.Invalidate(sqlNotification.Type, sqlNotification.Info, sqlNotification.Source); + } + catch (Exception e) + { + // Since we are looping over dependencies, do not allow one Invalidate + // that results in a throw prevent us from invalidating all dependencies + // related to this server. + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); + } + } + } + } + + // This method is called when a connection goes down or other unknown error occurs in the ProcessDispatcher. + internal void InvalidateServer(string server, SqlNotification sqlNotification) + { + List dependencies = new List(); + + lock (_instanceLock) + { // Copy inside of lock, but invalidate outside of lock. + foreach (KeyValuePair entry in _dependencyIdToDependencyHash) + { + SqlDependency dependency = entry.Value; + if (dependency.ContainsServer(server)) + { + dependencies.Add(dependency); + } + } + + foreach (SqlDependency dependency in dependencies) + { // Iterate over resulting list removing from our hashes. + // Ensure we remove from process static app domain hash for dependency initiated invalidates. + LookupDependencyEntryWithRemove(dependency.Id); + + // Completely remove Dependency from commandToDependenciesHash. + RemoveDependencyFromCommandToDependenciesHash(dependency); + } + } + + foreach (SqlDependency dependency in dependencies) + { // Iterate and invalidate. + try + { + dependency.Invalidate(sqlNotification.Type, sqlNotification.Info, sqlNotification.Source); + } + catch (Exception e) + { + // Since we are looping over dependencies, do not allow one Invalidate + // that results in a throw prevent us from invalidating all dependencies + // related to this server. + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + ADP.TraceExceptionWithoutRethrow(e); + } + } + } + + // This method is called by SqlCommand to enable ASP.Net scenarios - map from ID to Dependency. + internal SqlDependency LookupDependencyEntry(string id) + { + if (null == id) + { + throw ADP.ArgumentNull(nameof(id)); + } + if (string.IsNullOrEmpty(id)) + { + throw SQL.SqlDependencyIdMismatch(); + } + + SqlDependency entry = null; + + lock (_instanceLock) + { + if (_dependencyIdToDependencyHash.ContainsKey(id)) + { + entry = _dependencyIdToDependencyHash[id]; + } + } + + return entry; + } + + // Remove the dependency from the hashtable with the passed id. + private void LookupDependencyEntryWithRemove(string id) + { + lock (_instanceLock) + { + if (_dependencyIdToDependencyHash.ContainsKey(id)) + { + _dependencyIdToDependencyHash.Remove(id); + + // if there are no more dependencies then we can dispose the timer. + if (0 == _dependencyIdToDependencyHash.Count) + { + _timeoutTimer.Change(Timeout.Infinite, Timeout.Infinite); + _sqlDependencyTimeOutTimerStarted = false; + } + } + } + } + + // Find and return arraylist, and remove passed hash value. + private List LookupCommandEntryWithRemove(string notificationId) + { + DependencyList entry = null; + + lock (_instanceLock) + { + if (_notificationIdToDependenciesHash.TryGetValue(notificationId, out entry)) + { + // update the tables + _notificationIdToDependenciesHash.Remove(notificationId); + // Cleanup the map between the command hash and associated notification ID + _commandHashToNotificationId.Remove(entry.CommandHash); + } + + Debug.Assert(_notificationIdToDependenciesHash.Count == _commandHashToNotificationId.Count, "always keep these maps in sync!"); + } + + return entry; // DependencyList inherits from List + } + + // Remove from commandToDependenciesHash all references to the passed dependency. + private void RemoveDependencyFromCommandToDependenciesHash(SqlDependency dependency) + { + lock (_instanceLock) + { + List notificationIdsToRemove = new List(); + List commandHashesToRemove = new List(); + + foreach (KeyValuePair entry in _notificationIdToDependenciesHash) + { + DependencyList dependencies = entry.Value; + if (dependencies.Remove(dependency)) + { + if (dependencies.Count == 0) + { + // this dependency was the last associated with this notification ID, remove the entry + // note: cannot do it inside foreach over dictionary + notificationIdsToRemove.Add(entry.Key); + commandHashesToRemove.Add(entry.Value.CommandHash); + } + } + + // same SqlDependency can be associated with more than one command, so we have to continue till the end... + } + + Debug.Assert(commandHashesToRemove.Count == notificationIdsToRemove.Count, "maps should be kept in sync"); + for (int i = 0; i < notificationIdsToRemove.Count; i++) + { + // cleanup the entry outside of foreach + _notificationIdToDependenciesHash.Remove(notificationIdsToRemove[i]); + // Cleanup the map between the command hash and associated notification ID + _commandHashToNotificationId.Remove(commandHashesToRemove[i]); + } + + Debug.Assert(_notificationIdToDependenciesHash.Count == _commandHashToNotificationId.Count, "always keep these maps in sync!"); + } + } + + // Methods for Timer maintenance and firing. + + internal void StartTimer(SqlDependency dep) + { + // If this dependency expires sooner than the current next timeout, change + // the timeout and enable timer callback as needed. Note that we change _nextTimeout + // only inside the hashtable syncroot. + lock (_instanceLock) + { + // Enable the timer if needed (disable when empty, enable on the first addition). + if (!_sqlDependencyTimeOutTimerStarted) + { + _timeoutTimer.Change(15000 /* 15 secs */, 15000 /* 15 secs */); + + // Save this as the earlier timeout to come. + _nextTimeout = dep.ExpirationTime; + _sqlDependencyTimeOutTimerStarted = true; + } + else if (_nextTimeout > dep.ExpirationTime) + { + // Save this as the earlier timeout to come. + _nextTimeout = dep.ExpirationTime; + } + } + } + + private static void TimeoutTimerCallback(object state) + { + SqlDependency[] dependencies; + + // Only take the lock for checking whether there is work to do + // if we do have work, we'll copy the hashtable and scan it after releasing + // the lock. + lock (SingletonInstance._instanceLock) + { + if (0 == SingletonInstance._dependencyIdToDependencyHash.Count) + { + // Nothing to check. + return; + } + if (SingletonInstance._nextTimeout > DateTime.UtcNow) + { + // No dependency timed-out yet. + return; + } + + // If at least one dependency timed-out do a scan of the table. + // NOTE: we could keep a shadow table sorted by expiration time, but + // given the number of typical simultaneously alive dependencies it's + // probably not worth the optimization. + dependencies = new SqlDependency[SingletonInstance._dependencyIdToDependencyHash.Count]; + SingletonInstance._dependencyIdToDependencyHash.Values.CopyTo(dependencies, 0); + } + + // Scan the active dependencies if needed. + DateTime now = DateTime.UtcNow; + DateTime newNextTimeout = DateTime.MaxValue; + + for (int i = 0; i < dependencies.Length; i++) + { + // If expired fire the change notification. + if (dependencies[i].ExpirationTime <= now) + { + try + { + // This invokes user-code which may throw exceptions. + // NOTE: this is intentionally outside of the lock, we don't want + // to invoke user-code while holding an internal lock. + dependencies[i].Invalidate(SqlNotificationType.Change, SqlNotificationInfo.Error, SqlNotificationSource.Timeout); + } + catch (Exception e) + { + if (!ADP.IsCatchableExceptionType(e)) + { + throw; + } + + // This is an exception in user code, and we're in a thread-pool thread + // without user's code up in the stack, no much we can do other than + // eating the exception. + ADP.TraceExceptionWithoutRethrow(e); + } + } + else + { + if (dependencies[i].ExpirationTime < newNextTimeout) + { + newNextTimeout = dependencies[i].ExpirationTime; // Track the next earlier timeout. + } + dependencies[i] = null; // Null means "don't remove it from the hashtable" in the loop below. + } + } + + // Remove timed-out dependencies from the hashtable. + lock (SingletonInstance._instanceLock) + { + for (int i = 0; i < dependencies.Length; i++) + { + if (null != dependencies[i]) + { + SingletonInstance._dependencyIdToDependencyHash.Remove(dependencies[i].Id); + } + } + if (newNextTimeout < SingletonInstance._nextTimeout) + { + SingletonInstance._nextTimeout = newNextTimeout; // We're inside the lock so ok to update. + } + } + } + } + + // Simple class used to encapsulate all data in a notification. + internal class SqlNotification : MarshalByRefObject + { + // This class could be Serializable rather than MBR... + + private readonly SqlNotificationInfo _info; + private readonly SqlNotificationSource _source; + private readonly SqlNotificationType _type; + private readonly string _key; + + internal SqlNotification(SqlNotificationInfo info, SqlNotificationSource source, SqlNotificationType type, string key) + { + _info = info; + _source = source; + _type = type; + _key = key; + } + + internal SqlNotificationInfo Info => _info; + + internal string Key => _key; + + internal SqlNotificationSource Source => _source; + + internal SqlNotificationType Type => _type; + } +} + diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlInternalConnectionTds.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlInternalConnectionTds.cs index 38af50966fe0..8d71238057ac 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlInternalConnectionTds.cs @@ -527,7 +527,7 @@ override protected void ChangeDatabaseInternal(string database) { // Add brackets around database database = SqlConnection.FixupDatabaseTransactionName(database); - Task executeTask = _parser.TdsExecuteSQLBatch("use " + database, ConnectionOptions.ConnectTimeout, _parser._physicalStateObj, sync: true); + Task executeTask = _parser.TdsExecuteSQLBatch("use " + database, ConnectionOptions.ConnectTimeout, null, _parser._physicalStateObj, sync: true); Debug.Assert(executeTask == null, "Shouldn't get a task when doing sync writes"); _parser.Run(RunBehavior.UntilDone, null, null, null, _parser._physicalStateObj); } diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlNotificationEventArgs.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlNotificationEventArgs.cs new file mode 100644 index 000000000000..c3577f065cba --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlNotificationEventArgs.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Data.SqlClient +{ + public class SqlNotificationEventArgs : EventArgs + { + private SqlNotificationType _type; + private SqlNotificationInfo _info; + private SqlNotificationSource _source; + + public SqlNotificationEventArgs(SqlNotificationType type, SqlNotificationInfo info, SqlNotificationSource source) + { + _info = info; + _source = source; + _type = type; + } + + public SqlNotificationType Type => _type; + + public SqlNotificationInfo Info => _info; + + public SqlNotificationSource Source => _source; + + internal static SqlNotificationEventArgs s_notifyError = new SqlNotificationEventArgs(SqlNotificationType.Subscribe, SqlNotificationInfo.Error, SqlNotificationSource.Object); + } +} \ No newline at end of file diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlNotificationInfo.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlNotificationInfo.cs new file mode 100644 index 000000000000..983919a70dce --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlNotificationInfo.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Data.SqlClient +{ + public enum SqlNotificationInfo + { + Truncate = 0, + Insert = 1, + Update = 2, + Delete = 3, + Drop = 4, + Alter = 5, + Restart = 6, + Error = 7, + Query = 8, + Invalid = 9, + Options = 10, + Isolation = 11, + Expired = 12, + Resource = 13, + PreviousFire = 14, + TemplateLimit = 15, + Merge = 16, + + // use negative values for client-only-generated values + Unknown = -1, + AlreadyChanged = -2 + } +} \ No newline at end of file diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlNotificationSource.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlNotificationSource.cs new file mode 100644 index 000000000000..52be708d1bf4 --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlNotificationSource.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Data.SqlClient +{ + public enum SqlNotificationSource + { + Data = 0, + Timeout = 1, + Object = 2, + Database = 3, + System = 4, + Statement = 5, + Environment = 6, + Execution = 7, + Owner = 8, + + // use negative values for client-only-generated values + Unknown = -1, + Client = -2 + } +} \ No newline at end of file diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlNotificationType.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlNotificationType.cs new file mode 100644 index 000000000000..e7bb989ae042 --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlNotificationType.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Data.SqlClient +{ + public enum SqlNotificationType + { + Change = 0, + Subscribe = 1, + + // use negative values for client-only-generated values + Unknown = -1 + } +} \ No newline at end of file diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlUtil.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlUtil.cs index 161764b8840d..9b4deb0a63fa 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlUtil.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlUtil.cs @@ -276,6 +276,10 @@ internal static Exception PendingBeginXXXExists() return ADP.InvalidOperation(SR.GetString(SR.SQL_PendingBeginXXXExists)); } + internal static ArgumentOutOfRangeException InvalidSqlDependencyTimeout(string param) + { + return ADP.ArgumentOutOfRange(SR.GetString(SR.SqlDependency_InvalidTimeout), param); + } internal static Exception NonXmlResult() { @@ -413,6 +417,50 @@ internal static Exception XmlReaderNotSupportOnColumnType(string columnName) return ADP.InvalidCast(SR.GetString(SR.SQL_XmlReaderNotSupportOnColumnType, columnName)); } + // + // SQL.SqlDependency + // + internal static Exception SqlCommandHasExistingSqlNotificationRequest() + { + return ADP.InvalidOperation(SR.GetString(SR.SQLNotify_AlreadyHasCommand)); + } + + internal static Exception SqlDepDefaultOptionsButNoStart() + { + return ADP.InvalidOperation(SR.GetString(SR.SqlDependency_DefaultOptionsButNoStart)); + } + + internal static Exception SqlDependencyDatabaseBrokerDisabled() + { + return ADP.InvalidOperation(SR.GetString(SR.SqlDependency_DatabaseBrokerDisabled)); + } + + internal static Exception SqlDependencyEventNoDuplicate() + { + return ADP.InvalidOperation(SR.GetString(SR.SqlDependency_EventNoDuplicate)); + } + + internal static Exception SqlDependencyDuplicateStart() + { + return ADP.InvalidOperation(SR.GetString(SR.SqlDependency_DuplicateStart)); + } + + internal static Exception SqlDependencyIdMismatch() + { + // do not include the id because it may require SecurityPermission(Infrastructure) permission + return ADP.InvalidOperation(SR.GetString(SR.SqlDependency_IdMismatch)); + } + + internal static Exception SqlDependencyNoMatchingServerStart() + { + return ADP.InvalidOperation(SR.GetString(SR.SqlDependency_NoMatchingServerStart)); + } + + internal static Exception SqlDependencyNoMatchingServerDatabaseStart() + { + return ADP.InvalidOperation(SR.GetString(SR.SqlDependency_NoMatchingServerDatabaseStart)); + } + // // SQL.SqlMetaData // @@ -801,6 +849,12 @@ internal static string GetSNIErrorMessage(int sniError) string errorMessageId = String.Format((IFormatProvider)null, "SNI_ERROR_{0}", sniError); return SR.GetResourceString(errorMessageId, errorMessageId); } + + // Default values for SqlDependency and SqlNotificationRequest + internal const int SqlDependencyTimeoutDefault = 0; + internal const int SqlDependencyServerTimeout = 5 * 24 * 3600; // 5 days - used to compute default TTL of the dependency + internal const string SqlNotificationServiceDefault = "SqlQueryNotificationService"; + internal const string SqlNotificationStoredProcedureDefault = "SqlQueryNotificationStoredProcedure"; } sealed internal class SQLMessage @@ -971,5 +1025,24 @@ internal static string EscapeStringAsLiteral(string input) Debug.Assert(input != null, "input string cannot be null"); return input.Replace("'", "''"); } + + /// + /// Escape a string as a TSQL literal, wrapping it around with single quotes. + /// Use this method to escape input strings to prevent SQL injection + /// and to get correct behavior for embedded quotes. + /// + /// unescaped string + /// escaped and quoted literal string + internal static string MakeStringLiteral(string input) + { + if (string.IsNullOrEmpty(input)) + { + return "''"; + } + else + { + return "'" + EscapeStringAsLiteral(input) + "'"; + } + } } }//namespace diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.cs index 2603d762ff56..558fcf1356b7 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Data.Common; +using System.Data.Sql; using System.Data.SqlTypes; using System.Diagnostics; using System.Globalization; @@ -6459,7 +6460,7 @@ internal void FailureCleanup(TdsParserStateObject stateObj, Exception e) } } - internal Task TdsExecuteSQLBatch(string text, int timeout, TdsParserStateObject stateObj, bool sync, bool callerHasConnectionLock = false) + internal Task TdsExecuteSQLBatch(string text, int timeout, SqlNotificationRequest notificationRequest, TdsParserStateObject stateObj, bool sync, bool callerHasConnectionLock = false) { if (TdsParserState.Broken == State || TdsParserState.Closed == State) { @@ -6509,7 +6510,7 @@ internal Task TdsExecuteSQLBatch(string text, int timeout, TdsParserStateObject stateObj.SetTimeoutSeconds(timeout); stateObj.SniContext = SniContext.Snix_Execute; - WriteRPCBatchHeaders(stateObj); + WriteRPCBatchHeaders(stateObj, notificationRequest); stateObj._outputMessageType = TdsEnums.MT_SQL; @@ -6575,7 +6576,7 @@ internal Task TdsExecuteSQLBatch(string text, int timeout, TdsParserStateObject } } - internal Task TdsExecuteRPC(_SqlRPC[] rpcArray, int timeout, bool inSchema, TdsParserStateObject stateObj, bool isCommandProc, bool sync = true, + internal Task TdsExecuteRPC(_SqlRPC[] rpcArray, int timeout, bool inSchema, SqlNotificationRequest notificationRequest, TdsParserStateObject stateObj, bool isCommandProc, bool sync = true, TaskCompletionSource completion = null, int startRpc = 0, int startParam = 0) { bool firstCall = (completion == null); @@ -6624,7 +6625,7 @@ internal Task TdsExecuteSQLBatch(string text, int timeout, TdsParserStateObject stateObj.SetTimeoutSeconds(timeout); stateObj.SniContext = SniContext.Snix_Execute; - WriteRPCBatchHeaders(stateObj); + WriteRPCBatchHeaders(stateObj, notificationRequest); stateObj._outputMessageType = TdsEnums.MT_RPC; } @@ -6981,7 +6982,7 @@ internal Task TdsExecuteSQLBatch(string text, int timeout, TdsParserStateObject } AsyncHelper.ContinueTask(writeParamTask, completion, - () => TdsExecuteRPC(rpcArray, timeout, inSchema, stateObj, isCommandProc, sync, completion, + () => TdsExecuteRPC(rpcArray, timeout, inSchema, notificationRequest, stateObj, isCommandProc, sync, completion, startRpc: ii, startParam: i + 1), connectionToDoom: _connHandler, onFailure: exc => TdsExecuteRPC_OnFailure(exc, stateObj)); @@ -7853,7 +7854,89 @@ private void WriteMarsHeaderData(TdsParserStateObject stateObj, SqlInternalTrans } } - private void WriteRPCBatchHeaders(TdsParserStateObject stateObj) + private int GetNotificationHeaderSize(SqlNotificationRequest notificationRequest) + { + if (null != notificationRequest) + { + string callbackId = notificationRequest.UserData; + string service = notificationRequest.Options; + int timeout = notificationRequest.Timeout; + + if (null == callbackId) + { + throw ADP.ArgumentNull(nameof(callbackId)); + } + else if (ushort.MaxValue < callbackId.Length) + { + throw ADP.ArgumentOutOfRange(nameof(callbackId)); + } + + if (null == service) + { + throw ADP.ArgumentNull(nameof(service)); + } + else if (ushort.MaxValue < service.Length) + { + throw ADP.ArgumentOutOfRange(nameof(service)); + } + + if (-1 > timeout) + { + throw ADP.ArgumentOutOfRange(nameof(timeout)); + } + + // Header Length (uint) (included in size) (already written to output buffer) + // Header Type (ushort) + // NotifyID Length (ushort) + // NotifyID UnicodeStream (unicode text) + // SSBDeployment Length (ushort) + // SSBDeployment UnicodeStream (unicode text) + // Timeout (uint) -- optional + // Don't send timeout value if it is 0 + + int headerLength = 4 + 2 + 2 + (callbackId.Length * 2) + 2 + (service.Length * 2); + if (timeout > 0) + headerLength += 4; + return headerLength; + } + else + { + return 0; + } + } + + // Write query notificaiton header data, not including the notificaiton header length + private void WriteQueryNotificationHeaderData(SqlNotificationRequest notificationRequest, TdsParserStateObject stateObj) + { + Debug.Assert(_isYukon, "WriteQueryNotificationHeaderData called on a non-Yukon server"); + + // We may need to update the notification header length if the header is changed in the future + + Debug.Assert(null != notificationRequest, "notificaitonRequest is null"); + + string callbackId = notificationRequest.UserData; + string service = notificationRequest.Options; + int timeout = notificationRequest.Timeout; + + // we did verification in GetNotificationHeaderSize, so just assert here. + Debug.Assert(null != callbackId, "CallbackId is null"); + Debug.Assert(ushort.MaxValue >= callbackId.Length, "CallbackId length is out of range"); + Debug.Assert(null != service, "Service is null"); + Debug.Assert(ushort.MaxValue >= service.Length, "Service length is out of range"); + Debug.Assert(-1 <= timeout, "Timeout"); + + WriteShort(TdsEnums.HEADERTYPE_QNOTIFICATION, stateObj); // Query notifications Type + + WriteShort(callbackId.Length * 2, stateObj); // Length in bytes + WriteString(callbackId, stateObj); + + WriteShort(service.Length * 2, stateObj); // Length in bytes + WriteString(service, stateObj); + if (timeout > 0) + WriteInt(timeout, stateObj); + } + + private void WriteRPCBatchHeaders(TdsParserStateObject stateObj, SqlNotificationRequest notificationRequest) { /* Header: TotalLength - DWORD - including all headers and lengths, including itself @@ -7865,10 +7948,11 @@ private void WriteRPCBatchHeaders(TdsParserStateObject stateObj) } */ + int notificationHeaderSize = GetNotificationHeaderSize(notificationRequest); const int marsHeaderSize = 18; // 4 + 2 + 8 + 4 - int totalHeaderLength = 4 + marsHeaderSize; + int totalHeaderLength = 4 + marsHeaderSize + notificationHeaderSize; Debug.Assert(stateObj._outBytesUsed == stateObj._outputHeaderLen, "Output bytes written before total header length"); // Write total header length WriteInt(totalHeaderLength, stateObj); @@ -7877,6 +7961,14 @@ private void WriteRPCBatchHeaders(TdsParserStateObject stateObj) WriteInt(marsHeaderSize, stateObj); // Write Mars header data WriteMarsHeaderData(stateObj, CurrentTransaction); + + if (0 != notificationHeaderSize) + { + // Write Notification header length + WriteInt(notificationHeaderSize, stateObj); + // Write notificaiton header data + WriteQueryNotificationHeaderData(notificationRequest, stateObj); + } } diff --git a/src/System.Data.SqlClient/tests/ManualTests/SQL/SqlNotificationTest/SqlNotificationTest.cs b/src/System.Data.SqlClient/tests/ManualTests/SQL/SqlNotificationTest/SqlNotificationTest.cs new file mode 100644 index 000000000000..44625e1f5658 --- /dev/null +++ b/src/System.Data.SqlClient/tests/ManualTests/SQL/SqlNotificationTest/SqlNotificationTest.cs @@ -0,0 +1,351 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading; +using Xunit; + +namespace System.Data.SqlClient.ManualTesting.Tests +{ + public class SqlNotificationTest : IDisposable + { + // Misc constants + private const int CALLBACK_TIMEOUT = 5000; // milliseconds + + // Database schema + private readonly string _tableName = "dbo.[SQLDEP_" + Guid.NewGuid().ToString() + "]"; + private readonly string _queueName = "SQLDEP_" + Guid.NewGuid().ToString(); + private readonly string _serviceName = "SQLDEP_" + Guid.NewGuid().ToString(); + private readonly string _schemaQueue; + + // Connection information used by all tests + private readonly string _startConnectionString; + private readonly string _execConnectionString; + + public SqlNotificationTest() + { + _startConnectionString = DataTestUtility.TcpConnStr; + _execConnectionString = DataTestUtility.TcpConnStr; + + var startBuilder = new SqlConnectionStringBuilder(_startConnectionString); + if (startBuilder.IntegratedSecurity) + { + _schemaQueue = string.Format("[{0}]", _queueName); + } + else + { + _schemaQueue = string.Format("[{0}].[{1}]", startBuilder.UserID, _queueName); + } + + Setup(); + } + + public void Dispose() + { + Cleanup(); + } + + #region StartStop_Tests + + [CheckConnStrSetupFact] + public void Test_DoubleStart_SameConnStr() + { + Assert.True(SqlDependency.Start(_startConnectionString), "Failed to start listener."); + + Assert.False(SqlDependency.Start(_startConnectionString), "Expected failure when trying to start listener."); + + Assert.False(SqlDependency.Stop(_startConnectionString), "Expected failure when trying to completely stop listener."); + + Assert.True(SqlDependency.Stop(_startConnectionString), "Failed to stop listener."); + } + + [CheckConnStrSetupFact] + public void Test_DoubleStart_DifferentConnStr() + { + SqlConnectionStringBuilder cb = new SqlConnectionStringBuilder(_startConnectionString); + + // just change something that doesn't impact the dependency dispatcher + if (cb.ShouldSerialize("connect timeout")) + cb.ConnectTimeout = cb.ConnectTimeout + 1; + else + cb.ConnectTimeout = 50; + + Assert.True(SqlDependency.Start(_startConnectionString), "Failed to start listener."); + + try + { + DataTestUtility.AssertThrowsWrapper(() => SqlDependency.Start(cb.ToString())); + } + finally + { + Assert.True(SqlDependency.Stop(_startConnectionString), "Failed to stop listener."); + + Assert.False(SqlDependency.Stop(cb.ToString()), "Expected failure when trying to completely stop listener."); + } + } + + [CheckConnStrSetupFact] + public void Test_Start_DifferentDB() + { + SqlConnectionStringBuilder cb = new SqlConnectionStringBuilder(_startConnectionString) + { + InitialCatalog = "tempdb" + }; + string altDatabaseConnectionString = cb.ToString(); + + Assert.True(SqlDependency.Start(_startConnectionString), "Failed to start listener."); + + Assert.True(SqlDependency.Start(altDatabaseConnectionString), "Failed to start listener."); + + Assert.True(SqlDependency.Stop(_startConnectionString), "Failed to stop listener."); + + Assert.True(SqlDependency.Stop(altDatabaseConnectionString), "Failed to stop listener."); + } + #endregion + + #region SqlDependency_Tests + + [CheckConnStrSetupFact] + public void Test_SingleDependency_NoStart() + { + using (SqlConnection conn = new SqlConnection(_execConnectionString)) + using (SqlCommand cmd = new SqlCommand("SELECT a, b, c FROM " + _tableName, conn)) + { + conn.Open(); + + SqlDependency dep = new SqlDependency(cmd); + dep.OnChange += delegate (object o, SqlNotificationEventArgs args) + { + Console.WriteLine("4 Notification callback. Type={0}, Info={1}, Source={2}", args.Type, args.Info, args.Source); + }; + + DataTestUtility.AssertThrowsWrapper(() => cmd.ExecuteReader()); + } + } + + [CheckConnStrSetupFact] + public void Test_SingleDependency_Stopped() + { + SqlDependency.Start(_startConnectionString); + SqlDependency.Stop(_startConnectionString); + + using (SqlConnection conn = new SqlConnection(_execConnectionString)) + using (SqlCommand cmd = new SqlCommand("SELECT a, b, c FROM " + _tableName, conn)) + { + conn.Open(); + + SqlDependency dep = new SqlDependency(cmd); + dep.OnChange += delegate (object o, SqlNotificationEventArgs args) + { + // Delegate won't be called, since notifications were stoppped + Console.WriteLine("5 Notification callback. Type={0}, Info={1}, Source={2}", args.Type, args.Info, args.Source); + }; + + DataTestUtility.AssertThrowsWrapper(() => cmd.ExecuteReader()); + } + } + + [CheckConnStrSetupFact] + public void Test_SingleDependency_AllDefaults_SqlAuth() + { + Assert.True(SqlDependency.Start(_startConnectionString), "Failed to start listener."); + + try + { + // create a new event every time to avoid mixing notification callbacks + ManualResetEvent notificationReceived = new ManualResetEvent(false); + ManualResetEvent updateCompleted = new ManualResetEvent(false); + + using (SqlConnection conn = new SqlConnection(_execConnectionString)) + using (SqlCommand cmd = new SqlCommand("SELECT a, b, c FROM " + _tableName, conn)) + { + conn.Open(); + + SqlDependency dep = new SqlDependency(cmd); + dep.OnChange += delegate (object o, SqlNotificationEventArgs arg) + { + Assert.True(updateCompleted.WaitOne(CALLBACK_TIMEOUT, false), "Received notification, but update did not complete."); + + DataTestUtility.AssertEqualsWithDescription(SqlNotificationType.Change, arg.Type, "Unexpected Type value."); + DataTestUtility.AssertEqualsWithDescription(SqlNotificationInfo.Update, arg.Info, "Unexpected Info value."); + DataTestUtility.AssertEqualsWithDescription(SqlNotificationSource.Data, arg.Source, "Unexpected Source value."); + + notificationReceived.Set(); + }; + + cmd.ExecuteReader(); + } + + int count = RunSQL("UPDATE " + _tableName + " SET c=" + Environment.TickCount); + DataTestUtility.AssertEqualsWithDescription(1, count, "Unexpected count value."); + + updateCompleted.Set(); + + Assert.True(notificationReceived.WaitOne(CALLBACK_TIMEOUT, false), "Notification not received within the timeout period"); + } + finally + { + Assert.True(SqlDependency.Stop(_startConnectionString), "Failed to stop listener."); + } + } + + [CheckConnStrSetupFact] + public void Test_SingleDependency_CustomQueue_SqlAuth() + { + Assert.True(SqlDependency.Start(_startConnectionString, _queueName), "Failed to start listener."); + + try + { + // create a new event every time to avoid mixing notification callbacks + ManualResetEvent notificationReceived = new ManualResetEvent(false); + ManualResetEvent updateCompleted = new ManualResetEvent(false); + + using (SqlConnection conn = new SqlConnection(_execConnectionString)) + using (SqlCommand cmd = new SqlCommand("SELECT a, b, c FROM " + _tableName, conn)) + { + conn.Open(); + + SqlDependency dep = new SqlDependency(cmd, "service=" + _serviceName + ";local database=msdb", 0); + dep.OnChange += delegate (object o, SqlNotificationEventArgs args) + { + Assert.True(updateCompleted.WaitOne(CALLBACK_TIMEOUT, false), "Received notification, but update did not complete."); + + Console.WriteLine("7 Notification callback. Type={0}, Info={1}, Source={2}", args.Type, args.Info, args.Source); + notificationReceived.Set(); + }; + + cmd.ExecuteReader(); + } + + int count = RunSQL("UPDATE " + _tableName + " SET c=" + Environment.TickCount); + DataTestUtility.AssertEqualsWithDescription(1, count, "Unexpected count value."); + + updateCompleted.Set(); + + Assert.False(notificationReceived.WaitOne(CALLBACK_TIMEOUT, false), "Notification should not be received."); + } + finally + { + Assert.True(SqlDependency.Stop(_startConnectionString, _queueName), "Failed to stop listener."); + } + } + + /// + /// SqlDependecy premature timeout + /// + [CheckConnStrSetupFact] + public void Test_SingleDependency_Timeout() + { + Assert.True(SqlDependency.Start(_startConnectionString), "Failed to start listener."); + + try + { + // with resolution of 15 seconds, SqlDependency should fire timeout notification only after 45 seconds, leave 5 seconds gap from both sides. + const int SqlDependencyTimerResolution = 15; // seconds + const int testTimeSeconds = SqlDependencyTimerResolution * 3 - 5; + const int minTimeoutEventInterval = testTimeSeconds - 1; + const int maxTimeoutEventInterval = testTimeSeconds + SqlDependencyTimerResolution + 1; + + // create a new event every time to avoid mixing notification callbacks + ManualResetEvent notificationReceived = new ManualResetEvent(false); + DateTime startUtcTime; + + using (SqlConnection conn = new SqlConnection(_execConnectionString)) + using (SqlCommand cmd = new SqlCommand("SELECT a, b, c FROM " + _tableName, conn)) + { + conn.Open(); + + // create SqlDependency with timeout + SqlDependency dep = new SqlDependency(cmd, null, testTimeSeconds); + dep.OnChange += delegate (object o, SqlNotificationEventArgs arg) + { + // notification of Timeout can arrive either from server or from client timer. Handle both situations here: + SqlNotificationInfo info = arg.Info; + if (info == SqlNotificationInfo.Unknown) + { + // server timed out before the client, replace it with Error to produce consistent output for trun + info = SqlNotificationInfo.Error; + } + + DataTestUtility.AssertEqualsWithDescription(SqlNotificationType.Change, arg.Type, "Unexpected Type value."); + DataTestUtility.AssertEqualsWithDescription(SqlNotificationInfo.Error, arg.Info, "Unexpected Info value."); + DataTestUtility.AssertEqualsWithDescription(SqlNotificationSource.Timeout, arg.Source, "Unexpected Source value."); + notificationReceived.Set(); + }; + + cmd.ExecuteReader(); + startUtcTime = DateTime.UtcNow; + } + + Assert.True( + notificationReceived.WaitOne(TimeSpan.FromSeconds(maxTimeoutEventInterval), false), + string.Format("Notification not received within the maximum timeout period of {0} seconds", maxTimeoutEventInterval)); + + // notification received in time, check that it is not too early + TimeSpan notificationTime = DateTime.UtcNow - startUtcTime; + Assert.True( + notificationTime >= TimeSpan.FromSeconds(minTimeoutEventInterval), + string.Format( + "Notification was not expected before {0} seconds: received after {1} seconds", + minTimeoutEventInterval, notificationTime.TotalSeconds)); + } + finally + { + Assert.True(SqlDependency.Stop(_startConnectionString), "Failed to stop listener."); + } + } + + #endregion + + #region Utility_Methods + private static string[] CreateSqlSetupStatements(string tableName, string queueName, string serviceName) + { + return new string[] { + string.Format("CREATE TABLE {0}(a INT NOT NULL, b NVARCHAR(10), c INT NOT NULL)", tableName), + string.Format("INSERT INTO {0} (a, b, c) VALUES (1, 'foo', 0)", tableName), + string.Format("CREATE QUEUE {0}", queueName), + string.Format("CREATE SERVICE [{0}] ON QUEUE {1} ([http://schemas.microsoft.com/SQL/Notifications/PostQueryNotification])", serviceName, queueName) + }; + } + + private static string[] CreateSqlCleanupStatements(string tableName, string queueName, string serviceName) + { + return new string[] { + string.Format("DROP TABLE {0}", tableName), + string.Format("DROP SERVICE [{0}]", serviceName), + string.Format("DROP QUEUE {0}", queueName) + }; + } + + private void Setup() + { + RunSQL(CreateSqlSetupStatements(_tableName, _schemaQueue, _serviceName)); + } + + private void Cleanup() + { + RunSQL(CreateSqlCleanupStatements(_tableName, _schemaQueue, _serviceName)); + } + + private int RunSQL(params string[] stmts) + { + int count = -1; + using (SqlConnection conn = new SqlConnection(_execConnectionString)) + { + conn.Open(); + + SqlCommand cmd = conn.CreateCommand(); + + foreach (string stmt in stmts) + { + cmd.CommandText = stmt; + int tmp = cmd.ExecuteNonQuery(); + count = ((0 <= tmp) ? ((0 <= count) ? count + tmp : tmp) : count); + } + } + return count; + } + + #endregion + } +} diff --git a/src/System.Data.SqlClient/tests/ManualTests/System.Data.SqlClient.ManualTesting.Tests.csproj b/src/System.Data.SqlClient/tests/ManualTests/System.Data.SqlClient.ManualTesting.Tests.csproj index afa45def3d91..64ceb991d9f3 100644 --- a/src/System.Data.SqlClient/tests/ManualTests/System.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/System.Data.SqlClient/tests/ManualTests/System.Data.SqlClient.ManualTesting.Tests.csproj @@ -106,6 +106,7 @@ +