Skip to content

Commit

Permalink
Merge pull request #838 from jbogard/publish-strategies
Browse files Browse the repository at this point in the history
Adding notification publisher strategies
  • Loading branch information
jbogard committed Feb 14, 2023
2 parents 2bf8b6b + 9da86cd commit ba9d3ee
Show file tree
Hide file tree
Showing 14 changed files with 437 additions and 34 deletions.
8 changes: 4 additions & 4 deletions samples/MediatR.Examples.PublishStrategies/CustomMediator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ namespace MediatR.Examples.PublishStrategies;

public class CustomMediator : Mediator
{
private readonly Func<IEnumerable<Func<INotification, CancellationToken, Task>>, INotification, CancellationToken, Task> _publish;
private readonly Func<IEnumerable<NotificationHandlerExecutor>, INotification, CancellationToken, Task> _publish;

public CustomMediator(IServiceProvider serviceFactory, Func<IEnumerable<Func<INotification, CancellationToken, Task>>, INotification, CancellationToken, Task> publish) : base(serviceFactory)
public CustomMediator(IServiceProvider serviceFactory, Func<IEnumerable<NotificationHandlerExecutor>, INotification, CancellationToken, Task> publish) : base(serviceFactory)
=> _publish = publish;

protected override Task PublishCore(IEnumerable<Func<INotification, CancellationToken, Task>> allHandlers, INotification notification, CancellationToken cancellationToken)
=> _publish(allHandlers, notification, cancellationToken);
protected override Task PublishCore(IEnumerable<NotificationHandlerExecutor> handlerExecutors, INotification notification, CancellationToken cancellationToken)
=> _publish(handlerExecutors, notification, cancellationToken);
}
24 changes: 12 additions & 12 deletions samples/MediatR.Examples.PublishStrategies/Publisher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,41 +50,41 @@ public Task Publish<TNotification>(TNotification notification, PublishStrategy s
return mediator.Publish(notification, cancellationToken);
}

private Task ParallelWhenAll(IEnumerable<Func<INotification, CancellationToken, Task>> handlers, INotification notification, CancellationToken cancellationToken)
private Task ParallelWhenAll(IEnumerable<NotificationHandlerExecutor> handlers, INotification notification, CancellationToken cancellationToken)
{
var tasks = new List<Task>();

foreach (var handler in handlers)
{
tasks.Add(Task.Run(() => handler(notification, cancellationToken)));
tasks.Add(Task.Run(() => handler.HandlerCallback(notification, cancellationToken)));
}

return Task.WhenAll(tasks);
}

private Task ParallelWhenAny(IEnumerable<Func<INotification, CancellationToken, Task>> handlers, INotification notification, CancellationToken cancellationToken)
private Task ParallelWhenAny(IEnumerable<NotificationHandlerExecutor> handlers, INotification notification, CancellationToken cancellationToken)
{
var tasks = new List<Task>();

foreach (var handler in handlers)
{
tasks.Add(Task.Run(() => handler(notification, cancellationToken)));
tasks.Add(Task.Run(() => handler.HandlerCallback(notification, cancellationToken)));
}

return Task.WhenAny(tasks);
}

private Task ParallelNoWait(IEnumerable<Func<INotification, CancellationToken, Task>> handlers, INotification notification, CancellationToken cancellationToken)
private Task ParallelNoWait(IEnumerable<NotificationHandlerExecutor> handlers, INotification notification, CancellationToken cancellationToken)
{
foreach (var handler in handlers)
{
Task.Run(() => handler(notification, cancellationToken));
Task.Run(() => handler.HandlerCallback(notification, cancellationToken));
}

return Task.CompletedTask;
}

