Skip to content

Commit

Permalink
Allowing explicit processor registration
Browse files Browse the repository at this point in the history
  • Loading branch information
jbogard committed Jul 7, 2023
1 parent 4452ce8 commit 9ebdf7b
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 51 deletions.
118 changes: 118 additions & 0 deletions src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Reflection;
using MediatR;
using MediatR.NotificationPublishers;
using MediatR.Pipeline;

namespace Microsoft.Extensions.DependencyInjection;

Expand Down Expand Up @@ -52,6 +53,16 @@ public class MediatRServiceConfiguration
/// </summary>
public List<ServiceDescriptor> StreamBehaviorsToRegister { get; } = new();

/// <summary>
/// List of request pre processors to register in specific order
/// </summary>
public List<ServiceDescriptor> RequestPreProcessorsToRegister { get; } = new();

/// <summary>
/// List of request post processors to register in specific order
/// </summary>
public List<ServiceDescriptor> RequestPostProcessorsToRegister { get; } = new();

/// <summary>
/// Register various handlers from assembly containing given type
/// </summary>
Expand Down Expand Up @@ -200,4 +211,111 @@ public MediatRServiceConfiguration AddOpenStreamBehavior(Type openBehaviorType,
}


/// <summary>
/// Register a closed request pre processor type
/// </summary>
/// <typeparam name="TServiceType">Closed request pre processor interface type</typeparam>
/// <typeparam name="TImplementationType">Closed request pre processor implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPreProcessor<TServiceType, TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddRequestPreProcessor(typeof(TServiceType), typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed request pre processor type
/// </summary>
/// <param name="serviceType">Closed request pre processor interface type</param>
/// <param name="implementationType">Closed request pre processor implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPreProcessor(Type serviceType, Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
RequestPreProcessorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime));

return this;
}

/// <summary>
/// Registers an open request pre processor type against the <see cref="IRequestPreProcessor{TRequest}"/> open generic interface type
/// </summary>
/// <param name="openBehaviorType">An open generic request pre processor type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddOpenRequestPreProcessor(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
if (!openBehaviorType.IsGenericType)
{
throw new InvalidOperationException($"{openBehaviorType.Name} must be generic");
}

var implementedGenericInterfaces = openBehaviorType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedOpenBehaviorInterfaces = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPreProcessor<>)));

if (implementedOpenBehaviorInterfaces.Count == 0)
{
throw new InvalidOperationException($"{openBehaviorType.Name} must implement {typeof(IRequestPreProcessor<>).FullName}");
}

foreach (var openBehaviorInterface in implementedOpenBehaviorInterfaces)
{
RequestPreProcessorsToRegister.Add(new ServiceDescriptor(openBehaviorInterface, openBehaviorType, serviceLifetime));
}

return this;
}

/// <summary>
/// Register a closed request post processor type
/// </summary>
/// <typeparam name="TServiceType">Closed request post processor interface type</typeparam>
/// <typeparam name="TImplementationType">Closed request post processor implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPostProcessor<TServiceType, TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddRequestPreProcessor(typeof(TServiceType), typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed request post processor type
/// </summary>
/// <param name="serviceType">Closed request post processor interface type</param>
/// <param name="implementationType">Closed request post processor implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPostProcessor(Type serviceType, Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
RequestPostProcessorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime));

return this;
}

/// <summary>
/// Registers an open request post processor type against the <see cref="IRequestPostProcessor{TRequest,TResponse}"/> open generic interface type
/// </summary>
/// <param name="openBehaviorType">An open generic request post processor type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddOpenRequestPostProcessor(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
if (!openBehaviorType.IsGenericType)
{
throw new InvalidOperationException($"{openBehaviorType.Name} must be generic");
}

var implementedGenericInterfaces = openBehaviorType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedOpenBehaviorInterfaces = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPostProcessor<,>)));

if (implementedOpenBehaviorInterfaces.Count == 0)
{
throw new InvalidOperationException($"{openBehaviorType.Name} must implement {typeof(IRequestPostProcessor<,>).FullName}");
}

foreach (var openBehaviorInterface in implementedOpenBehaviorInterfaces)
{
RequestPostProcessorsToRegister.Add(new ServiceDescriptor(openBehaviorInterface, openBehaviorType, serviceLifetime));
}

