diff --git a/.ci/build-steps.yml b/.ci/build-steps.yml
index c95f7acf3..e81044196 100644
--- a/.ci/build-steps.yml
+++ b/.ci/build-steps.yml
@@ -47,6 +47,18 @@ steps:
artifactName: 'Conformance.Tests-8.0-$(Agent.OS)'
targetPath: 'artifacts/publish/Conformance.Tests/release_net8.0'
+- task: DotNetCoreCLI@2
+ displayName: 'Publish MySqlConnector.DependencyInjection.Tests'
+ inputs:
+ command: 'publish'
+ arguments: '-c Release -f net8.0 --no-build tests/MySqlConnector.DependencyInjection.Tests/MySqlConnector.DependencyInjection.Tests.csproj'
+ publishWebProjects: false
+ zipAfterPublish: false
+- task: PublishPipelineArtifact@0
+ inputs:
+ artifactName: 'MySqlConnector.DependencyInjection.Tests-8.0-$(Agent.OS)'
+ targetPath: 'artifacts/publish/MySqlConnector.DependencyInjection.Tests/release_net8.0'
+
- task: DotNetCoreCLI@2
displayName: 'Publish IntegrationTests (7.0)'
inputs:
diff --git a/.ci/mysqlconnector-tests-steps.yml b/.ci/mysqlconnector-tests-steps.yml
index ca4b1ce8f..012b11a0b 100644
--- a/.ci/mysqlconnector-tests-steps.yml
+++ b/.ci/mysqlconnector-tests-steps.yml
@@ -14,6 +14,16 @@ steps:
command: 'custom'
custom: 'vstest'
arguments: 'MySqlConnector.Tests.dll /logger:trx'
+- task: DownloadPipelineArtifact@0
+ inputs:
+ artifactName: 'MySqlConnector.DependencyInjection.Tests-8.0-$(Agent.OS)'
+ targetPath: $(System.DefaultWorkingDirectory)
+- task: DotNetCoreCLI@2
+ displayName: 'Run MySqlConnector.DependencyInjection.Tests'
+ inputs:
+ command: 'custom'
+ custom: 'vstest'
+ arguments: 'MySqlConnector.DependencyInjection.Tests.dll /logger:trx'
- task: PublishTestResults@2
inputs:
testResultsFormat: VSTest
diff --git a/.ci/test.ps1 b/.ci/test.ps1
index 7ebe26842..bf3b973c9 100644
--- a/.ci/test.ps1
+++ b/.ci/test.ps1
@@ -23,6 +23,12 @@ if ($LASTEXITCODE -ne 0){
exit $LASTEXITCODE;
}
popd
+pushd tests\MySqlConnector.DependencyIntegration.Tests
+dotnet test -c Release
+if ($LASTEXITCODE -ne 0){
+ exit $LASTEXITCODE;
+}
+popd
pushd .\tests\IntegrationTests
diff --git a/Directory.Packages.props b/Directory.Packages.props
index b319ba3da..d8225df5a 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -12,7 +12,8 @@
-
+
+
diff --git a/MySqlConnector.sln b/MySqlConnector.sln
index 0d56e5fa4..c3646c415 100644
--- a/MySqlConnector.sln
+++ b/MySqlConnector.sln
@@ -26,6 +26,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SchemaCollectionGenerator",
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MySqlConnector.DependencyInjection", "src\MySqlConnector.DependencyInjection\MySqlConnector.DependencyInjection.csproj", "{D48B3619-7FE1-420C-A96C-B231B7EA73EA}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MySqlConnector.DependencyInjection.Tests", "tests\MySqlConnector.DependencyInjection.Tests\MySqlConnector.DependencyInjection.Tests.csproj", "{E41AD8B7-2F67-444F-A8DC-51C3C8B1FD16}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -76,6 +78,10 @@ Global
{D48B3619-7FE1-420C-A96C-B231B7EA73EA}.Debug|Any CPU.Build.0 = Debug|Any CPU
{D48B3619-7FE1-420C-A96C-B231B7EA73EA}.Release|Any CPU.ActiveCfg = Release|Any CPU
{D48B3619-7FE1-420C-A96C-B231B7EA73EA}.Release|Any CPU.Build.0 = Release|Any CPU
+ {E41AD8B7-2F67-444F-A8DC-51C3C8B1FD16}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {E41AD8B7-2F67-444F-A8DC-51C3C8B1FD16}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {E41AD8B7-2F67-444F-A8DC-51C3C8B1FD16}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {E41AD8B7-2F67-444F-A8DC-51C3C8B1FD16}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/src/MySqlConnector.DependencyInjection/MySqlConnectorServiceCollectionExtensions.cs b/src/MySqlConnector.DependencyInjection/MySqlConnectorServiceCollectionExtensions.cs
index 782ac6183..f00ff40db 100644
--- a/src/MySqlConnector.DependencyInjection/MySqlConnectorServiceCollectionExtensions.cs
+++ b/src/MySqlConnector.DependencyInjection/MySqlConnectorServiceCollectionExtensions.cs
@@ -42,6 +42,45 @@ public static class MySqlConnectorServiceCollectionExtensions
ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton) =>
DoAddMySqlDataSource(serviceCollection, connectionString, dataSourceBuilderAction, connectionLifetime, dataSourceLifetime);
+ ///
+ /// Registers a and a in the .
+ ///
+ /// The to add services to.
+ /// The of the service.
+ /// A MySQL connection string.
+ /// The lifetime with which to register the in the container. Defaults to .
+ /// The lifetime with which to register the service in the container. Defaults to .
+ /// The same service collection so that multiple calls can be chained.
+ /// If the is a , it will automatically be used to initialize the data source name.
+ public static IServiceCollection AddKeyedMySqlDataSource(
+ this IServiceCollection serviceCollection,
+ object? serviceKey,
+ string connectionString,
+ ServiceLifetime connectionLifetime = ServiceLifetime.Transient,
+ ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton) =>
+ DoAddMySqlDataSource(serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, connectionLifetime, dataSourceLifetime);
+
+ ///
+ /// Registers a and a in the .
+ ///
+ /// The to add services to.
+ /// The of the service.
+ /// A MySQL connection string.
+ /// An action to configure the for further customizations of the .
+ /// The lifetime with which to register the in the container. Defaults to .
+ /// The lifetime with which to register the service in the container. Defaults to .
+ /// The same service collection so that multiple calls can be chained.
+ /// If the is a , it will automatically be used to initialize the data source name; this can
+ /// be overridden by the configuration action.
+ public static IServiceCollection AddKeyedMySqlDataSource(
+ this IServiceCollection serviceCollection,
+ object? serviceKey,
+ string connectionString,
+ Action dataSourceBuilderAction,
+ ServiceLifetime connectionLifetime = ServiceLifetime.Transient,
+ ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton) =>
+ DoAddMySqlDataSource(serviceCollection, serviceKey, connectionString, dataSourceBuilderAction, connectionLifetime, dataSourceLifetime);
+
private static IServiceCollection DoAddMySqlDataSource(
this IServiceCollection serviceCollection,
string connectionString,
@@ -52,10 +91,10 @@ public static class MySqlConnectorServiceCollectionExtensions
serviceCollection.TryAdd(
new ServiceDescriptor(
typeof(MySqlDataSource),
- x =>
+ serviceProvider =>
{
var dataSourceBuilder = new MySqlDataSourceBuilder(connectionString)
- .UseLoggerFactory(x.GetService());
+ .UseLoggerFactory(serviceProvider.GetService());
dataSourceBuilderAction?.Invoke(dataSourceBuilder);
return dataSourceBuilder.Build();
},
@@ -71,4 +110,37 @@ public static class MySqlConnectorServiceCollectionExtensions
return serviceCollection;
}
+
+ private static IServiceCollection DoAddMySqlDataSource(
+ this IServiceCollection serviceCollection,
+ object? serviceKey,
+ string connectionString,
+ Action? dataSourceBuilderAction,
+ ServiceLifetime connectionLifetime,
+ ServiceLifetime dataSourceLifetime)
+ {
+ serviceCollection.TryAdd(
+ new ServiceDescriptor(
+ typeof(MySqlDataSource),
+ serviceKey,
+ (serviceProvider, serviceKey) =>
+ {
+ var dataSourceBuilder = new MySqlDataSourceBuilder(connectionString)
+ .UseLoggerFactory(serviceProvider.GetService())
+ .UseName(serviceKey as string);
+ dataSourceBuilderAction?.Invoke(dataSourceBuilder);
+ return dataSourceBuilder.Build();
+ },
+ dataSourceLifetime));
+
+ serviceCollection.TryAdd(new ServiceDescriptor(typeof(MySqlConnection), serviceKey, (sp, sk) => sp.GetRequiredKeyedService(sk).CreateConnection(), connectionLifetime));
+
+#if NET7_0_OR_GREATER
+ serviceCollection.TryAdd(new ServiceDescriptor(typeof(DbDataSource), serviceKey, (sp, sk) => sp.GetRequiredKeyedService(sk), dataSourceLifetime));
+#endif
+
+ serviceCollection.TryAdd(new ServiceDescriptor(typeof(DbConnection), serviceKey, (sp, sk) => sp.GetRequiredKeyedService(sk), connectionLifetime));
+
+ return serviceCollection;
+ }
}
diff --git a/src/MySqlConnector.DependencyInjection/docs/README.md b/src/MySqlConnector.DependencyInjection/docs/README.md
index d28d69389..a19a3f96b 100644
--- a/src/MySqlConnector.DependencyInjection/docs/README.md
+++ b/src/MySqlConnector.DependencyInjection/docs/README.md
@@ -48,3 +48,33 @@ builder.Services.AddMySqlDataSource("Server=server;User ID=test;Password=test;Da
x => x.UseRemoteCertificateValidationCallback((sender, certificate, chain, sslPolicyErrors) => { /* custom logic */ })
);
```
+
+## Keyed Services
+
+Use the `AddKeyedMySqlDataSource` method to register a `MySqlDataSource` as a [keyed service](https://learn.microsoft.com/en-us/dotnet/core/whats-new/dotnet-8#keyed-di-services).
+This is useful if you have multiple connection strings or need to connect to multiple databases.
+If the service key is a string, it will automatically be used as the `MySqlDataSource` name;
+to customize this, call the `AddKeyedMySqlDataSource(object?, string, Action)` overload and call `MySqlDataSourceBuilder.UseName`.
+
+```csharp
+builder.Services.AddKeyedMySqlDataSource("users", builder.Configuration.GetConnectionString("Users"));
+builder.Services.AddKeyedMySqlDataSource("products", builder.Configuration.GetConnectionString("Products"));
+
+app.MapGet("/users/{userId}", async (int userId, [FromKeyedServices("users")] MySqlConnection connection) =>
+{
+ await connection.OpenAsync();
+ await using var command = connection.CreateCommand();
+ command.CommandText = "SELECT name FROM users WHERE user_id = @userId LIMIT 1";
+ command.Parameters.AddWithValue("@userId", userId);
+ return $"Hello, {await command.ExecuteScalarAsync()}";
+});
+
+app.MapGet("/products/{productId}", async (int productId, [FromKeyedServices("products")] MySqlConnection connection) =>
+{
+ await connection.OpenAsync();
+ await using var command = connection.CreateCommand();
+ command.CommandText = "SELECT name FROM products WHERE product_id = @productId LIMIT 1";
+ command.Parameters.AddWithValue("@productId", productId);
+ return await command.ExecuteScalarAsync();
+});
+```
diff --git a/src/MySqlConnector/MySqlConnector.csproj b/src/MySqlConnector/MySqlConnector.csproj
index 5ce7293e0..536f437c3 100644
--- a/src/MySqlConnector/MySqlConnector.csproj
+++ b/src/MySqlConnector/MySqlConnector.csproj
@@ -34,6 +34,7 @@
+
diff --git a/tests/MySqlConnector.DependencyInjection.Tests/DependencyInjectionTests.cs b/tests/MySqlConnector.DependencyInjection.Tests/DependencyInjectionTests.cs
new file mode 100644
index 000000000..6d9bf4775
--- /dev/null
+++ b/tests/MySqlConnector.DependencyInjection.Tests/DependencyInjectionTests.cs
@@ -0,0 +1,174 @@
+namespace MySqlConnector.DependencyInjection.Tests;
+
+public class DependencyInjectionTests
+{
+ [Fact]
+ public async Task MySqlDataSourceIsRegistered()
+ {
+ var serviceCollection = new ServiceCollection();
+ serviceCollection.AddMySqlDataSource(c_connectionString);
+
+ await using var serviceProvider = serviceCollection.BuildServiceProvider();
+
+ var dataSource = serviceProvider.GetRequiredService();
+ await using var connection = dataSource.CreateConnection();
+ Assert.Equal(c_connectionString, connection.ConnectionString);
+ }
+
+ [Fact]
+ public async Task MySqlConnectionIsRegistered()
+ {
+ var serviceCollection = new ServiceCollection();
+ serviceCollection.AddMySqlDataSource(c_connectionString);
+
+ await using var serviceProvider = serviceCollection.BuildServiceProvider();
+
+ await using var connection = serviceProvider.GetRequiredService();
+ Assert.Equal(c_connectionString, connection.ConnectionString);
+ }
+
+ [Fact]
+ public async Task DbConnectionIsRegistered()
+ {
+ var serviceCollection = new ServiceCollection();
+ serviceCollection.AddMySqlDataSource(c_connectionString);
+
+ await using var serviceProvider = serviceCollection.BuildServiceProvider();
+
+ await using var connection = serviceProvider.GetRequiredService();
+ Assert.IsAssignableFrom(connection);
+ Assert.Equal(c_connectionString, connection.ConnectionString);
+ }
+
+ [Fact]
+ public async Task DbDataSourceIsRegistered()
+ {
+ var serviceCollection = new ServiceCollection();
+ serviceCollection.AddMySqlDataSource(c_connectionString);
+
+ await using var serviceProvider = serviceCollection.BuildServiceProvider();
+
+ await using var dataSource = serviceProvider.GetRequiredService();
+ Assert.IsAssignableFrom(dataSource);
+ await using var connection = dataSource.CreateConnection();
+ Assert.IsAssignableFrom(connection);
+ Assert.Equal(c_connectionString, connection.ConnectionString);
+ }
+
+ [Fact]
+ public async Task MySqlDataSourceCanSetName()
+ {
+ var serviceCollection = new ServiceCollection();
+
+ serviceCollection.AddMySqlDataSource(c_connectionString, builder =>
+ {
+ builder.UseName("MyName");
+ });
+
+ await using var serviceProvider = serviceCollection.BuildServiceProvider();
+ var dataSource = serviceProvider.GetRequiredService();
+ Assert.Equal("MyName", dataSource.Name);
+ }
+
+ [Fact]
+ public async Task KeyedMySqlDataSourceIsRegistered()
+ {
+ var serviceCollection = new ServiceCollection();
+ serviceCollection.AddKeyedMySqlDataSource(KeyedService.AnyKey, c_connectionString);
+
+ await using var serviceProvider = serviceCollection.BuildServiceProvider();
+
+ var dataSource = serviceProvider.GetRequiredKeyedService(new object());
+ Assert.Null(dataSource.Name);
+ await using var connection = dataSource.CreateConnection();
+ Assert.Equal(c_connectionString, connection.ConnectionString);
+ }
+
+ [Fact]
+ public async Task StringKeyedMySqlDataSourceHasNameSet()
+ {
+ var serviceCollection = new ServiceCollection();
+ serviceCollection.AddKeyedMySqlDataSource("key", c_connectionString);
+
+ await using var serviceProvider = serviceCollection.BuildServiceProvider();
+
+ var dataSource = serviceProvider.GetRequiredKeyedService("key");
+ Assert.Equal("key", dataSource.Name);
+ await using var connection = dataSource.CreateConnection();
+ Assert.Equal(c_connectionString, connection.ConnectionString);
+ }
+
+ [Fact]
+ public async Task KeyedMySqlDataSourceRetrievedWithStringKeyHasName()
+ {
+ var serviceCollection = new ServiceCollection();
+ serviceCollection.AddKeyedMySqlDataSource(KeyedService.AnyKey, c_connectionString);
+
+ await using var serviceProvider = serviceCollection.BuildServiceProvider();
+
+ var dataSource = serviceProvider.GetRequiredKeyedService("key");
+ Assert.Equal("key", dataSource.Name);
+ await using var connection = dataSource.CreateConnection();
+ Assert.Equal(c_connectionString, connection.ConnectionString);
+ }
+
+ [Fact]
+ public async Task KeyedMySqlConnectionIsRegistered()
+ {
+ var serviceCollection = new ServiceCollection();
+ serviceCollection.AddKeyedMySqlDataSource("key", c_connectionString);
+
+ await using var serviceProvider = serviceCollection.BuildServiceProvider();
+
+ await using var connection = serviceProvider.GetRequiredKeyedService("key");
+ Assert.Equal(c_connectionString, connection.ConnectionString);
+ }
+
+ [Fact]
+ public async Task TwoKeyedMySqlDataConnectionsAreRegistered()
+ {
+ const string c_connectionString2 = c_connectionString + ";Database=test";
+
+ var serviceCollection = new ServiceCollection();
+ serviceCollection.AddKeyedMySqlDataSource(KeyedService.AnyKey, c_connectionString);
+ serviceCollection.AddKeyedMySqlDataSource("key2", c_connectionString2);
+
+ await using var serviceProvider = serviceCollection.BuildServiceProvider();
+
+ await using var connection1 = serviceProvider.GetRequiredKeyedService("key");
+ Assert.Equal(c_connectionString, connection1.ConnectionString);
+
+ await using var connection2 = serviceProvider.GetRequiredKeyedService("key2");
+ Assert.Equal(c_connectionString2, connection2.ConnectionString);
+ }
+
+ [Fact]
+ public async Task KeyedDbConnectionIsRegistered()
+ {
+ var serviceCollection = new ServiceCollection();
+ serviceCollection.AddKeyedMySqlDataSource("key", c_connectionString);
+
+ await using var serviceProvider = serviceCollection.BuildServiceProvider();
+
+ await using var connection = serviceProvider.GetRequiredKeyedService("key");
+ Assert.IsAssignableFrom(connection);
+ Assert.Equal(c_connectionString, connection.ConnectionString);
+ }
+
+ [Fact]
+ public async Task KeyedDbDataSourceIsRegistered()
+ {
+ var serviceCollection = new ServiceCollection();
+ serviceCollection.AddKeyedMySqlDataSource("key", c_connectionString);
+
+ await using var serviceProvider = serviceCollection.BuildServiceProvider();
+
+ await using var dataSource = serviceProvider.GetRequiredKeyedService("key");
+ Assert.IsAssignableFrom(dataSource);
+ await using var connection = dataSource.CreateConnection();
+ Assert.IsAssignableFrom(connection);
+ Assert.Equal(c_connectionString, connection.ConnectionString);
+ }
+
+ const string c_connectionString = "Server=localhost;User ID=root;Password=pass";
+}
diff --git a/tests/MySqlConnector.DependencyInjection.Tests/MySqlConnector.DependencyInjection.Tests.csproj b/tests/MySqlConnector.DependencyInjection.Tests/MySqlConnector.DependencyInjection.Tests.csproj
new file mode 100644
index 000000000..fac726a3a
--- /dev/null
+++ b/tests/MySqlConnector.DependencyInjection.Tests/MySqlConnector.DependencyInjection.Tests.csproj
@@ -0,0 +1,33 @@
+
+
+
+ net8.0
+ true
+ true
+ ..\..\MySqlConnector.snk
+ true
+ enable
+ enable
+
+
+
+
+
+
+
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+ all
+
+
+
+
+
+
+
+
+
+
+
+
+
+