Skip to content

Commit

Permalink
Revert "Fix tsan problem where the per-thread shared_ptr() can be loc…
Browse files Browse the repository at this point in the history
…ked right before the cache is destroyed causing a race where it tries to remove an entry from a destroyed cache."

This reverts commit bcc1081.

Reason: Broke the aarch64-asan bot. More information available in the
Phabricator review: https://reviews.llvm.org/D140931
  • Loading branch information
hctim committed Feb 1, 2023
1 parent 5cff68f commit 44003d7
Showing 1 changed file with 27 additions and 38 deletions.
65 changes: 27 additions & 38 deletions mlir/include/mlir/Support/ThreadLocalCache.h
Expand Up @@ -25,40 +25,12 @@ namespace mlir {
/// cache has very large lock contention.
template <typename ValueT>
class ThreadLocalCache {
// Keep a separate shared_ptr protected state that can be acquired atomically
// instead of using shared_ptr's for each value. This avoids a problem
// where the instance shared_ptr is locked() successfully, and then the
// ThreadLocalCache gets destroyed before remove() can be called successfully.
struct PerInstanceState {
/// Remove the given value entry. This is generally called when a thread
/// local cache is destructing.
void remove(ValueT *value) {
// Erase the found value directly, because it is guaranteed to be in the
// list.
llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
auto it =
llvm::find_if(instances, [&](std::unique_ptr<ValueT> &instance) {
return instance.get() == value;
});
assert(it != instances.end() && "expected value to exist in cache");
instances.erase(it);
}

/// Owning pointers to all of the values that have been constructed for this
/// object in the static cache.
SmallVector<std::unique_ptr<ValueT>, 1> instances;

/// A mutex used when a new thread instance has been added to the cache for
/// this object.
llvm::sys::SmartMutex<true> instanceMutex;
};

/// The type used for the static thread_local cache. This is a map between an
/// instance of the non-static cache and a weak reference to an instance of
/// ValueT. We use a weak reference here so that the object can be destroyed
/// without needing to lock access to the cache itself.
struct CacheType
: public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
struct CacheType : public llvm::SmallDenseMap<ThreadLocalCache<ValueT> *,
std::weak_ptr<ValueT>> {
~CacheType() {
// Remove the values of this cache that haven't already expired.
for (auto &it : *this)
Expand Down Expand Up @@ -88,16 +60,15 @@ class ThreadLocalCache {
ValueT &get() {
// Check for an already existing instance for this thread.
CacheType &staticCache = getStaticCache();
std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
std::weak_ptr<ValueT> &threadInstance = staticCache[this];
if (std::shared_ptr<ValueT> value = threadInstance.lock())
return *value;

// Otherwise, create a new instance for this thread.
llvm::sys::SmartScopedLock<true> threadInstanceLock(
perInstanceState->instanceMutex);
perInstanceState->instances.push_back(std::make_unique<ValueT>());
ValueT *instance = perInstanceState->instances.back().get();
threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
instances.push_back(std::make_shared<ValueT>());
std::shared_ptr<ValueT> &instance = instances.back();
threadInstance = instance;

// Before returning the new instance, take the chance to clear out any used
// entries in the static map. The cache is only cleared within the same
Expand All @@ -119,8 +90,26 @@ class ThreadLocalCache {
return cache;
}

std::shared_ptr<PerInstanceState> perInstanceState =
std::make_shared<PerInstanceState>();
/// Remove the given value entry. This is generally called when a thread local
/// cache is destructing.
void remove(ValueT *value) {
// Erase the found value directly, because it is guaranteed to be in the
// list.
llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
auto it = llvm::find_if(instances, [&](std::shared_ptr<ValueT> &instance) {
return instance.get() == value;
});
assert(it != instances.end() && "expected value to exist in cache");
instances.erase(it);
}

/// Owning pointers to all of the values that have been constructed for this
/// object in the static cache.
SmallVector<std::shared_ptr<ValueT>, 1> instances;

/// A mutex used when a new thread instance has been added to the cache for
/// this object.
llvm::sys::SmartMutex<true> instanceMutex;
};
} // namespace mlir

Expand Down

0 comments on commit 44003d7

Please sign in to comment.