return this;
}


}
22 changes: 13 additions & 9 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@ public static void AddMediatRClasses(IServiceCollection services, MediatRService
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(INotificationHandler<>), services, assembliesToScan, true, configuration);
ConnectImplementationsToTypesClosing(typeof(IStreamRequestHandler<,>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestPreProcessor<>), services, assembliesToScan, true, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestPostProcessor<,>), services, assembliesToScan, true, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestExceptionHandler<,,>), services, assembliesToScan, true, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestExceptionAction<,>), services, assembliesToScan, true, configuration);

var multiOpenInterfaces = new[]
{
typeof(INotificationHandler<>),
typeof(IRequestPreProcessor<>),
typeof(IRequestPostProcessor<,>),
typeof(IRequestExceptionHandler<,,>),
typeof(IRequestExceptionAction<,>)
};
Expand Down Expand Up @@ -224,6 +220,19 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi

services.TryAdd(notificationPublisherServiceDescriptor);

// Register pre processors, then post processors, then behaviors
if (serviceConfiguration.RequestPreProcessorsToRegister.Any())
{
services.TryAddEnumerable(new ServiceDescriptor(typeof(IPipelineBehavior<,>), typeof(RequestPreProcessorBehavior<,>), ServiceLifetime.Transient));
services.TryAddEnumerable(serviceConfiguration.RequestPreProcessorsToRegister);
}

if (serviceConfiguration.RequestPostProcessorsToRegister.Any())
{
services.TryAddEnumerable(new ServiceDescriptor(typeof(IPipelineBehavior<,>), typeof(RequestPostProcessorBehavior<,>), ServiceLifetime.Transient));
services.TryAddEnumerable(serviceConfiguration.RequestPostProcessorsToRegister);
}

foreach (var serviceDescriptor in serviceConfiguration.BehaviorsToRegister)
{
services.TryAddEnumerable(serviceDescriptor);
Expand All @@ -234,11 +243,6 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
services.TryAddEnumerable(serviceDescriptor);
}

// Use built-in Microsoft TryAddEnumerable method, we do want to register our Pre/Post processor behavior,
// even if (a more concrete) registration for IPipelineBehavior<,> already exists. But only once.
RegisterBehaviorIfImplementationsExist(services, typeof(RequestPreProcessorBehavior<,>), typeof(IRequestPreProcessor<>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestPostProcessorBehavior<,>), typeof(IRequestPostProcessor<,>));

if (serviceConfiguration.RequestExceptionActionProcessorStrategy == RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions)
{
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ public async Task Should_not_call_constructor_multiple_times_when_using_a_pipeli

services.AddSingleton(output);
services.AddTransient(typeof(IPipelineBehavior<,>), typeof(ConstructorTestBehavior<,>));
services.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly));
services.AddMediatR(cfg =>
{
cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly);
cfg.AddOpenBehavior(typeof(ConstructorTestBehavior<,>));
});
var provider = services.BuildServiceProvider();

var mediator = provider.GetRequiredService<IMediator>();
Expand All @@ -93,11 +97,7 @@ public async Task Should_not_call_constructor_multiple_times_when_using_a_pipeli
output.Messages.ShouldBe(new[]
{
"ConstructorTestBehavior before",
"First pre processor",
"Next pre processor",
"Handler",
"First post processor",
"Next post processor",
"ConstructorTestBehavior after"
});
ConstructorTestHandler.ConstructorCallCount.ShouldBe(1);
Expand Down
83 changes: 46 additions & 37 deletions test/MediatR.Tests/MicrosoftExtensionsDI/PipelineTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,12 @@ public async Task Should_wrap_with_behavior()
var output = new Logger();
IServiceCollection services = new ServiceCollection();
services.AddSingleton(output);
services.AddTransient<IPipelineBehavior<Ping, Pong>, OuterBehavior>();
services.AddTransient<IPipelineBehavior<Ping, Pong>, InnerBehavior>();
services.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly));
services.AddMediatR(cfg =>
{
cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly);
cfg.AddBehavior<IPipelineBehavior<Ping, Pong>, OuterBehavior>();
cfg.AddBehavior<IPipelineBehavior<Ping, Pong>, InnerBehavior>();
});
var provider = services.BuildServiceProvider();

var mediator = provider.GetRequiredService<IMediator>();
Expand All @@ -366,15 +369,7 @@ public async Task Should_wrap_with_behavior()
{
"Outer before",
"Inner before",
"First concrete pre processor",
"Next concrete pre processor",
"First pre processor",
"Next pre processor",
"Handler",
"First concrete post processor",
"Next concrete post processor",
"First post processor",
"Next post processor",
"Inner after",
"Outer after"
});
Expand Down Expand Up @@ -408,27 +403,30 @@ public async Task Should_wrap_generics_with_behavior()
{
"Outer generic before",
"Inner generic before",
"First concrete pre processor",
"Next concrete pre processor",
"First pre processor",
"Next pre processor",
"Handler",
"First concrete post processor",
"Next concrete post processor",
"First post processor",
"Next post processor",
"Inner generic after",
"Outer generic after",
});
}

