diff --git a/src/Microsoft.FeatureManagement/FeatureManagerSnapshot.cs b/src/Microsoft.FeatureManagement/FeatureManagerSnapshot.cs index a218e1cc..1c676cbb 100644 --- a/src/Microsoft.FeatureManagement/FeatureManagerSnapshot.cs +++ b/src/Microsoft.FeatureManagement/FeatureManagerSnapshot.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. // using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Threading.Tasks; @@ -13,7 +14,7 @@ namespace Microsoft.FeatureManagement class FeatureManagerSnapshot : IFeatureManagerSnapshot { private readonly IFeatureManager _featureManager; - private readonly IDictionary _flagCache = new Dictionary(); + private readonly ConcurrentDictionary> _flagCache = new ConcurrentDictionary>(); private IEnumerable _featureNames; public FeatureManagerSnapshot(IFeatureManager featureManager) @@ -41,36 +42,18 @@ await foreach (string featureName in _featureManager.GetFeatureNamesAsync().Conf } } - public async Task IsEnabledAsync(string feature) + public Task IsEnabledAsync(string feature) { - // - // First, check local cache - if (_flagCache.ContainsKey(feature)) - { - return _flagCache[feature]; - } - - bool enabled = await _featureManager.IsEnabledAsync(feature).ConfigureAwait(false); - - _flagCache[feature] = enabled; - - return enabled; + return _flagCache.GetOrAdd( + feature, + (key) => _featureManager.IsEnabledAsync(key)); } - public async Task IsEnabledAsync(string feature, TContext context) + public Task IsEnabledAsync(string feature, TContext context) { - // - // First, check local cache - if (_flagCache.ContainsKey(feature)) - { - return _flagCache[feature]; - } - - bool enabled = await _featureManager.IsEnabledAsync(feature, context).ConfigureAwait(false); - - _flagCache[feature] = enabled; - - return enabled; + return _flagCache.GetOrAdd( + feature, + (key) => _featureManager.IsEnabledAsync(key, context)); } } } diff --git a/tests/Tests.FeatureManagement/FeatureManagement.cs b/tests/Tests.FeatureManagement/FeatureManagement.cs index d7e85366..2dade89c 100644 --- a/tests/Tests.FeatureManagement/FeatureManagement.cs +++ b/tests/Tests.FeatureManagement/FeatureManagement.cs @@ -63,7 +63,7 @@ public async Task ReadsConfiguration() Assert.Equal(ConditionalFeature, evaluationContext.FeatureName); - return true; + return Task.FromResult(true); }; await featureManager.IsEnabledAsync(ConditionalFeature); @@ -106,14 +106,14 @@ public async Task Integrates() TestFilter testFeatureFilter = (TestFilter)featureFilters.First(f => f is TestFilter); - testFeatureFilter.Callback = _ => true; + testFeatureFilter.Callback = _ => Task.FromResult(true); HttpResponseMessage res = await testServer.CreateClient().GetAsync(""); Assert.True(res.Headers.Contains(nameof(MvcFilter))); Assert.True(res.Headers.Contains(nameof(RouterMiddleware))); - testFeatureFilter.Callback = _ => false; + testFeatureFilter.Callback = _ => Task.FromResult(false); res = await testServer.CreateClient().GetAsync(""); @@ -143,7 +143,7 @@ public async Task GatesFeatures() // // Enable all features - testFeatureFilter.Callback = ctx => true; + testFeatureFilter.Callback = ctx => Task.FromResult(true); HttpResponseMessage gateAllResponse = await testServer.CreateClient().GetAsync("gateAll"); HttpResponseMessage gateAnyResponse = await testServer.CreateClient().GetAsync("gateAny"); @@ -153,7 +153,7 @@ public async Task GatesFeatures() // // Enable 1/2 features - testFeatureFilter.Callback = ctx => ctx.FeatureName == Enum.GetName(typeof(Features), Features.ConditionalFeature); + testFeatureFilter.Callback = ctx => Task.FromResult(ctx.FeatureName == Enum.GetName(typeof(Features), Features.ConditionalFeature)); gateAllResponse = await testServer.CreateClient().GetAsync("gateAll"); gateAnyResponse = await testServer.CreateClient().GetAsync("gateAny"); @@ -163,7 +163,7 @@ public async Task GatesFeatures() // // Enable no - testFeatureFilter.Callback = ctx => false; + testFeatureFilter.Callback = ctx => Task.FromResult(false); gateAllResponse = await testServer.CreateClient().GetAsync("gateAll"); gateAnyResponse = await testServer.CreateClient().GetAsync("gateAny"); @@ -555,7 +555,7 @@ public async Task CustomFeatureDefinitionProvider() Assert.Equal(ConditionalFeature, evaluationContext.FeatureName); - return true; + return Task.FromResult(true); }; await featureManager.IsEnabledAsync(ConditionalFeature); @@ -563,6 +563,58 @@ public async Task CustomFeatureDefinitionProvider() Assert.True(called); } + [Fact] + public async Task ThreadsafeSnapshot() + { + IConfiguration config = new ConfigurationBuilder().AddJsonFile("appsettings.json").Build(); + + var services = new ServiceCollection(); + + services + .AddSingleton(config) + .AddFeatureManagement() + .AddFeatureFilter(); + + ServiceProvider serviceProvider = services.BuildServiceProvider(); + + IFeatureManager featureManager = serviceProvider.GetRequiredService(); + + IEnumerable featureFilters = serviceProvider.GetRequiredService>(); + + // + // Sync filter + TestFilter testFeatureFilter = (TestFilter)featureFilters.First(f => f is TestFilter); + + bool called = false; + + testFeatureFilter.Callback = async (evaluationContext) => + { + called = true; + + await Task.Delay(10); + + return new Random().Next(0, 100) > 50; + }; + + var tasks = new List>(); + + for (int i = 0; i < 1000; i++) + { + tasks.Add(featureManager.IsEnabledAsync(ConditionalFeature)); + } + + Assert.True(called); + + await Task.WhenAll(tasks); + + bool result = tasks.First().Result; + + foreach (Task t in tasks) + { + Assert.Equal(result, t.Result); + } + } + private static void DisableEndpointRouting(MvcOptions options) { #if NET5_0 || NETCOREAPP3_1 diff --git a/tests/Tests.FeatureManagement/TestFilter.cs b/tests/Tests.FeatureManagement/TestFilter.cs index ed236742..d3db4225 100644 --- a/tests/Tests.FeatureManagement/TestFilter.cs +++ b/tests/Tests.FeatureManagement/TestFilter.cs @@ -9,11 +9,11 @@ namespace Tests.FeatureManagement { class TestFilter : IFeatureFilter { - public Func Callback { get; set; } + public Func> Callback { get; set; } public Task EvaluateAsync(FeatureFilterEvaluationContext context) { - return Task.FromResult(Callback?.Invoke(context) ?? false); + return Callback?.Invoke(context) ?? Task.FromResult(false); } } }