Skip to content

Commit

Permalink
Detect services based on service provider (#32737)
Browse files Browse the repository at this point in the history
* Detect services based on service provider
- Use IServiceProviderIsService to detect if a parameter is a service.
- As a final fallback, try to detect services from the DI container before falling back to body behavior.
  • Loading branch information
davidfowl committed Jun 13, 2021
1 parent d7d5f41 commit d2ab01b
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 62 deletions.
6 changes: 3 additions & 3 deletions src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ static Microsoft.AspNetCore.Http.HeaderDictionaryTypeExtensions.AppendList<T>(th
static Microsoft.AspNetCore.Http.HeaderDictionaryTypeExtensions.GetTypedHeaders(this Microsoft.AspNetCore.Http.HttpRequest! request) -> Microsoft.AspNetCore.Http.Headers.RequestHeaders!
static Microsoft.AspNetCore.Http.HeaderDictionaryTypeExtensions.GetTypedHeaders(this Microsoft.AspNetCore.Http.HttpResponse! response) -> Microsoft.AspNetCore.Http.Headers.ResponseHeaders!
static Microsoft.AspNetCore.Http.HttpContextServerVariableExtensions.GetServerVariable(this Microsoft.AspNetCore.Http.HttpContext! context, string! variableName) -> string?
static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Delegate! action) -> Microsoft.AspNetCore.Http.RequestDelegate!
static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo) -> Microsoft.AspNetCore.Http.RequestDelegate!
static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo, System.Func<Microsoft.AspNetCore.Http.HttpContext!, object!>! targetFactory) -> Microsoft.AspNetCore.Http.RequestDelegate!
static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Delegate! action, System.IServiceProvider? serviceProvider) -> Microsoft.AspNetCore.Http.RequestDelegate!
static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo, System.IServiceProvider? serviceProvider) -> Microsoft.AspNetCore.Http.RequestDelegate!
static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo, System.IServiceProvider? serviceProvider, System.Func<Microsoft.AspNetCore.Http.HttpContext!, object!>! targetFactory) -> Microsoft.AspNetCore.Http.RequestDelegate!
static Microsoft.AspNetCore.Http.ResponseExtensions.Clear(this Microsoft.AspNetCore.Http.HttpResponse! response) -> void
static Microsoft.AspNetCore.Http.ResponseExtensions.Redirect(this Microsoft.AspNetCore.Http.HttpResponse! response, string! location, bool permanent, bool preserveMethod) -> void
static Microsoft.AspNetCore.Http.SendFileResponseExtensions.SendFileAsync(this Microsoft.AspNetCore.Http.HttpResponse! response, Microsoft.Extensions.FileProviders.IFileInfo! file, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!
Expand Down
32 changes: 24 additions & 8 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ public static class RequestDelegateFactory
/// Creates a <see cref="RequestDelegate"/> implementation for <paramref name="action"/>.
/// </summary>
/// <param name="action">A request handler with any number of custom parameters that often produces a response with its return value.</param>
/// <param name="serviceProvider">The <see cref="IServiceProvider"/> instance used to detect which parameters are services.</param>
/// <returns>The <see cref="RequestDelegate"/>.</returns>
public static RequestDelegate Create(Delegate action)
public static RequestDelegate Create(Delegate action, IServiceProvider? serviceProvider)
{
if (action is null)
{
Expand All @@ -76,7 +77,7 @@ public static RequestDelegate Create(Delegate action)
null => null,
};

var targetableRequestDelegate = CreateTargetableRequestDelegate(action.Method, targetExpression);
var targetableRequestDelegate = CreateTargetableRequestDelegate(action.Method, serviceProvider, targetExpression);

return httpContext =>
{
Expand All @@ -88,15 +89,16 @@ public static RequestDelegate Create(Delegate action)
/// Creates a <see cref="RequestDelegate"/> implementation for <paramref name="methodInfo"/>.
/// </summary>
/// <param name="methodInfo">A static request handler with any number of custom parameters that often produces a response with its return value.</param>
/// <param name="serviceProvider">The <see cref="IServiceProvider"/> instance used to detect which parameters are services.</param>
/// <returns>The <see cref="RequestDelegate"/>.</returns>
public static RequestDelegate Create(MethodInfo methodInfo)
public static RequestDelegate Create(MethodInfo methodInfo, IServiceProvider? serviceProvider)
{
if (methodInfo is null)
{
throw new ArgumentNullException(nameof(methodInfo));
}

var targetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, targetExpression: null);
var targetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, serviceProvider, targetExpression: null);

return httpContext =>
{
Expand All @@ -108,9 +110,10 @@ public static RequestDelegate Create(MethodInfo methodInfo)
/// Creates a <see cref="RequestDelegate"/> implementation for <paramref name="methodInfo"/>.
/// </summary>
/// <param name="methodInfo">A request handler with any number of custom parameters that often produces a response with its return value.</param>
/// <param name="serviceProvider">The <see cref="IServiceProvider"/> instance used to detect which parameters are services.</param>
/// <param name="targetFactory">Creates the <see langword="this"/> for the non-static method.</param>
/// <returns>The <see cref="RequestDelegate"/>.</returns>
public static RequestDelegate Create(MethodInfo methodInfo, Func<HttpContext, object> targetFactory)
public static RequestDelegate Create(MethodInfo methodInfo, IServiceProvider? serviceProvider, Func<HttpContext, object> targetFactory)
{
if (methodInfo is null)
{
Expand All @@ -128,15 +131,15 @@ public static RequestDelegate Create(MethodInfo methodInfo, Func<HttpContext, ob
}

var targetExpression = Expression.Convert(TargetExpr, methodInfo.DeclaringType);
var targetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, targetExpression);
var targetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, serviceProvider, targetExpression);

return httpContext =>
{
return targetableRequestDelegate(targetFactory(httpContext), httpContext);
};
}

private static Func<object?, HttpContext, Task> CreateTargetableRequestDelegate(MethodInfo methodInfo, Expression? targetExpression)
private static Func<object?, HttpContext, Task> CreateTargetableRequestDelegate(MethodInfo methodInfo, IServiceProvider? serviceProvider, Expression? targetExpression)
{
// Non void return type

Expand All @@ -154,7 +157,10 @@ public static RequestDelegate Create(MethodInfo methodInfo, Func<HttpContext, ob
// return default;
// }

var factoryContext = new FactoryContext();
var factoryContext = new FactoryContext()
{
ServiceProvider = serviceProvider
};

var arguments = CreateArguments(methodInfo.GetParameters(), factoryContext);

Expand Down Expand Up @@ -234,6 +240,15 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext
}
else
{
if (factoryContext.ServiceProvider?.GetService<IServiceProviderIsService>() is IServiceProviderIsService serviceProviderIsService)
{
// If the parameter resolves as a service then get it from services
if (serviceProviderIsService.IsService(parameter.ParameterType))
{
return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr);
}
}

return BindParameterFromBody(parameter.ParameterType, allowEmpty: false, factoryContext);
}
}
Expand Down Expand Up @@ -788,6 +803,7 @@ private class FactoryContext
{
public Type? JsonRequestBodyType { get; set; }
public bool AllowEmptyRequestBody { get; set; }
public IServiceProvider? ServiceProvider { get; init; }

public bool UsingTempSourceString { get; set; }
public List<(ParameterExpression, Expression)> TryParseParams { get; } = new();
Expand Down
Loading

0 comments on commit d2ab01b

Please sign in to comment.