diff --git a/BitFaster.Caching.UnitTests/Atomic/ConcurrentDictionaryExtensionTests.cs b/BitFaster.Caching.UnitTests/Atomic/ConcurrentDictionaryExtensionTests.cs new file mode 100644 index 00000000..f5076a20 --- /dev/null +++ b/BitFaster.Caching.UnitTests/Atomic/ConcurrentDictionaryExtensionTests.cs @@ -0,0 +1,62 @@ + +using System.Collections.Concurrent; +using System.Collections.Generic; +using BitFaster.Caching.Atomic; +using FluentAssertions; +using Xunit; + +namespace BitFaster.Caching.UnitTests.Atomic +{ + public class ConcurrentDictionaryExtensionTests + { + private ConcurrentDictionary> dictionary = new ConcurrentDictionary>(); + + [Fact] + public void WhenItemIsAddedItCanBeRetrieved() + { + dictionary.GetOrAdd(1, k => k); + + dictionary.TryGetValue(1, out int value).Should().BeTrue(); + value.Should().Be(1); + } + + [Fact] + public void WhenItemIsAddedWithArgItCanBeRetrieved() + { + dictionary.GetOrAdd(1, (k,a) => k + a, 2); + + dictionary.TryGetValue(1, out int value).Should().BeTrue(); + value.Should().Be(3); + } + + [Fact] + public void WhenKeyDoesNotExistTryGetReturnsFalse() + { + dictionary.TryGetValue(1, out int _).Should().BeFalse(); + } + + [Fact] + public void WhenItemIsAddedItCanBeRemovedByKey() + { + dictionary.GetOrAdd(1, k => k); + + dictionary.TryRemove(1, out int value).Should().BeTrue(); + value.Should().Be(1); + } + + [Fact] + public void WhenItemIsAddedItCanBeRemovedByKvp() + { + dictionary.GetOrAdd(1, k => k); + + dictionary.TryRemove(new KeyValuePair(1, 1)).Should().BeTrue(); + dictionary.TryGetValue(1, out _).Should().BeFalse(); + } + + [Fact] + public void WhenKeyDoesNotExistTryRemoveReturnsFalse() + { + dictionary.TryRemove(1, out int _).Should().BeFalse(); + } + } +} diff --git a/BitFaster.Caching/Atomic/AtomicFactoryCache.cs b/BitFaster.Caching/Atomic/AtomicFactoryCache.cs index 6af07f8c..0433b682 100644 --- a/BitFaster.Caching/Atomic/AtomicFactoryCache.cs +++ b/BitFaster.Caching/Atomic/AtomicFactoryCache.cs @@ -2,7 +2,6 @@ using System.Collections; using System.Collections.Generic; using System.Diagnostics; -using System.Linq.Expressions; namespace BitFaster.Caching.Atomic { diff --git a/BitFaster.Caching/Atomic/ConcurrentDictionaryExtensions.cs b/BitFaster.Caching/Atomic/ConcurrentDictionaryExtensions.cs new file mode 100644 index 00000000..333b8339 --- /dev/null +++ b/BitFaster.Caching/Atomic/ConcurrentDictionaryExtensions.cs @@ -0,0 +1,97 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; + +namespace BitFaster.Caching.Atomic +{ + /// + /// Convenience methods for using AtomicFactory with ConcurrentDictionary. + /// + public static class ConcurrentDictionaryExtensions + { + /// + /// Adds a key/value pair to the ConcurrentDictionary if the key does not already exist. Returns the new value, or the existing value if the key already exists. + /// + /// The ConcurrentDictionary to use. + /// The key of the element to add. + /// The function used to generate a value for the key. + /// The value for the key. This will be either the existing value for the key if the key is already in the dictionary, or the new value if the key was not in the dictionary. + public static V GetOrAdd(this ConcurrentDictionary> dictionary, K key, Func valueFactory) + { + var atomicFactory = dictionary.GetOrAdd(key, _ => new AtomicFactory()); + return atomicFactory.GetValue(key, valueFactory); + } + + /// + /// Adds a key/value pair to the ConcurrentDictionary by using the specified function and an argument if the key does not already exist, or returns the existing value if the key exists. + /// + /// The ConcurrentDictionary to use. + /// The key of the element to add. + /// The function used to generate a value for the key. + /// An argument value to pass into valueFactory. + /// The value for the key. This will be either the existing value for the key if the key is already in the dictionary, or the new value if the key was not in the dictionary. + public static V GetOrAdd(this ConcurrentDictionary> dictionary, K key, Func valueFactory, TArg factoryArgument) + { + var atomicFactory = dictionary.GetOrAdd(key, _ => new AtomicFactory()); + return atomicFactory.GetValue(key, valueFactory, factoryArgument); + } + + /// + /// Attempts to get the value associated with the specified key from the ConcurrentDictionary. + /// + /// The ConcurrentDictionary to use. + /// The key of the value to get. + /// When this method returns, contains the object from the ConcurrentDictionary that has the specified key, or the default value of the type if the operation failed. + /// true if the key was found in the ConcurrentDictionary; otherwise, false. + public static bool TryGetValue(this ConcurrentDictionary> dictionary, K key, out V value) + { + AtomicFactory output; + var ret = dictionary.TryGetValue(key, out output); + + if (ret && output.IsValueCreated) + { + value = output.ValueIfCreated; + return true; + } + + value = default; + return false; + } + + /// + /// Removes a key and value from the dictionary. + /// + /// The ConcurrentDictionary to use. + /// The KeyValuePair representing the key and value to remove. + /// true if the object was removed successfully; otherwise, false. + public static bool TryRemove(this ConcurrentDictionary> dictionary, KeyValuePair item) + { + var kvp = new KeyValuePair>(item.Key, new AtomicFactory(item.Value)); +#if NET6_0_OR_GREATER + return dictionary.TryRemove(kvp); +#else + // https://devblogs.microsoft.com/pfxteam/little-known-gems-atomic-conditional-removals-from-concurrentdictionary/ + return ((ICollection>>)dictionary).Remove(kvp); +#endif + } + + /// + /// Attempts to remove and return the value that has the specified key from the ConcurrentDictionary. + /// + /// The ConcurrentDictionary to use. + /// The key of the element to remove and return. + /// When this method returns, contains the object removed from the ConcurrentDictionary, or the default value of the TValue type if key does not exist. + /// true if the object was removed successfully; otherwise, false. + public static bool TryRemove(this ConcurrentDictionary> dictionary, K key, out V value) + { + if (dictionary.TryRemove(key, out var atomic)) + { + value = atomic.ValueIfCreated; + return true; + } + + value = default; + return false; + } + } +}