private async Task AsyncContinueOnException(IEnumerable<Func<INotification, CancellationToken, Task>> handlers, INotification notification, CancellationToken cancellationToken)
private async Task AsyncContinueOnException(IEnumerable<NotificationHandlerExecutor> handlers, INotification notification, CancellationToken cancellationToken)
{
var tasks = new List<Task>();
var exceptions = new List<Exception>();
Expand All @@ -93,7 +93,7 @@ private async Task AsyncContinueOnException(IEnumerable<Func<INotification, Canc
{
try
{
tasks.Add(handler(notification, cancellationToken));
tasks.Add(handler.HandlerCallback(notification, cancellationToken));
}
catch (Exception ex) when (!(ex is OutOfMemoryException || ex is StackOverflowException))
{
Expand All @@ -120,23 +120,23 @@ private async Task AsyncContinueOnException(IEnumerable<Func<INotification, Canc
}
}

private async Task SyncStopOnException(IEnumerable<Func<INotification, CancellationToken, Task>> handlers, INotification notification, CancellationToken cancellationToken)
private async Task SyncStopOnException(IEnumerable<NotificationHandlerExecutor> handlers, INotification notification, CancellationToken cancellationToken)
{
foreach (var handler in handlers)
{
await handler(notification, cancellationToken).ConfigureAwait(false);
await handler.HandlerCallback(notification, cancellationToken).ConfigureAwait(false);
}
}

private async Task SyncContinueOnException(IEnumerable<Func<INotification, CancellationToken, Task>> handlers, INotification notification, CancellationToken cancellationToken)
private async Task SyncContinueOnException(IEnumerable<NotificationHandlerExecutor> handlers, INotification notification, CancellationToken cancellationToken)
{
var exceptions = new List<Exception>();

foreach (var handler in handlers)
{
try
{
await handler(notification, cancellationToken).ConfigureAwait(false);
await handler.HandlerCallback(notification, cancellationToken).ConfigureAwait(false);
}
catch (AggregateException ex)
{
Expand Down
11 changes: 11 additions & 0 deletions src/MediatR/INotificationPublisher.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using System.Threading;

namespace MediatR;

public interface INotificationPublisher
{
Task Publish(IEnumerable<NotificationHandlerExecutor> handlerExecutors, INotification notification,
CancellationToken cancellationToken);
}
6 changes: 5 additions & 1 deletion src/MediatR/MediatR.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="IsExternalInit" Version="1.0.3">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="MediatR.Contracts" Version="[2.0.0, 3.0.0)" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="7.0.0" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="7.0.0" />
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.1.1" PrivateAssets="All" />
<PackageReference Include="MinVer" Version="4.2.0" PrivateAssets="All" />
<PackageReference Include="MinVer" Version="4.3.0" PrivateAssets="All" />
</ItemGroup>

</Project>
29 changes: 19 additions & 10 deletions src/MediatR/Mediator.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using MediatR.NotificationPublishers;

namespace MediatR;

using System;
Expand All @@ -14,6 +16,7 @@ namespace MediatR;
public class Mediator : IMediator
{
private readonly IServiceProvider _serviceProvider;
private readonly INotificationPublisher _publisher;
private static readonly ConcurrentDictionary<Type, RequestHandlerBase> _requestHandlers = new();
private static readonly ConcurrentDictionary<Type, NotificationHandlerWrapper> _notificationHandlers = new();
private static readonly ConcurrentDictionary<Type, StreamRequestHandlerBase> _streamRequestHandlers = new();
Expand All @@ -23,7 +26,18 @@ public class Mediator : IMediator
/// </summary>
/// <param name="serviceProvider">Service provider. Can be a scoped or root provider</param>
public Mediator(IServiceProvider serviceProvider)
=> _serviceProvider = serviceProvider;
: this(serviceProvider, new ForeachAwaitPublisher()) { }

/// <summary>
/// Initializes a new instance of the <see cref="Mediator"/> class.
/// </summary>
/// <param name="serviceProvider">Service provider. Can be a scoped or root provider</param>
/// <param name="publisher">Notification publisher. Defaults to <see cref="ForeachAwaitPublisher"/>.</param>
public Mediator(IServiceProvider serviceProvider, INotificationPublisher publisher)
{
_serviceProvider = serviceProvider;
_publisher = publisher;
}

public Task<TResponse> Send<TResponse>(IRequest<TResponse> request, CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -124,19 +138,14 @@ notification switch
};

/// <summary>
/// Override in a derived class to control how the tasks are awaited. By default the implementation is a foreach and await of each handler
/// Override in a derived class to control how the tasks are awaited. By default the implementation calls the <see cref="INotificationPublisher"/>.
/// </summary>
/// <param name="allHandlers">Enumerable of tasks representing invoking each notification handler</param>
/// <param name="handlerExecutors">Enumerable of tasks representing invoking each notification handler</param>
/// <param name="notification">The notification being published</param>
/// <param name="cancellationToken">The cancellation token</param>
/// <returns>A task representing invoking all handlers</returns>
protected virtual async Task PublishCore(IEnumerable<Func<INotification, CancellationToken, Task>> allHandlers, INotification notification, CancellationToken cancellationToken)
{
foreach (var handler in allHandlers)
{
await handler(notification, cancellationToken).ConfigureAwait(false);
}
}
protected virtual Task PublishCore(IEnumerable<NotificationHandlerExecutor> handlerExecutors, INotification notification, CancellationToken cancellationToken)
=> _publisher.Publish(handlerExecutors, notification, cancellationToken);

private Task PublishNotification(INotification notification, CancellationToken cancellationToken = default)
{
Expand Down
69 changes: 69 additions & 0 deletions src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,83 @@
using System.Collections.Generic;
using System.Reflection;
using MediatR;
using MediatR.NotificationPublishers;

namespace Microsoft.Extensions.DependencyInjection;

public class MediatRServiceConfiguration
{
/// <summary>
/// Optional filter for types to register. Default value is a function returning true.
/// </summary>
public Func<Type, bool> TypeEvaluator { get; set; } = t => true;

/// <summary>
/// Mediator implementation type to register. Default is <see cref="Mediator"/>
/// </summary>
public Type MediatorImplementationType { get; set; } = typeof(Mediator);

/// <summary>
/// Strategy for publishing notifications. Defaults to <see cref="ForeachAwaitPublisher"/>
/// </summary>
public INotificationPublisher NotificationPublisher { get; set; } = new ForeachAwaitPublisher();

/// <summary>
/// Type of notification publisher strategy to register. If set, overrides <see cref="NotificationPublisher"/>
/// </summary>
public Type? NotificationPublisherType { get; set; }

/// <summary>
/// Service lifetime to register services under. Default value is <see cref="ServiceLifetime.Transient"/>
/// </summary>
public ServiceLifetime Lifetime { get; set; } = ServiceLifetime.Transient;

/// <summary>
/// Request exception action processor strategy. Default value is <see cref="DependencyInjection.RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions"/>
/// </summary>
public RequestExceptionActionProcessorStrategy RequestExceptionActionProcessorStrategy { get; set; }
= RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions;

internal List<Assembly> AssembliesToRegister { get; } = new();

/// <summary>
/// List of behaviors to register in specific order
/// </summary>
public List<ServiceDescriptor> BehaviorsToRegister { get; } = new();

/// <summary>
/// Register various handlers from assembly containing given type
/// </summary>
/// <typeparam name="T">Type from assembly to scan</typeparam>
/// <returns>This</returns>
public MediatRServiceConfiguration RegisterServicesFromAssemblyContaining<T>()
=> RegisterServicesFromAssemblyContaining(typeof(T));

/// <summary>
/// Register various handlers from assembly containing given type
/// </summary>
/// <param name="type">Type from assembly to scan</param>
/// <returns>This</returns>
public MediatRServiceConfiguration RegisterServicesFromAssemblyContaining(Type type)
=> RegisterServicesFromAssembly(type.Assembly);

/// <summary>
/// Register various handlers from assembly
/// </summary>
/// <param name="assembly">Assembly to scan</param>
/// <returns>This</returns>
public MediatRServiceConfiguration RegisterServicesFromAssembly(Assembly assembly)
{
AssembliesToRegister.Add(assembly);

return this;
}

/// <summary>
/// Register various handlers from assemblies
/// </summary>
/// <param name="assemblies">Assemblies to scan</param>
/// <returns>This</returns>
public MediatRServiceConfiguration RegisterServicesFromAssemblies(
params Assembly[] assemblies)
{
Expand All @@ -38,10 +87,24 @@ public MediatRServiceConfiguration RegisterServicesFromAssembly(Assembly assembl
return this;
}

/// <summary>
/// Register a closed behavior type
/// </summary>
/// <typeparam name="TServiceType">Closed behavior interface type</typeparam>
/// <typeparam name="TImplementationType">Closed behavior implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddBehavior<TServiceType, TImplementationType>(
ServiceLifetime serviceLifetime = ServiceLifetime.Transient) =>
AddBehavior(typeof(TServiceType), typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed behavior type
/// </summary>
/// <param name="serviceType">Closed behavior interface type</param>
/// <param name="implementationType">Closed behavior implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddBehavior(
Type serviceType,
Type implementationType,
Expand All @@ -52,6 +115,12 @@ public MediatRServiceConfiguration RegisterServicesFromAssembly(Assembly assembl
return this;
}

/// <summary>
/// Registers an open behavior type against the <see cref="IPipelineBehavior{TRequest,TResponse}"/> open generic interface type
/// </summary>
/// <param name="openBehaviorType">An open generic behavior type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddOpenBehavior(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var serviceType = typeof(IPipelineBehavior<,>);
Expand Down
7 changes: 7 additions & 0 deletions src/MediatR/NotificationHandlerExecutor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
using System;
using System.Threading;
using System.Threading.Tasks;

namespace MediatR;

public record NotificationHandlerExecutor(object HandlerInstance, Func<INotification, CancellationToken, Task> HandlerCallback);
24 changes: 24 additions & 0 deletions src/MediatR/NotificationPublishers/ForeachAwaitPublisher.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

namespace MediatR.NotificationPublishers;

/// <summary>
/// Awaits each notification handler in a single foreach loop:
/// <code>
/// foreach (var handler in handlers) {
/// await handler(notification, cancellationToken);
/// }
/// </code>
/// </summary>
public class ForeachAwaitPublisher : INotificationPublisher
{
public async Task Publish(IEnumerable<NotificationHandlerExecutor> handlerExecutors, INotification notification, CancellationToken cancellationToken)
{
foreach (var handler in handlerExecutors)
{
await handler.HandlerCallback(notification, cancellationToken).ConfigureAwait(false);
}
}
}
28 changes: 28 additions & 0 deletions src/MediatR/NotificationPublishers/TaskWhenAllPublisher.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace MediatR.NotificationPublishers;

/// <summary>
/// Uses Task.WhenAll with the list of Handler tasks:
/// <code>
/// var tasks = handlers
/// .Select(handler => handler.Handle(notification, cancellationToken))
/// .ToList();
///
/// return Task.WhenAll(tasks);
/// </code>
/// </summary>
public class TaskWhenAllPublisher : INotificationPublisher
{
public Task Publish(IEnumerable<NotificationHandlerExecutor> handlerExecutors, INotification notification, CancellationToken cancellationToken)
{
var tasks = handlerExecutors
.Select(handler => handler.HandlerCallback(notification, cancellationToken))
.ToArray();

return Task.WhenAll(tasks);
}
}
5 changes: 5 additions & 0 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
services.TryAdd(new ServiceDescriptor(typeof(ISender), sp => sp.GetRequiredService<IMediator>(), serviceConfiguration.Lifetime));
services.TryAdd(new ServiceDescriptor(typeof(IPublisher), sp => sp.GetRequiredService<IMediator>(), serviceConfiguration.Lifetime));

services.TryAdd(serviceConfiguration.NotificationPublisherType != null
? new ServiceDescriptor(typeof(INotificationPublisher), serviceConfiguration.NotificationPublisherType,
serviceConfiguration.Lifetime)
: new ServiceDescriptor(typeof(INotificationPublisher), serviceConfiguration.NotificationPublisher));

foreach (var serviceDescriptor in serviceConfiguration.BehaviorsToRegister)
{
services.Add(serviceDescriptor);
Expand Down
Loading

0 comments on commit ba9d3ee

Please sign in to comment.