Skip to content

Commit

Permalink
IExceptionHandler for exception handler middleware #46280 (#47923)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kahbazi committed Apr 28, 2023
1 parent b92218a commit b2ce346
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ private static IApplicationBuilder SetExceptionHandlerMiddleware(IApplicationBui
{
var loggerFactory = app.ApplicationServices.GetRequiredService<ILoggerFactory>();
var diagnosticListener = app.ApplicationServices.GetRequiredService<DiagnosticListener>();
var exceptionHandlers = app.ApplicationServices.GetRequiredService<IEnumerable<IExceptionHandler>>();
if (options is null)
{
Expand All @@ -110,7 +111,7 @@ private static IApplicationBuilder SetExceptionHandlerMiddleware(IApplicationBui
options.Value.ExceptionHandler = newNext;
}
return new ExceptionHandlerMiddlewareImpl(next, loggerFactory, options, diagnosticListener, problemDetailsService).Invoke;
return new ExceptionHandlerMiddlewareImpl(next, loggerFactory, options, diagnosticListener, exceptionHandlers, problemDetailsService).Invoke;
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Linq;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -29,11 +30,12 @@ public class ExceptionHandlerMiddleware
IOptions<ExceptionHandlerOptions> options,
DiagnosticListener diagnosticListener)
{
_innerMiddlewareImpl = new (
_innerMiddlewareImpl = new(
next,
loggerFactory,
options,
diagnosticListener,
Enumerable.Empty<IExceptionHandler>(),
problemDetailsService: null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Microsoft.AspNetCore.Diagnostics;
/// <summary>
/// A middleware for handling exceptions in the application.
/// </summary>
internal class ExceptionHandlerMiddlewareImpl
internal sealed class ExceptionHandlerMiddlewareImpl
{
private const int DefaultStatusCode = StatusCodes.Status500InternalServerError;

Expand All @@ -25,6 +25,7 @@ internal class ExceptionHandlerMiddlewareImpl
private readonly ILogger _logger;
private readonly Func<object, Task> _clearCacheHeadersDelegate;
private readonly DiagnosticListener _diagnosticListener;
private readonly IExceptionHandler[] _exceptionHandlers;
private readonly IProblemDetailsService? _problemDetailsService;

/// <summary>
Expand All @@ -34,19 +35,22 @@ internal class ExceptionHandlerMiddlewareImpl
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> used for logging.</param>
/// <param name="options">The options for configuring the middleware.</param>
/// <param name="diagnosticListener">The <see cref="DiagnosticListener"/> used for writing diagnostic messages.</param>
/// <param name="exceptionHandlers"></param>
/// <param name="problemDetailsService">The <see cref="IProblemDetailsService"/> used for writing <see cref="ProblemDetails"/> messages.</param>
public ExceptionHandlerMiddlewareImpl(
RequestDelegate next,
ILoggerFactory loggerFactory,
IOptions<ExceptionHandlerOptions> options,
DiagnosticListener diagnosticListener,
IEnumerable<IExceptionHandler> exceptionHandlers,
IProblemDetailsService? problemDetailsService = null)
{
_next = next;
_options = options.Value;
_logger = loggerFactory.CreateLogger<ExceptionHandlerMiddleware>();
_clearCacheHeadersDelegate = ClearCacheHeaders;
_diagnosticListener = diagnosticListener;
_exceptionHandlers = exceptionHandlers as IExceptionHandler[] ?? new List<IExceptionHandler>(exceptionHandlers).ToArray();
_problemDetailsService = problemDetailsService;

if (_options.ExceptionHandler == null)
Expand Down Expand Up @@ -133,7 +137,7 @@ private async Task HandleException(HttpContext context, ExceptionDispatchInfo ed
edi.Throw();
}

PathString originalPath = context.Request.Path;
var originalPath = context.Request.Path;
if (_options.ExceptionHandlingPath.HasValue)
{
context.Request.Path = _options.ExceptionHandlingPath;
Expand All @@ -155,24 +159,35 @@ private async Task HandleException(HttpContext context, ExceptionDispatchInfo ed
context.Response.StatusCode = DefaultStatusCode;
context.Response.OnStarting(_clearCacheHeadersDelegate, context.Response);

var problemDetailsWritten = false;
if (_options.ExceptionHandler != null)
var handled = false;
foreach (var exceptionHandler in _exceptionHandlers)
{
await _options.ExceptionHandler!(context);
handled = await exceptionHandler.TryHandleAsync(context, edi.SourceException, context.RequestAborted);
if (handled)
{
break;
}
}
else

if (!handled)
{
problemDetailsWritten = await _problemDetailsService!.TryWriteAsync(new()
if (_options.ExceptionHandler is not null)
{
HttpContext = context,
AdditionalMetadata = exceptionHandlerFeature.Endpoint?.Metadata,
ProblemDetails = { Status = DefaultStatusCode },
Exception = edi.SourceException,
});
await _options.ExceptionHandler!(context);
}
else
{
handled = await _problemDetailsService!.TryWriteAsync(new()
{
HttpContext = context,
AdditionalMetadata = exceptionHandlerFeature.Endpoint?.Metadata,
ProblemDetails = { Status = DefaultStatusCode },
Exception = edi.SourceException,
});
}
}

// If the response has already started, assume exception handler was successful.
if (context.Response.HasStarted || problemDetailsWritten || context.Response.StatusCode != StatusCodes.Status404NotFound || _options.AllowStatusCode404Response)
if (context.Response.HasStarted || handled || context.Response.StatusCode != StatusCodes.Status404NotFound || _options.AllowStatusCode404Response)
{
const string eventName = "Microsoft.AspNetCore.Diagnostics.HandledException";
if (_diagnosticListener.IsEnabled() && _diagnosticListener.IsEnabled(eventName))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Diagnostics;

namespace Microsoft.Extensions.DependencyInjection;

Expand Down Expand Up @@ -38,4 +40,16 @@ public static IServiceCollection AddExceptionHandler(this IServiceCollection ser
services.AddOptions<ExceptionHandlerOptions>().Configure(configureOptions);
return services;
}

/// <summary>
///
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="services"></param>
/// <returns></returns>

public static IServiceCollection AddExceptionHandler<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] T>(this IServiceCollection services) where T : class, IExceptionHandler
{
return services.AddSingleton<IExceptionHandler, T>();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.AspNetCore.Http;

namespace Microsoft.AspNetCore.Diagnostics;

/// <summary>
///
/// </summary>
public interface IExceptionHandler
{
/// <summary>
///
/// </summary>
/// <param name="httpContext"></param>
/// <param name="exception"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
ValueTask<bool> TryHandleAsync(HttpContext httpContext, Exception exception, CancellationToken cancellationToken);
}
3 changes: 3 additions & 0 deletions src/Middleware/Diagnostics/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
#nullable enable
Microsoft.AspNetCore.Diagnostics.IExceptionHandler
Microsoft.AspNetCore.Diagnostics.IExceptionHandler.TryHandleAsync(Microsoft.AspNetCore.Http.HttpContext! httpContext, System.Exception! exception, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.ValueTask<bool>
Microsoft.AspNetCore.Diagnostics.StatusCodeReExecuteFeature.OriginalStatusCode.get -> int
static Microsoft.Extensions.DependencyInjection.ExceptionHandlerServiceCollectionExtensions.AddExceptionHandler<T>(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services) -> Microsoft.Extensions.DependencyInjection.IServiceCollection!
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,105 @@ public async Task Invoke_ExceptionHandlerCaptureRouteValuesAndEndpoint()
await middleware.Invoke(httpContext);
}

[Fact]
public async Task IExceptionHandlers_CallNextIfNotHandled()
{
// Arrange
var httpContext = CreateHttpContext();

var optionsAccessor = CreateOptionsAccessor();

var exceptionHandlers = new List<IExceptionHandler>
{
new TestExceptionHandler(false, "1"),
new TestExceptionHandler(false, "2"),
new TestExceptionHandler(true, "3"),
};

var middleware = CreateMiddleware(_ => throw new InvalidOperationException(), optionsAccessor, exceptionHandlers);

// Act & Assert
await middleware.Invoke(httpContext);

Assert.True(httpContext.Items.ContainsKey("1"));
Assert.True(httpContext.Items.ContainsKey("2"));
Assert.True(httpContext.Items.ContainsKey("3"));
}

[Fact]
public async Task IExceptionHandlers_SkipIfOneHandle()
{
// Arrange
var httpContext = CreateHttpContext();

var optionsAccessor = CreateOptionsAccessor();

var exceptionHandlers = new List<IExceptionHandler>
{
new TestExceptionHandler(false, "1"),
new TestExceptionHandler(true, "2"),
new TestExceptionHandler(true, "3"),
};

var middleware = CreateMiddleware(_ => throw new InvalidOperationException(), optionsAccessor, exceptionHandlers);

// Act & Assert
await middleware.Invoke(httpContext);

Assert.True(httpContext.Items.ContainsKey("1"));
Assert.True(httpContext.Items.ContainsKey("2"));
Assert.False(httpContext.Items.ContainsKey("3"));
}

[Fact]
public async Task IExceptionHandlers_CallOptionExceptionHandlerIfNobodyHandles()
{
// Arrange
var httpContext = CreateHttpContext();

var optionsAccessor = CreateOptionsAccessor(
(context) =>
{
context.Items["ExceptionHandler"] = true;
return Task.CompletedTask;
});

var exceptionHandlers = new List<IExceptionHandler>
{
new TestExceptionHandler(false, "1"),
new TestExceptionHandler(false, "2"),
new TestExceptionHandler(false, "3"),
};

var middleware = CreateMiddleware(_ => throw new InvalidOperationException(), optionsAccessor, exceptionHandlers);

// Act & Assert
await middleware.Invoke(httpContext);

Assert.True(httpContext.Items.ContainsKey("1"));
Assert.True(httpContext.Items.ContainsKey("2"));
Assert.True(httpContext.Items.ContainsKey("3"));
Assert.True(httpContext.Items.ContainsKey("ExceptionHandler"));
}

private class TestExceptionHandler : IExceptionHandler
{
private readonly bool _handle;
private readonly string _name;

public TestExceptionHandler(bool handle, string name)
{
_handle = handle;
_name = name;
}

public ValueTask<bool> TryHandleAsync(HttpContext httpContext, Exception exception, CancellationToken cancellationToken)
{
httpContext.Items[_name] = true;
return ValueTask.FromResult(_handle);
}
}

private HttpContext CreateHttpContext()
{
var httpContext = new DefaultHttpContext
Expand All @@ -138,18 +237,20 @@ private HttpContext CreateHttpContext()
return optionsAccessor;
}

private ExceptionHandlerMiddleware CreateMiddleware(
private ExceptionHandlerMiddlewareImpl CreateMiddleware(
RequestDelegate next,
IOptions<ExceptionHandlerOptions> options)
IOptions<ExceptionHandlerOptions> options,
IEnumerable<IExceptionHandler> exceptionHandlers = null)
{
next ??= c => Task.CompletedTask;
var listener = new DiagnosticListener("Microsoft.AspNetCore");

var middleware = new ExceptionHandlerMiddleware(
var middleware = new ExceptionHandlerMiddlewareImpl(
next,
NullLoggerFactory.Instance,
options,
listener);
listener,
exceptionHandlers ?? Enumerable.Empty<IExceptionHandler>());

return middleware;
}
Expand Down

0 comments on commit b2ce346

Please sign in to comment.