diff --git a/src/MediatR/Mediator.cs b/src/MediatR/Mediator.cs index d3948c97..c4978e3c 100644 --- a/src/MediatR/Mediator.cs +++ b/src/MediatR/Mediator.cs @@ -48,19 +48,22 @@ public Task Send(IRequest request, Cancellation throw new ArgumentNullException(nameof(request)); } var requestType = request.GetType(); - var requestInterfaceType = requestType - .GetInterfaces() - .FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IRequest<>)); - var isValidRequest = requestInterfaceType != null; - - if (!isValidRequest) - { - throw new ArgumentException($"{nameof(request)} does not implement ${nameof(IRequest)}"); - } - - var responseType = requestInterfaceType!.GetGenericArguments()[0]; var handler = _requestHandlers.GetOrAdd(requestType, - t => (RequestHandlerBase)Activator.CreateInstance(typeof(RequestHandlerWrapperImpl<,>).MakeGenericType(requestType, responseType))); + requestTypeKey => + { + var requestInterfaceType = requestTypeKey + .GetInterfaces() + .FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IRequest<>)); + var isValidRequest = requestInterfaceType != null; + + if (!isValidRequest) + { + throw new ArgumentException($"{requestType.Name} does not implement {nameof(IRequest)}", nameof(request)); + } + + var responseType = requestInterfaceType!.GetGenericArguments()[0]; + return (RequestHandlerBase)Activator.CreateInstance(typeof(RequestHandlerWrapperImpl<,>).MakeGenericType(requestTypeKey, responseType)); + }); // call via dynamic dispatch to avoid calling through reflection for performance reasons return handler.Handle(request, cancellationToken, _serviceFactory); diff --git a/test/MediatR.Tests/ExceptionTests.cs b/test/MediatR.Tests/ExceptionTests.cs index 16a1f26d..fbe39570 100644 --- a/test/MediatR.Tests/ExceptionTests.cs +++ b/test/MediatR.Tests/ExceptionTests.cs @@ -244,9 +244,9 @@ public async Task Should_throw_argument_exception_for_publish_when_request_is_no public class PingException : IRequest { - + } - + public class PingExceptionHandler : IRequestHandler { public Task Handle(PingException request, CancellationToken cancellationToken) @@ -254,7 +254,7 @@ public Task Handle(PingException request, CancellationToken cancellationTo throw new NotImplementedException(); } } - + [Fact] public async Task Should_throw_exception_for_non_generic_send_when_exception_occurs() { @@ -271,12 +271,40 @@ public async Task Should_throw_exception_for_non_generic_send_when_exception_occ cfg.For().Use(); }); var mediator = container.GetInstance(); - + object pingException = new PingException(); - + await Should.ThrowAsync(async () => await mediator.Send(pingException)); } - + + [Fact] + public async Task Should_throw_exception_for_non_request_send() + { + var container = new Container(cfg => + { + cfg.Scan(scanner => + { + scanner.AssemblyContainingType(typeof(NullPinged)); + scanner.IncludeNamespaceContainingType(); + scanner.WithDefaultConventions(); + scanner.AddAllTypesOf(typeof(IRequestHandler<,>)); + }); + cfg.For().Use(ctx => t => ctx.GetInstance(t)); + cfg.For().Use(); + }); + var mediator = container.GetInstance(); + + object nonRequest = new NonRequest(); + + var argumentException = await Should.ThrowAsync(async () => await mediator.Send(nonRequest)); + Assert.StartsWith("NonRequest does not implement IRequest", argumentException.Message); + } + + public class NonRequest + { + + } + [Fact] public async Task Should_throw_exception_for_generic_send_when_exception_occurs() { @@ -293,9 +321,9 @@ public async Task Should_throw_exception_for_generic_send_when_exception_occurs( cfg.For().Use(); }); var mediator = container.GetInstance(); - + PingException pingException = new PingException(); - + await Should.ThrowAsync(async () => await mediator.Send(pingException)); } }