Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update middleware that assumes UseRouting is called after them, for minimal hosting #35426

Merged
merged 14 commits into from Aug 24, 2021
@@ -1,9 +1,11 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics;
using Microsoft.AspNetCore.Diagnostics;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Microsoft.AspNetCore.Builder
Expand All @@ -26,7 +28,7 @@ public static IApplicationBuilder UseExceptionHandler(this IApplicationBuilder a
throw new ArgumentNullException(nameof(app));
}

return app.UseMiddleware<ExceptionHandlerMiddleware>();
return SetExceptionHandlerMiddleware(app, options: null);
}

/// <summary>
Expand Down Expand Up @@ -95,7 +97,52 @@ public static IApplicationBuilder UseExceptionHandler(this IApplicationBuilder a
throw new ArgumentNullException(nameof(options));
}

return app.UseMiddleware<ExceptionHandlerMiddleware>(Options.Create(options));
var iOptions = Options.Create(options);
return SetExceptionHandlerMiddleware(app, iOptions);
}

private static IApplicationBuilder SetExceptionHandlerMiddleware(IApplicationBuilder app, IOptions<ExceptionHandlerOptions>? options)
{
const string globalRouteBuilderKey = "__GlobalEndpointRouteBuilder";
// Only use this path if there's a global router (in the 'WebApplication' case).
if (app.Properties.TryGetValue(globalRouteBuilderKey, out var routeBuilder))
{
return app.Use(next =>
{
var loggerFactory = app.ApplicationServices.GetRequiredService<ILoggerFactory>();
var diagnosticListener = app.ApplicationServices.GetRequiredService<DiagnosticListener>();

if (options is null)
{
options = app.ApplicationServices.GetRequiredService<IOptions<ExceptionHandlerOptions>>();
}

if (!string.IsNullOrEmpty(options.Value.ExceptionHandlingPath) && options.Value.ExceptionHandler is null)
{
// start a new middleware pipeline
var builder = app.New();
if (routeBuilder is not null)
{
// use the old routing pipeline if it exists so we preserve all the routes and matching logic
builder.Properties[globalRouteBuilderKey] = routeBuilder;
BrennanConroy marked this conversation as resolved.
Show resolved Hide resolved
BrennanConroy marked this conversation as resolved.
Show resolved Hide resolved
}
builder.UseRouting();
// apply the next middleware
builder.Run(next);
// store the pipeline for the error case
options.Value.ExceptionHandler = builder.Build();
}

return new ExceptionHandlerMiddleware(next, loggerFactory, options, diagnosticListener).Invoke;
});
}

if (options is null)
{
return app.UseMiddleware<ExceptionHandlerMiddleware>();
}

return app.UseMiddleware<ExceptionHandlerMiddleware>(options);
}
}
}
Expand Up @@ -174,7 +174,36 @@ public static IApplicationBuilder UseStatusCodePages(this IApplicationBuilder ap
throw new ArgumentNullException(nameof(app));
}

return app.UseStatusCodePages(async context =>
const string globalRouteBuilderKey = "__GlobalEndpointRouteBuilder";
// Only use this path if there's a global router (in the 'WebApplication' case).
if (app.Properties.TryGetValue(globalRouteBuilderKey, out var routeBuilder))
{
return app.Use(next =>
{
RequestDelegate? newNext = null;
// start a new middleware pipeline
var builder = app.New();
if (routeBuilder is not null)
{
// use the old routing pipeline if it exists so we preserve all the routes and matching logic
builder.Properties[globalRouteBuilderKey] = routeBuilder;
BrennanConroy marked this conversation as resolved.
Show resolved Hide resolved
builder.UseRouting();
// apply the next middleware
builder.Run(next);
newNext = builder.Build();
}

return new StatusCodePagesMiddleware(next,
Options.Create(new StatusCodePagesOptions() { HandleAsync = CreateHandler(pathFormat, queryFormat, newNext) })).Invoke;
});
}

return app.UseStatusCodePages(CreateHandler(pathFormat, queryFormat));
}

private static Func<StatusCodeContext, Task> CreateHandler(string pathFormat, string? queryFormat, RequestDelegate? next = null)
{
var handler = async (StatusCodeContext context) =>
{
var newPath = new PathString(
string.Format(CultureInfo.InvariantCulture, pathFormat, context.HttpContext.Response.StatusCode));
Expand Down Expand Up @@ -202,15 +231,24 @@ public static IApplicationBuilder UseStatusCodePages(this IApplicationBuilder ap
context.HttpContext.Request.QueryString = newQueryString;
try
{
await context.Next(context.HttpContext);
if (next is not null)
{
await next(context.HttpContext);
}
else
{
await context.Next(context.HttpContext);
}
}
finally
{
context.HttpContext.Request.QueryString = originalQueryString;
context.HttpContext.Request.Path = originalPath;
context.HttpContext.Features.Set<IStatusCodeReExecuteFeature?>(null);
}
});
};

return handler;
}
}
}