Skip to content

Commit

Permalink
ratelimti
Browse files Browse the repository at this point in the history
  • Loading branch information
KSemenenko committed May 19, 2023
1 parent 4d27258 commit 888ccc9
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,18 @@ public AuthorizedIpRateLimiterAttribute(string configurationName)
{
ConfigurationName = configurationName;
}
}


[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)]
public class InRoleIpRateLimiterAttribute : Attribute, IRateLimiterAttribute
{
public string ConfigurationName { get; }
public string Role { get; }

public InRoleIpRateLimiterAttribute(string configurationName, string role)
{
ConfigurationName = configurationName;
Role = role;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

<ItemGroup>

<PackageReference Include="ManagedCode.Communication" Version="2.0.22" />

<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="7.0.0" />
<PackageReference Include="Microsoft.Orleans.Client" Version="7.1.1" />
<PackageReference Include="System.Threading.RateLimiting" Version="7.0.0" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net;
using System.Reflection;
using System.Threading.Tasks;
using ManagedCode.Communication;
using ManagedCode.Orleans.RateLimiting.Client.Attributes;
using ManagedCode.Orleans.RateLimiting.Client.Extensions;
using ManagedCode.Orleans.RateLimiting.Core.Extensions;
Expand Down Expand Up @@ -39,34 +41,90 @@ public async Task Invoke(HttpContext httpContext)

AddIpRateLimiter(httpContext, holder);
AddAnonymousIpRateLimiter(httpContext, holder);
AddAuthorizedIpRateLimiter(httpContext, holder);

// if user is authenticated add in role limiter
if (!AddInRoleIpRateLimiter(httpContext, holder))
{
// if user is not authenticated add authorized limiter
AddAuthorizedIpRateLimiter(httpContext, holder);
}

await holder.AcquireAsync();
await _next(httpContext);
// throw too many requests if any of the limiters is null code 429
var error = await holder.AcquireAsync();
if (error is null)
{
await _next(httpContext);
}
else
{
httpContext.Response.Clear();
httpContext.Response.StatusCode = (int)HttpStatusCode.TooManyRequests;
await httpContext.Response.WriteAsJsonAsync(Result.Fail(HttpStatusCode.TooManyRequests,error.ToException()));
}
}


void AddIpRateLimiter(HttpContext httpContext, GroupLimiterHolder holder)
private bool AddIpRateLimiter(HttpContext httpContext, GroupLimiterHolder holder)
{
holder.AddLimiter(TryGetLimiterHolder<IpRateLimiterAttribute>(httpContext, httpContext.Request.GetClientIpAddress()));
var attribute = TryGetAttribute<IpRateLimiterAttribute>(httpContext);
if (attribute.HasValue)
{
return holder.AddLimiter(TryGetLimiterHolder(httpContext, CreateKey(httpContext.Request.GetClientIpAddress(), attribute.Value.postfix!),
attribute.Value.postfix!));
}

return false;
}

void AddAnonymousIpRateLimiter(HttpContext httpContext, GroupLimiterHolder holder)
private bool AddAnonymousIpRateLimiter(HttpContext httpContext, GroupLimiterHolder holder)
{
if(httpContext.User?.Identity?.IsAuthenticated is not true)
holder.AddLimiter(TryGetLimiterHolder<AnonymousIpRateLimiterAttribute>(httpContext, httpContext.Request.GetClientIpAddress()));
if (httpContext.User?.Identity?.IsAuthenticated is not true)
{
var attribute = TryGetAttribute<AnonymousIpRateLimiterAttribute>(httpContext);
if (attribute.HasValue)
{
return holder.AddLimiter(TryGetLimiterHolder(httpContext, CreateKey(httpContext.Request.GetClientIpAddress(), attribute.Value.postfix!),
attribute.Value.postfix!));
}
}

return false;
}

void AddAuthorizedIpRateLimiter(HttpContext httpContext, GroupLimiterHolder holder)
private bool AddAuthorizedIpRateLimiter(HttpContext httpContext, GroupLimiterHolder holder)
{
if(httpContext.User?.Identity?.IsAuthenticated is true)
holder.AddLimiter(TryGetLimiterHolder<AuthorizedIpRateLimiterAttribute>(httpContext,
CreateKey(httpContext.User.Identity.Name,httpContext.Request.GetClientIpAddress())));
if (httpContext.User?.Identity?.IsAuthenticated is true)
{
var attribute = TryGetAttribute<AuthorizedIpRateLimiterAttribute>(httpContext);
if (attribute.HasValue)
{
return holder.AddLimiter(TryGetLimiterHolder(httpContext,
CreateKey(httpContext.Request.GetClientIpAddress(), httpContext.User.Identity.Name!, attribute.Value.postfix!),
attribute.Value.postfix!));
}
}

return false;
}

private bool AddInRoleIpRateLimiter(HttpContext httpContext, GroupLimiterHolder holder)
{
var attribute = TryGetAttribute<InRoleIpRateLimiterAttribute>(httpContext);
if (attribute.HasValue)
{
if (httpContext.User?.Identity?.IsAuthenticated is true && httpContext.User.IsInRole(attribute.Value.attribute.Role))
{
return holder.AddLimiter(TryGetLimiterHolder(httpContext,
CreateKey(httpContext.Request.GetClientIpAddress(), httpContext.User.Identity.Name!, attribute.Value.attribute.Role, attribute.Value.postfix!),
attribute.Value.postfix!));
}
}

return false;
}
private ILimiterHolder? TryGetLimiterHolder<T>(HttpContext httpContext, string key) where T : Attribute, IRateLimiterAttribute


private (T attribute, string? postfix)? TryGetAttribute<T>(HttpContext httpContext) where T : Attribute, IRateLimiterAttribute
{
var endpoint = httpContext.GetEndpoint();

Expand All @@ -89,21 +147,24 @@ void AddAuthorizedIpRateLimiter(HttpContext httpContext, GroupLimiterHolder hold
}
}

if (attribute != null)
{
var limiter = _client.GetRateLimiterByConfig(CreateKey(key,postfix), attribute.ConfigurationName, _services.GetService<IEnumerable<RateLimiterConfig>>());

if(limiter is null)
_logger.LogError($"Configuration {attribute.ConfigurationName} not found for RateLimiter");

return limiter;
}
if (attribute is null)
return null;

return null;
return (attribute, postfix);
}

