diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h index d0a6170805bfd..a3429ac14e56e 100644 --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -231,28 +231,6 @@ class StorageUniquer { return mutateImpl(id, mutationFn); } - /// Erases a uniqued instance of 'Storage'. This function is used for derived - /// types that have complex storage or uniquing constraints. - template - void erase(TypeID id, Arg &&arg, Args &&...args) { - // Construct a value of the derived key type. - auto derivedKey = - getKey(std::forward(arg), std::forward(args)...); - - // Create a hash of the derived key. - unsigned hashValue = getHash(derivedKey); - - // Generate an equality function for the derived storage. - auto isEqual = [&derivedKey](const BaseStorage *existing) { - return static_cast(*existing) == derivedKey; - }; - - // Attempt to erase the storage instance. - eraseImpl(id, hashValue, isEqual, [](BaseStorage *storage) { - static_cast(storage)->cleanup(); - }); - } - private: /// Implementation for getting/creating an instance of a derived type with /// parametric storage. @@ -275,12 +253,6 @@ class StorageUniquer { registerSingletonImpl(TypeID id, function_ref ctorFn); - /// Implementation for erasing an instance of a derived type with complex - /// storage. - void eraseImpl(TypeID id, unsigned hashValue, - function_ref isEqual, - function_ref cleanupFn); - /// Implementation for mutating an instance of a derived storage. LogicalResult mutateImpl(TypeID id, diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h new file mode 100644 index 0000000000000..3b5d6f0f424f9 --- /dev/null +++ b/mlir/include/mlir/Support/ThreadLocalCache.h @@ -0,0 +1,117 @@ +//===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a definition of the ThreadLocalCache class. This class +// provides support for defining thread local objects with non-static duration. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_THREADLOCALCACHE_H +#define MLIR_SUPPORT_THREADLOCALCACHE_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/Mutex.h" +#include "llvm/Support/ThreadLocal.h" + +namespace mlir { +/// This class provides support for defining a thread local object with non +/// static storage duration. This is very useful for situations in which a data +/// cache has very large lock contention. +template +class ThreadLocalCache { + /// The type used for the static thread_local cache. This is a map between an + /// instance of the non-static cache and a weak reference to an instance of + /// ValueT. We use a weak reference here so that the object can be destroyed + /// without needing to lock access to the cache itself. + struct CacheType : public llvm::SmallDenseMap *, + std::weak_ptr> { + ~CacheType() { + // Remove the values of this cache that haven't already expired. + for (auto &it : *this) + if (std::shared_ptr value = it.second.lock()) + it.first->remove(value.get()); + } + + /// Clear out any unused entries within the map. This method is not + /// thread-safe, and should only be called by the same thread as the cache. + void clearExpiredEntries() { + for (auto it = this->begin(), e = this->end(); it != e;) { + auto curIt = it++; + if (curIt->second.expired()) + this->erase(curIt); + } + } + }; + +public: + ThreadLocalCache() = default; + ~ThreadLocalCache() { + // No cleanup is necessary here as the shared_pointer memory will go out of + // scope and invalidate the weak pointers held by the thread_local caches. + } + + /// Return an instance of the value type for the current thread. + ValueT &get() { + // Check for an already existing instance for this thread. + CacheType &staticCache = getStaticCache(); + std::weak_ptr &threadInstance = staticCache[this]; + if (std::shared_ptr value = threadInstance.lock()) + return *value; + + // Otherwise, create a new instance for this thread. + llvm::sys::SmartScopedLock threadInstanceLock(instanceMutex); + instances.push_back(std::make_shared()); + std::shared_ptr &instance = instances.back(); + threadInstance = instance; + + // Before returning the new instance, take the chance to clear out any used + // entries in the static map. The cache is only cleared within the same + // thread to remove the need to lock the cache itself. + staticCache.clearExpiredEntries(); + return *instance; + } + ValueT &operator*() { return get(); } + ValueT *operator->() { return &get(); } + +private: + ThreadLocalCache(ThreadLocalCache &&) = delete; + ThreadLocalCache(const ThreadLocalCache &) = delete; + ThreadLocalCache &operator=(const ThreadLocalCache &) = delete; + + /// Return the static thread local instance of the cache type. + static CacheType &getStaticCache() { + static LLVM_THREAD_LOCAL CacheType cache; + return cache; + } + + /// Remove the given value entry. This is generally called when a thread local + /// cache is destructing. + void remove(ValueT *value) { + // Erase the found value directly, because it is guaranteed to be in the + // list. + llvm::sys::SmartScopedLock threadInstanceLock(instanceMutex); + auto it = llvm::find_if(instances, [&](std::shared_ptr &instance) { + return instance.get() == value; + }); + assert(it != instances.end() && "expected value to exist in cache"); + instances.erase(it); + } + + /// Owning pointers to all of the values that have been constructed for this + /// object in the static cache. + SmallVector, 1> instances; + + /// A mutex used when a new thread instance has been added to the cache for + /// this object. + llvm::sys::SmartMutex instanceMutex; +}; +} // end namespace mlir + +#endif // MLIR_SUPPORT_THREADLOCALCACHE_H diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 7fffb51a1d1a3..7551bb929970e 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Module.h" #include "mlir/IR/Types.h" +#include "mlir/Support/ThreadLocalCache.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" @@ -291,8 +292,12 @@ class MLIRContextImpl { /// operations. llvm::StringMap registeredOperations; - /// These are identifiers uniqued into this MLIRContext. + /// Identifers are uniqued by string value and use the internal string set for + /// storage. llvm::StringSet identifiers; + /// A thread local cache of identifiers to reduce lock contention. + ThreadLocalCache *>> + localIdentifierCache; /// An allocator used for AbstractAttribute and AbstractType objects. llvm::BumpPtrAllocator abstractDialectSymbolAllocator; @@ -703,16 +708,6 @@ const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { /// Return an identifier for the specified string. Identifier Identifier::get(StringRef str, MLIRContext *context) { - auto &impl = context->getImpl(); - - // Check for an existing identifier in read-only mode. - if (context->isMultithreadingEnabled()) { - llvm::sys::SmartScopedReader contextLock(impl.identifierMutex); - auto it = impl.identifiers.find(str); - if (it != impl.identifiers.end()) - return Identifier(&*it); - } - // Check invariants after seeing if we already have something in the // identifier table - if we already had it in the table, then it already // passed invariant checks. @@ -720,10 +715,30 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) { assert(str.find('\0') == StringRef::npos && "Cannot create an identifier with a nul character"); + auto &impl = context->getImpl(); + if (!context->isMultithreadingEnabled()) + return Identifier(&*impl.identifiers.insert(str).first); + + // Check for an existing instance in the local cache. + auto *&localEntry = (*impl.localIdentifierCache)[str]; + if (localEntry) + return Identifier(localEntry); + + // Check for an existing identifier in read-only mode. + { + llvm::sys::SmartScopedReader contextLock(impl.identifierMutex); + auto it = impl.identifiers.find(str); + if (it != impl.identifiers.end()) { + localEntry = &*it; + return Identifier(localEntry); + } + } + // Acquire a writer-lock so that we can safely create the new instance. - ScopedWriterLock contextLock(impl.identifierMutex, impl.threadingIsEnabled); + llvm::sys::SmartScopedWriter contextLock(impl.identifierMutex); auto it = impl.identifiers.insert(str).first; - return Identifier(&*it); + localEntry = &*it; + return Identifier(localEntry); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp index a3e296e99e738..8e0ef6b8f2765 100644 --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -9,6 +9,7 @@ #include "mlir/Support/StorageUniquer.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/ThreadLocalCache.h" #include "mlir/Support/TypeID.h" #include "llvm/Support/RWMutex.h" @@ -37,6 +38,8 @@ struct ParametricStorageUniquer { /// A utility wrapper object representing a hashed storage object. This class /// contains a storage object and an existing computed hash value. struct HashedStorage { + HashedStorage(unsigned hashValue = 0, BaseStorage *storage = nullptr) + : hashValue(hashValue), storage(storage) {} unsigned hashValue; BaseStorage *storage; }; @@ -44,10 +47,10 @@ struct ParametricStorageUniquer { /// Storage info for derived TypeStorage objects. struct StorageKeyInfo : DenseMapInfo { static HashedStorage getEmptyKey() { - return HashedStorage{0, DenseMapInfo::getEmptyKey()}; + return HashedStorage(0, DenseMapInfo::getEmptyKey()); } static HashedStorage getTombstoneKey() { - return HashedStorage{0, DenseMapInfo::getTombstoneKey()}; + return HashedStorage(0, DenseMapInfo::getTombstoneKey()); } static unsigned getHashValue(const HashedStorage &key) { @@ -70,6 +73,10 @@ struct ParametricStorageUniquer { using StorageTypeSet = DenseSet; StorageTypeSet instances; + /// A thread local cache for storage objects. This helps to reduce the lock + /// contention when an object already existing in the cache. + ThreadLocalCache localCache; + /// Allocator to use when constructing derived instances. StorageAllocator allocator; @@ -104,25 +111,31 @@ struct StorageUniquerImpl { if (!threadingIsEnabled) return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn); + // Check for a instance of this object in the local cache. + auto localIt = storageUniquer.localCache->insert_as({hashValue}, lookupKey); + BaseStorage *&localInst = localIt.first->storage; + if (localInst) + return localInst; + // Check for an existing instance in read-only mode. { llvm::sys::SmartScopedReader typeLock(storageUniquer.mutex); auto it = storageUniquer.instances.find_as(lookupKey); if (it != storageUniquer.instances.end()) - return it->storage; + return localInst = it->storage; } // Acquire a writer-lock so that we can safely create the new type instance. llvm::sys::SmartScopedWriter typeLock(storageUniquer.mutex); - return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn); + return localInst = getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn); } - /// Get or create an instance of a complex derived type in an thread-unsafe + /// Get or create an instance of a param derived type in an thread-unsafe /// fashion. BaseStorage * getOrCreateUnsafe(ParametricStorageUniquer &storageUniquer, - ParametricStorageUniquer::LookupKey &lookupKey, + ParametricStorageUniquer::LookupKey &key, function_ref ctorFn) { - auto existing = storageUniquer.instances.insert_as({}, lookupKey); + auto existing = storageUniquer.instances.insert_as({key.hashValue}, key); if (!existing.second) return existing.first->storage; @@ -130,30 +143,10 @@ struct StorageUniquerImpl { // instance. BaseStorage *storage = ctorFn(storageUniquer.allocator); *existing.first = - ParametricStorageUniquer::HashedStorage{lookupKey.hashValue, storage}; + ParametricStorageUniquer::HashedStorage{key.hashValue, storage}; return storage; } - /// Erase an instance of a parametric derived type. - void erase(TypeID id, unsigned hashValue, - function_ref isEqual, - function_ref cleanupFn) { - assert(parametricUniquers.count(id) && - "erasing unregistered storage instance"); - ParametricStorageUniquer &storageUniquer = *parametricUniquers[id]; - ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual}; - - // Acquire a writer-lock so that we can safely erase the type instance. - llvm::sys::SmartScopedWriter lock(storageUniquer.mutex); - auto existing = storageUniquer.instances.find_as(lookupKey); - if (existing == storageUniquer.instances.end()) - return; - - // Cleanup the storage and remove it from the map. - cleanupFn(existing->storage); - storageUniquer.instances.erase(existing); - } - /// Mutates an instance of a derived storage in a thread-safe way. LogicalResult mutate(TypeID id, @@ -252,14 +245,6 @@ void StorageUniquer::registerSingletonImpl( impl->singletonInstances.try_emplace(id, ctorFn(impl->singletonAllocator)); } -/// Implementation for erasing an instance of a derived type with parametric -/// storage. -void StorageUniquer::eraseImpl(TypeID id, unsigned hashValue, - function_ref isEqual, - function_ref cleanupFn) { - impl->erase(id, hashValue, isEqual, cleanupFn); -} - /// Implementation for mutating an instance of a derived storage. LogicalResult StorageUniquer::mutateImpl( TypeID id, function_ref mutationFn) {