diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 2a4aaef2eb..41b433776e 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -150,6 +150,9 @@ Microsoft\Data\SqlClient\ConnectionPool\SqlConnectionPoolProviderInfo.cs + + Microsoft\Data\SqlClient\ConnectionPool\TransactedConnectionPool.cs + Microsoft\Data\SqlClient\ConnectionPool\WaitHandleDbConnectionPool.cs diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 3f693db65e..684bd3cf39 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -345,6 +345,9 @@ Microsoft\Data\SqlClient\ConnectionPool\SqlConnectionPoolProviderInfo.cs + + Microsoft\Data\SqlClient\ConnectionPool\TransactedConnectionPool.cs + Microsoft\Data\SqlClient\ConnectionPool\WaitHandleDbConnectionPool.cs diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/TransactedConnectionPool.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/TransactedConnectionPool.cs new file mode 100644 index 0000000000..b8e56bd9e3 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/TransactedConnectionPool.cs @@ -0,0 +1,353 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Transactions; +using Microsoft.Data.ProviderBase; + +#nullable enable + +namespace Microsoft.Data.SqlClient.ConnectionPool; + +/// +/// Manages database connections that are enlisted in transactions, providing a specialized +/// pool that groups connections by their associated transactions. This class ensures that +/// connections participating in the same transaction can be efficiently reused while +/// maintaining transaction integrity and thread safety. +/// +/// +/// The TransactedConnectionPool works in conjunction with the main connection pool to handle +/// connections that are enlisted in System.Transactions. When a connection is enlisted in a +/// transaction, it cannot be returned to the general pool until the transaction completes. +/// This class provides temporary storage for such connections, organized by transaction, +/// allowing for efficient reuse within the same transaction scope. +/// +internal class TransactedConnectionPool +{ + /// + /// A specialized list that holds database connections associated with a specific transaction. + /// Maintains a reference to the transaction for proper cleanup when the transaction completes. + /// + private sealed class TransactedConnectionList : List + { + private readonly Transaction _transaction; + + /// + /// Initializes a new instance of the TransactedConnectionList class with the specified + /// initial capacity and associated transaction. + /// + /// The initial number of elements that the list can contain. + /// The transaction associated with the connections in this list. + internal TransactedConnectionList(int initialAllocation, Transaction tx) : base(initialAllocation) + { + _transaction = tx; + } + + /// + /// Releases the resources used by the TransactedConnectionList, including + /// disposing of the associated transaction reference. + /// + internal void Dispose() + { + if (_transaction != null) + { + _transaction.Dispose(); + } + } + } + + #region Fields + + private readonly Dictionary _transactedCxns; + + private static int _objectTypeCount; + internal readonly int _objectID = System.Threading.Interlocked.Increment(ref _objectTypeCount); + + #endregion + + /// + /// Initializes a new instance of the TransactedConnectionPool class for the specified connection pool. + /// + /// The main connection pool that this transacted pool is associated with. + /// + /// The transacted connection pool works as a companion to the main connection pool, + /// temporarily holding connections that are enlisted in transactions until those + /// transactions complete. + /// + internal TransactedConnectionPool(IDbConnectionPool pool) + { + Pool = pool; + _transactedCxns = new Dictionary(); + SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Constructed for connection pool {1}", Id, Pool.Id); + } + + #region Properties + + /// + /// Gets the unique identifier for this transacted connection pool instance. + /// + /// A unique integer identifier used for logging and diagnostics. + internal int Id => _objectID; + + /// + /// Gets the main connection pool that this transacted pool is associated with. + /// + /// The IDbConnectionPool instance that owns this transacted pool. + internal IDbConnectionPool Pool { get; } + + #endregion + + #region Methods + + /// + /// Retrieves a database connection that is already enlisted in the specified transaction. + /// + /// The transaction to look for an existing enlisted connection. + /// + /// A DbConnectionInternal instance that is enlisted in the specified transaction, + /// or null if no such connection is available in the pool. + /// + /// + /// This method is only used when AutoEnlist is true and there is a valid ambient transaction. + /// The method is thread-safe and will return the most recently added connection for the + /// specified transaction. If a connection is found and returned, it is removed from the + /// transacted pool. + /// + internal DbConnectionInternal? GetTransactedObject(Transaction transaction) + { + DbConnectionInternal? transactedObject = null; + + TransactedConnectionList? connections; + bool txnFound = false; + + lock (_transactedCxns) + { + txnFound = _transactedCxns.TryGetValue(transaction, out connections); + } + + // NOTE: GetTransactedObject is only used when AutoEnlist = True and the ambient transaction + // (Sys.Txns.Txn.Current) is still valid/non-null. This, in turn, means that we don't need + // to worry about a pending asynchronous TransactionCompletedEvent to trigger processing in + // TransactionEnded below and potentially wipe out the connections list underneath us. It + // is similarly alright if a pending addition to the connections list in PutTransactedObject + // below is not completed prior to the lock on the connections object here...getting a new + // connection is probably better than unnecessarily locking + if (txnFound && connections is not null) + { + + // synchronize multi-threaded access with PutTransactedObject (TransactionEnded should + // not be a concern, see comments above) + lock (connections) + { + int i = connections.Count - 1; + if (0 <= i) + { + transactedObject = connections[i]; + connections.RemoveAt(i); + } + } + } + + if (transactedObject != null) + { + SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Popped.", Id, transaction.GetHashCode(), transactedObject.ObjectID); + } + return transactedObject; + } + + /// + /// Adds a database connection to the transacted pool, associating it with the specified transaction. + /// + /// The transaction that the connection is enlisted in. + /// The database connection to add to the transacted pool. + /// + /// This method handles the complex synchronization required when multiple threads may be + /// attempting to add connections to the same transaction pool simultaneously. If a pool + /// for the specified transaction doesn't exist, it will be created. The method uses + /// transaction cloning to ensure that the pool maintains a valid reference to the transaction + /// even if the original transaction object is disposed. + /// + /// Due to the asynchronous nature of transaction completion notifications, this method + /// must handle race conditions where TransactionEnded might be called before or during + /// the execution of this method. + /// + internal void PutTransactedObject(Transaction transaction, DbConnectionInternal transactedObject) + { + TransactedConnectionList? connections; + bool txnFound = false; + + // NOTE: because TransactionEnded is an asynchronous notification, there's no guarantee + // around the order in which PutTransactionObject and TransactionEnded are called. + + lock (_transactedCxns) + { + // Check if a transacted pool has been created for this transaction + if ((txnFound = _transactedCxns.TryGetValue(transaction, out connections)) + && connections is not null) + { + // synchronize multi-threaded access with GetTransactedObject + lock (connections) + { + // TODO: validate that we're not adding the same connection twice? + // Debug.Assert(0 > connections.IndexOf(transactedObject), "adding to pool a second time?"); + SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Pushing.", Id, transaction.GetHashCode(), transactedObject.ObjectID); + connections.Add(transactedObject); + } + } + } + + // CONSIDER: the following code is more complicated than it needs to be to avoid cloning the + // transaction and allocating memory within a lock. Is that complexity really necessary? + if (!txnFound) + { + // create the transacted pool, making sure to clone the associated transaction + // for use as a key in our internal dictionary of transactions and connections + Transaction? transactionClone = null; + TransactedConnectionList? newConnections = null; + + try + { + transactionClone = transaction.Clone(); + newConnections = new TransactedConnectionList(2, transactionClone); // start with only two connections in the list; most times we won't need that many. + + lock (_transactedCxns) + { + // NOTE: in the interim between the locks on the transacted pool (this) during + // execution of this method, another thread (threadB) may have attempted to + // add a different connection to the transacted pool under the same + // transaction. As a result, threadB may have completed creating the + // transacted pool while threadA was processing the above instructions. + if (_transactedCxns.TryGetValue(transaction, out connections) + && connections is not null) + { + // synchronize multi-threaded access with GetTransactedObject + lock (connections) + { + Debug.Assert(0 > connections.IndexOf(transactedObject), "adding to pool a second time?"); + SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Pushing.", Id, transaction.GetHashCode(), transactedObject.ObjectID); + connections.Add(transactedObject); + } + } + else + { + SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Adding List to transacted pool.", Id, transaction.GetHashCode(), transactedObject.ObjectID); + + // add the connection/transacted object to the list + newConnections.Add(transactedObject); + + _transactedCxns.Add(transactionClone, newConnections); + transactionClone = null; // we've used it -- don't throw it or the TransactedConnectionList that references it away. + } + } + } + finally + { + if (transactionClone != null) + { + if (newConnections != null) + { + // another thread created the transaction pool and thus the new + // TransactedConnectionList was not used, so dispose of it and + // the transaction clone that it incorporates. + newConnections.Dispose(); + } + else + { + // memory allocation for newConnections failed...clean up unused transactionClone + transactionClone.Dispose(); + } + } + } + SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Added.", Id, transaction.GetHashCode(), transactedObject.ObjectID); + } + + SqlClientEventSource.Metrics.EnterFreeConnection(); + } + + /// + /// Handles the completion of a transaction by removing the associated connection from the + /// transacted pool and returning it to the main connection pool. + /// + /// The transaction that has completed. + /// The database connection that was enlisted in the completed transaction. + /// + /// This method is called when a transaction completes (either by committing or rolling back). + /// It removes the specified connection from the transacted pool and returns it to the main + /// connection pool for reuse. If this was the last connection in the transaction's pool, + /// the entire transaction pool entry is removed and disposed. + /// + /// Due to the asynchronous nature of transaction completion notifications, there's no guarantee + /// about the order in which PutTransactedObject and TransactionEnded are called. This method + /// handles cases where the transaction pool may not yet exist when the transaction completes. + /// + internal void TransactionEnded(Transaction transaction, DbConnectionInternal transactedObject) + { + SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Transaction Completed", Id, transaction.GetHashCode(), transactedObject.ObjectID); + TransactedConnectionList? connections; + int entry = -1; + + // NOTE: because TransactionEnded is an asynchronous notification, there's no guarantee + // around the order in which PutTransactionObject and TransactionEnded are called. As + // such, it is possible that the transaction does not yet have a pool created. + + // TODO: is this a plausible and/or likely scenario? Do we need to have a mechanism to ensure + // TODO: that the pending creation of a transacted pool for this transaction is aborted when + // TODO: PutTransactedObject finally gets some CPU time? + + lock (_transactedCxns) + { + if (_transactedCxns.TryGetValue(transaction, out connections) + && connections is not null) + { + bool shouldDisposeConnections = false; + + // Lock connections to avoid conflict with GetTransactionObject + lock (connections) + { + entry = connections.IndexOf(transactedObject); + + if (entry >= 0) + { + connections.RemoveAt(entry); + } + + // Once we've completed all the ended notifications, we can + // safely remove the list from the transacted pool. + if (0 >= connections.Count) + { + SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Removing List from transacted pool.", Id, transaction.GetHashCode()); + _transactedCxns.Remove(transaction); + + // we really need to dispose our connection list; it may have + // native resources via the tx and GC may not happen soon enough. + shouldDisposeConnections = true; + } + } + + if (shouldDisposeConnections) + { + connections.Dispose(); + } + } + else + { + SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Transacted pool not yet created prior to transaction completing. Connection may be leaked.", Id, transaction.GetHashCode(), transactedObject.ObjectID); + } + } + + // If (and only if) we found the connection in the list of + // connections, we'll put it back... + if (0 <= entry) + { + // TODO: can we give this responsibility to the main pool? + // The bi-directional dependency between the main pool and this pool + // is messy and hard to understand. + SqlClientEventSource.Metrics.ExitFreeConnection(); + Pool.PutObjectFromTransactedPool(transactedObject); + } + } + + #endregion +} \ No newline at end of file diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/WaitHandleDbConnectionPool.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/WaitHandleDbConnectionPool.cs index ab54f4731e..c588fcd3f7 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/WaitHandleDbConnectionPool.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/WaitHandleDbConnectionPool.cs @@ -59,22 +59,6 @@ internal sealed class WaitHandleDbConnectionPool : IDbConnectionPool // This class is a way to stash our cloned Tx key for later disposal when it's no longer needed. // We can't get at the key in the dictionary without enumerating entries, so we stash an extra // copy as part of the value. - private sealed class TransactedConnectionList : List - { - private Transaction _transaction; - internal TransactedConnectionList(int initialAllocation, Transaction tx) : base(initialAllocation) - { - _transaction = tx; - } - - internal void Dispose() - { - if (_transaction != null) - { - _transaction.Dispose(); - } - } - } private sealed class PendingGetConnection { @@ -91,250 +75,6 @@ public PendingGetConnection(long dueTime, DbConnection owner, TaskCompletionSour public DbConnectionOptions UserOptions { get; private set; } } - private sealed class TransactedConnectionPool - { - Dictionary _transactedCxns; - - IDbConnectionPool _pool; - - private static int _objectTypeCount; // EventSource Counter - internal readonly int _objectID = System.Threading.Interlocked.Increment(ref _objectTypeCount); - - internal TransactedConnectionPool(IDbConnectionPool pool) - { - Debug.Assert(pool != null, "null pool?"); - - _pool = pool; - _transactedCxns = new Dictionary(); - SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Constructed for connection pool {1}", ObjectID, _pool.Id); - } - - internal int ObjectID - { - get - { - return _objectID; - } - } - - internal IDbConnectionPool Pool - { - get - { - return _pool; - } - } - - internal DbConnectionInternal GetTransactedObject(Transaction transaction) - { - Debug.Assert(transaction != null, "null transaction?"); - - DbConnectionInternal transactedObject = null; - - TransactedConnectionList connections; - bool txnFound = false; - - lock (_transactedCxns) - { - txnFound = _transactedCxns.TryGetValue(transaction, out connections); - } - - // NOTE: GetTransactedObject is only used when AutoEnlist = True and the ambient transaction - // (Sys.Txns.Txn.Current) is still valid/non-null. This, in turn, means that we don't need - // to worry about a pending asynchronous TransactionCompletedEvent to trigger processing in - // TransactionEnded below and potentially wipe out the connections list underneath us. It - // is similarly alright if a pending addition to the connections list in PutTransactedObject - // below is not completed prior to the lock on the connections object here...getting a new - // connection is probably better than unnecessarily locking - if (txnFound) - { - Debug.Assert(connections != null); - - // synchronize multi-threaded access with PutTransactedObject (TransactionEnded should - // not be a concern, see comments above) - lock (connections) - { - int i = connections.Count - 1; - if (0 <= i) - { - transactedObject = connections[i]; - connections.RemoveAt(i); - } - } - } - - if (transactedObject != null) - { - SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Popped.", ObjectID, transaction.GetHashCode(), transactedObject.ObjectID); - } - return transactedObject; - } - - internal void PutTransactedObject(Transaction transaction, DbConnectionInternal transactedObject) - { - Debug.Assert(transaction != null, "null transaction?"); - Debug.Assert(transactedObject != null, "null transactedObject?"); - - TransactedConnectionList connections; - bool txnFound = false; - - // NOTE: because TransactionEnded is an asynchronous notification, there's no guarantee - // around the order in which PutTransactionObject and TransactionEnded are called. - - lock (_transactedCxns) - { - // Check if a transacted pool has been created for this transaction - if (txnFound = _transactedCxns.TryGetValue(transaction, out connections)) - { - Debug.Assert(connections != null); - - // synchronize multi-threaded access with GetTransactedObject - lock (connections) - { - Debug.Assert(0 > connections.IndexOf(transactedObject), "adding to pool a second time?"); - SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Pushing.", ObjectID, transaction.GetHashCode(), transactedObject.ObjectID); - connections.Add(transactedObject); - } - } - } - - // CONSIDER: the following code is more complicated than it needs to be to avoid cloning the - // transaction and allocating memory within a lock. Is that complexity really necessary? - if (!txnFound) - { - // create the transacted pool, making sure to clone the associated transaction - // for use as a key in our internal dictionary of transactions and connections - Transaction transactionClone = null; - TransactedConnectionList newConnections = null; - - try - { - transactionClone = transaction.Clone(); - newConnections = new TransactedConnectionList(2, transactionClone); // start with only two connections in the list; most times we won't need that many. - - lock (_transactedCxns) - { - // NOTE: in the interim between the locks on the transacted pool (this) during - // execution of this method, another thread (threadB) may have attempted to - // add a different connection to the transacted pool under the same - // transaction. As a result, threadB may have completed creating the - // transacted pool while threadA was processing the above instructions. - if (txnFound = _transactedCxns.TryGetValue(transaction, out connections)) - { - Debug.Assert(connections != null); - - // synchronize multi-threaded access with GetTransactedObject - lock (connections) - { - Debug.Assert(0 > connections.IndexOf(transactedObject), "adding to pool a second time?"); - SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Pushing.", ObjectID, transaction.GetHashCode(), transactedObject.ObjectID); - connections.Add(transactedObject); - } - } - else - { - SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Adding List to transacted pool.", ObjectID, transaction.GetHashCode(), transactedObject.ObjectID); - - // add the connection/transacted object to the list - newConnections.Add(transactedObject); - - _transactedCxns.Add(transactionClone, newConnections); - transactionClone = null; // we've used it -- don't throw it or the TransactedConnectionList that references it away. - } - } - } - finally - { - if (transactionClone != null) - { - if (newConnections != null) - { - // another thread created the transaction pool and thus the new - // TransactedConnectionList was not used, so dispose of it and - // the transaction clone that it incorporates. - newConnections.Dispose(); - } - else - { - // memory allocation for newConnections failed...clean up unused transactionClone - transactionClone.Dispose(); - } - } - } - SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Added.", ObjectID, transaction.GetHashCode(), transactedObject.ObjectID); - } - - SqlClientEventSource.Metrics.EnterFreeConnection(); - } - - internal void TransactionEnded(Transaction transaction, DbConnectionInternal transactedObject) - { - SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Transaction Completed", ObjectID, transaction.GetHashCode(), transactedObject.ObjectID); - TransactedConnectionList connections; - int entry = -1; - - // NOTE: because TransactionEnded is an asynchronous notification, there's no guarantee - // around the order in which PutTransactionObject and TransactionEnded are called. As - // such, it is possible that the transaction does not yet have a pool created. - - // TODO: is this a plausible and/or likely scenario? Do we need to have a mechanism to ensure - // TODO: that the pending creation of a transacted pool for this transaction is aborted when - // TODO: PutTransactedObject finally gets some CPU time? - - lock (_transactedCxns) - { - if (_transactedCxns.TryGetValue(transaction, out connections)) - { - Debug.Assert(connections != null); - - bool shouldDisposeConnections = false; - - // Lock connections to avoid conflict with GetTransactionObject - lock (connections) - { - entry = connections.IndexOf(transactedObject); - - if (entry >= 0) - { - connections.RemoveAt(entry); - } - - // Once we've completed all the ended notifications, we can - // safely remove the list from the transacted pool. - if (0 >= connections.Count) - { - SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Removing List from transacted pool.", ObjectID, transaction.GetHashCode()); - _transactedCxns.Remove(transaction); - - // we really need to dispose our connection list; it may have - // native resources via the tx and GC may not happen soon enough. - shouldDisposeConnections = true; - } - } - - if (shouldDisposeConnections) - { - connections.Dispose(); - } - } - else - { - SqlClientEventSource.Log.TryPoolerTraceEvent(" {0}, Transaction {1}, Connection {2}, Transacted pool not yet created prior to transaction completing. Connection may be leaked.", ObjectID, transaction.GetHashCode(), transactedObject.ObjectID); - } - } - - // If (and only if) we found the connection in the list of - // connections, we'll put it back... - if (0 <= entry) - { - - SqlClientEventSource.Metrics.ExitFreeConnection(); - Pool.PutObjectFromTransactedPool(transactedObject); - } - } - - } - private sealed class PoolWaitHandles { private readonly Semaphore _poolSemaphore; diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/ConnectionPool/ChannelDbConnectionPoolTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/ConnectionPool/ChannelDbConnectionPoolTest.cs similarity index 100% rename from src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/ConnectionPool/ChannelDbConnectionPoolTest.cs rename to src/Microsoft.Data.SqlClient/tests/UnitTests/ConnectionPool/ChannelDbConnectionPoolTest.cs diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/ConnectionPool/ConnectionPoolSlotsTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/ConnectionPool/ConnectionPoolSlotsTest.cs similarity index 100% rename from src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/ConnectionPool/ConnectionPoolSlotsTest.cs rename to src/Microsoft.Data.SqlClient/tests/UnitTests/ConnectionPool/ConnectionPoolSlotsTest.cs diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/ConnectionPool/TransactedConnectionPoolTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/ConnectionPool/TransactedConnectionPoolTest.cs new file mode 100644 index 0000000000..8139767c74 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/ConnectionPool/TransactedConnectionPoolTest.cs @@ -0,0 +1,738 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient.ConnectionPool; +using Microsoft.Data.ProviderBase; +using Xunit; +using System.Data; +using System.Data.Common; +using System.Transactions; +using System.Collections.Concurrent; +using System.Threading; +using Microsoft.Data.Common.ConnectionString; +using System.Collections.Generic; +using System.Linq; + +#nullable enable + +namespace Microsoft.Data.SqlClient.UnitTests.ConnectionPool; + +public class TransactedConnectionPoolTest +{ + #region Constructor Tests + + [Fact] + public void Constructor_WithValidPool_SetsPoolProperty() + { + // Arrange + var mockPool = new MockDbConnectionPool(); + + // Act + var transactedPool = new TransactedConnectionPool(mockPool); + + // Assert + Assert.Same(mockPool, transactedPool.Pool); + Assert.True(transactedPool.Id > 0); + } + + [Fact] + public void Constructor_UniqueIds() + { + // Arrange + var pool1 = new TransactedConnectionPool(new MockDbConnectionPool()); + var pool2 = new TransactedConnectionPool(new MockDbConnectionPool()); + + // Act & Assert + Assert.NotEqual(pool1.Id, pool2.Id); + Assert.True(pool1.Id > 0); + Assert.True(pool2.Id > 0); + } + + #endregion + + #region GetTransactedObject Tests + + [Fact] + public void GetTransactedObject_WithNonExistentTransaction_ReturnsNull() + { + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Act + var result = transactedPool.GetTransactedObject(transaction); + + // Assert + Assert.Null(result); + } + + [Fact] + public void GetTransactedObject_WithExistingTransaction_ReturnsAndRemovesConnection() + { + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + var connection = new MockDbConnectionInternal(); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // First add a connection + transactedPool.PutTransactedObject(transaction, connection); + + // Act + var result = transactedPool.GetTransactedObject(transaction); + + // Assert + Assert.Same(connection, result); + + // Verify the connection is removed (second call should return null) + var secondResult = transactedPool.GetTransactedObject(transaction); + Assert.Null(secondResult); + } + + [Fact] + public void GetTransactedObject_WithMultipleConnections_ReturnsLastAdded() + { + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + var connection1 = new MockDbConnectionInternal(); + var connection2 = new MockDbConnectionInternal(); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Add multiple connections + transactedPool.PutTransactedObject(transaction, connection1); + transactedPool.PutTransactedObject(transaction, connection2); + + // Act + var result = transactedPool.GetTransactedObject(transaction); + + // Assert + Assert.Same(connection2, result); // Should return the last added (LIFO behavior) + } + + [Fact] + public void GetTransactedObject_ConcurrentAccess_ThreadSafe() + { + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + var connections = new DbConnectionInternal[10]; + for (int i = 0; i < connections.Length; i++) + { + connections[i] = new MockDbConnectionInternal(); + } + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Add all connections + foreach (var conn in connections) + { + transactedPool.PutTransactedObject(transaction, conn); + } + + var retrievedConnections = new ConcurrentBag(); + var tasks = new Task[connections.Length]; + + // Act - retrieve connections concurrently + for (int i = 0; i < tasks.Length; i++) + { + tasks[i] = Task.Run(() => + { + var conn = transactedPool.GetTransactedObject(transaction); + Assert.NotNull(conn); + retrievedConnections.Add(conn); + }); + } + + Task.WaitAll(tasks); + + // Assert + Assert.Equal(connections.Length, retrievedConnections.Count); + Assert.True(connections.All(retrievedConnections.Contains)); + } + + #endregion + + #region PutTransactedObject Tests + + [Fact] + public void PutTransactedObject_WithNullConnection_ThrowsArgumentNullException() + { + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Act & Assert + Assert.Throws(() => + transactedPool.PutTransactedObject(transaction, null!)); + } + + [Fact] + public void PutTransactedObject_WithNewTransaction_CreatesNewConnectionList() + { + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + var connection = new MockDbConnectionInternal(); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Act + transactedPool.PutTransactedObject(transaction, connection); + + // Assert + var retrievedConnection = transactedPool.GetTransactedObject(transaction); + Assert.Same(connection, retrievedConnection); + } + + [Fact] + public void PutTransactedObject_WithExistingTransaction_AddsToExistingConnectionList() + { + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + var connection1 = new MockDbConnectionInternal(); + var connection2 = new MockDbConnectionInternal(); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Act + transactedPool.PutTransactedObject(transaction, connection1); + transactedPool.PutTransactedObject(transaction, connection2); + + // Assert + var retrieved1 = transactedPool.GetTransactedObject(transaction); + var retrieved2 = transactedPool.GetTransactedObject(transaction); + + Assert.Same(connection2, retrieved1); // Last in, first out + Assert.Same(connection1, retrieved2); + } + + [Fact] + public void PutTransactedObject_ConcurrentAccess_ThreadSafe() + { + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + var connections = new DbConnectionInternal[10]; + for (int i = 0; i < connections.Length; i++) + { + connections[i] = new MockDbConnectionInternal(); + } + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + var tasks = new Task[connections.Length]; + + // Act - add connections concurrently + for (int i = 0; i < tasks.Length; i++) + { + var connection = connections[i]; + tasks[i] = Task.Run(() => transactedPool.PutTransactedObject(transaction, connection)); + } + + Task.WaitAll(tasks); + + // Assert - all connections should be retrievable + var retrievedConnections = new List(); + DbConnectionInternal? conn; + while ((conn = transactedPool.GetTransactedObject(transaction)) != null) + { + retrievedConnections.Add(conn); + } + + Assert.Equal(connections.Length, retrievedConnections.Count); + Assert.True(connections.All(retrievedConnections.Contains)); + } + + [Fact] + public void PutTransactedObject_SameConnectionTwice_AddsToPoolTwice() + { + // TODO: this behavior is suspicious should we prevent this? + + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + var connection = new MockDbConnectionInternal(); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Act + transactedPool.PutTransactedObject(transaction, connection); + transactedPool.PutTransactedObject(transaction, connection); + + // Assert + var retrieved1 = transactedPool.GetTransactedObject(transaction); + var retrieved2 = transactedPool.GetTransactedObject(transaction); + + Assert.Same(connection, retrieved1); + Assert.Same(connection, retrieved2); + } + + #endregion + + #region TransactionEnded Tests + + [Fact] + public void TransactionEnded_WithNullConnection_ThrowsNullReferenceException() + { + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Act & Assert + Assert.Throws(() => + transactedPool.TransactionEnded(transaction, null!)); + } + + [Fact] + public void TransactionEnded_WithNonExistentTransaction_DoesNotThrow() + { + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + var connection = new MockDbConnectionInternal(); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Act & Assert (should not throw) + transactedPool.TransactionEnded(transaction, connection); + // TODO: is this really the behavior we want? + } + + [Fact] + public void TransactionEnded_WithExistingConnection_RemovesConnectionAndReturnsToPool() + { + // Arrange + var mockPool = new MockDbConnectionPool(); + var transactedPool = new TransactedConnectionPool(mockPool); + var connection = new MockDbConnectionInternal(); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Add connection to transacted pool + transactedPool.PutTransactedObject(transaction, connection); + + // Act + transactedPool.TransactionEnded(transaction, connection); + + // Assert + Assert.Contains(connection, mockPool.ReturnedConnections); + + // Verify connection is no longer in transacted pool + var retrievedConnection = transactedPool.GetTransactedObject(transaction); + Assert.Null(retrievedConnection); + } + + [Fact] + public void TransactionEnded_WithMultipleConnections_RemovesOnlySpecifiedConnection() + { + // Arrange + var mockPool = new MockDbConnectionPool(); + var transactedPool = new TransactedConnectionPool(mockPool); + var connection1 = new MockDbConnectionInternal(); + var connection2 = new MockDbConnectionInternal(); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Add multiple connections + transactedPool.PutTransactedObject(transaction, connection1); + transactedPool.PutTransactedObject(transaction, connection2); + + // Act - end only one connection + transactedPool.TransactionEnded(transaction, connection1); + + // Assert + Assert.Contains(connection1, mockPool.ReturnedConnections); + Assert.DoesNotContain(connection2, mockPool.ReturnedConnections); + + // Verify other connection is still in pool + // TODO: there shouldn't be partial state in the pool after the transaction ends + // May be a way to register a single callback to clear the whole list. + var retrievedConnection = transactedPool.GetTransactedObject(transaction); + Assert.Same(connection2, retrievedConnection); + } + + [Fact] + public void TransactionEnded_WithConnectionNotInPool_DoesNotReturnToMainPool() + { + // Arrange + var mockPool = new MockDbConnectionPool(); + var transactedPool = new TransactedConnectionPool(mockPool); + var connection = new MockDbConnectionInternal(); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Don't add connection to transacted pool + + // Act + transactedPool.TransactionEnded(transaction, connection); + + // Assert + Assert.DoesNotContain(connection, mockPool.ReturnedConnections); + } + + [Fact] + public void TransactionEnded_ConcurrentAccess_ThreadSafe() + { + // Arrange + var mockPool = new MockDbConnectionPool(); + var transactedPool = new TransactedConnectionPool(mockPool); + var connections = new DbConnectionInternal[10]; + for (int i = 0; i < connections.Length; i++) + { + connections[i] = new MockDbConnectionInternal(); + } + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Add all connections + foreach (var conn in connections) + { + transactedPool.PutTransactedObject(transaction, conn); + } + + var tasks = new Task[connections.Length]; + + // Act - end transactions concurrently + for (int i = 0; i < tasks.Length; i++) + { + var connection = connections[i]; + tasks[i] = Task.Run(() => transactedPool.TransactionEnded(transaction, connection)); + } + + Task.WaitAll(tasks); + + // Assert + Assert.Equal(connections.Length, mockPool.ReturnedConnections.Count); + Assert.True(connections.All(mockPool.ReturnedConnections.Contains)); + } + + [Fact] + public void TransactionEnded_MultipleCallsWithSameConnection_OnlyReturnsOnce() + { + // Arrange + var mockPool = new MockDbConnectionPool(); + var transactedPool = new TransactedConnectionPool(mockPool); + var connection = new MockDbConnectionInternal(); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Add connection to transacted pool + transactedPool.PutTransactedObject(transaction, connection); + + // Act - call TransactionEnded multiple times + transactedPool.TransactionEnded(transaction, connection); + transactedPool.TransactionEnded(transaction, connection); + transactedPool.TransactionEnded(transaction, connection); + + // Assert - connection should only be returned to pool once + Assert.Single(mockPool.ReturnedConnections); + Assert.Contains(connection, mockPool.ReturnedConnections); + } + + [Fact] + public void TransactionEnded_CalledBeforePut_HandlesRaceCondition() + { + // TODO: this test shows that we actually don't handle the race correctly + // we shouldn't allow connections associated with ended transactions in the pool + + // Arrange + var mockPool = new MockDbConnectionPool(); + var transactedPool = new TransactedConnectionPool(mockPool); + var connection = new MockDbConnectionInternal(); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Act - simulate race condition where TransactionEnded is called before PutTransactedObject + transactedPool.TransactionEnded(transaction, connection); + transactedPool.PutTransactedObject(transaction, connection); + + // Assert - connection should still be in the transacted pool + var retrievedConnection = transactedPool.GetTransactedObject(transaction); + Assert.Same(connection, retrievedConnection); + } + + #endregion + + #region Integration Tests + + [Fact] + public void FullLifecycle_PutGetEnd_WorksCorrectly() + { + // Arrange + var mockPool = new MockDbConnectionPool(); + var transactedPool = new TransactedConnectionPool(mockPool); + var connection = new MockDbConnectionInternal(); + + using var transactionScope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Act & Assert + // 1. Put connection in transacted pool + transactedPool.PutTransactedObject(transaction, connection); + + // 2. Get connection from transacted pool + var retrievedConnection = transactedPool.GetTransactedObject(transaction); + Assert.Same(connection, retrievedConnection); + + // 3. Put it back + transactedPool.PutTransactedObject(transaction, connection); + + // 4. End transaction + transactedPool.TransactionEnded(transaction, connection); + + // 5. Verify connection returned to main pool + Assert.Contains(connection, mockPool.ReturnedConnections); + + // 6. Verify transacted pool is empty + var finalRetrieved = transactedPool.GetTransactedObject(transaction); + Assert.Null(finalRetrieved); + } + + [Fact] + public void MultipleTransactions_IsolatedCorrectly() + { + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + var connection1 = new MockDbConnectionInternal(); + var connection2 = new MockDbConnectionInternal(); + + Transaction? transaction1 = null; + Transaction? transaction2 = null; + + using (new TransactionScope()) + { + transaction1 = Transaction.Current; + transactedPool.PutTransactedObject(transaction1!, connection1); + } + + using (new TransactionScope()) + { + transaction2 = Transaction.Current; + transactedPool.PutTransactedObject(transaction2!, connection2); + } + + // Act & Assert + var retrieved1 = transactedPool.GetTransactedObject(transaction1!); + var retrieved2 = transactedPool.GetTransactedObject(transaction2!); + + Assert.Same(connection1, retrieved1); + Assert.Same(connection2, retrieved2); + } + + [Fact] + public void ConcurrentPutAndGet_DifferentTransactions_Isolated() + { + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + var numberOfTransactions = 5; + var connectionsPerTransaction = 3; + var results = new ConcurrentDictionary>(); + using var countdown = new CountdownEvent(numberOfTransactions); + + // Act - create multiple transactions concurrently + var tasks = Enumerable.Range(0, numberOfTransactions).Select(txIndex => + { + return Task.Run(() => + { + try + { + using var scope = new TransactionScope(); + var transaction = Transaction.Current!; + + // Add connections to this transaction + for (int i = 0; i < connectionsPerTransaction; i++) + { + var conn = new MockDbConnectionInternal(); + transactedPool.PutTransactedObject(transaction, conn); + } + + // Retrieve connections from this transaction + var retrieved = new List(); + DbConnectionInternal? retrievedConn; + while ((retrievedConn = transactedPool.GetTransactedObject(transaction)) != null) + { + retrieved.Add(retrievedConn); + } + + results[txIndex] = retrieved; + } + finally + { + countdown.Signal(); + } + }); + }).ToList(); + + // Wait for all tasks to complete + Task.WaitAll(tasks.ToArray()); + + // Assert - each transaction should have isolated connections + Assert.Equal(numberOfTransactions, results.Count); + + foreach (var result in results.Values) + { + Assert.Equal(connectionsPerTransaction, result.Count); + } + + // Verify no overlap between transactions + var allConnections = results.Values.SelectMany(r => r).ToList(); + Assert.Equal(allConnections.Count, allConnections.Distinct().Count()); + } + + [Fact] + public void TransactionScope_CompleteAndDispose_HandledCorrectly() + { + // TODO: this test shows that we don't give strong guarantees that + // the pool state will match the transaction state. + + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + var connection = new MockDbConnectionInternal(); + Transaction? capturedTransaction = null; + + // Act + using (var scope = new TransactionScope()) + { + capturedTransaction = Transaction.Current!; + transactedPool.PutTransactedObject(capturedTransaction, connection); + scope.Complete(); + } // TransactionScope disposes here + + // Assert - connection should still be retrievable if transaction completed + var retrieved = transactedPool.GetTransactedObject(capturedTransaction!); + Assert.Same(connection, retrieved); + } + + [Fact] + public void PutTransactedObject_WithDisposedTransaction_HandlesGracefully() + { + //TODO: this test should not pass! why would we store connections from a disposed transaction? + + // Arrange + var transactedPool = new TransactedConnectionPool(new MockDbConnectionPool()); + var connection = new MockDbConnectionInternal(); + Transaction? disposedTransaction = null; + + using (var scope = new TransactionScope()) + { + disposedTransaction = Transaction.Current!; + } // Transaction is now disposed + + // Act & Assert - should handle gracefully without throwing + try + { + transactedPool.PutTransactedObject(disposedTransaction!, connection); + // If no exception, test passes + Assert.True(true); + } + catch (ObjectDisposedException) + { + // This is expected behavior and acceptable + Assert.True(true); + } + } + + #endregion + + #region Mock Classes + + internal class MockDbConnectionPool : IDbConnectionPool + { + public ConcurrentDictionary AuthenticationContexts { get; } = new(); + public SqlConnectionFactory ConnectionFactory => throw new NotImplementedException(); + public int Count => throw new NotImplementedException(); + public bool ErrorOccurred => throw new NotImplementedException(); + public int Id { get; } = 1; + public DbConnectionPoolIdentity Identity => throw new NotImplementedException(); + public bool IsRunning => throw new NotImplementedException(); + public TimeSpan LoadBalanceTimeout => throw new NotImplementedException(); + public DbConnectionPoolGroup PoolGroup => throw new NotImplementedException(); + public DbConnectionPoolGroupOptions PoolGroupOptions => throw new NotImplementedException(); + public DbConnectionPoolProviderInfo ProviderInfo => throw new NotImplementedException(); + public DbConnectionPoolState State => throw new NotImplementedException(); + public bool UseLoadBalancing => throw new NotImplementedException(); + + public ConcurrentBag ReturnedConnections { get; } = new(); + + public void Clear() => throw new NotImplementedException(); + + public bool TryGetConnection(DbConnection owningObject, TaskCompletionSource taskCompletionSource, DbConnectionOptions userOptions, out DbConnectionInternal? connection) + { + throw new NotImplementedException(); + } + + public DbConnectionInternal ReplaceConnection(DbConnection owningObject, DbConnectionOptions userOptions, DbConnectionInternal oldConnection) + { + throw new NotImplementedException(); + } + + public void ReturnInternalConnection(DbConnectionInternal obj, DbConnection owningObject) + { + throw new NotImplementedException(); + } + + public void PutObjectFromTransactedPool(DbConnectionInternal obj) + { + ReturnedConnections.Add(obj); + } + + public void Startup() => throw new NotImplementedException(); + + public void Shutdown() => throw new NotImplementedException(); + + public void TransactionEnded(Transaction transaction, DbConnectionInternal transactedObject) + { + throw new NotImplementedException(); + } + } + + internal class MockDbConnectionInternal : DbConnectionInternal + { + private static int s_nextId = 1; + public int MockId { get; } = Interlocked.Increment(ref s_nextId); + + public override string ServerVersion => "Mock"; + + public override DbTransaction BeginTransaction(System.Data.IsolationLevel il) + { + throw new NotImplementedException(); + } + + public override void EnlistTransaction(Transaction transaction) + { + // Mock implementation - do nothing + } + + protected override void Activate(Transaction transaction) + { + // Mock implementation - do nothing + } + + protected override void Deactivate() + { + // Mock implementation - do nothing + } + + public override string ToString() => $"MockConnection_{MockId}"; + } + + #endregion +} \ No newline at end of file