diff --git a/BitFaster.Caching.UnitTests/Lfu/ConcurrentLfuTests.cs b/BitFaster.Caching.UnitTests/Lfu/ConcurrentLfuTests.cs index e74b26d4..73e9c762 100644 --- a/BitFaster.Caching.UnitTests/Lfu/ConcurrentLfuTests.cs +++ b/BitFaster.Caching.UnitTests/Lfu/ConcurrentLfuTests.cs @@ -43,6 +43,16 @@ public void WhenKeyIsRequestedItIsCreatedAndCached() result1.Should().Be(result2); } + [Fact] + public void WhenKeyIsRequestedWithArgItIsCreatedAndCached() + { + var result1 = cache.GetOrAdd(1, valueFactory.Create, 9); + var result2 = cache.GetOrAdd(1, valueFactory.Create, 17); + + valueFactory.timesCalled.Should().Be(1); + result1.Should().Be(result2); + } + [Fact] public async Task WhenKeyIsRequesteItIsCreatedAndCachedAsync() { @@ -53,6 +63,16 @@ public async Task WhenKeyIsRequesteItIsCreatedAndCachedAsync() result1.Should().Be(result2); } + [Fact] + public async Task WhenKeyIsRequestedWithArgItIsCreatedAndCachedAsync() + { + var result1 = await cache.GetOrAddAsync(1, valueFactory.CreateAsync, 9).ConfigureAwait(false); + var result2 = await cache.GetOrAddAsync(1, valueFactory.CreateAsync, 17).ConfigureAwait(false); + + valueFactory.timesCalled.Should().Be(1); + result1.Should().Be(result2); + } + [Fact] public void WhenItemsAddedExceedsCapacityItemsAreDiscarded() { @@ -852,11 +872,23 @@ public int Create(int key) return key; } + public int Create(int key, int arg) + { + timesCalled++; + return key + arg; + } + public Task CreateAsync(int key) { timesCalled++; return Task.FromResult(key); } + + public Task CreateAsync(int key, int arg) + { + timesCalled++; + return Task.FromResult(key + arg); + } } } } diff --git a/BitFaster.Caching.UnitTests/Lru/ConcurrentLruTests.cs b/BitFaster.Caching.UnitTests/Lru/ConcurrentLruTests.cs index 9158d9a1..dc592913 100644 --- a/BitFaster.Caching.UnitTests/Lru/ConcurrentLruTests.cs +++ b/BitFaster.Caching.UnitTests/Lru/ConcurrentLruTests.cs @@ -246,7 +246,17 @@ public void WhenKeyIsRequestedItIsCreatedAndCached() } [Fact] - public async Task WhenKeyIsRequesteItIsCreatedAndCachedAsync() + public void WhenKeyIsRequestedWithArgItIsCreatedAndCached() + { + var result1 = lru.GetOrAdd(1, valueFactory.Create, "x"); + var result2 = lru.GetOrAdd(1, valueFactory.Create, "y"); + + valueFactory.timesCalled.Should().Be(1); + result1.Should().Be(result2); + } + + [Fact] + public async Task WhenKeyIsRequestedItIsCreatedAndCachedAsync() { var result1 = await lru.GetOrAddAsync(1, valueFactory.CreateAsync).ConfigureAwait(false); var result2 = await lru.GetOrAddAsync(1, valueFactory.CreateAsync).ConfigureAwait(false); @@ -255,6 +265,16 @@ public async Task WhenKeyIsRequesteItIsCreatedAndCachedAsync() result1.Should().Be(result2); } + [Fact] + public async Task WhenKeyIsRequestedWithArgItIsCreatedAndCachedAsync() + { + var result1 = await lru.GetOrAddAsync(1, valueFactory.CreateAsync, "x").ConfigureAwait(false); + var result2 = await lru.GetOrAddAsync(1, valueFactory.CreateAsync, "y").ConfigureAwait(false); + + valueFactory.timesCalled.Should().Be(1); + result1.Should().Be(result2); + } + [Fact] public void WhenDifferentKeysAreRequestedValueIsCreatedForEach() { diff --git a/BitFaster.Caching.UnitTests/Lru/ValueFactory.cs b/BitFaster.Caching.UnitTests/Lru/ValueFactory.cs index 1aa61229..18156cfd 100644 --- a/BitFaster.Caching.UnitTests/Lru/ValueFactory.cs +++ b/BitFaster.Caching.UnitTests/Lru/ValueFactory.cs @@ -15,10 +15,22 @@ public string Create(int key) return key.ToString(); } + public string Create(int key, TArg arg) + { + timesCalled++; + return $"{key}{arg}"; + } + public Task CreateAsync(int key) { timesCalled++; return Task.FromResult(key.ToString()); } + + public Task CreateAsync(int key, TArg arg) + { + timesCalled++; + return Task.FromResult($"{key}{arg}"); + } } } diff --git a/BitFaster.Caching/Lfu/ConcurrentLfu.cs b/BitFaster.Caching/Lfu/ConcurrentLfu.cs index b102bcee..1ca2bce5 100644 --- a/BitFaster.Caching/Lfu/ConcurrentLfu.cs +++ b/BitFaster.Caching/Lfu/ConcurrentLfu.cs @@ -194,6 +194,20 @@ public void Trim(int itemCount) } } + private bool TryAdd(K key, V value) + { + var node = new LfuNode(key, value); + + if (this.dictionary.TryAdd(key, node)) + { + AfterWrite(node); + return true; + } + + Disposer.Dispose(node.Value); + return false; + } + /// public V GetOrAdd(K key, Func valueFactory) { @@ -204,14 +218,38 @@ public V GetOrAdd(K key, Func valueFactory) return value; } - var node = new LfuNode(key, valueFactory(key)); - if (this.dictionary.TryAdd(key, node)) + value = valueFactory(key); + if (this.TryAdd(key, value)) { - AfterWrite(node); - return node.Value; + return value; + } + } + } + + /// + /// Adds a key/value pair to the cache if the key does not already exist. Returns the new value, or the + /// existing value if the key already exists. + /// + /// The type of an argument to pass into valueFactory. + /// The key of the element to add. + /// The factory function used to generate a value for the key. + /// An argument value to pass into valueFactory. + /// The value for the key. This will be either the existing value for the key if the key is already + /// in the cache, or the new value if the key was not in the cache. + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) + { + while (true) + { + if (this.TryGet(key, out V value)) + { + return value; } - Disposer.Dispose(node.Value); + value = valueFactory(key, factoryArgument); + if (this.TryAdd(key, value)) + { + return value; + } } } @@ -225,14 +263,37 @@ public async ValueTask GetOrAddAsync(K key, Func> valueFactory) return value; } - var node = new LfuNode(key, await valueFactory(key).ConfigureAwait(false)); - if (this.dictionary.TryAdd(key, node)) + value = await valueFactory(key).ConfigureAwait(false); + if (this.TryAdd(key, value)) { - AfterWrite(node); - return node.Value; + return value; } + } + } - Disposer.Dispose(node.Value); + /// + /// Adds a key/value pair to the cache if the key does not already exist. Returns the new value, or the + /// existing value if the key already exists. + /// + /// The type of an argument to pass into valueFactory. + /// The key of the element to add. + /// The factory function used to asynchronously generate a value for the key. + /// An argument value to pass into valueFactory. + /// A task that represents the asynchronous GetOrAdd operation. + public async ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) + { + while (true) + { + if (this.TryGet(key, out V value)) + { + return value; + } + + value = await valueFactory(key, factoryArgument).ConfigureAwait(false); + if (this.TryAdd(key, value)) + { + return value; + } } } diff --git a/BitFaster.Caching/Lru/ConcurrentLruCore.cs b/BitFaster.Caching/Lru/ConcurrentLruCore.cs index 493ad2c1..5c6fc95d 100644 --- a/BitFaster.Caching/Lru/ConcurrentLruCore.cs +++ b/BitFaster.Caching/Lru/ConcurrentLruCore.cs @@ -184,6 +184,21 @@ private bool GetOrDiscard(I item, out V value) return true; } + private bool TryAdd(K key, V value) + { + var newItem = this.itemPolicy.CreateItem(key, value); + + if (this.dictionary.TryAdd(key, newItem)) + { + this.hotQueue.Enqueue(newItem); + Cycle(Interlocked.Increment(ref counter.hot)); + return true; + } + + Disposer.Dispose(newItem.Value); + return false; + } + /// public V GetOrAdd(K key, Func valueFactory) { @@ -195,17 +210,41 @@ public V GetOrAdd(K key, Func valueFactory) } // The value factory may be called concurrently for the same key, but the first write to the dictionary wins. - // This is identical logic in ConcurrentDictionary.GetOrAdd method. - var newItem = this.itemPolicy.CreateItem(key, valueFactory(key)); + value = valueFactory(key); - if (this.dictionary.TryAdd(key, newItem)) + if (TryAdd(key, value)) { - this.hotQueue.Enqueue(newItem); - Cycle(Interlocked.Increment(ref counter.hot)); - return newItem.Value; + return value; } + } + } - Disposer.Dispose(newItem.Value); + /// + /// Adds a key/value pair to the cache if the key does not already exist. Returns the new value, or the + /// existing value if the key already exists. + /// + /// The type of an argument to pass into valueFactory. + /// The key of the element to add. + /// The factory function used to generate a value for the key. + /// An argument value to pass into valueFactory. + /// The value for the key. This will be either the existing value for the key if the key is already + /// in the cache, or the new value if the key was not in the cache. + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) + { + while (true) + { + if (this.TryGet(key, out var value)) + { + return value; + } + + // The value factory may be called concurrently for the same key, but the first write to the dictionary wins. + value = valueFactory(key, factoryArgument); + + if (TryAdd(key, value)) + { + return value; + } } } @@ -221,16 +260,40 @@ public async ValueTask GetOrAddAsync(K key, Func> valueFactory) // The value factory may be called concurrently for the same key, but the first write to the dictionary wins. // This is identical logic in ConcurrentDictionary.GetOrAdd method. - var newItem = this.itemPolicy.CreateItem(key, await valueFactory(key).ConfigureAwait(false)); + value = await valueFactory(key).ConfigureAwait(false); - if (this.dictionary.TryAdd(key, newItem)) + if (TryAdd(key, value)) { - this.hotQueue.Enqueue(newItem); - Cycle(Interlocked.Increment(ref counter.hot)); - return newItem.Value; + return value; } + } + } - Disposer.Dispose(newItem.Value); + /// + /// Adds a key/value pair to the cache if the key does not already exist. Returns the new value, or the + /// existing value if the key already exists. + /// + /// The type of an argument to pass into valueFactory. + /// The key of the element to add. + /// The factory function used to asynchronously generate a value for the key. + /// An argument value to pass into valueFactory. + /// A task that represents the asynchronous GetOrAdd operation. + public async ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) + { + while (true) + { + if (this.TryGet(key, out var value)) + { + return value; + } + + // The value factory may be called concurrently for the same key, but the first write to the dictionary wins. + value = await valueFactory(key, factoryArgument).ConfigureAwait(false); + + if (TryAdd(key, value)) + { + return value; + } } }