diff --git a/offload/include/OpenMP/InteropAPI.h b/offload/include/OpenMP/InteropAPI.h index 8c06ba36fc3f3..02e097eab5099 100644 --- a/offload/include/OpenMP/InteropAPI.h +++ b/offload/include/OpenMP/InteropAPI.h @@ -160,17 +160,11 @@ struct InteropTableEntry { Interops.push_back(obj); } - template void clear(ClearFuncTy f) { - for (auto &Obj : Interops) { - f(Obj); - } - } - /// vector interface int size() const { return Interops.size(); } iterator begin() { return Interops.begin(); } iterator end() { return Interops.end(); } - iterator erase(iterator it) { return Interops.erase(it); } + void clear() { Interops.clear(); } }; struct InteropTblTy diff --git a/offload/include/PerThreadTable.h b/offload/include/PerThreadTable.h index 45b196171b4c8..d8222d99b6515 100644 --- a/offload/include/PerThreadTable.h +++ b/offload/include/PerThreadTable.h @@ -14,8 +14,66 @@ #define OFFLOAD_PERTHREADTABLE_H #include +#include +#include #include #include +#include + +template struct PerThread { + struct PerThreadData { + std::unique_ptr ThreadEntry; + }; + + std::mutex Mutex; + llvm::SmallVector> ThreadDataList; + + // define default constructors, disable copy and move constructors + PerThread() = default; + PerThread(const PerThread &) = delete; + PerThread(PerThread &&) = delete; + PerThread &operator=(const PerThread &) = delete; + PerThread &operator=(PerThread &&) = delete; + ~PerThread() { + assert(Mutex.try_lock() && + "Cannot be deleted while other threads are adding entries"); + ThreadDataList.clear(); + } + +private: + PerThreadData &getThreadData() { + static thread_local std::shared_ptr ThreadData = nullptr; + if (!ThreadData) { + ThreadData = std::make_shared(); + std::lock_guard Lock(Mutex); + ThreadDataList.push_back(ThreadData); + } + return *ThreadData; + } + +protected: + ObjectType &getThreadEntry() { + PerThreadData &ThreadData = getThreadData(); + if (ThreadData.ThreadEntry) + return *ThreadData.ThreadEntry; + ThreadData.ThreadEntry = std::make_unique(); + return *ThreadData.ThreadEntry; + } + +public: + ObjectType &get() { return getThreadEntry(); } + + template void clear(ClearFuncTy ClearFunc) { + assert(Mutex.try_lock() && + "Clear cannot be called while other threads are adding entries"); + for (std::shared_ptr ThreadData : ThreadDataList) { + if (!ThreadData->ThreadEntry) + continue; + ClearFunc(*ThreadData->ThreadEntry); + } + ThreadDataList.clear(); + } +}; // Using an STL container (such as std::vector) indexed by thread ID has // too many race conditions issues so we store each thread entry into a @@ -23,17 +81,39 @@ // T is the container type used to store the objects, e.g., std::vector, // std::set, etc. by each thread. O is the type of the stored objects e.g., // omp_interop_val_t *, ... - template struct PerThreadTable { using iterator = typename ContainerType::iterator; + template > + struct has_iterator : std::false_type {}; + template + struct has_iterator> : std::true_type {}; + + template > + struct has_clear : std::false_type {}; + template + struct has_clear().clear())>> + : std::true_type {}; + + template > + struct has_clearAll : std::false_type {}; + template + struct has_clearAll().clearAll(1))>> + : std::true_type {}; + + template > + struct is_associative : std::false_type {}; + template + struct is_associative> + : std::true_type {}; + struct PerThreadData { size_t NElements = 0; - std::unique_ptr ThEntry; + std::unique_ptr ThreadEntry; }; - std::mutex Mtx; - std::list> ThreadDataList; + std::mutex Mutex; + llvm::SmallVector> ThreadDataList; // define default constructors, disable copy and move constructors PerThreadTable() = default; @@ -42,46 +122,52 @@ template struct PerThreadTable { PerThreadTable &operator=(const PerThreadTable &) = delete; PerThreadTable &operator=(PerThreadTable &&) = delete; ~PerThreadTable() { - std::lock_guard Lock(Mtx); + assert(Mutex.try_lock() && + "Cannot be deleted while other threads are adding entries"); ThreadDataList.clear(); } private: PerThreadData &getThreadData() { - static thread_local std::shared_ptr ThData = nullptr; - if (!ThData) { - ThData = std::make_shared(); - std::lock_guard Lock(Mtx); - ThreadDataList.push_back(ThData); + static thread_local std::shared_ptr ThreadData = nullptr; + if (!ThreadData) { + ThreadData = std::make_shared(); + std::lock_guard Lock(Mutex); + ThreadDataList.push_back(ThreadData); } - return *ThData; + return *ThreadData; } protected: ContainerType &getThreadEntry() { - auto &ThData = getThreadData(); - if (ThData.ThEntry) - return *ThData.ThEntry; - ThData.ThEntry = std::make_unique(); - return *ThData.ThEntry; + PerThreadData &ThreadData = getThreadData(); + if (ThreadData.ThreadEntry) + return *ThreadData.ThreadEntry; + ThreadData.ThreadEntry = std::make_unique(); + return *ThreadData.ThreadEntry; } size_t &getThreadNElements() { - auto &ThData = getThreadData(); - return ThData.NElements; + PerThreadData &ThreadData = getThreadData(); + return ThreadData.NElements; + } + + void setNElements(size_t Size) { + size_t &NElements = getThreadNElements(); + NElements = Size; } public: void add(ObjectType obj) { - auto &Entry = getThreadEntry(); - auto &NElements = getThreadNElements(); + ContainerType &Entry = getThreadEntry(); + size_t &NElements = getThreadNElements(); NElements++; Entry.add(obj); } iterator erase(iterator it) { - auto &Entry = getThreadEntry(); - auto &NElements = getThreadNElements(); + ContainerType &Entry = getThreadEntry(); + size_t &NElements = getThreadNElements(); NElements--; return Entry.erase(it); } @@ -91,24 +177,114 @@ template struct PerThreadTable { // Iterators to traverse objects owned by // the current thread iterator begin() { - auto &Entry = getThreadEntry(); + ContainerType &Entry = getThreadEntry(); return Entry.begin(); } iterator end() { - auto &Entry = getThreadEntry(); + ContainerType &Entry = getThreadEntry(); return Entry.end(); } - template void clear(F f) { - std::lock_guard Lock(Mtx); - for (auto ThData : ThreadDataList) { - if (!ThData->ThEntry || ThData->NElements == 0) + template void clear(ClearFuncTy ClearFunc) { + assert(Mutex.try_lock() && + "Clear cannot be called while other threads are adding entries"); + for (std::shared_ptr ThreadData : ThreadDataList) { + if (!ThreadData->ThreadEntry || ThreadData->NElements == 0) continue; - ThData->ThEntry->clear(f); - ThData->NElements = 0; + if constexpr (has_clearAll::value) { + ThreadData->ThreadEntry->clearAll(ClearFunc); + } else if constexpr (has_iterator::value && + has_clear::value) { + for (auto &Obj : *ThreadData->ThreadEntry) { + if constexpr (is_associative::value) { + ClearFunc(Obj.second); + } else { + ClearFunc(Obj); + } + } + ThreadData->ThreadEntry->clear(); + } else { + static_assert(true, "Container type not supported"); + } + ThreadData->NElements = 0; } ThreadDataList.clear(); } + + template llvm::Error deinit(DeinitFuncTy DeinitFunc) { + assert(Mutex.try_lock() && + "Deinit cannot be called while other threads are adding entries"); + for (std::shared_ptr ThreadData : ThreadDataList) { + if (!ThreadData->ThreadEntry || ThreadData->NElements == 0) + continue; + for (auto &Obj : *ThreadData->ThreadEntry) { + if constexpr (is_associative::value) { + if (auto Err = DeinitFunc(Obj.second)) + return Err; + } else { + if (auto Err = DeinitFunc(Obj)) + return Err; + } + } + } + return llvm::Error::success(); + } +}; + +template > struct ContainerValueType { + using type = typename T::value_type; +}; +template +struct ContainerValueType> { + using type = typename T::mapped_type; +}; + +template +struct PerThreadContainer + : public PerThreadTable::type> { + + // helpers + template > struct indexType { + using type = typename T::size_type; + }; + template struct indexType> { + using type = typename T::key_type; + }; + template > + struct has_resize : std::false_type {}; + template + struct has_resize().resize(1))>> + : std::true_type {}; + + template > + struct has_reserve : std::false_type {}; + template + struct has_reserve().reserve(1))>> + : std::true_type {}; + + using IndexType = typename indexType::type; + using ObjectType = typename ContainerValueType::type; + + // Get the object for the given index in the current thread + ObjectType &get(IndexType Index) { + ContainerType &Entry = this->getThreadEntry(); + + // specialized code for vector-like containers + if constexpr (has_resize::value) { + if (Index >= Entry.size()) { + if constexpr (has_reserve::value && reserveSize > 0) { + if (Entry.capacity() < reserveSize) + Entry.reserve(reserveSize); + } + // If the index is out of bounds, try resize the container + Entry.resize(Index + 1); + } + } + ObjectType &Ret = Entry[Index]; + this->setNElements(Entry.size()); + return Ret; + } }; #endif