diff --git a/libc/src/__support/threads/linux/CMakeLists.txt b/libc/src/__support/threads/linux/CMakeLists.txt index 364e7e2b90585..ba25f9b6a75fc 100644 --- a/libc/src/__support/threads/linux/CMakeLists.txt +++ b/libc/src/__support/threads/linux/CMakeLists.txt @@ -69,6 +69,7 @@ add_header_library( .futex_utils .raw_mutex libc.src.__support.threads.mutex_common + libc.src.__support.threads.identifier ) add_object_library( diff --git a/libc/src/__support/threads/linux/mutex.h b/libc/src/__support/threads/linux/mutex.h index 0c4b1ae09af6f..e161159ec06e5 100644 --- a/libc/src/__support/threads/linux/mutex.h +++ b/libc/src/__support/threads/linux/mutex.h @@ -13,6 +13,7 @@ #include "src/__support/CPP/optional.h" #include "src/__support/libc_assert.h" #include "src/__support/macros/config.h" +#include "src/__support/threads/identifier.h" #include "src/__support/threads/linux/futex_utils.h" #include "src/__support/threads/linux/raw_mutex.h" #include "src/__support/threads/mutex_common.h" @@ -31,6 +32,25 @@ class Mutex final : private RawMutex { pid_t owner; unsigned long long lock_count; + LIBC_INLINE bool increase_lock_count() { + if (__builtin_add_overflow(this->lock_count, 1, &this->lock_count)) + return false; + return true; + } + + LIBC_INLINE MutexError recursive_lock() { + if (!this->recursive) + return MutexError::DEADLOCK; + if (increase_lock_count()) + return MutexError::NONE; + return MutexError::MAX_RECURSION; + } + + LIBC_INLINE void post_raw_lock() { + this->lock_count++; + this->owner = internal::gettid(); + } + public: LIBC_INLINE constexpr Mutex(bool is_timed, bool is_recursive, bool is_robust, bool is_pshared) @@ -56,29 +76,50 @@ class Mutex final : private RawMutex { return MutexError::NONE; } - // TODO: record owner and lock count. LIBC_INLINE MutexError lock() { + if (this->owner == internal::gettid()) + return recursive_lock(); // Since timeout is not specified, we do not need to check the return value. this->RawMutex::lock( /* timeout=*/cpp::nullopt, this->pshared); + post_raw_lock(); return MutexError::NONE; } - // TODO: record owner and lock count. LIBC_INLINE MutexError timed_lock(internal::AbsTimeout abs_time) { - if (this->RawMutex::lock(abs_time, this->pshared)) + if (this->owner == internal::gettid()) + return recursive_lock(); + if (this->RawMutex::lock(abs_time, this->pshared)) { + post_raw_lock(); return MutexError::NONE; + } return MutexError::TIMEOUT; } LIBC_INLINE MutexError unlock() { - if (this->RawMutex::unlock(this->pshared)) + if (this->owner != internal::gettid()) + return MutexError::UNLOCK_WITHOUT_LOCK; + + if (this->lock_count > 1) { + this->lock_count--; return MutexError::NONE; - return MutexError::UNLOCK_WITHOUT_LOCK; + } + + // no longer holding the lock + if (this->RawMutex::unlock(this->pshared)) { + this->lock_count--; + this->owner = 0; + return MutexError::NONE; + } + + // memory corrupted + return MutexError::BAD_LOCK_STATE; } // TODO: record owner and lock count. LIBC_INLINE MutexError try_lock() { + if (this->owner == internal::gettid()) + return recursive_lock(); if (this->RawMutex::try_lock()) return MutexError::NONE; return MutexError::BUSY; diff --git a/libc/src/__support/threads/mutex_common.h b/libc/src/__support/threads/mutex_common.h index 9913f69a6a61a..c849420d0d50f 100644 --- a/libc/src/__support/threads/mutex_common.h +++ b/libc/src/__support/threads/mutex_common.h @@ -19,6 +19,8 @@ enum class MutexError : int { TIMEOUT, UNLOCK_WITHOUT_LOCK, BAD_LOCK_STATE, + DEADLOCK, + MAX_RECURSION, }; } // namespace LIBC_NAMESPACE_DECL