diff --git a/BitFaster.Caching.UnitTests/Lru/ConcurrentLruTests.cs b/BitFaster.Caching.UnitTests/Lru/ConcurrentLruTests.cs index 76603e23..ed2bf960 100644 --- a/BitFaster.Caching.UnitTests/Lru/ConcurrentLruTests.cs +++ b/BitFaster.Caching.UnitTests/Lru/ConcurrentLruTests.cs @@ -7,7 +7,10 @@ using System.Threading.Tasks; using Xunit; using Xunit.Abstractions; - +using System.Collections.Concurrent; +using System.Reflection; +using System.Runtime.CompilerServices; + namespace BitFaster.Caching.UnitTests.Lru { public class ConcurrentLruTests @@ -1088,19 +1091,149 @@ public void WhenItemsAreTrimmedAnEventIsFired() } [Fact] - public async Task WhenItemsAreScannedInParallelCapacityIsNotExceeded() + public async Task WhenSoakConcurrentGetCacheEndsInConsistentState() { - await Threaded.Run(4, () => { - for (int i = 0; i < 100000; i++) - { - lru.GetOrAdd(i + 1, i =>i.ToString()); - } - }); + for (int i = 0; i < 10; i++) + { + await Threaded.Run(4, () => { + for (int i = 0; i < 100000; i++) + { + lru.GetOrAdd(i + 1, i =>i.ToString()); + } + }); - this.testOutputHelper.WriteLine($"{lru.HotCount} {lru.WarmCount} {lru.ColdCount}"); + this.testOutputHelper.WriteLine($"{lru.HotCount} {lru.WarmCount} {lru.ColdCount}"); + this.testOutputHelper.WriteLine(string.Join(" ", lru.Keys)); - // allow +/- 1 variance for capacity - lru.Count.Should().BeCloseTo(9, 1); + // allow +/- 1 variance for capacity + lru.Count.Should().BeCloseTo(9, 1); + RunIntegrityCheck(); + } + } + + [Fact] + public async Task WhenSoakConcurrentGetAsyncCacheEndsInConsistentState() + { + for (int i = 0; i < 10; i++) + { + await Threaded.RunAsync(4, async () => { + for (int i = 0; i < 100000; i++) + { + await lru.GetOrAddAsync(i + 1, i => Task.FromResult(i.ToString())); + } + }); + + this.testOutputHelper.WriteLine($"{lru.HotCount} {lru.WarmCount} {lru.ColdCount}"); + this.testOutputHelper.WriteLine(string.Join(" ", lru.Keys)); + + // allow +/- 1 variance for capacity + lru.Count.Should().BeCloseTo(9, 1); + RunIntegrityCheck(); + } + } + + [Fact] + public async Task WhenSoakConcurrentGetWithArgCacheEndsInConsistentState() + { + for (int i = 0; i < 10; i++) + { + await Threaded.Run(4, () => { + for (int i = 0; i < 100000; i++) + { + // use the arg overload + lru.GetOrAdd(i + 1, (i, s) => i.ToString(), "Foo"); + } + }); + + this.testOutputHelper.WriteLine($"{lru.HotCount} {lru.WarmCount} {lru.ColdCount}"); + this.testOutputHelper.WriteLine(string.Join(" ", lru.Keys)); + + // allow +/- 1 variance for capacity + lru.Count.Should().BeCloseTo(9, 1); + RunIntegrityCheck(); + } + } + + [Fact] + public async Task WhenSoakConcurrentGetAsyncWithArgCacheEndsInConsistentState() + { + for (int i = 0; i < 10; i++) + { + await Threaded.RunAsync(4, async () => { + for (int i = 0; i < 100000; i++) + { + // use the arg overload + await lru.GetOrAddAsync(i + 1, (i, s) => Task.FromResult(i.ToString()), "Foo"); + } + }); + + this.testOutputHelper.WriteLine($"{lru.HotCount} {lru.WarmCount} {lru.ColdCount}"); + this.testOutputHelper.WriteLine(string.Join(" ", lru.Keys)); + + // allow +/- 1 variance for capacity + lru.Count.Should().BeCloseTo(9, 1); + RunIntegrityCheck(); + } + } + + [Fact] + public async Task WhenSoakConcurrentGetAndRemoveCacheEndsInConsistentState() + { + for (int i = 0; i < 10; i++) + { + await Threaded.Run(4, () => { + for (int i = 0; i < 100000; i++) + { + lru.TryRemove(i + 1); + lru.GetOrAdd(i + 1, i => i.ToString()); + } + }); + + this.testOutputHelper.WriteLine($"{lru.HotCount} {lru.WarmCount} {lru.ColdCount}"); + this.testOutputHelper.WriteLine(string.Join(" ", lru.Keys)); + + RunIntegrityCheck(); + } + } + + [Fact] + public async Task WhenSoakConcurrentGetAndUpdateCacheEndsInConsistentState() + { + for (int i = 0; i < 10; i++) + { + await Threaded.Run(4, () => { + for (int i = 0; i < 100000; i++) + { + lru.TryUpdate(i + 1, i.ToString()); + lru.GetOrAdd(i + 1, i => i.ToString()); + } + }); + + this.testOutputHelper.WriteLine($"{lru.HotCount} {lru.WarmCount} {lru.ColdCount}"); + this.testOutputHelper.WriteLine(string.Join(" ", lru.Keys)); + + RunIntegrityCheck(); + } + } + + [Fact] + public async Task WhenSoakConcurrentGetAndAddCacheEndsInConsistentState() + { + for (int i = 0; i < 10; i++) + { + await Threaded.Run(4, () => { + for (int i = 0; i < 100000; i++) + { + lru.AddOrUpdate(i + 1, i.ToString()); + lru.GetOrAdd(i + 1, i => i.ToString()); + } + }); + + this.testOutputHelper.WriteLine($"{lru.HotCount} {lru.WarmCount} {lru.ColdCount}"); + this.testOutputHelper.WriteLine(string.Join(" ", lru.Keys)); + + RunIntegrityCheck(); + } } private void Warmup() @@ -1115,5 +1248,75 @@ private void Warmup() lru.GetOrAdd(-8, valueFactory.Create); lru.GetOrAdd(-9, valueFactory.Create); } + + private void RunIntegrityCheck() + { + new ConcurrentLruIntegrityChecker, LruPolicy, TelemetryPolicy>(this.lru).Validate(); + } + } + + public class ConcurrentLruIntegrityChecker + where I : LruItem + where P : struct, IItemPolicy + where T : struct, ITelemetryPolicy + { + private readonly ConcurrentLruCore cache; + + private readonly ConcurrentQueue hotQueue; + private readonly ConcurrentQueue warmQueue; + private readonly ConcurrentQueue coldQueue; + + private static FieldInfo hotQueueField = typeof(ConcurrentLruCore).GetField("hotQueue", BindingFlags.NonPublic | BindingFlags.Instance); + private static FieldInfo warmQueueField = typeof(ConcurrentLruCore).GetField("warmQueue", BindingFlags.NonPublic | BindingFlags.Instance); + private static FieldInfo coldQueueField = typeof(ConcurrentLruCore).GetField("coldQueue", BindingFlags.NonPublic | BindingFlags.Instance); + + public ConcurrentLruIntegrityChecker(ConcurrentLruCore cache) + { + this.cache = cache; + + // get queues via reflection + this.hotQueue = (ConcurrentQueue)hotQueueField.GetValue(cache); + this.warmQueue = (ConcurrentQueue)warmQueueField.GetValue(cache); + this.coldQueue = (ConcurrentQueue)coldQueueField.GetValue(cache); + } + + public void Validate() + { + // queue counters must be consistent with queues + this.hotQueue.Count.Should().Be(cache.HotCount, "hot queue has a corrupted count"); + this.warmQueue.Count.Should().Be(cache.WarmCount, "warm queue has a corrupted count"); + this.coldQueue.Count.Should().Be(cache.ColdCount, "cold queue has a corrupted count"); + + // cache contents must be consistent with queued items + ValidateQueue(cache, this.hotQueue, "hot"); + ValidateQueue(cache, this.warmQueue, "warm"); + ValidateQueue(cache, this.coldQueue, "cold"); + + // cache must be within capacity + cache.Count.Should().BeLessThanOrEqualTo(cache.Capacity + 1, "capacity out of valid range"); + } + + private void ValidateQueue(ConcurrentLruCore cache, ConcurrentQueue queue, string queueName) + { + foreach (var item in queue) + { + if (item.WasRemoved) + { + // It is possible for the queues to contain 2 (or more) instances of the same key/item. One that was removed, + // and one that was added after the other was removed. + // In this case, the dictionary may contain the value only if the queues contain an entry for that key marked as WasRemoved == false. + if (cache.TryGet(item.Key, out var value)) + { + hotQueue.Union(warmQueue).Union(coldQueue) + .Any(i => i.Key.Equals(item.Key) && !i.WasRemoved) + .Should().BeTrue($"{queueName} removed item {item.Key} was not removed"); + } + } + else + { + cache.TryGet(item.Key, out var value).Should().BeTrue($"{queueName} item {item.Key} was not present"); + } + } + } } } diff --git a/BitFaster.Caching.UnitTests/Threaded.cs b/BitFaster.Caching.UnitTests/Threaded.cs index 50e3f10b..a3f575f5 100644 --- a/BitFaster.Caching.UnitTests/Threaded.cs +++ b/BitFaster.Caching.UnitTests/Threaded.cs @@ -33,5 +33,30 @@ public static async Task Run(int threadCount, Action action) await Task.WhenAll(tasks); } + + public static Task RunAsync(int threadCount, Func action) + { + return Run(threadCount, i => action()); + } + + public static async Task RunAsync(int threadCount, Func action) + { + var tasks = new Task[threadCount]; + ManualResetEvent mre = new ManualResetEvent(false); + + for (int i = 0; i < threadCount; i++) + { + int run = i; + tasks[i] = Task.Run(async () => + { + mre.WaitOne(); + await action(run); + }); + } + + mre.Set(); + + await Task.WhenAll(tasks); + } } }