Skip to content
8 changes: 1 addition & 7 deletions offload/include/OpenMP/InteropAPI.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,11 @@ struct InteropTableEntry {
Interops.push_back(obj);
}

template <class ClearFuncTy> void clear(ClearFuncTy f) {
for (auto &Obj : Interops) {
f(Obj);
}
}

/// vector interface
int size() const { return Interops.size(); }
iterator begin() { return Interops.begin(); }
iterator end() { return Interops.end(); }
iterator erase(iterator it) { return Interops.erase(it); }
void clear() { Interops.clear(); }
};

struct InteropTblTy
Expand Down
236 changes: 206 additions & 30 deletions offload/include/PerThreadTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,106 @@
#define OFFLOAD_PERTHREADTABLE_H

#include <list>
#include <llvm/ADT/SmallVector.h>
#include <llvm/Support/Error.h>
#include <memory>
#include <mutex>
#include <type_traits>

template <typename ObjectType> struct PerThread {
struct PerThreadData {
std::unique_ptr<ObjectType> ThreadEntry;
};

std::mutex Mutex;
llvm::SmallVector<std::shared_ptr<PerThreadData>> ThreadDataList;

// define default constructors, disable copy and move constructors
PerThread() = default;
PerThread(const PerThread &) = delete;
PerThread(PerThread &&) = delete;
PerThread &operator=(const PerThread &) = delete;
PerThread &operator=(PerThread &&) = delete;
~PerThread() {
assert(Mutex.try_lock() &&
"Cannot be deleted while other threads are adding entries");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's undefined behavior to destroy the mutex while it's locked:

The behavior of a program is undefined if a mutex is destroyed while still owned by any threads, or a thread terminates while owning a mutex.

Given that there is no function to test the mutex, you may use:

assert(M.try_lock() && (M.unlock(), true) && "Cannot ...");

ThreadDataList.clear();
}

private:
PerThreadData &getThreadData() {
static thread_local std::shared_ptr<PerThreadData> ThreadData = nullptr;
if (!ThreadData) {
ThreadData = std::make_shared<PerThreadData>();
std::lock_guard<std::mutex> Lock(Mutex);
ThreadDataList.push_back(ThreadData);
}
return *ThreadData;
}

protected:
ObjectType &getThreadEntry() {
PerThreadData &ThreadData = getThreadData();
if (ThreadData.ThreadEntry)
return *ThreadData.ThreadEntry;
ThreadData.ThreadEntry = std::make_unique<ObjectType>();
return *ThreadData.ThreadEntry;
}

public:
ObjectType &get() { return getThreadEntry(); }

template <class ClearFuncTy> void clear(ClearFuncTy ClearFunc) {
assert(Mutex.try_lock() &&
"Clear cannot be called while other threads are adding entries");
for (std::shared_ptr<PerThreadData> ThreadData : ThreadDataList) {
if (!ThreadData->ThreadEntry)
continue;
ClearFunc(*ThreadData->ThreadEntry);
}
ThreadDataList.clear();
}
};

// Using an STL container (such as std::vector) indexed by thread ID has
// too many race conditions issues so we store each thread entry into a
// thread_local variable.
// T is the container type used to store the objects, e.g., std::vector,
// std::set, etc. by each thread. O is the type of the stored objects e.g.,
// omp_interop_val_t *, ...

