/
ConcurrencyLimiterMiddleware.cs
99 lines (86 loc) · 4.03 KB
/
ConcurrencyLimiterMiddleware.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.ConcurrencyLimiter;
/// <summary>
/// Limits the number of concurrent requests allowed in the application.
/// </summary>
public partial class ConcurrencyLimiterMiddleware
{
private readonly IQueuePolicy _queuePolicy;
private readonly RequestDelegate _next;
private readonly RequestDelegate _onRejected;
private readonly ILogger _logger;
/// <summary>
/// Creates a new <see cref="ConcurrencyLimiterMiddleware"/>.
/// </summary>
/// <param name="next">The <see cref="RequestDelegate"/> representing the next middleware in the pipeline.</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> used for logging.</param>
/// <param name="queue">The queueing strategy to use for the server.</param>
/// <param name="options">The options for the middleware, currently containing the 'OnRejected' callback.</param>
public ConcurrencyLimiterMiddleware(RequestDelegate next, ILoggerFactory loggerFactory, IQueuePolicy queue, IOptions<ConcurrencyLimiterOptions> options)
{
if (options.Value.OnRejected == null)
{
throw new ArgumentException("The value of 'options.OnRejected' must not be null.", nameof(options));
}
_next = next;
_logger = loggerFactory.CreateLogger<ConcurrencyLimiterMiddleware>();
_onRejected = options.Value.OnRejected;
_queuePolicy = queue;
}
/// <summary>
/// Invokes the logic of the middleware.
/// </summary>
/// <param name="context">The <see cref="HttpContext"/>.</param>
/// <returns>A <see cref="Task"/> that completes when the request leaves.</returns>
public async Task Invoke(HttpContext context)
{
var waitInQueueTask = _queuePolicy.TryEnterAsync();
// Make sure we only ever call GetResult once on the TryEnterAsync ValueTask b/c it resets.
bool result;
if (waitInQueueTask.IsCompleted)
{
ConcurrencyLimiterEventSource.Log.QueueSkipped();
result = waitInQueueTask.Result;
}
else
{
using (ConcurrencyLimiterEventSource.Log.QueueTimer())
{
result = await waitInQueueTask;
}
}
if (result)
{
try
{
await _next(context);
}
finally
{
_queuePolicy.OnExit();
}
}
else
{
ConcurrencyLimiterEventSource.Log.RequestRejected();
ConcurrencyLimiterLog.RequestRejectedQueueFull(_logger);
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
await _onRejected(context);
}
}
private static partial class ConcurrencyLimiterLog
{
[LoggerMessage(1, LogLevel.Debug, "MaxConcurrentRequests limit reached, request has been queued. Current active requests: {ActiveRequests}.", EventName = "RequestEnqueued")]
internal static partial void RequestEnqueued(ILogger logger, int activeRequests);
[LoggerMessage(2, LogLevel.Debug, "Request dequeued. Current active requests: {ActiveRequests}.", EventName = "RequestDequeued")]
internal static partial void RequestDequeued(ILogger logger, int activeRequests);
[LoggerMessage(3, LogLevel.Debug, "Below MaxConcurrentRequests limit, running request immediately. Current active requests: {ActiveRequests}", EventName = "RequestRunImmediately")]
internal static partial void RequestRunImmediately(ILogger logger, int activeRequests);
[LoggerMessage(4, LogLevel.Debug, "Currently at the 'RequestQueueLimit', rejecting this request with a '503 server not available' error", EventName = "RequestRejectedQueueFull")]
internal static partial void RequestRejectedQueueFull(ILogger logger);
}
}