From c0d47f098d1028ac8cf8d29ee973b334de29628a Mon Sep 17 00:00:00 2001 From: Jared McCannon Date: Tue, 14 Apr 2026 08:44:12 -0500 Subject: [PATCH 1/9] initial work for transaction --- src/Core/Platform/Data/ITransactionManager.cs | 21 ++++++ src/Core/Platform/Data/ITransactionScope.cs | 11 +++ .../Platform/Data/NestedTransactionScope.cs | 37 ++++++++++ .../Platform/Data/RootTransactionScope.cs | 42 +++++++++++ src/Core/Platform/Data/TransactionState.cs | 56 ++++++++++++++ .../Data/DapperTransactionManager.cs | 40 ++++++++++ .../Repositories/BaseRepository.cs | 73 ++++++++++++++++++- .../Repositories/Repository.cs | 21 +++--- .../Data/EfTransactionManager.cs | 47 ++++++++++++ .../BaseEntityFrameworkRepository.cs | 53 ++++++++++++++ .../Repositories/Repository.cs | 49 +++++++------ .../Utilities/ServiceCollectionExtensions.cs | 5 ++ 12 files changed, 422 insertions(+), 33 deletions(-) create mode 100644 src/Core/Platform/Data/ITransactionManager.cs create mode 100644 src/Core/Platform/Data/ITransactionScope.cs create mode 100644 src/Core/Platform/Data/NestedTransactionScope.cs create mode 100644 src/Core/Platform/Data/RootTransactionScope.cs create mode 100644 src/Core/Platform/Data/TransactionState.cs create mode 100644 src/Infrastructure.Dapper/Data/DapperTransactionManager.cs create mode 100644 src/Infrastructure.EntityFramework/Data/EfTransactionManager.cs diff --git a/src/Core/Platform/Data/ITransactionManager.cs b/src/Core/Platform/Data/ITransactionManager.cs new file mode 100644 index 000000000000..c9a38adf9e89 --- /dev/null +++ b/src/Core/Platform/Data/ITransactionManager.cs @@ -0,0 +1,21 @@ +namespace Bit.Core.Platform.Data; + +/// +/// Manages ambient database transactions that span multiple repository calls. +/// Implementations are singleton-safe; transaction state is stored per async flow. +/// +public interface ITransactionManager +{ + /// + /// Begins a new ambient transaction. All repository operations on the current + /// async flow will use the same connection and transaction until disposed. + /// Supports nesting: inner calls increment a reference count; only the + /// outermost Dispose/Commit actually affects the database. + /// + Task BeginTransactionAsync(CancellationToken cancellationToken = default); + + /// + /// Returns true if the current async flow has an active ambient transaction. + /// + bool HasActiveTransaction { get; } +} diff --git a/src/Core/Platform/Data/ITransactionScope.cs b/src/Core/Platform/Data/ITransactionScope.cs new file mode 100644 index 000000000000..998c2f5856c3 --- /dev/null +++ b/src/Core/Platform/Data/ITransactionScope.cs @@ -0,0 +1,11 @@ +namespace Bit.Core.Platform.Data; + +/// +/// Represents an ambient transaction scope. Commit must be called explicitly; +/// disposing without committing triggers rollback. +/// +public interface ITransactionScope : IAsyncDisposable +{ + Task CommitAsync(CancellationToken cancellationToken = default); + Task RollbackAsync(CancellationToken cancellationToken = default); +} diff --git a/src/Core/Platform/Data/NestedTransactionScope.cs b/src/Core/Platform/Data/NestedTransactionScope.cs new file mode 100644 index 000000000000..f049190fc739 --- /dev/null +++ b/src/Core/Platform/Data/NestedTransactionScope.cs @@ -0,0 +1,37 @@ +namespace Bit.Core.Platform.Data; + +public sealed class NestedTransactionScope : ITransactionScope +{ + private readonly TransactionHolder _holder; + private bool _disposed; + + public NestedTransactionScope(TransactionHolder holder) + { + _holder = holder; + } + + public Task CommitAsync(CancellationToken cancellationToken = default) + { + // Nested scope commit is a no-op; only the root scope commits. + return Task.CompletedTask; + } + + public Task RollbackAsync(CancellationToken cancellationToken = default) + { + // Mark the transaction as doomed so the root scope cannot commit. + _holder.Doomed = true; + return Task.CompletedTask; + } + + public ValueTask DisposeAsync() + { + if (_disposed) + { + return ValueTask.CompletedTask; + } + + _disposed = true; + _holder.ReferenceCount--; + return ValueTask.CompletedTask; + } +} diff --git a/src/Core/Platform/Data/RootTransactionScope.cs b/src/Core/Platform/Data/RootTransactionScope.cs new file mode 100644 index 000000000000..9574724ea835 --- /dev/null +++ b/src/Core/Platform/Data/RootTransactionScope.cs @@ -0,0 +1,42 @@ +namespace Bit.Core.Platform.Data; + +public sealed class RootTransactionScope : ITransactionScope +{ + private readonly TransactionHolder _holder; + private bool _disposed; + + public RootTransactionScope(TransactionHolder holder) + { + _holder = holder; + } + + public async Task CommitAsync(CancellationToken cancellationToken = default) + { + if (_holder.Doomed) + { + throw new InvalidOperationException( + "Cannot commit a transaction that has been marked for rollback by a nested scope."); + } + + _holder.Committed = true; + await _holder.Transaction.CommitAsync(cancellationToken); + } + + public async Task RollbackAsync(CancellationToken cancellationToken = default) + { + _holder.Doomed = true; + await _holder.Transaction.RollbackAsync(cancellationToken); + } + + public async ValueTask DisposeAsync() + { + if (_disposed) + { + return; + } + + _disposed = true; + TransactionState.Current = null; + await _holder.DisposeAsync(); + } +} diff --git a/src/Core/Platform/Data/TransactionState.cs b/src/Core/Platform/Data/TransactionState.cs new file mode 100644 index 000000000000..5a4c9c9d5628 --- /dev/null +++ b/src/Core/Platform/Data/TransactionState.cs @@ -0,0 +1,56 @@ +using System.Data.Common; + +namespace Bit.Core.Platform.Data; + +public static class TransactionState +{ + private static readonly AsyncLocal _current = new(); + + public static TransactionHolder? Current + { + get => _current.Value; + set => _current.Value = value; + } +} + +public sealed class TransactionHolder : IAsyncDisposable +{ + public required DbConnection Connection { get; init; } + public required DbTransaction Transaction { get; init; } + public int ReferenceCount { get; set; } = 1; + public bool Committed { get; set; } + public bool Doomed { get; set; } + + /// + /// For EF: the DatabaseContext associated with this transaction. + /// + public object? DbContext { get; set; } + + /// + /// For EF: the IServiceScope that owns the DatabaseContext. + /// + public IAsyncDisposable? Scope { get; set; } + + public async ValueTask DisposeAsync() + { + if (!Committed) + { + try + { + await Transaction.RollbackAsync(); + } + catch + { + // Best-effort rollback; connection may already be broken + } + } + + await Transaction.DisposeAsync(); + await Connection.DisposeAsync(); + + if (Scope is not null) + { + await Scope.DisposeAsync(); + } + } +} diff --git a/src/Infrastructure.Dapper/Data/DapperTransactionManager.cs b/src/Infrastructure.Dapper/Data/DapperTransactionManager.cs new file mode 100644 index 000000000000..fc9bc8f9cef5 --- /dev/null +++ b/src/Infrastructure.Dapper/Data/DapperTransactionManager.cs @@ -0,0 +1,40 @@ +using System.Data.Common; +using Bit.Core.Platform.Data; +using Bit.Core.Settings; +using Microsoft.Data.SqlClient; + +namespace Bit.Infrastructure.Dapper.Data; + +public sealed class DapperTransactionManager : ITransactionManager +{ + private readonly string _connectionString; + + public DapperTransactionManager(GlobalSettings globalSettings) + { + _connectionString = globalSettings.SqlServer.ConnectionString; + } + + public bool HasActiveTransaction => TransactionState.Current is not null; + + public async Task BeginTransactionAsync(CancellationToken cancellationToken = default) + { + var existing = TransactionState.Current; + if (existing is not null) + { + existing.ReferenceCount++; + return new NestedTransactionScope(existing); + } + + var connection = new SqlConnection(_connectionString); + await connection.OpenAsync(cancellationToken); + var transaction = (DbTransaction)await connection.BeginTransactionAsync(cancellationToken); + + var holder = new TransactionHolder + { + Connection = connection, + Transaction = transaction, + }; + TransactionState.Current = holder; + return new RootTransactionScope(holder); + } +} diff --git a/src/Infrastructure.Dapper/Repositories/BaseRepository.cs b/src/Infrastructure.Dapper/Repositories/BaseRepository.cs index 317e7ebbb3ba..37d79475e014 100644 --- a/src/Infrastructure.Dapper/Repositories/BaseRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/BaseRepository.cs @@ -1,4 +1,7 @@ -using Dapper; +using System.Data.Common; +using Bit.Core.Platform.Data; +using Dapper; +using Microsoft.Data.SqlClient; #nullable enable @@ -29,4 +32,72 @@ public BaseRepository(string connectionString, string readOnlyConnectionString) protected string ConnectionString { get; private set; } protected string ReadOnlyConnectionString { get; private set; } + + /// + /// Returns the ambient connection and transaction if an ambient transaction is active, + /// or creates a new owned connection. The caller must dispose the connection only if + /// Owned is true. + /// + protected (SqlConnection Connection, DbTransaction? Transaction, bool Owned) GetConnection() + { + var holder = TransactionState.Current; + if (holder is not null) + { + return ((SqlConnection)holder.Connection, holder.Transaction, false); + } + + return (new SqlConnection(ConnectionString), null, true); + } + + /// + /// Executes an action using the ambient transaction connection (if active) or a new + /// owned connection. The connection is opened and disposed automatically when owned. + /// + protected async Task ExecuteWithConnectionAsync( + Func> action) + { + var (connection, transaction, owned) = GetConnection(); + try + { + if (owned) + { + await connection.OpenAsync(); + } + + return await action(connection, transaction); + } + finally + { + if (owned) + { + await connection.DisposeAsync(); + } + } + } + + /// + /// Executes an action using the ambient transaction connection (if active) or a new + /// owned connection. The connection is opened and disposed automatically when owned. + /// + protected async Task ExecuteWithConnectionAsync( + Func action) + { + var (connection, transaction, owned) = GetConnection(); + try + { + if (owned) + { + await connection.OpenAsync(); + } + + await action(connection, transaction); + } + finally + { + if (owned) + { + await connection.DisposeAsync(); + } + } + } } diff --git a/src/Infrastructure.Dapper/Repositories/Repository.cs b/src/Infrastructure.Dapper/Repositories/Repository.cs index 43bffb359892..d0c4026713e3 100644 --- a/src/Infrastructure.Dapper/Repositories/Repository.cs +++ b/src/Infrastructure.Dapper/Repositories/Repository.cs @@ -2,7 +2,6 @@ using Bit.Core.Entities; using Bit.Core.Repositories; using Dapper; -using Microsoft.Data.SqlClient; #nullable enable @@ -32,21 +31,22 @@ public Repository(string connectionString, string readOnlyConnectionString, public virtual async Task GetByIdAsync(TId id) { - using (var connection = new SqlConnection(ConnectionString)) + return await ExecuteWithConnectionAsync(async (connection, transaction) => { var results = await connection.QueryAsync( $"[{Schema}].[{Table}_ReadById]", new { Id = id }, + transaction: transaction, commandType: CommandType.StoredProcedure); return results.SingleOrDefault(); - } + }); } public virtual async Task CreateAsync(T obj) { obj.SetNewId(); - using (var connection = new SqlConnection(ConnectionString)) + await ExecuteWithConnectionAsync(async (connection, transaction) => { var parameters = new DynamicParameters(); parameters.AddDynamicParams(obj); @@ -54,21 +54,23 @@ public virtual async Task CreateAsync(T obj) await connection.ExecuteAsync( $"[{Schema}].[{Table}_Create]", parameters, + transaction: transaction, commandType: CommandType.StoredProcedure); obj.Id = parameters.Get(nameof(obj.Id)); - } + }); return obj; } public virtual async Task ReplaceAsync(T obj) { - using (var connection = new SqlConnection(ConnectionString)) + await ExecuteWithConnectionAsync(async (connection, transaction) => { await connection.ExecuteAsync( $"[{Schema}].[{Table}_Update]", obj, + transaction: transaction, commandType: CommandType.StoredProcedure); - } + }); } public virtual async Task UpsertAsync(T obj) @@ -85,12 +87,13 @@ public virtual async Task UpsertAsync(T obj) public virtual async Task DeleteAsync(T obj) { - using (var connection = new SqlConnection(ConnectionString)) + await ExecuteWithConnectionAsync(async (connection, transaction) => { await connection.ExecuteAsync( $"[{Schema}].[{Table}_DeleteById]", new { Id = obj.Id }, + transaction: transaction, commandType: CommandType.StoredProcedure); - } + }); } } diff --git a/src/Infrastructure.EntityFramework/Data/EfTransactionManager.cs b/src/Infrastructure.EntityFramework/Data/EfTransactionManager.cs new file mode 100644 index 000000000000..b004f56fc70d --- /dev/null +++ b/src/Infrastructure.EntityFramework/Data/EfTransactionManager.cs @@ -0,0 +1,47 @@ +using System.Data.Common; +using Bit.Core.Platform.Data; +using Bit.Infrastructure.EntityFramework.Repositories; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; + +namespace Bit.Infrastructure.EntityFramework.Data; + +public sealed class EfTransactionManager : ITransactionManager +{ + private readonly IServiceScopeFactory _serviceScopeFactory; + + public EfTransactionManager(IServiceScopeFactory serviceScopeFactory) + { + _serviceScopeFactory = serviceScopeFactory; + } + + public bool HasActiveTransaction => TransactionState.Current is not null; + + public async Task BeginTransactionAsync(CancellationToken cancellationToken = default) + { + var existing = TransactionState.Current; + if (existing is not null) + { + existing.ReferenceCount++; + return new NestedTransactionScope(existing); + } + + var scope = _serviceScopeFactory.CreateAsyncScope(); + var dbContext = scope.ServiceProvider.GetRequiredService(); + var connection = dbContext.Database.GetDbConnection(); + await connection.OpenAsync(cancellationToken); + var transaction = (DbTransaction)await connection.BeginTransactionAsync(cancellationToken); + + await dbContext.Database.UseTransactionAsync(transaction, cancellationToken); + + var holder = new TransactionHolder + { + Connection = connection, + Transaction = transaction, + DbContext = dbContext, + Scope = scope, + }; + TransactionState.Current = holder; + return new RootTransactionScope(holder); + } +} diff --git a/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs b/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs index 6cf7cbb46efc..8b1edea1dd73 100644 --- a/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs @@ -1,5 +1,6 @@ using System.Text.Json; using AutoMapper; +using Bit.Core.Platform.Data; using Bit.Infrastructure.EntityFramework.AdminConsole.Models; using Bit.Infrastructure.EntityFramework.Repositories.Queries; using LinqToDB.Data; @@ -31,6 +32,58 @@ public DatabaseContext GetDatabaseContext(IServiceScope serviceScope) return serviceScope.ServiceProvider.GetRequiredService(); } + /// + /// Returns the ambient DatabaseContext if a transaction is active, or creates a new + /// scope and resolves a fresh DatabaseContext. The caller must dispose the returned + /// scope only if it is non-null (i.e., when not using the ambient context). + /// + protected (DatabaseContext DbContext, IServiceScope? OwnedScope) GetDatabaseContextOrAmbient() + { + var holder = TransactionState.Current; + if (holder?.DbContext is DatabaseContext ambientContext) + { + return (ambientContext, null); + } + + var scope = ServiceScopeFactory.CreateScope(); + return (GetDatabaseContext(scope), scope); + } + + /// + /// Executes an action using the ambient transaction's DatabaseContext (if active) or a + /// new scoped DatabaseContext. The scope is disposed automatically when owned. + /// + protected async Task ExecuteWithContextAsync( + Func> action) + { + var (dbContext, ownedScope) = GetDatabaseContextOrAmbient(); + try + { + return await action(dbContext); + } + finally + { + ownedScope?.Dispose(); + } + } + + /// + /// Executes an action using the ambient transaction's DatabaseContext (if active) or a + /// new scoped DatabaseContext. The scope is disposed automatically when owned. + /// + protected async Task ExecuteWithContextAsync(Func action) + { + var (dbContext, ownedScope) = GetDatabaseContextOrAmbient(); + try + { + await action(dbContext); + } + finally + { + ownedScope?.Dispose(); + } + } + public void ClearChangeTracking() { using (var scope = ServiceScopeFactory.CreateScope()) diff --git a/src/Infrastructure.EntityFramework/Repositories/Repository.cs b/src/Infrastructure.EntityFramework/Repositories/Repository.cs index e26db55d714c..b823ecf04e00 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Repository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Repository.cs @@ -22,33 +22,30 @@ public Repository(IServiceScopeFactory serviceScopeFactory, IMapper mapper, Func public virtual async Task GetByIdAsync(TId id) { - using (var scope = ServiceScopeFactory.CreateScope()) + return await ExecuteWithContextAsync(async dbContext => { - var dbContext = GetDatabaseContext(scope); var entity = await GetDbSet(dbContext).FindAsync(id); return Mapper.Map(entity); - } + }); } public virtual async Task CreateAsync(T obj) { - using (var scope = ServiceScopeFactory.CreateScope()) + return await ExecuteWithContextAsync(async dbContext => { - var dbContext = GetDatabaseContext(scope); obj.SetNewId(); var entity = Mapper.Map(obj); await dbContext.AddAsync(entity); await dbContext.SaveChangesAsync(); obj.Id = entity.Id; return obj; - } + }); } public virtual async Task ReplaceAsync(T obj) { - using (var scope = ServiceScopeFactory.CreateScope()) + await ExecuteWithContextAsync(async dbContext => { - var dbContext = GetDatabaseContext(scope); var entity = await GetDbSet(dbContext).FindAsync(obj.Id); if (entity != null) { @@ -56,7 +53,7 @@ public virtual async Task ReplaceAsync(T obj) dbContext.Entry(entity).CurrentValues.SetValues(mappedEntity); await dbContext.SaveChangesAsync(); } - } + }); } public virtual async Task UpsertAsync(T obj) @@ -73,28 +70,26 @@ public virtual async Task UpsertAsync(T obj) public virtual async Task DeleteAsync(T obj) { - using (var scope = ServiceScopeFactory.CreateScope()) + await ExecuteWithContextAsync(async dbContext => { - var dbContext = GetDatabaseContext(scope); var entity = Mapper.Map(obj); dbContext.Remove(entity); await dbContext.SaveChangesAsync(); - } + }); } public virtual async Task RefreshDb() { - using (var scope = ServiceScopeFactory.CreateScope()) + await ExecuteWithContextAsync(async dbContext => { - var context = GetDatabaseContext(scope); - await context.Database.EnsureDeletedAsync(); - await context.Database.EnsureCreatedAsync(); - } + await dbContext.Database.EnsureDeletedAsync(); + await dbContext.Database.EnsureCreatedAsync(); + }); } public virtual async Task> CreateMany(List objs) { - using (var scope = ServiceScopeFactory.CreateScope()) + return await ExecuteWithContextAsync(async dbContext => { var entities = new List(); foreach (var o in objs) @@ -103,19 +98,27 @@ public virtual async Task> CreateMany(List objs) var entity = Mapper.Map(o); entities.Add(entity); } - var dbContext = GetDatabaseContext(scope); await GetDbSet(dbContext).AddRangeAsync(entities); await dbContext.SaveChangesAsync(); return objs; - } + }); } public IQueryable Run(IQuery query) { - using (var scope = ServiceScopeFactory.CreateScope()) + var (dbContext, ownedScope) = GetDatabaseContextOrAmbient(); + // Note: IQueryable is deferred, so disposing the scope here would break it. + // This matches the existing behavior where the scope is disposed before + // the query is materialized. Callers must materialize within scope. + if (ownedScope is not null) { - var dbContext = GetDatabaseContext(scope); - return query.Run(dbContext); + // Fall back to original behavior for non-transactional context + using (var scope = ServiceScopeFactory.CreateScope()) + { + var context = GetDatabaseContext(scope); + return query.Run(context); + } } + return query.Run(dbContext); } } diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index 85886027ac2d..eea77aa97698 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -36,6 +36,7 @@ using Bit.Core.NotificationCenter; using Bit.Core.OrganizationFeatures; using Bit.Core.Platform; +using Bit.Core.Platform.Data; using Bit.Core.Platform.Mail.Delivery; using Bit.Core.Platform.Mail.Enqueuing; using Bit.Core.Platform.Mail.Mailer; @@ -57,7 +58,9 @@ using Bit.Core.Vault; using Bit.Core.Vault.Services; using Bit.Infrastructure.Dapper; +using Bit.Infrastructure.Dapper.Data; using Bit.Infrastructure.EntityFramework; +using Bit.Infrastructure.EntityFramework.Data; using Bit.SharedWeb.Play; using DnsClient; using Duende.IdentityModel; @@ -101,10 +104,12 @@ public static SupportedDatabaseProviders AddDatabaseRepositories(this IServiceCo if (provider != SupportedDatabaseProviders.SqlServer) { services.AddPasswordManagerEFRepositories(globalSettings.SelfHosted); + services.AddSingleton(); } else { services.AddDapperRepositories(globalSettings.SelfHosted); + services.AddSingleton(); } if (globalSettings.SelfHosted) From 53759ec6291375208e7ce66a24e24c4ec5befe78 Mon Sep 17 00:00:00 2001 From: Jared McCannon Date: Mon, 20 Apr 2026 09:26:48 -0500 Subject: [PATCH 2/9] added the transaction --- bitwarden_license/src/Scim/Users/PostUserCommand.cs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bitwarden_license/src/Scim/Users/PostUserCommand.cs b/bitwarden_license/src/Scim/Users/PostUserCommand.cs index 696d6003482b..8074af7c9280 100644 --- a/bitwarden_license/src/Scim/Users/PostUserCommand.cs +++ b/bitwarden_license/src/Scim/Users/PostUserCommand.cs @@ -12,6 +12,7 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Platform.Data; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Scim.Context; @@ -30,7 +31,8 @@ public class PostUserCommand( IFeatureService featureService, IInviteOrganizationUsersCommand inviteOrganizationUsersCommand, TimeProvider timeProvider, - IPricingClient pricingClient) + IPricingClient pricingClient, + ITransactionManager transactionManager) : IPostUserCommand { public async Task PostUserAsync(Guid organizationId, ScimUserRequestModel model) @@ -48,6 +50,8 @@ public class PostUserCommand( Guid organizationId, ScimProviderType scimProvider) { + await transactionManager.BeginTransactionAsync(); + var organization = await organizationRepository.GetByIdAsync(organizationId); if (organization is null) From 9863df49049f8a2dff9e2c13b9effd1d93283418 Mon Sep 17 00:00:00 2001 From: Jared McCannon Date: Tue, 28 Apr 2026 14:58:13 -0500 Subject: [PATCH 3/9] Added transaction isolation and wrapped all calls (cept sm) in transaaction --- .../src/Scim/Users/PostUserCommand.cs | 3 +- .../InviteOrganizationUsersCommand.cs | 6 +- src/Core/Platform/Data/ITransactionManager.cs | 11 ++- .../Repositories/OrganizationRepository.cs | 17 ++-- .../OrganizationUserRepository.cs | 66 ++++++++------- .../Repositories/ProviderRepository.cs | 5 +- .../Data/DapperTransactionManager.cs | 8 +- .../Repositories/OrganizationRepository.cs | 38 +++++---- .../OrganizationUserRepository.cs | 81 +++++++++---------- .../Repositories/ProviderRepository.cs | 5 +- .../Data/EfTransactionManager.cs | 8 +- .../BaseEntityFrameworkRepository.cs | 5 +- 12 files changed, 133 insertions(+), 120 deletions(-) diff --git a/bitwarden_license/src/Scim/Users/PostUserCommand.cs b/bitwarden_license/src/Scim/Users/PostUserCommand.cs index 8074af7c9280..364b630a64e0 100644 --- a/bitwarden_license/src/Scim/Users/PostUserCommand.cs +++ b/bitwarden_license/src/Scim/Users/PostUserCommand.cs @@ -1,5 +1,6 @@ #nullable enable +using System.Data; using Bit.Core; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Business; @@ -50,7 +51,7 @@ public class PostUserCommand( Guid organizationId, ScimProviderType scimProvider) { - await transactionManager.BeginTransactionAsync(); + await transactionManager.BeginTransactionAsync(IsolationLevel.Serializable); var organization = await organizationRepository.GetByIdAsync(organizationId); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUsersCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUsersCommand.cs index abcc39aea02c..fb6e318b550f 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUsersCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUsersCommand.cs @@ -156,12 +156,14 @@ private async Task> InviteOrganiz { logger.LogError(ex, FailedToInviteUsersError.Code); - await organizationUserRepository.DeleteManyAsync(organizationUserToInviteEntities.Select(x => x.OrganizationUser.Id)); + // this should already be done + //await organizationUserRepository.DeleteManyAsync(organizationUserToInviteEntities.Select(x => x.OrganizationUser.Id)); // Do this first so that SmSeats never exceed PM seats (due to current billing requirements) await RevertSecretsManagerChangesAsync(validatedRequest, organization, validatedRequest.Value.InviteOrganization.SmSeats); - await RevertPasswordManagerChangesAsync(validatedRequest, organization); + //this should already be done + //await RevertPasswordManagerChangesAsync(validatedRequest, organization); return new Failure( new FailedToInviteUsersError( diff --git a/src/Core/Platform/Data/ITransactionManager.cs b/src/Core/Platform/Data/ITransactionManager.cs index c9a38adf9e89..d3e0910f47cc 100644 --- a/src/Core/Platform/Data/ITransactionManager.cs +++ b/src/Core/Platform/Data/ITransactionManager.cs @@ -1,4 +1,6 @@ -namespace Bit.Core.Platform.Data; +using System.Data; + +namespace Bit.Core.Platform.Data; /// /// Manages ambient database transactions that span multiple repository calls. @@ -10,9 +12,12 @@ public interface ITransactionManager /// Begins a new ambient transaction. All repository operations on the current /// async flow will use the same connection and transaction until disposed. /// Supports nesting: inner calls increment a reference count; only the - /// outermost Dispose/Commit actually affects the database. + /// outermost Dispose/Commit actually affects the database. The isolation level + /// on a nested call is ignored — the inner scope joins the outer transaction. /// - Task BeginTransactionAsync(CancellationToken cancellationToken = default); + Task BeginTransactionAsync( + IsolationLevel isolationLevel = IsolationLevel.ReadCommitted, + CancellationToken cancellationToken = default); /// /// Returns true if the current async flow has an active ambient transaction. diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationRepository.cs index cce80a9eb4b7..eb104c393366 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationRepository.cs @@ -250,15 +250,16 @@ public async Task> GetManyByIdsAsync(IEnumerable public async Task GetOccupiedSeatCountByOrganizationIdAsync(Guid organizationId) { - using (var connection = new SqlConnection(ConnectionString)) + return await ExecuteWithConnectionAsync(async (connection, transaction) => { var result = await connection.QueryAsync( "[dbo].[Organization_ReadOccupiedSeatCountByOrganizationId]", new { OrganizationId = organizationId }, + transaction: transaction, commandType: CommandType.StoredProcedure); return result.SingleOrDefault() ?? new OrganizationSeatCounts(); - } + }); } public async Task> GetOrganizationsForSubscriptionSyncAsync() @@ -285,11 +286,13 @@ await connection.ExecuteAsync("[dbo].[Organization_UpdateSubscriptionStatus]", public async Task IncrementSeatCountAsync(Guid organizationId, int increaseAmount, DateTime requestDate) { - await using var connection = new SqlConnection(ConnectionString); - - await connection.ExecuteAsync("[dbo].[Organization_IncrementSeatCount]", - new { OrganizationId = organizationId, SeatsToAdd = increaseAmount, RequestDate = requestDate }, - commandType: CommandType.StoredProcedure); + await ExecuteWithConnectionAsync(async (connection, transaction) => + { + await connection.ExecuteAsync("[dbo].[Organization_IncrementSeatCount]", + new { OrganizationId = organizationId, SeatsToAdd = increaseAmount, RequestDate = requestDate }, + transaction: transaction, + commandType: CommandType.StoredProcedure); + }); } public async Task InitializeOrganizationAsync(Organization organization, Func confirmOwnerAction) diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs index 8d1fe3565f9d..520040c7c79c 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -94,15 +94,16 @@ public async Task GetCountByOrganizationAsync(Guid organizationId, string e public async Task GetOccupiedSmSeatCountByOrganizationIdAsync(Guid organizationId) { - using (var connection = new SqlConnection(ConnectionString)) + return await ExecuteWithConnectionAsync(async (connection, transaction) => { var result = await connection.ExecuteScalarAsync( "[dbo].[OrganizationUser_ReadOccupiedSmSeatCountByOrganizationId]", new { OrganizationId = organizationId }, + transaction: transaction, commandType: CommandType.StoredProcedure); return result; - } + }); } public async Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, @@ -205,11 +206,12 @@ public async Task> GetManyByOrganizationAsync(Guid public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId, bool includeGroups, bool includeSharedCollections) { - using (var connection = new SqlConnection(ConnectionString)) + return await ExecuteWithConnectionAsync(async (connection, transaction) => { var results = await connection.QueryAsync( "[dbo].[OrganizationUserUserDetails_ReadByOrganizationId]", new { OrganizationId = organizationId }, + transaction: transaction, commandType: CommandType.StoredProcedure); List>? userGroups = null; @@ -229,6 +231,7 @@ public async Task> GetManyDetailsByOrga userGroups = (await connection.QueryAsync( "[dbo].[GroupUser_ReadByOrganizationUserIds]", new { OrganizationUserIds = orgUserIds }, + transaction: transaction, commandType: CommandType.StoredProcedure)).GroupBy(u => u.OrganizationUserId).ToList(); } @@ -237,6 +240,7 @@ public async Task> GetManyDetailsByOrga userCollections = (await connection.QueryAsync( "[dbo].[CollectionUser_ReadSharedCollectionsByOrganizationUserIds]", new { OrganizationUserIds = orgUserIds }, + transaction: transaction, commandType: CommandType.StoredProcedure)).GroupBy(u => u.OrganizationUserId).ToList(); } @@ -265,7 +269,7 @@ public async Task> GetManyDetailsByOrga } return users; - } + }); } public async Task> GetManyDetailsByOrganizationAsync_vNext(Guid organizationId, bool includeGroups, bool includeSharedCollections) @@ -664,38 +668,40 @@ public async Task> GetManyDetailsByRole public async Task CreateManyAsync(IEnumerable organizationUserCollection) { - await using var connection = new SqlConnection(_marsConnectionString); - var organizationUsersList = organizationUserCollection.ToList(); if (organizationUsersList.Count == 0) { return; } - await connection.ExecuteAsync( - $"[{Schema}].[OrganizationUser_CreateManyWithCollectionsAndGroups]", - new - { - OrganizationUserData = JsonSerializer.Serialize(organizationUsersList.Select(x => x.OrganizationUser)), - CollectionData = JsonSerializer.Serialize(organizationUsersList - .SelectMany(x => x.Collections, (user, collection) => new CollectionUser - { - CollectionId = collection.Id, - OrganizationUserId = user.OrganizationUser.Id, - ReadOnly = collection.ReadOnly, - HidePasswords = collection.HidePasswords, - Manage = collection.Manage - })), - GroupData = JsonSerializer.Serialize(organizationUsersList - .SelectMany(x => x.Groups, (user, group) => new GroupUser - { - GroupId = group, - OrganizationUserId = user.OrganizationUser.Id - })), - // Use the same RevisionDate as the created OrganizationUsers - RevisionDate = organizationUsersList.First().OrganizationUser.RevisionDate - }, - commandType: CommandType.StoredProcedure); + await ExecuteWithConnectionAsync(async (connection, transaction) => + { + await connection.ExecuteAsync( + $"[{Schema}].[OrganizationUser_CreateManyWithCollectionsAndGroups]", + new + { + OrganizationUserData = + JsonSerializer.Serialize(organizationUsersList.Select(x => x.OrganizationUser)), + CollectionData = JsonSerializer.Serialize(organizationUsersList + .SelectMany(x => x.Collections, + (user, collection) => new CollectionUser + { + CollectionId = collection.Id, + OrganizationUserId = user.OrganizationUser.Id, + ReadOnly = collection.ReadOnly, + HidePasswords = collection.HidePasswords, + Manage = collection.Manage + })), + GroupData = JsonSerializer.Serialize(organizationUsersList + .SelectMany(x => x.Groups, + (user, group) => + new GroupUser { GroupId = group, OrganizationUserId = user.OrganizationUser.Id })), + // Use the same RevisionDate as the created OrganizationUsers + RevisionDate = organizationUsersList.First().OrganizationUser.RevisionDate + }, + transaction: transaction, + commandType: CommandType.StoredProcedure); + }); } public async Task ConfirmOrganizationUserAsync(AcceptedOrganizationUserToConfirm organizationUserToConfirm) diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/ProviderRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/ProviderRepository.cs index 5a8ff286b006..b492a9840864 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/ProviderRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/ProviderRepository.cs @@ -49,15 +49,16 @@ public ProviderRepository(string connectionString, string readOnlyConnectionStri public async Task GetByOrganizationIdAsync(Guid organizationId) { - using (var connection = new SqlConnection(ConnectionString)) + return await ExecuteWithConnectionAsync(async (connection, transaction) => { var results = await connection.QueryAsync( "[dbo].[Provider_ReadByOrganizationId]", new { OrganizationId = organizationId }, + transaction: transaction, commandType: CommandType.StoredProcedure); return results.FirstOrDefault(); - } + }); } public async Task> SearchAsync(string name, string userEmail, int skip, int take) diff --git a/src/Infrastructure.Dapper/Data/DapperTransactionManager.cs b/src/Infrastructure.Dapper/Data/DapperTransactionManager.cs index fc9bc8f9cef5..83d44ecfe1b9 100644 --- a/src/Infrastructure.Dapper/Data/DapperTransactionManager.cs +++ b/src/Infrastructure.Dapper/Data/DapperTransactionManager.cs @@ -1,4 +1,4 @@ -using System.Data.Common; +using System.Data; using Bit.Core.Platform.Data; using Bit.Core.Settings; using Microsoft.Data.SqlClient; @@ -16,7 +16,9 @@ public DapperTransactionManager(GlobalSettings globalSettings) public bool HasActiveTransaction => TransactionState.Current is not null; - public async Task BeginTransactionAsync(CancellationToken cancellationToken = default) + public async Task BeginTransactionAsync( + IsolationLevel isolationLevel = IsolationLevel.ReadCommitted, + CancellationToken cancellationToken = default) { var existing = TransactionState.Current; if (existing is not null) @@ -27,7 +29,7 @@ public async Task BeginTransactionAsync(CancellationToken can var connection = new SqlConnection(_connectionString); await connection.OpenAsync(cancellationToken); - var transaction = (DbTransaction)await connection.BeginTransactionAsync(cancellationToken); + var transaction = await connection.BeginTransactionAsync(isolationLevel, cancellationToken); var holder = new TransactionHolder { diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs index 8e72fe78abaf..553b0e5ef474 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs @@ -424,26 +424,24 @@ where ids.Contains(organization.Id) public async Task GetOccupiedSeatCountByOrganizationIdAsync(Guid organizationId) { - using (var scope = ServiceScopeFactory.CreateScope()) + return await ExecuteWithContextAsync(async dbContext => { - var dbContext = GetDatabaseContext(scope); var users = await dbContext.OrganizationUsers .Where(ou => ou.OrganizationId == organizationId && ou.Status >= 0) .CountAsync(); var sponsored = await dbContext.OrganizationSponsorships .Where(os => os.SponsoringOrganizationId == organizationId && - os.IsAdminInitiated && - (os.ToDelete == false || (os.ToDelete == true && os.ValidUntil != null && os.ValidUntil > DateTime.UtcNow)) && - (os.SponsoredOrganizationId == null || (os.SponsoredOrganizationId != null && (os.ValidUntil == null || os.ValidUntil > DateTime.UtcNow)))) + os.IsAdminInitiated && + (os.ToDelete == false || (os.ToDelete == true && os.ValidUntil != null && + os.ValidUntil > DateTime.UtcNow)) && + (os.SponsoredOrganizationId == null || (os.SponsoredOrganizationId != null && + (os.ValidUntil == null || + os.ValidUntil > DateTime.UtcNow)))) .CountAsync(); - return new OrganizationSeatCounts - { - Users = users, - Sponsored = sponsored - }; - } + return new OrganizationSeatCounts { Users = users, Sponsored = sponsored }; + }); } public async Task> GetOrganizationsForSubscriptionSyncAsync() @@ -472,15 +470,15 @@ await dbContext.Organizations public async Task IncrementSeatCountAsync(Guid organizationId, int increaseAmount, DateTime requestDate) { - using var scope = ServiceScopeFactory.CreateScope(); - await using var dbContext = GetDatabaseContext(scope); - - await dbContext.Organizations - .Where(o => o.Id == organizationId) - .ExecuteUpdateAsync(s => s - .SetProperty(o => o.Seats, o => o.Seats + increaseAmount) - .SetProperty(o => o.SyncSeats, true) - .SetProperty(o => o.RevisionDate, requestDate)); + await ExecuteWithContextAsync(async dbContext => + { + await dbContext.Organizations + .Where(o => o.Id == organizationId) + .ExecuteUpdateAsync(s => s + .SetProperty(o => o.Seats, o => o.Seats + increaseAmount) + .SetProperty(o => o.SyncSeats, true) + .SetProperty(o => o.RevisionDate, requestDate)); + }); } public async Task InitializeOrganizationAsync(Core.AdminConsole.Entities.Organization organization, Func confirmOwnerAction) diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs index 27e71a485e48..8575d083b1fe 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -447,9 +447,8 @@ where Ids.Contains(ou.Id) public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId, bool includeGroups, bool includeSharedCollections) { - using (var scope = ServiceScopeFactory.CreateScope()) + return await ExecuteWithContextAsync(async dbContext => { - var dbContext = GetDatabaseContext(scope); var view = new OrganizationUserUserDetailsViewQuery(); var users = await (from ou in view.Run(dbContext) where ou.OrganizationId == organizationId @@ -509,7 +508,7 @@ join c in dbContext.Collections on cu.CollectionId equals c.Id } return users; - } + }); } public async Task> GetManyDetailsByOrganizationAsync_vNext( @@ -952,52 +951,50 @@ on ou.UserId equals u.Id public async Task CreateManyAsync(IEnumerable organizationUserCollection) { - if (!organizationUserCollection.Any()) + var organizationUserCollectionList = organizationUserCollection.ToList(); + + if (organizationUserCollectionList.Count == 0) { return; } - using var scope = ServiceScopeFactory.CreateScope(); - - await using var dbContext = GetDatabaseContext(scope); - - dbContext.OrganizationUsers.AddRange(Mapper.Map>(organizationUserCollection.Select(x => x.OrganizationUser))); - dbContext.CollectionUsers.AddRange(organizationUserCollection.SelectMany(x => x.Collections, (user, collection) => new CollectionUser - { - CollectionId = collection.Id, - HidePasswords = collection.HidePasswords, - OrganizationUserId = user.OrganizationUser.Id, - Manage = collection.Manage, - ReadOnly = collection.ReadOnly - })); - dbContext.GroupUsers.AddRange(organizationUserCollection.SelectMany(x => x.Groups, (user, group) => new GroupUser + await ExecuteWithContextAsync(async dbContext => { - GroupId = group, - OrganizationUserId = user.OrganizationUser.Id - })); - - // Bump RevisionDate on all affected collections - var affectedCollectionIds = organizationUserCollection - .SelectMany(x => x.Collections) - .Select(c => c.Id) - .Distinct() - .ToList(); - if (affectedCollectionIds.Count > 0) - { - var organizationId = organizationUserCollection.First().OrganizationUser.OrganizationId; - var affectedCollections = await dbContext.Collections - .Where(c => c.OrganizationId == organizationId - && affectedCollectionIds.Contains(c.Id)) - .ToListAsync(); - // Use the same RevisionDate as the created OrganizationUsers - var revisionDate = organizationUserCollection.First().OrganizationUser.RevisionDate; - foreach (var c in affectedCollections) + dbContext.OrganizationUsers.AddRange( + Mapper.Map>(organizationUserCollectionList.Select(x => x.OrganizationUser))); + dbContext.CollectionUsers.AddRange(organizationUserCollectionList.SelectMany(x => x.Collections, + (user, collection) => new CollectionUser + { + CollectionId = collection.Id, + HidePasswords = collection.HidePasswords, + OrganizationUserId = user.OrganizationUser.Id, + Manage = collection.Manage, + ReadOnly = collection.ReadOnly + })); + dbContext.GroupUsers.AddRange(organizationUserCollectionList.SelectMany(x => x.Groups, + (user, group) => new GroupUser { GroupId = group, OrganizationUserId = user.OrganizationUser.Id })); + + // Bump RevisionDate on all affected collections + var affectedCollectionIds = organizationUserCollectionList + .SelectMany(x => x.Collections) + .Select(c => c.Id) + .Distinct() + .ToList(); + if (affectedCollectionIds.Count > 0) { - c.RevisionDate = revisionDate; + var organizationId = organizationUserCollectionList.First().OrganizationUser.OrganizationId; + var affectedCollections = await dbContext.Collections + .Where(c => c.OrganizationId == organizationId + && affectedCollectionIds.Contains(c.Id)) + .ToListAsync(); + // Use the same RevisionDate as the created OrganizationUsers + var revisionDate = organizationUserCollectionList.First().OrganizationUser.RevisionDate; + foreach (var c in affectedCollections) + { + c.RevisionDate = revisionDate; + } } - } - - await dbContext.SaveChangesAsync(); + }); } public async Task ConfirmOrganizationUserAsync(AcceptedOrganizationUserToConfirm organizationUserToConfirm) diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/ProviderRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/ProviderRepository.cs index 0450dc19495c..c9abbc580275 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/ProviderRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/ProviderRepository.cs @@ -55,16 +55,15 @@ public async Task GetByGatewaySubscriptionIdAsync(string gatewaySubscr public async Task GetByOrganizationIdAsync(Guid organizationId) { - using (var scope = ServiceScopeFactory.CreateScope()) + return await ExecuteWithContextAsync(async dbContext => { - var dbContext = GetDatabaseContext(scope); var query = from p in dbContext.Providers join po in dbContext.ProviderOrganizations on p.Id equals po.ProviderId where po.OrganizationId == organizationId select p; return await query.FirstOrDefaultAsync(); - } + }); } public async Task> SearchAsync(string name, string userEmail, int skip, int take) diff --git a/src/Infrastructure.EntityFramework/Data/EfTransactionManager.cs b/src/Infrastructure.EntityFramework/Data/EfTransactionManager.cs index b004f56fc70d..49aec3f4c04a 100644 --- a/src/Infrastructure.EntityFramework/Data/EfTransactionManager.cs +++ b/src/Infrastructure.EntityFramework/Data/EfTransactionManager.cs @@ -1,4 +1,4 @@ -using System.Data.Common; +using System.Data; using Bit.Core.Platform.Data; using Bit.Infrastructure.EntityFramework.Repositories; using Microsoft.EntityFrameworkCore; @@ -17,7 +17,9 @@ public EfTransactionManager(IServiceScopeFactory serviceScopeFactory) public bool HasActiveTransaction => TransactionState.Current is not null; - public async Task BeginTransactionAsync(CancellationToken cancellationToken = default) + public async Task BeginTransactionAsync( + IsolationLevel isolationLevel = IsolationLevel.ReadCommitted, + CancellationToken cancellationToken = default) { var existing = TransactionState.Current; if (existing is not null) @@ -30,7 +32,7 @@ public async Task BeginTransactionAsync(CancellationToken can var dbContext = scope.ServiceProvider.GetRequiredService(); var connection = dbContext.Database.GetDbConnection(); await connection.OpenAsync(cancellationToken); - var transaction = (DbTransaction)await connection.BeginTransactionAsync(cancellationToken); + var transaction = await connection.BeginTransactionAsync(isolationLevel, cancellationToken); await dbContext.Database.UseTransactionAsync(transaction, cancellationToken); diff --git a/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs b/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs index 8b1edea1dd73..026803f31a6a 100644 --- a/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs @@ -95,10 +95,7 @@ public void ClearChangeTracking() public async Task GetCountFromQuery(IQuery query) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - return await query.Run(GetDatabaseContext(scope)).CountAsync(); - } + return await ExecuteWithContextAsync(dbContext => query.Run(dbContext).CountAsync()); } protected async Task OrganizationUpdateStorage(Guid organizationId) From 23027577795918235d5821f14c1f5d8a70641149 Mon Sep 17 00:00:00 2001 From: Jared McCannon Date: Tue, 28 Apr 2026 15:20:42 -0500 Subject: [PATCH 4/9] Corrected scope. --- .../src/Scim/Users/PostUserCommand.cs | 3 +- .../v2/UsersControllerConcurrencyTests.cs | 92 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs diff --git a/bitwarden_license/src/Scim/Users/PostUserCommand.cs b/bitwarden_license/src/Scim/Users/PostUserCommand.cs index 364b630a64e0..a10470213fe7 100644 --- a/bitwarden_license/src/Scim/Users/PostUserCommand.cs +++ b/bitwarden_license/src/Scim/Users/PostUserCommand.cs @@ -51,7 +51,7 @@ public class PostUserCommand( Guid organizationId, ScimProviderType scimProvider) { - await transactionManager.BeginTransactionAsync(IsolationLevel.Serializable); + await using var transactionScope = await transactionManager.BeginTransactionAsync(IsolationLevel.Serializable); var organization = await organizationRepository.GetByIdAsync(organizationId); @@ -91,6 +91,7 @@ public class PostUserCommand( ? await organizationUserRepository.GetDetailsByIdAsync(invitedOrganizationUserId.Value) : null; + await transactionScope.CommitAsync(); return organizationUser; } diff --git a/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs b/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs new file mode 100644 index 000000000000..9e29597828de --- /dev/null +++ b/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs @@ -0,0 +1,92 @@ +using Bit.Core; +using Bit.Core.Billing.Enums; +using Bit.Core.Services; +using Bit.Infrastructure.EntityFramework.Repositories; +using Bit.IntegrationTestCommon; +using Bit.Scim.IntegrationTest.Factories; +using Bit.Scim.Models; +using Bit.Scim.Utilities; +using NSubstitute; +using Xunit; + +namespace Bit.Scim.IntegrationTest.Controllers.v2; + +/// +/// Verifies seat-count integrity when SCIM invite requests run concurrently. +/// Requires a real SQL Server (vault_test) — SQLite serializes writes globally and +/// cannot reproduce the read-modify-write race on Organization.Seats. +/// +public class UsersControllerConcurrencyTests +{ + [Fact] + public async Task Post_ConcurrentInvites_DoNotOvershootMaxAutoscaleSeats() + { + const short startingSeats = 3; + const int availableSeats = 2; + const int concurrentInvites = 6; + + var factory = new ScimApplicationFactory + { + TestDatabase = new SqlServerTestDatabase() + }; + factory.SubstituteService((IFeatureService f) + => f.IsEnabled(FeatureFlagKeys.ScimInviteUserOptimization).Returns(true)); + + try + { + factory.ReinitializeDbForTests(factory.GetDatabaseContext()); + + using (var setupScope = factory.Services.CreateScope()) + { + var setupContext = setupScope.ServiceProvider.GetRequiredService(); + var org = setupContext.Organizations.Single(o => o.Id == ScimApplicationFactory.TestOrganizationId1); + org.PlanType = PlanType.EnterpriseAnnually; + org.Plan = "Enterprise (Annually)"; + org.Seats = startingSeats; + org.MaxAutoscaleSeats = startingSeats + availableSeats; + await setupContext.SaveChangesAsync(); + } + + var inputs = Enumerable.Range(0, concurrentInvites).Select(BuildInvite).ToArray(); + + var responses = await Task.WhenAll( + inputs.Select(input => + factory.UsersPostAsync(ScimApplicationFactory.TestOrganizationId1, input))); + + var successfulInvites = responses.Count(r => r.Response.StatusCode == StatusCodes.Status201Created); + + using var verifyScope = factory.Services.CreateScope(); + var verifyContext = verifyScope.ServiceProvider.GetRequiredService(); + var finalOrg = verifyContext.Organizations + .Single(o => o.Id == ScimApplicationFactory.TestOrganizationId1); + var finalActiveUserCount = verifyContext.OrganizationUsers + .Count(ou => ou.OrganizationId == ScimApplicationFactory.TestOrganizationId1 && ou.Status >= 0); + + Assert.All(responses, r => Assert.True(r.Response.StatusCode < 500, + $"Expected non-5xx status, got {r.Response.StatusCode}")); + + Assert.Equal(startingSeats + successfulInvites, finalOrg.Seats); + + Assert.Equal(startingSeats + successfulInvites, finalActiveUserCount); + + Assert.True(finalOrg.Seats <= finalOrg.MaxAutoscaleSeats, + $"Seats {finalOrg.Seats} exceeded MaxAutoscaleSeats {finalOrg.MaxAutoscaleSeats}"); + } + finally + { + factory.Dispose(); + } + } + + private static ScimUserRequestModel BuildInvite(int i) => new() + { + DisplayName = $"Concurrent User {i}", + Emails = new List + { + new() { Primary = true, Type = "work", Value = $"concurrent-{i}@example.com" } + }, + ExternalId = $"CONC-{i}", + Active = true, + Schemas = new List { ScimConstants.Scim2SchemaUser } + }; +} From b9de163bb8ecbce8a76a756cd95342d3dbc0e953 Mon Sep 17 00:00:00 2001 From: Jared McCannon Date: Wed, 29 Apr 2026 15:54:35 -0500 Subject: [PATCH 5/9] Added db.SaveContext --- .../AdminConsole/Repositories/OrganizationUserRepository.cs | 2 ++ .../InviteUsers/InviteOrganizationUserCommandTests.cs | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs index 8575d083b1fe..ab36a1286baf 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -994,6 +994,8 @@ await ExecuteWithContextAsync(async dbContext => c.RevisionDate = revisionDate; } } + + await dbContext.SaveChangesAsync(); }); } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs index 5d82f0717d0c..956aec1b6909 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs @@ -617,7 +617,7 @@ public async Task InviteScimOrganizationUserAsync_WhenAnErrorOccursWhileInviting Assert.Equal(FailedToInviteUsersError.Code, (result as Failure)!.Error.Message); // org user revert - await orgUserRepository.Received(1).DeleteManyAsync(Arg.Is>(x => x.Count() == 1)); + // await orgUserRepository.Received(1).DeleteManyAsync(Arg.Is>(x => x.Count() == 1)); // SM revert await sutProvider.GetDependency() @@ -625,7 +625,7 @@ await sutProvider.GetDependency() .UpdateSubscriptionAsync(Arg.Any()); // PM revert - await orgRepository.Received(1).ReplaceAsync(Arg.Any()); + // await orgRepository.Received(1).ReplaceAsync(Arg.Any()); await sutProvider.GetDependency().Received(2) .UpsertOrganizationAbilityAsync(Arg.Any()); From 465410473ef927fc1df7ba1ba9f3bbcfe9ad1de5 Mon Sep 17 00:00:00 2001 From: Jared McCannon Date: Wed, 29 Apr 2026 15:55:27 -0500 Subject: [PATCH 6/9] commenting temporarily --- .../InviteUsers/InviteOrganizationUsersCommand.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUsersCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUsersCommand.cs index fb6e318b550f..badfa2993f19 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUsersCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUsersCommand.cs @@ -163,7 +163,7 @@ private async Task> InviteOrganiz await RevertSecretsManagerChangesAsync(validatedRequest, organization, validatedRequest.Value.InviteOrganization.SmSeats); //this should already be done - //await RevertPasswordManagerChangesAsync(validatedRequest, organization); + await RevertPasswordManagerChangesAsync(validatedRequest, organization); return new Failure( new FailedToInviteUsersError( @@ -193,7 +193,7 @@ private async Task RevertPasswordManagerChangesAsync(Valid Date: Thu, 30 Apr 2026 10:08:22 -0500 Subject: [PATCH 7/9] Correcting ownership of EF transaction (it disposes of the context itself later) --- .../v2/UsersControllerConcurrencyTests.cs | 7 ++++--- src/Core/Platform/Data/TransactionState.cs | 13 ++++++++++++- .../Data/EfTransactionManager.cs | 1 + 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs b/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs index 9e29597828de..c4fdd139c656 100644 --- a/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs +++ b/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs @@ -29,8 +29,9 @@ public async Task Post_ConcurrentInvites_DoNotOvershootMaxAutoscaleSeats() { TestDatabase = new SqlServerTestDatabase() }; - factory.SubstituteService((IFeatureService f) - => f.IsEnabled(FeatureFlagKeys.ScimInviteUserOptimization).Returns(true)); + + factory.SubstituteService((IFeatureService f) => f.IsEnabled(FeatureFlagKeys.ScimInviteUserOptimization) + .Returns(true)); try { @@ -74,7 +75,7 @@ public async Task Post_ConcurrentInvites_DoNotOvershootMaxAutoscaleSeats() } finally { - factory.Dispose(); + await factory.DisposeAsync(); } } diff --git a/src/Core/Platform/Data/TransactionState.cs b/src/Core/Platform/Data/TransactionState.cs index 5a4c9c9d5628..d33daf1c4a6b 100644 --- a/src/Core/Platform/Data/TransactionState.cs +++ b/src/Core/Platform/Data/TransactionState.cs @@ -21,6 +21,13 @@ public sealed class TransactionHolder : IAsyncDisposable public bool Committed { get; set; } public bool Doomed { get; set; } + /// + /// True when this holder is responsible for disposing . + /// EF reuses the DbContext's connection and must leave its lifetime to the scope; + /// Dapper opens its own connection and must dispose it here. + /// + public bool OwnsConnection { get; init; } = true; + /// /// For EF: the DatabaseContext associated with this transaction. /// @@ -46,7 +53,11 @@ public async ValueTask DisposeAsync() } await Transaction.DisposeAsync(); - await Connection.DisposeAsync(); + + if (OwnsConnection) + { + await Connection.DisposeAsync(); + } if (Scope is not null) { diff --git a/src/Infrastructure.EntityFramework/Data/EfTransactionManager.cs b/src/Infrastructure.EntityFramework/Data/EfTransactionManager.cs index 49aec3f4c04a..f75694411d47 100644 --- a/src/Infrastructure.EntityFramework/Data/EfTransactionManager.cs +++ b/src/Infrastructure.EntityFramework/Data/EfTransactionManager.cs @@ -40,6 +40,7 @@ public async Task BeginTransactionAsync( { Connection = connection, Transaction = transaction, + OwnsConnection = false, DbContext = dbContext, Scope = scope, }; From cc9df675d0072ffe33f3a4d779fb37e9c4ba76de Mon Sep 17 00:00:00 2001 From: Jared McCannon Date: Fri, 1 May 2026 10:41:59 -0500 Subject: [PATCH 8/9] testing envsqlserver config --- .../v2/UsersControllerConcurrencyTests.cs | 3 +- .../EnvSqlServerTestDatabase.cs | 103 ++++++++++++++++++ 2 files changed, 104 insertions(+), 2 deletions(-) create mode 100644 bitwarden_license/test/Scim.IntegrationTest/EnvSqlServerTestDatabase.cs diff --git a/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs b/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs index c4fdd139c656..380f1e5cf93a 100644 --- a/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs +++ b/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs @@ -2,7 +2,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Services; using Bit.Infrastructure.EntityFramework.Repositories; -using Bit.IntegrationTestCommon; using Bit.Scim.IntegrationTest.Factories; using Bit.Scim.Models; using Bit.Scim.Utilities; @@ -27,7 +26,7 @@ public async Task Post_ConcurrentInvites_DoNotOvershootMaxAutoscaleSeats() var factory = new ScimApplicationFactory { - TestDatabase = new SqlServerTestDatabase() + TestDatabase = new EnvSqlServerTestDatabase() }; factory.SubstituteService((IFeatureService f) => f.IsEnabled(FeatureFlagKeys.ScimInviteUserOptimization) diff --git a/bitwarden_license/test/Scim.IntegrationTest/EnvSqlServerTestDatabase.cs b/bitwarden_license/test/Scim.IntegrationTest/EnvSqlServerTestDatabase.cs new file mode 100644 index 000000000000..724ccdb36615 --- /dev/null +++ b/bitwarden_license/test/Scim.IntegrationTest/EnvSqlServerTestDatabase.cs @@ -0,0 +1,103 @@ +using Bit.Core.Enums; +using Bit.Core.Settings; +using Bit.Infrastructure.EntityFramework.Repositories; +using Bit.IntegrationTestCommon; +using Bit.Migrator; +using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; + +namespace Bit.Scim.IntegrationTest; + +/// +/// SQL Server test database that resolves its connection string from the same +/// BW_TEST_DATABASES__n__CONNECTIONSTRING env vars used by DatabaseDataAttribute, +/// with a fallback to Identity user secrets for local dev. +/// +public class EnvSqlServerTestDatabase : ITestDatabase +{ + private readonly string _connectionString; + + public EnvSqlServerTestDatabase() + { + var config = new ConfigurationBuilder() + .AddUserSecrets(typeof(Identity.Startup).Assembly, optional: true) + .AddEnvironmentVariables("BW_TEST_") + .Build(); + + var resolved = + config.Get()?.Databases? + .FirstOrDefault(d => d.Type == SupportedDatabaseProviders.SqlServer)?.ConnectionString + ?? config.GetSection("globalSettings:sqlServer:connectionString").Value + ?? throw new InvalidOperationException( + "No SQL Server connection string found. Set BW_TEST_DATABASES__n__TYPE=SqlServer " + + "and BW_TEST_DATABASES__n__CONNECTIONSTRING, or configure Identity user secrets locally."); + + _connectionString = new SqlConnectionStringBuilder(resolved) + { + InitialCatalog = "vault_test" + }.ConnectionString; + } + + public void ModifyGlobalSettings(Dictionary config) + { + config["globalSettings:databaseProvider"] = "sqlserver"; + config["globalSettings:sqlServer:connectionString"] = _connectionString; + } + + public void AddDatabase(IServiceCollection serviceCollection) + { + serviceCollection.AddScoped(s => new DbContextOptionsBuilder() + .UseSqlServer(_connectionString) + .UseApplicationServiceProvider(s) + .Options); + } + + public void Migrate(IServiceCollection serviceCollection) + { + var serviceProvider = serviceCollection.BuildServiceProvider(); + using var scope = serviceProvider.CreateScope(); + var services = scope.ServiceProvider; + var globalSettings = services.GetRequiredService(); + var logger = services.GetRequiredService>(); + + var migrator = new SqlServerDbMigrator(globalSettings, logger); + migrator.MigrateDatabase(); + } + + public void Dispose() + { + var masterConnectionString = new SqlConnectionStringBuilder(_connectionString) + { + InitialCatalog = "master" + }.ConnectionString; + + using var connection = new SqlConnection(masterConnectionString); + var databaseName = new SqlConnectionStringBuilder(_connectionString).InitialCatalog; + + connection.Open(); + + var databaseNameQuoted = new SqlCommandBuilder().QuoteIdentifier(databaseName); + + using (var cmd = new SqlCommand( + $"ALTER DATABASE {databaseNameQuoted} SET single_user WITH rollback IMMEDIATE", connection)) + { + cmd.ExecuteNonQuery(); + } + + using (var cmd = new SqlCommand($"DROP DATABASE {databaseNameQuoted}", connection)) + { + cmd.ExecuteNonQuery(); + } + } + + private class TypedConfig + { + public DatabaseEntry[]? Databases { get; set; } + } + + private class DatabaseEntry + { + public SupportedDatabaseProviders Type { get; set; } + public string ConnectionString { get; set; } = default!; + } +} From 4c4ed74b84738177398f33547f04060cc21d32c4 Mon Sep 17 00:00:00 2001 From: Jared McCannon Date: Mon, 4 May 2026 07:57:19 -0500 Subject: [PATCH 9/9] reverting previous change --- .../v2/UsersControllerConcurrencyTests.cs | 3 +- .../EnvSqlServerTestDatabase.cs | 103 ------------------ 2 files changed, 2 insertions(+), 104 deletions(-) delete mode 100644 bitwarden_license/test/Scim.IntegrationTest/EnvSqlServerTestDatabase.cs diff --git a/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs b/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs index 380f1e5cf93a..c4fdd139c656 100644 --- a/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs +++ b/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/UsersControllerConcurrencyTests.cs @@ -2,6 +2,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Services; using Bit.Infrastructure.EntityFramework.Repositories; +using Bit.IntegrationTestCommon; using Bit.Scim.IntegrationTest.Factories; using Bit.Scim.Models; using Bit.Scim.Utilities; @@ -26,7 +27,7 @@ public async Task Post_ConcurrentInvites_DoNotOvershootMaxAutoscaleSeats() var factory = new ScimApplicationFactory { - TestDatabase = new EnvSqlServerTestDatabase() + TestDatabase = new SqlServerTestDatabase() }; factory.SubstituteService((IFeatureService f) => f.IsEnabled(FeatureFlagKeys.ScimInviteUserOptimization) diff --git a/bitwarden_license/test/Scim.IntegrationTest/EnvSqlServerTestDatabase.cs b/bitwarden_license/test/Scim.IntegrationTest/EnvSqlServerTestDatabase.cs deleted file mode 100644 index 724ccdb36615..000000000000 --- a/bitwarden_license/test/Scim.IntegrationTest/EnvSqlServerTestDatabase.cs +++ /dev/null @@ -1,103 +0,0 @@ -using Bit.Core.Enums; -using Bit.Core.Settings; -using Bit.Infrastructure.EntityFramework.Repositories; -using Bit.IntegrationTestCommon; -using Bit.Migrator; -using Microsoft.Data.SqlClient; -using Microsoft.EntityFrameworkCore; - -namespace Bit.Scim.IntegrationTest; - -/// -/// SQL Server test database that resolves its connection string from the same -/// BW_TEST_DATABASES__n__CONNECTIONSTRING env vars used by DatabaseDataAttribute, -/// with a fallback to Identity user secrets for local dev. -/// -public class EnvSqlServerTestDatabase : ITestDatabase -{ - private readonly string _connectionString; - - public EnvSqlServerTestDatabase() - { - var config = new ConfigurationBuilder() - .AddUserSecrets(typeof(Identity.Startup).Assembly, optional: true) - .AddEnvironmentVariables("BW_TEST_") - .Build(); - - var resolved = - config.Get()?.Databases? - .FirstOrDefault(d => d.Type == SupportedDatabaseProviders.SqlServer)?.ConnectionString - ?? config.GetSection("globalSettings:sqlServer:connectionString").Value - ?? throw new InvalidOperationException( - "No SQL Server connection string found. Set BW_TEST_DATABASES__n__TYPE=SqlServer " + - "and BW_TEST_DATABASES__n__CONNECTIONSTRING, or configure Identity user secrets locally."); - - _connectionString = new SqlConnectionStringBuilder(resolved) - { - InitialCatalog = "vault_test" - }.ConnectionString; - } - - public void ModifyGlobalSettings(Dictionary config) - { - config["globalSettings:databaseProvider"] = "sqlserver"; - config["globalSettings:sqlServer:connectionString"] = _connectionString; - } - - public void AddDatabase(IServiceCollection serviceCollection) - { - serviceCollection.AddScoped(s => new DbContextOptionsBuilder() - .UseSqlServer(_connectionString) - .UseApplicationServiceProvider(s) - .Options); - } - - public void Migrate(IServiceCollection serviceCollection) - { - var serviceProvider = serviceCollection.BuildServiceProvider(); - using var scope = serviceProvider.CreateScope(); - var services = scope.ServiceProvider; - var globalSettings = services.GetRequiredService(); - var logger = services.GetRequiredService>(); - - var migrator = new SqlServerDbMigrator(globalSettings, logger); - migrator.MigrateDatabase(); - } - - public void Dispose() - { - var masterConnectionString = new SqlConnectionStringBuilder(_connectionString) - { - InitialCatalog = "master" - }.ConnectionString; - - using var connection = new SqlConnection(masterConnectionString); - var databaseName = new SqlConnectionStringBuilder(_connectionString).InitialCatalog; - - connection.Open(); - - var databaseNameQuoted = new SqlCommandBuilder().QuoteIdentifier(databaseName); - - using (var cmd = new SqlCommand( - $"ALTER DATABASE {databaseNameQuoted} SET single_user WITH rollback IMMEDIATE", connection)) - { - cmd.ExecuteNonQuery(); - } - - using (var cmd = new SqlCommand($"DROP DATABASE {databaseNameQuoted}", connection)) - { - cmd.ExecuteNonQuery(); - } - } - - private class TypedConfig - { - public DatabaseEntry[]? Databases { get; set; } - } - - private class DatabaseEntry - { - public SupportedDatabaseProviders Type { get; set; } - public string ConnectionString { get; set; } = default!; - } -}