template <typename ContainerType, typename ObjectType> struct PerThreadTable {
using iterator = typename ContainerType::iterator;

template <typename, typename = std::void_t<>>
struct has_iterator : std::false_type {};
template <typename T>
struct has_iterator<T, std::void_t<typename T::iterator>> : std::true_type {};

template <typename T, typename = std::void_t<>>
struct has_clear : std::false_type {};
template <typename T>
struct has_clear<T, std::void_t<decltype(std::declval<T>().clear())>>
: std::true_type {};

template <typename T, typename = std::void_t<>>
struct has_clearAll : std::false_type {};
template <typename T>
struct has_clearAll<T, std::void_t<decltype(std::declval<T>().clearAll(1))>>
: std::true_type {};

template <typename, typename = std::void_t<>>
struct is_associative : std::false_type {};
template <typename T>
struct is_associative<T, std::void_t<typename T::mapped_type>>
: std::true_type {};

struct PerThreadData {
size_t NElements = 0;
std::unique_ptr<ContainerType> ThEntry;
std::unique_ptr<ContainerType> ThreadEntry;
};

std::mutex Mtx;
std::list<std::shared_ptr<PerThreadData>> ThreadDataList;
std::mutex Mutex;
llvm::SmallVector<std::shared_ptr<PerThreadData>> ThreadDataList;

// define default constructors, disable copy and move constructors
PerThreadTable() = default;
Expand All @@ -42,46 +122,52 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable {
PerThreadTable &operator=(const PerThreadTable &) = delete;
PerThreadTable &operator=(PerThreadTable &&) = delete;
~PerThreadTable() {
std::lock_guard<std::mutex> Lock(Mtx);
assert(Mutex.try_lock() &&
"Cannot be deleted while other threads are adding entries");
ThreadDataList.clear();
}

private:
PerThreadData &getThreadData() {
static thread_local std::shared_ptr<PerThreadData> ThData = nullptr;
if (!ThData) {
ThData = std::make_shared<PerThreadData>();
std::lock_guard<std::mutex> Lock(Mtx);
ThreadDataList.push_back(ThData);
static thread_local std::shared_ptr<PerThreadData> ThreadData = nullptr;
if (!ThreadData) {
ThreadData = std::make_shared<PerThreadData>();
std::lock_guard<std::mutex> Lock(Mutex);
ThreadDataList.push_back(ThreadData);
}
return *ThData;
return *ThreadData;
}

protected:
ContainerType &getThreadEntry() {
auto &ThData = getThreadData();
if (ThData.ThEntry)
return *ThData.ThEntry;
ThData.ThEntry = std::make_unique<ContainerType>();
return *ThData.ThEntry;
PerThreadData &ThreadData = getThreadData();
if (ThreadData.ThreadEntry)
return *ThreadData.ThreadEntry;
ThreadData.ThreadEntry = std::make_unique<ContainerType>();
return *ThreadData.ThreadEntry;
}

size_t &getThreadNElements() {
auto &ThData = getThreadData();
return ThData.NElements;
PerThreadData &ThreadData = getThreadData();
return ThreadData.NElements;
}

void setNElements(size_t Size) {
size_t &NElements = getThreadNElements();
NElements = Size;
}

public:
void add(ObjectType obj) {
auto &Entry = getThreadEntry();
auto &NElements = getThreadNElements();
ContainerType &Entry = getThreadEntry();
size_t &NElements = getThreadNElements();
NElements++;
Entry.add(obj);
}

iterator erase(iterator it) {
auto &Entry = getThreadEntry();
auto &NElements = getThreadNElements();
ContainerType &Entry = getThreadEntry();
size_t &NElements = getThreadNElements();
NElements--;
return Entry.erase(it);
}
Expand All @@ -91,24 +177,114 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable {
// Iterators to traverse objects owned by
// the current thread
iterator begin() {
auto &Entry = getThreadEntry();
ContainerType &Entry = getThreadEntry();
return Entry.begin();
}
iterator end() {
auto &Entry = getThreadEntry();
ContainerType &Entry = getThreadEntry();
return Entry.end();
}

template <class F> void clear(F f) {
std::lock_guard<std::mutex> Lock(Mtx);
for (auto ThData : ThreadDataList) {
if (!ThData->ThEntry || ThData->NElements == 0)
template <class ClearFuncTy> void clear(ClearFuncTy ClearFunc) {
assert(Mutex.try_lock() &&
"Clear cannot be called while other threads are adding entries");
for (std::shared_ptr<PerThreadData> ThreadData : ThreadDataList) {
if (!ThreadData->ThreadEntry || ThreadData->NElements == 0)
continue;
ThData->ThEntry->clear(f);
ThData->NElements = 0;
if constexpr (has_clearAll<ContainerType>::value) {
ThreadData->ThreadEntry->clearAll(ClearFunc);
} else if constexpr (has_iterator<ContainerType>::value &&
has_clear<ContainerType>::value) {
for (auto &Obj : *ThreadData->ThreadEntry) {
if constexpr (is_associative<ContainerType>::value) {
ClearFunc(Obj.second);
} else {
ClearFunc(Obj);
}
}
ThreadData->ThreadEntry->clear();
} else {
static_assert(true, "Container type not supported");
}
ThreadData->NElements = 0;
}
ThreadDataList.clear();
}

template <class DeinitFuncTy> llvm::Error deinit(DeinitFuncTy DeinitFunc) {
assert(Mutex.try_lock() &&
"Deinit cannot be called while other threads are adding entries");
for (std::shared_ptr<PerThreadData> ThreadData : ThreadDataList) {
if (!ThreadData->ThreadEntry || ThreadData->NElements == 0)
continue;
for (auto &Obj : *ThreadData->ThreadEntry) {
if constexpr (is_associative<ContainerType>::value) {
if (auto Err = DeinitFunc(Obj.second))
return Err;
} else {
if (auto Err = DeinitFunc(Obj))
return Err;
}
}
}
return llvm::Error::success();
}
};

template <typename T, typename = std::void_t<>> struct ContainerValueType {
using type = typename T::value_type;
};
template <typename T>
struct ContainerValueType<T, std::void_t<typename T::mapped_type>> {
using type = typename T::mapped_type;
};

template <typename ContainerType, size_t reserveSize = 0>
struct PerThreadContainer
: public PerThreadTable<ContainerType,
typename ContainerValueType<ContainerType>::type> {

// helpers
template <typename T, typename = std::void_t<>> struct indexType {
using type = typename T::size_type;
};
template <typename T> struct indexType<T, std::void_t<typename T::key_type>> {
using type = typename T::key_type;
};
template <typename T, typename = std::void_t<>>
struct has_resize : std::false_type {};
template <typename T>
struct has_resize<T, std::void_t<decltype(std::declval<T>().resize(1))>>
: std::true_type {};

template <typename T, typename = std::void_t<>>
struct has_reserve : std::false_type {};
template <typename T>
struct has_reserve<T, std::void_t<decltype(std::declval<T>().reserve(1))>>
: std::true_type {};

using IndexType = typename indexType<ContainerType>::type;
using ObjectType = typename ContainerValueType<ContainerType>::type;

// Get the object for the given index in the current thread
ObjectType &get(IndexType Index) {
ContainerType &Entry = this->getThreadEntry();

// specialized code for vector-like containers
if constexpr (has_resize<ContainerType>::value) {
if (Index >= Entry.size()) {
if constexpr (has_reserve<ContainerType>::value && reserveSize > 0) {
if (Entry.capacity() < reserveSize)
Entry.reserve(reserveSize);
}
// If the index is out of bounds, try resize the container
Entry.resize(Index + 1);
}
}
ObjectType &Ret = Entry[Index];
this->setNElements(Entry.size());
return Ret;
}
};

#endif
Loading