From 224be8ffe4f3d436b4cd3b3d7b73f47c105446f7 Mon Sep 17 00:00:00 2001 From: Kahbazi Date: Wed, 24 Feb 2021 17:35:46 +0330 Subject: [PATCH 1/3] Support optional input for MapAction --- .../MapActionExpressionTreeBuilder.cs | 108 +++++++---- .../MapActionExpressionTreeBuilderTest.cs | 175 ++++++++++++++++-- 2 files changed, 228 insertions(+), 55 deletions(-) diff --git a/src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs b/src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs index 54b931cd9875..66f8f761a925 100644 --- a/src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs +++ b/src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs @@ -24,7 +24,8 @@ internal static class MapActionExpressionTreeBuilder private static readonly MethodInfo ChangeTypeMethodInfo = GetMethodInfo>((value, type) => Convert.ChangeType(value, type, CultureInfo.InvariantCulture)); private static readonly MethodInfo ExecuteTaskOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteTask), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteTaskOfStringMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!; - private static readonly MethodInfo ExecuteValueTaskOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTask), BindingFlags.NonPublic | BindingFlags.Static)!; + private static readonly MethodInfo ExecuteValueTaskOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!; + private static readonly MethodInfo ExecuteValueTaskMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTask), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueTaskOfStringMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteTaskResultOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueResultTaskOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!; @@ -71,28 +72,31 @@ public static RequestDelegate BuildRequestDelegate(Delegate action) // This argument represents the deserialized body returned from IHttpRequestReader // when the method has a FromBody attribute declared - var args = new List(); + var methodParameters = method.GetParameters(); + var args = new List(methodParameters.Length); - foreach (var parameter in method.GetParameters()) + foreach (var parameter in methodParameters) { Expression paramterExpression = Expression.Default(parameter.ParameterType); - if (parameter.GetCustomAttributes().OfType().FirstOrDefault() is { } routeAttribute) + var parameterCustomAttributes = parameter.GetCustomAttributes(); + + if (parameterCustomAttributes.OfType().FirstOrDefault() is { } routeAttribute) { var routeValuesProperty = Expression.Property(HttpRequestExpr, nameof(HttpRequest.RouteValues)); paramterExpression = BindParamenter(routeValuesProperty, parameter, routeAttribute.Name); } - else if (parameter.GetCustomAttributes().OfType().FirstOrDefault() is { } queryAttribute) + else if (parameterCustomAttributes.OfType().FirstOrDefault() is { } queryAttribute) { var queryProperty = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Query)); paramterExpression = BindParamenter(queryProperty, parameter, queryAttribute.Name); } - else if (parameter.GetCustomAttributes().OfType().FirstOrDefault() is { } headerAttribute) + else if (parameterCustomAttributes.OfType().FirstOrDefault() is { } headerAttribute) { var headersProperty = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Headers)); paramterExpression = BindParamenter(headersProperty, parameter, headerAttribute.Name); } - else if (parameter.GetCustomAttributes().OfType().FirstOrDefault() is { } bodyAttribute) + else if (parameterCustomAttributes.OfType().FirstOrDefault() is { } bodyAttribute) { if (consumeBodyDirectly) { @@ -109,7 +113,7 @@ public static RequestDelegate BuildRequestDelegate(Delegate action) bodyType = parameter.ParameterType; paramterExpression = Expression.Convert(DeserializedBodyArg, bodyType); } - else if (parameter.GetCustomAttributes().OfType().FirstOrDefault() is { } formAttribute) + else if (parameterCustomAttributes.OfType().FirstOrDefault() is { } formAttribute) { if (consumeBodyDirectly) { @@ -125,27 +129,24 @@ public static RequestDelegate BuildRequestDelegate(Delegate action) { paramterExpression = Expression.Call(GetRequiredServiceMethodInfo.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr); } - else + else if (parameter.ParameterType == typeof(IFormCollection)) { - if (parameter.ParameterType == typeof(IFormCollection)) + if (consumeBodyDirectly) { - if (consumeBodyDirectly) - { - ThrowCannotReadBodyDirectlyAndAsForm(); - } + ThrowCannotReadBodyDirectlyAndAsForm(); + } - consumeBodyAsForm = true; + consumeBodyAsForm = true; - paramterExpression = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Form)); - } - else if (parameter.ParameterType == typeof(HttpContext)) - { - paramterExpression = HttpContextParameter; - } - else if (parameter.ParameterType == typeof(CancellationToken)) - { - paramterExpression = RequestAbortedExpr; - } + paramterExpression = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Form)); + } + else if (parameter.ParameterType == typeof(HttpContext)) + { + paramterExpression = HttpContextParameter; + } + else if (parameter.ParameterType == typeof(CancellationToken)) + { + paramterExpression = RequestAbortedExpr; } args.Add(paramterExpression); @@ -182,6 +183,12 @@ public static RequestDelegate BuildRequestDelegate(Delegate action) { body = methodCall; } + else if (method.ReturnType == typeof(ValueTask)) + { + body = Expression.Call( + ExecuteValueTaskMethodInfo, + methodCall); + } else if (method.ReturnType.IsGenericType && method.ReturnType.GetGenericTypeDefinition() == typeof(Task<>)) { @@ -263,7 +270,7 @@ public static RequestDelegate BuildRequestDelegate(Delegate action) var box = Expression.TypeAs(methodCall, typeof(object)); body = Expression.Call(JsonResultWriteResponseAsync, HttpResponseExpr, box, Expression.Constant(CancellationToken.None)); } - else + else { body = Expression.Call(JsonResultWriteResponseAsync, HttpResponseExpr, methodCall, Expression.Constant(CancellationToken.None)); } @@ -294,7 +301,7 @@ public static RequestDelegate BuildRequestDelegate(Delegate action) { try { - bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType!); + bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType!, httpContext.RequestAborted); } catch (IOException ex) { @@ -324,7 +331,7 @@ public static RequestDelegate BuildRequestDelegate(Delegate action) // so the within the method it's cached try { - await httpContext.Request.ReadFormAsync(); + await httpContext.Request.ReadFormAsync(httpContext.RequestAborted); } catch (IOException ex) { @@ -398,10 +405,20 @@ private static Expression BindParamenter(Expression sourceExpression, ParameterI expr = Expression.Convert(expr, parameter.ParameterType); } + Expression defaultExpression; + if (parameter.HasDefaultValue) + { + defaultExpression = Expression.Constant(parameter.DefaultValue); + } + else + { + defaultExpression = Expression.Default(parameter.ParameterType); + } + // property[key] == null ? default : (ParameterType){Type}.Parse(property[key]); expr = Expression.Condition( Expression.Equal(valueArg, Expression.Constant(null)), - Expression.Default(parameter.ParameterType), + defaultExpression, expr); return expr; @@ -423,12 +440,12 @@ private static Task ExecuteTask(Task task, HttpContext httpContext) { static async Task ExecuteAwaited(Task task, HttpContext httpContext) { - await httpContext.Response.WriteAsJsonAsync(await task); + await httpContext.Response.WriteAsJsonAsync(await task, httpContext.RequestAborted); } if (task.IsCompletedSuccessfully) { - return httpContext.Response.WriteAsJsonAsync(task.GetAwaiter().GetResult()); + return httpContext.Response.WriteAsJsonAsync(task.GetAwaiter().GetResult(), httpContext.RequestAborted); } return ExecuteAwaited(task, httpContext); @@ -438,27 +455,42 @@ private static Task ExecuteTaskOfString(Task task, HttpContext httpConte { static async Task ExecuteAwaited(Task task, HttpContext httpContext) { - await httpContext.Response.WriteAsync(await task); + await httpContext.Response.WriteAsync(await task, httpContext.RequestAborted); } if (task.IsCompletedSuccessfully) { - return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()); + return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult(), httpContext.RequestAborted); } return ExecuteAwaited(task, httpContext); } - private static Task ExecuteValueTask(ValueTask task, HttpContext httpContext) + private static Task ExecuteValueTask(ValueTask task) + { + static async Task ExecuteAwaited(ValueTask task) + { + await task; + } + + if (task.IsCompletedSuccessfully) + { + task.GetAwaiter().GetResult(); + } + + return ExecuteAwaited(task); + } + + private static Task ExecuteValueTaskOfT(ValueTask task, HttpContext httpContext) { static async Task ExecuteAwaited(ValueTask task, HttpContext httpContext) { - await httpContext.Response.WriteAsJsonAsync(await task); + await httpContext.Response.WriteAsJsonAsync(await task, httpContext.RequestAborted); } if (task.IsCompletedSuccessfully) { - return httpContext.Response.WriteAsJsonAsync(task.GetAwaiter().GetResult()); + return httpContext.Response.WriteAsJsonAsync(task.GetAwaiter().GetResult(), httpContext.RequestAborted); } return ExecuteAwaited(task, httpContext); @@ -468,12 +500,12 @@ private static Task ExecuteValueTaskOfString(ValueTask task, HttpContext { static async Task ExecuteAwaited(ValueTask task, HttpContext httpContext) { - await httpContext.Response.WriteAsync(await task); + await httpContext.Response.WriteAsync(await task, httpContext.RequestAborted); } if (task.IsCompletedSuccessfully) { - return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()); + return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult(), httpContext.RequestAborted); } return ExecuteAwaited(task, httpContext); diff --git a/src/Http/Routing/test/UnitTests/Internal/MapActionExpressionTreeBuilderTest.cs b/src/Http/Routing/test/UnitTests/Internal/MapActionExpressionTreeBuilderTest.cs index d9c7754327e5..1a3112aa3593 100644 --- a/src/Http/Routing/test/UnitTests/Internal/MapActionExpressionTreeBuilderTest.cs +++ b/src/Http/Routing/test/UnitTests/Internal/MapActionExpressionTreeBuilderTest.cs @@ -24,44 +24,185 @@ namespace Microsoft.AspNetCore.Routing.Internal { public class MapActionExpressionTreeBuilderTest { - [Fact] - public async Task RequestDelegateInvokesAction() + public static IEnumerable NoResult { - var invoked = false; - - void TestAction() + get { - invoked = true; + void TestAction(HttpContext httpContext) + { + MarkAsInvoked(httpContext); + } + + Task TaskTestAction(HttpContext httpContext) + { + MarkAsInvoked(httpContext); + return Task.CompletedTask; + } + + ValueTask ValueTaskTestAction(HttpContext httpContext) + { + MarkAsInvoked(httpContext); + return ValueTask.CompletedTask; + } + + void StaticTestAction(HttpContext httpContext) + { + MarkAsInvoked(httpContext); + } + + Task StaticTaskTestAction(HttpContext httpContext) + { + MarkAsInvoked(httpContext); + return Task.CompletedTask; + } + + ValueTask StaticValueTaskTestAction(HttpContext httpContext) + { + MarkAsInvoked(httpContext); + return ValueTask.CompletedTask; + } + + void MarkAsInvoked(HttpContext httpContext) + { + httpContext.Items.Add("invoked", true); + } + + return new List + { + new object[] { (Action)TestAction }, + new object[] { (Func)TaskTestAction }, + new object[] { (Func)ValueTaskTestAction }, + new object[] { (Action)StaticTestAction }, + new object[] { (Func)StaticTaskTestAction }, + new object[] { (Func)StaticValueTaskTestAction }, + }; } + } - var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate((Action)TestAction); + [Theory] + [MemberData(nameof(NoResult))] + public async Task RequestDelegateInvokesAction(Delegate @delegate) + { + var httpContext = new DefaultHttpContext(); + + var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate); - await requestDelegate(null!); + await requestDelegate(httpContext); - Assert.True(invoked); + Assert.True(httpContext.Items["invoked"] as bool?); } - [Fact] - public async Task RequestDelegatePopulatesFromRouteParameterBasedOnParameterName() + public static IEnumerable FromRouteResult + { + get + { + void TestAction(HttpContext httpContext, [FromRoute] int value) + { + StoreInput(httpContext, value); + }; + + Task TaskTestAction(HttpContext httpContext, [FromRoute] int value) + { + StoreInput(httpContext, value); + return Task.CompletedTask; + } + + ValueTask ValueTaskTestAction(HttpContext httpContext, [FromRoute] int value) + { + StoreInput(httpContext, value); + return ValueTask.CompletedTask; + } + + + + return new List + { + new object[] { (Action)TestAction }, + new object[] { (Func)TaskTestAction }, + new object[] { (Func)ValueTaskTestAction }, + }; + } + } + private static void StoreInput(HttpContext httpContext, object value) + { + httpContext.Items.Add("input", value); + } + + [Theory] + [MemberData(nameof(FromRouteResult))] + public async Task RequestDelegatePopulatesFromRouteParameterBasedOnParameterName(Delegate @delegate) { const string paramName = "value"; const int originalRouteParam = 42; - int? deserializedRouteParam = null; + var httpContext = new DefaultHttpContext(); + httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); + + var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate); + + await requestDelegate(httpContext); + + Assert.Equal(originalRouteParam, httpContext.Items["value"] as int?); + } - void TestAction([FromRoute] int value) + public static IEnumerable FromRouteOptionalResult + { + get { - deserializedRouteParam = value; + return new List + { + new object[] { (Action)TestAction }, + new object[] { (Func)TaskTestAction }, + new object[] { (Func)ValueTaskTestAction } + }; } + } + + private static void TestAction(HttpContext httpContext, [FromRoute] int value = 42) + { + StoreInput(httpContext, value); + } + private static Task TaskTestAction(HttpContext httpContext, [FromRoute] int value = 42) + { + StoreInput(httpContext, value); + return Task.CompletedTask; + } + + private static ValueTask ValueTaskTestAction(HttpContext httpContext, [FromRoute] int value = 42) + { + StoreInput(httpContext, value); + return ValueTask.CompletedTask; + } + + [Theory] + [MemberData(nameof(FromRouteOptionalResult))] + public async Task RequestDelegatePopulatesFromRouteOptionalParameter(Delegate @delegate) + { var httpContext = new DefaultHttpContext(); - httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate((Action)TestAction); + var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate); await requestDelegate(httpContext); - Assert.Equal(originalRouteParam, deserializedRouteParam); + Assert.Equal(42, httpContext.Items["value"] as int?); + } + + [Theory] + [MemberData(nameof(FromRouteOptionalResult))] + public async Task RequestDelegatePopulatesFromRouteOptionalParameterBasedOnParameterName(Delegate @delegate) + { + const string paramName = "value"; + const int originalRouteParam = 420; + + var httpContext = new DefaultHttpContext(); + httpContext.Items.Add("expected", originalRouteParam); + + httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); + + var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate); + + await requestDelegate(httpContext); } [Fact] From 0dea9678096a4d18b9b8836eec3a6a320e13e402 Mon Sep 17 00:00:00 2001 From: Kahbazi Date: Thu, 25 Feb 2021 13:08:11 +0330 Subject: [PATCH 2/3] Revert passing HttpContext.RequestAborted --- .../MapActionExpressionTreeBuilder.cs | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs b/src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs index 66f8f761a925..1e83d92c3497 100644 --- a/src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs +++ b/src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs @@ -301,7 +301,7 @@ public static RequestDelegate BuildRequestDelegate(Delegate action) { try { - bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType!, httpContext.RequestAborted); + bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType!); } catch (IOException ex) { @@ -331,7 +331,7 @@ public static RequestDelegate BuildRequestDelegate(Delegate action) // so the within the method it's cached try { - await httpContext.Request.ReadFormAsync(httpContext.RequestAborted); + await httpContext.Request.ReadFormAsync(); } catch (IOException ex) { @@ -440,12 +440,12 @@ private static Task ExecuteTask(Task task, HttpContext httpContext) { static async Task ExecuteAwaited(Task task, HttpContext httpContext) { - await httpContext.Response.WriteAsJsonAsync(await task, httpContext.RequestAborted); + await httpContext.Response.WriteAsJsonAsync(await task); } if (task.IsCompletedSuccessfully) { - return httpContext.Response.WriteAsJsonAsync(task.GetAwaiter().GetResult(), httpContext.RequestAborted); + return httpContext.Response.WriteAsJsonAsync(task.GetAwaiter().GetResult()); } return ExecuteAwaited(task, httpContext); @@ -455,12 +455,12 @@ private static Task ExecuteTaskOfString(Task task, HttpContext httpConte { static async Task ExecuteAwaited(Task task, HttpContext httpContext) { - await httpContext.Response.WriteAsync(await task, httpContext.RequestAborted); + await httpContext.Response.WriteAsync(await task); } if (task.IsCompletedSuccessfully) { - return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult(), httpContext.RequestAborted); + return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()); } return ExecuteAwaited(task, httpContext); @@ -485,12 +485,12 @@ private static Task ExecuteValueTaskOfT(ValueTask task, HttpContext httpCo { static async Task ExecuteAwaited(ValueTask task, HttpContext httpContext) { - await httpContext.Response.WriteAsJsonAsync(await task, httpContext.RequestAborted); + await httpContext.Response.WriteAsJsonAsync(await task); } if (task.IsCompletedSuccessfully) { - return httpContext.Response.WriteAsJsonAsync(task.GetAwaiter().GetResult(), httpContext.RequestAborted); + return httpContext.Response.WriteAsJsonAsync(task.GetAwaiter().GetResult()); } return ExecuteAwaited(task, httpContext); @@ -500,12 +500,12 @@ private static Task ExecuteValueTaskOfString(ValueTask task, HttpContext { static async Task ExecuteAwaited(ValueTask task, HttpContext httpContext) { - await httpContext.Response.WriteAsync(await task, httpContext.RequestAborted); + await httpContext.Response.WriteAsync(await task); } if (task.IsCompletedSuccessfully) { - return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult(), httpContext.RequestAborted); + return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()); } return ExecuteAwaited(task, httpContext); From a724f9ee1641abd44e2df7cb6c86e541bbd7d3f1 Mon Sep 17 00:00:00 2001 From: Kahbazi Date: Fri, 26 Feb 2021 22:52:02 +0330 Subject: [PATCH 3/3] Fix tests --- .../Internal/MapActionExpressionTreeBuilderTest.cs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/Http/Routing/test/UnitTests/Internal/MapActionExpressionTreeBuilderTest.cs b/src/Http/Routing/test/UnitTests/Internal/MapActionExpressionTreeBuilderTest.cs index 1a3112aa3593..543e8a17d808 100644 --- a/src/Http/Routing/test/UnitTests/Internal/MapActionExpressionTreeBuilderTest.cs +++ b/src/Http/Routing/test/UnitTests/Internal/MapActionExpressionTreeBuilderTest.cs @@ -142,7 +142,7 @@ public async Task RequestDelegatePopulatesFromRouteParameterBasedOnParameterName await requestDelegate(httpContext); - Assert.Equal(originalRouteParam, httpContext.Items["value"] as int?); + Assert.Equal(originalRouteParam, httpContext.Items["input"] as int?); } public static IEnumerable FromRouteOptionalResult @@ -185,7 +185,7 @@ public async Task RequestDelegatePopulatesFromRouteOptionalParameter(Delegate @d await requestDelegate(httpContext); - Assert.Equal(42, httpContext.Items["value"] as int?); + Assert.Equal(42, httpContext.Items["input"] as int?); } [Theory] @@ -193,16 +193,17 @@ public async Task RequestDelegatePopulatesFromRouteOptionalParameter(Delegate @d public async Task RequestDelegatePopulatesFromRouteOptionalParameterBasedOnParameterName(Delegate @delegate) { const string paramName = "value"; - const int originalRouteParam = 420; + const int originalRouteParam = 47; var httpContext = new DefaultHttpContext(); - httpContext.Items.Add("expected", originalRouteParam); httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate); await requestDelegate(httpContext); + + Assert.Equal(47, httpContext.Items["input"] as int?); } [Fact]