[Fact]
public async Task Should_pick_up_pre_and_post_processors()
public async Task Should_register_pre_and_post_processors()
{
var output = new Logger();
IServiceCollection services = new ServiceCollection();
services.AddSingleton(output);
services.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly));
services.AddMediatR(cfg =>
{
cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly);
cfg.AddRequestPreProcessor<IRequestPreProcessor<Ping>, FirstConcretePreProcessor>();
cfg.AddRequestPreProcessor<IRequestPreProcessor<Ping>, NextConcretePreProcessor>();
cfg.AddOpenRequestPreProcessor(typeof(FirstPreProcessor<>));
cfg.AddOpenRequestPreProcessor(typeof(NextPreProcessor<>));
cfg.AddRequestPostProcessor<IRequestPostProcessor<Ping, Pong>, FirstConcretePostProcessor>();
cfg.AddRequestPostProcessor<IRequestPostProcessor<Ping, Pong>, NextConcretePostProcessor>();
cfg.AddOpenRequestPostProcessor(typeof(FirstPostProcessor<,>));
cfg.AddOpenRequestPostProcessor(typeof(NextPostProcessor<,>));
});
var provider = services.BuildServiceProvider();

var mediator = provider.GetRequiredService<IMediator>();
Expand Down Expand Up @@ -508,10 +506,21 @@ public async Task Should_handle_constrained_generics()
var output = new Logger();
IServiceCollection services = new ServiceCollection();
services.AddSingleton(output);
services.AddTransient(typeof(IPipelineBehavior<,>), typeof(OuterBehavior<,>));
services.AddTransient(typeof(IPipelineBehavior<,>), typeof(InnerBehavior<,>));
services.AddTransient(typeof(IPipelineBehavior<,>), typeof(ConstrainedBehavior<,>));
services.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly));
services.AddMediatR(cfg =>
{
cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly);
cfg.AddOpenBehavior(typeof(OuterBehavior<,>));
cfg.AddOpenBehavior(typeof(InnerBehavior<,>));
cfg.AddOpenBehavior(typeof(ConstrainedBehavior<,>));
cfg.AddRequestPreProcessor<IRequestPreProcessor<Ping>, FirstConcretePreProcessor>();
cfg.AddRequestPreProcessor<IRequestPreProcessor<Ping>, NextConcretePreProcessor>();
cfg.AddOpenRequestPreProcessor(typeof(FirstPreProcessor<>));
cfg.AddOpenRequestPreProcessor(typeof(NextPreProcessor<>));
cfg.AddRequestPostProcessor<IRequestPostProcessor<Ping, Pong>, FirstConcretePostProcessor>();
cfg.AddRequestPreProcessor<IRequestPostProcessor<Ping, Pong>, NextConcretePostProcessor>();
cfg.AddOpenRequestPostProcessor(typeof(FirstPostProcessor<,>));
cfg.AddOpenRequestPostProcessor(typeof(NextPostProcessor<,>));
});
var provider = services.BuildServiceProvider();

var mediator = provider.GetRequiredService<IMediator>();
Expand All @@ -522,21 +531,21 @@ public async Task Should_handle_constrained_generics()

output.Messages.ShouldBe(new[]
{
"Outer generic before",
"Inner generic before",
"Constrained before",
"First concrete pre processor",
"Next concrete pre processor",
"First pre processor",
"Next pre processor",
"Outer generic before",
"Inner generic before",
"Constrained before",
"Handler",
"Constrained after",
"Inner generic after",
"Outer generic after",
"First concrete post processor",
"Next concrete post processor",
"First post processor",
"Next post processor",
"Constrained after",
"Inner generic after",
"Outer generic after"
"Next post processor"
});

output.Messages.Clear();
Expand All @@ -547,15 +556,15 @@ public async Task Should_handle_constrained_generics()

output.Messages.ShouldBe(new[]
{
"Outer generic before",
"Inner generic before",
"First pre processor",
"Next pre processor",
"Outer generic before",
"Inner generic before",
"Handler",
"First post processor",
"Next post processor",
"Inner generic after",
"Outer generic after"
"Outer generic after",
"First post processor",
"Next post processor"
});
}

Expand Down

0 comments on commit 9ebdf7b

Please sign in to comment.