Skip to content

Commit

Permalink
fix: Make expression cache thread safe (casbin#261)
Browse files Browse the repository at this point in the history
Signed-off-by: sagilio <sagilio@outlook.com>
  • Loading branch information
sagilio committed Jun 12, 2022
1 parent 820900f commit e65b782
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 73 deletions.
7 changes: 2 additions & 5 deletions Casbin/Abstractions/Evaluation/IExpressionHandler.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
using System;
using Casbin.Caching;
using Casbin.Effect;
using Casbin.Model;

namespace Casbin.Evaluation;

public interface IExpressionHandler
{
public IExpressionCachePool Cache { get; set; }

public void SetFunction(string name, Delegate function);

public bool Invoke<TRequest, TPolicy>(in EnforceContext context, string expressionString, in TRequest request, in TPolicy policy)
public bool Invoke<TRequest, TPolicy>(in EnforceContext context, string expressionString, in TRequest request,
in TPolicy policy)
where TRequest : IRequestValues
where TPolicy : IPolicyValues;
}
22 changes: 2 additions & 20 deletions Casbin/Caching/ExpressionCache.cs
Original file line number Diff line number Diff line change
@@ -1,29 +1,11 @@
using System;
using System.Collections.Generic;
using DynamicExpresso;
using System.Collections.Concurrent;

namespace Casbin.Caching;

internal class ExpressionCache : IExpressionCache<Lambda>
{
private readonly Dictionary<string, Lambda> _cache = new();

public bool TryGet(string expressionString, out Lambda lambda)
{
return _cache.TryGetValue(expressionString, out lambda);
}

public void Set(string expressionString, Lambda lambda)
{
_cache[expressionString] = lambda;
}

public void Clear() => _cache.Clear();
}

internal class ExpressionCache<TFunc> : IExpressionCache<TFunc> where TFunc : Delegate
{
private readonly Dictionary<string, TFunc> _cache = new();
private readonly ConcurrentDictionary<string, TFunc> _cache = new();

public bool TryGet(string expressionString, out TFunc func)
{
Expand Down
37 changes: 6 additions & 31 deletions Casbin/Caching/ExpressionCachePool.cs
Original file line number Diff line number Diff line change
@@ -1,36 +1,12 @@
using System;
using System.Collections.Generic;
using DynamicExpresso;
using System.Collections.Concurrent;
using System.Threading;

namespace Casbin.Caching;

public class ExpressionCachePool : IExpressionCachePool
{
private readonly Dictionary<Type, IExpressionCache> _cachePool = new();

public void SetLambda(string expression, Lambda lambda)
{
Type type = typeof(Lambda);
if (_cachePool.TryGetValue(type, out IExpressionCache cache) is false)
{
cache = new ExpressionCache();
_cachePool[type] = cache;
}
var cacheImpl = (ExpressionCache) cache;
cacheImpl.Set(expression, lambda);
}

public bool TryGetLambda(string expression, out Lambda lambda)
{
Type type = typeof(Lambda);
if (_cachePool.TryGetValue(type, out IExpressionCache cache) is false)
{
lambda = default;
return false;
}
var cacheImpl = (ExpressionCache) cache;
return cacheImpl.TryGet(expression, out lambda);
}
private ConcurrentDictionary<Type, IExpressionCache> _cachePool = new();

public void SetFunc<TFunc>(string expression, TFunc func) where TFunc : Delegate
{
Expand All @@ -53,15 +29,14 @@ public bool TryGetFunc<TFunc>(string expression, out TFunc func) where TFunc : D
cache = new ExpressionCache<TFunc>();
_cachePool[type] = cache;
}

var cacheImpl = (IExpressionCache<TFunc>)cache;
return cacheImpl.TryGet(expression, out func);
}

public void Clear()
{
foreach (IExpressionCache cache in _cachePool.Values)
{
cache?.Clear();
}
ConcurrentDictionary<Type, IExpressionCache> cachePool = new ConcurrentDictionary<Type, IExpressionCache>();
Interlocked.Exchange(ref _cachePool, cachePool);
}
}
16 changes: 8 additions & 8 deletions Casbin/Casbin.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -47,46 +47,46 @@
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net7.0'">
<PackageReference Include="DynamicExpresso.Core" Version="2.11.0"/>
<PackageReference Include="DynamicExpresso.Core" Version="2.12.0"/>
<PackageReference Include="Microsoft.Extensions.Logging" Version="7.0.0-preview.1.22076.8"/>
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net6.0'">
<PackageReference Include="DynamicExpresso.Core" Version="2.11.0"/>
<PackageReference Include="DynamicExpresso.Core" Version="2.12.0"/>
<PackageReference Include="Microsoft.Extensions.Logging" Version="6.0.0"/>
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net5.0'">
<PackageReference Include="DynamicExpresso.Core" Version="2.11.0"/>
<PackageReference Include="DynamicExpresso.Core" Version="2.12.0"/>
<PackageReference Include="Microsoft.Extensions.Logging" Version="5.0.0"/>
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp3.1'">
<PackageReference Include="DynamicExpresso.Core" Version="2.11.0"/>
<PackageReference Include="DynamicExpresso.Core" Version="2.12.0"/>
<PackageReference Include="Microsoft.Extensions.Logging" Version="3.1.22"/>
<PackageReference Include="IsExternalInit" Version="1.0.2" PrivateAssets="all"/>
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.1'">
<PackageReference Include="DynamicExpresso.Core" Version="2.11.0"/>
<PackageReference Include="DynamicExpresso.Core" Version="2.12.0"/>
<PackageReference Include="Microsoft.Extensions.Logging" Version="6.0.0"/>
<PackageReference Include="IsExternalInit" Version="1.0.2" PrivateAssets="all"/>
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="DynamicExpresso.Core" Version="2.11.0"/>
<PackageReference Include="DynamicExpresso.Core" Version="2.12.0"/>
<PackageReference Include="Microsoft.Extensions.Logging" Version="6.0.0"/>
<PackageReference Include="IsExternalInit" Version="1.0.2" PrivateAssets="all"/>
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net462'">
<PackageReference Include="DynamicExpresso.Core" Version="2.11.0"/>
<PackageReference Include="DynamicExpresso.Core" Version="2.12.0"/>
<PackageReference Include="Microsoft.Extensions.Logging" Version="6.0.0"/>
<PackageReference Include="IsExternalInit" Version="1.0.2" PrivateAssets="all"/>
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net461'">
<PackageReference Include="DynamicExpresso.Core" Version="2.10.0"/>
<PackageReference Include="DynamicExpresso.Core" Version="2.12.0"/>
<PackageReference Include="Microsoft.Extensions.Logging" Version="6.0.0"/>
<PackageReference Include="IsExternalInit" Version="1.0.2" PrivateAssets="all"/>
</ItemGroup>
Expand Down
33 changes: 24 additions & 9 deletions Casbin/Evaluation/ExpressionHandler.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Casbin.Caching;
using Casbin.Model;
Expand All @@ -10,9 +9,13 @@ namespace Casbin.Evaluation;

internal class ExpressionHandler : IExpressionHandler
{
public IExpressionCachePool Cache { get; set; } = new ExpressionCachePool();
private readonly FunctionMap _functionMap = FunctionMap.LoadFunctionMap();
private IExpressionCachePool _cachePool = new ExpressionCachePool();
#if !NET452
private readonly Interpreter _interpreter;
#else
private Interpreter _interpreter;
#endif

public ExpressionHandler()
{
Expand All @@ -24,6 +27,9 @@ public ExpressionHandler()

public void SetFunction(string name, Delegate function)
{
#if !NET452
_interpreter.SetFunction(name, function);
#else
List<Identifier> identifiers = new();
bool exist = false;
foreach (var identifier in _interpreter.Identifiers)
Expand All @@ -46,34 +52,42 @@ public void SetFunction(string name, Delegate function)
interpreter.SetIdentifiers(identifiers);
interpreter.SetFunction(name, function);
Interlocked.Exchange(ref _interpreter, interpreter);
Cache.Clear();
#endif
ExpressionCachePool cachePool = new ExpressionCachePool();
Interlocked.Exchange(ref _cachePool, cachePool);
}

public bool Invoke<TRequest, TPolicy>(in EnforceContext context, string expressionString, in TRequest request, in TPolicy policy)
public bool Invoke<TRequest, TPolicy>(in EnforceContext context, string expressionString, in TRequest request,
in TPolicy policy)
where TRequest : IRequestValues
where TPolicy : IPolicyValues
{
if (context.View.SupportGeneric is false)
{
if (Cache.TryGetFunc<Func<IRequestValues, IPolicyValues, bool>>(expressionString, out var func))
if (_cachePool.TryGetFunc<Func<IRequestValues, IPolicyValues, bool>>(expressionString,
out Func<IRequestValues, IPolicyValues, bool> func))
{
return func(request, policy);
}

func = CompileExpression<IRequestValues, IPolicyValues>(in context, expressionString);
Cache.SetFunc(expressionString, func);
_cachePool.SetFunc(expressionString, func);
return func(request, policy);
}

if (Cache.TryGetFunc<Func<TRequest, TPolicy, bool>>(expressionString, out var genericFunc) is not false)
if (_cachePool.TryGetFunc<Func<TRequest, TPolicy, bool>>(expressionString,
out Func<TRequest, TPolicy, bool> genericFunc) is not false)
{
return genericFunc(request, policy);
}

genericFunc = CompileExpression<TRequest, TPolicy>(in context, expressionString);
Cache.SetFunc(expressionString, genericFunc);
_cachePool.SetFunc(expressionString, genericFunc);
return genericFunc(request, policy);
}

private Func<TRequest, TPolicy, bool> CompileExpression<TRequest, TPolicy>(in EnforceContext context, string expressionString)
private Func<TRequest, TPolicy, bool> CompileExpression<TRequest, TPolicy>(in EnforceContext context,
string expressionString)
where TRequest : IRequestValues
where TPolicy : IPolicyValues
{
Expand All @@ -89,6 +103,7 @@ private Interpreter CreateInterpreter()
{
interpreter.SetFunction(functionKeyValue.Key, functionKeyValue.Value);
}

return interpreter;
}
}
3 changes: 3 additions & 0 deletions Casbin/Util/Utility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ internal static bool SetEquals(List<string> a, List<string> b)
{
a = new List<string>();
}

if (b == null)
{
b = new List<string>();
}

if (a.Count != b.Count)
{
return false;
Expand All @@ -40,6 +42,7 @@ internal static bool SetEquals(List<string> a, List<string> b)
return false;
}
}

return true;
}
}
Expand Down

0 comments on commit e65b782

Please sign in to comment.