private ILimiterHolder? TryGetLimiterHolder(HttpContext httpContext, string key, string configurationName)
{
var limiter = _client.GetRateLimiterByConfig(key, configurationName, _services.GetService<IEnumerable<RateLimiterConfig>>());

if(limiter is null)
_logger.LogError($"Configuration {configurationName} not found for RateLimiter");

return limiter;
}

string CreateKey(string key, string postfix)
string CreateKey(params string[] parts)
{
return $"{key}:{postfix}";
return string.Join(":", parts);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ public static class GrainFactoryExtensions

ILimiterHolder? limiter = option.Configuration switch
{
FixedWindowRateLimiterOptions => factory.GetFixedWindowRateLimiter(key),
ConcurrencyLimiterOptions => factory.GetConcurrencyLimiter(key),
SlidingWindowRateLimiterOptions => factory.GetSlidingWindowRateLimiter(key),
TokenBucketRateLimiterOptions=> factory.GetTokenBucketRateLimiter(key),
FixedWindowRateLimiterOptions options => factory.GetFixedWindowRateLimiter(key, options),
ConcurrencyLimiterOptions options => factory.GetConcurrencyLimiter(key, options),
SlidingWindowRateLimiterOptions options => factory.GetSlidingWindowRateLimiter(key, options),
TokenBucketRateLimiterOptions options => factory.GetTokenBucketRateLimiter(key, options),

_ => null //throw new ArgumentException("Unknown rate limiter grain type")
};
Expand All @@ -54,19 +54,39 @@ public static FixedWindowRateLimiterHolder GetFixedWindowRateLimiter(this IGrain
{
return new FixedWindowRateLimiterHolder(factory.GetGrain<IFixedWindowRateLimiterGrain>(key), factory);
}

public static FixedWindowRateLimiterHolder GetFixedWindowRateLimiter(this IGrainFactory factory, string key, FixedWindowRateLimiterOptions options)
{
return new FixedWindowRateLimiterHolder(factory.GetGrain<IFixedWindowRateLimiterGrain>(key), factory, options);
}

public static ConcurrencyLimiterHolder GetConcurrencyLimiter(this IGrainFactory factory, string key)
{
return new ConcurrencyLimiterHolder(factory.GetGrain<IConcurrencyLimiterGrain>(key), factory);
}

public static ConcurrencyLimiterHolder GetConcurrencyLimiter(this IGrainFactory factory, string key, ConcurrencyLimiterOptions options)
{
return new ConcurrencyLimiterHolder(factory.GetGrain<IConcurrencyLimiterGrain>(key), factory, options);
}

public static SlidingWindowRateLimiterHolder GetSlidingWindowRateLimiter(this IGrainFactory factory, string key)
{
return new SlidingWindowRateLimiterHolder(factory.GetGrain<ISlidingWindowRateLimiterGrain>(key), factory);
}

public static SlidingWindowRateLimiterHolder GetSlidingWindowRateLimiter(this IGrainFactory factory, string key, SlidingWindowRateLimiterOptions options)
{
return new SlidingWindowRateLimiterHolder(factory.GetGrain<ISlidingWindowRateLimiterGrain>(key), factory, options);
}

public static TokenBucketRateLimiterHolder GetTokenBucketRateLimiter(this IGrainFactory factory, string key)
{
return new TokenBucketRateLimiterHolder(factory.GetGrain<ITokenBucketRateLimiterGrain>(key), factory);
}

public static TokenBucketRateLimiterHolder GetTokenBucketRateLimiter(this IGrainFactory factory, string key, TokenBucketRateLimiterOptions options)
{
return new TokenBucketRateLimiterHolder(factory.GetGrain<ITokenBucketRateLimiterGrain>(key), factory, options);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,34 @@ namespace ManagedCode.Orleans.RateLimiting.Core.Models.Holders;
public class GroupLimiterHolder : IAsyncDisposable, IDisposable
{
private Dictionary<ILimiterHolder, OrleansRateLimitLease?> _holders = new();
public void AddLimiter(ILimiterHolder? holder)
public bool AddLimiter(ILimiterHolder? holder)
{
if(holder is not null)
_holders.Add(holder, null);
if (holder is not null)
{
_holders.Add(holder, null);
return true;
}

return false;

}

public async Task AcquireAsync()
public async Task<OrleansRateLimitLease?> AcquireAsync()
{
foreach (var holder in _holders.Keys)
{
var lease = await holder.AcquireAndConfigureAsync();
lease.ThrowIfNotAcquired();
if (lease.IsAcquired)
_holders[holder] = lease;
{
_holders[holder] = lease;
}
else
{
return lease;
}
}

return null;
}

public async ValueTask DisposeAsync()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ public void ThrowIfNotAcquired([CallerMemberName] string? caller = null, [Caller
throw new RateLimitExceededException(Reason, RetryAfter);
}

public RateLimitExceededException ToException()
{
return new(Reason, RetryAfter);
}

public void Dispose()
{
_ = DisposeAsync();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,17 @@
namespace ManagedCode.Orleans.RateLimiting.Tests.TestApp.Controllers;

[Route("test")]
[AnonymousIpRateLimiter("ipsdfsdf")]
[IpRateLimiter("ip")]
public class TestController : ControllerBase
{
[AuthorizedIpRateLimiter("ipsdfsdf")]
[AuthorizedIpRateLimiter("Authorized")]
[AnonymousIpRateLimiter("Authorized")]
[InRoleIpRateLimiter("Authorized", "Admin")]
[HttpGet("authorize")]
public ActionResult<string> Authorize()
public async Task<ActionResult<string>> Authorize()
{
await Task.Delay(300);
return "Authorize";
}

[ConcurrencyLimiter("LimitByUser")]
[HttpGet("anonymous")]
public ActionResult<string> Anonymous()
{
return "Anonymous";
}

[HttpGet("admin")]
public ActionResult<string> Admin()
{
return "admin";
}

[HttpGet("moderator")]
public ActionResult<string> Moderator()
{
return "moderator";
}

[HttpGet("common")]
public ActionResult<string> Common()
{
return "common";
}
}
21 changes: 21 additions & 0 deletions ManagedCode.Orleans.RateLimiting.Tests/TestApp/HttpHostProgram.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,26 @@ public static void Main(string[] args)
builder.Services.AddControllers();
builder.Services.AddSignalR();

builder.Services.AddRateLimiterOptions("ip", new FixedWindowRateLimiterOptions()
{
QueueLimit = 5,
PermitLimit = 10,
Window = TimeSpan.FromSeconds(1)
});

builder.Services.AddRateLimiterOptions("Anonymous", new FixedWindowRateLimiterOptions()
{
QueueLimit = 1,
PermitLimit = 1,
Window = TimeSpan.FromSeconds(1)
});

builder.Services.AddRateLimiterOptions("Authorized", new FixedWindowRateLimiterOptions()
{
QueueLimit = 2,
PermitLimit = 2,
Window = TimeSpan.FromSeconds(1)
});


var app = builder.Build();
Expand All @@ -26,6 +45,8 @@ public static void Main(string[] args)

app.UseMiddleware<RateLimitingMiddleware>();

app.UseRateLimiter();

app.Run();
}
}
Loading

0 comments on commit 888ccc9

Please sign in to comment.