Skip to content

Commit

Permalink
[OpenMP] Replaced mutex lock/unlock in target with std::lock_guard
Browse files Browse the repository at this point in the history
Reviewed By: ye-luo

Differential Revision: https://reviews.llvm.org/D84799
  • Loading branch information
shiltian committed Jul 29, 2020
1 parent 69fc33f commit 3044092
Showing 1 changed file with 44 additions and 41 deletions.
85 changes: 44 additions & 41 deletions openmp/libomptarget/src/omptarget.cpp
Expand Up @@ -722,38 +722,38 @@ int target(int64_t DeviceId, void *HostPtr, int32_t ArgNum, void **ArgBases,
// Find the table information in the map or look it up in the translation
// tables.
TableMap *TM = 0;
TblMapMtx->lock();
HostPtrToTableMapTy::iterator TableMapIt = HostPtrToTableMap->find(HostPtr);
if (TableMapIt == HostPtrToTableMap->end()) {
// We don't have a map. So search all the registered libraries.
TrlTblMtx->lock();
for (HostEntriesBeginToTransTableTy::iterator
II = HostEntriesBeginToTransTable->begin(),
IE = HostEntriesBeginToTransTable->end();
!TM && II != IE; ++II) {
// get the translation table (which contains all the good info).
TranslationTable *TransTable = &II->second;
// iterate over all the host table entries to see if we can locate the
// host_ptr.
__tgt_offload_entry *Begin = TransTable->HostTable.EntriesBegin;
__tgt_offload_entry *End = TransTable->HostTable.EntriesEnd;
__tgt_offload_entry *Cur = Begin;
for (uint32_t I = 0; Cur < End; ++Cur, ++I) {
if (Cur->addr != HostPtr)
continue;
// we got a match, now fill the HostPtrToTableMap so that we
// may avoid this search next time.
TM = &(*HostPtrToTableMap)[HostPtr];
TM->Table = TransTable;
TM->Index = I;
break;
{
std::lock_guard<std::mutex> TblMapLock(*TblMapMtx);
HostPtrToTableMapTy::iterator TableMapIt = HostPtrToTableMap->find(HostPtr);
if (TableMapIt == HostPtrToTableMap->end()) {
// We don't have a map. So search all the registered libraries.
std::lock_guard<std::mutex> TrlTblLock(*TrlTblMtx);
for (HostEntriesBeginToTransTableTy::iterator
II = HostEntriesBeginToTransTable->begin(),
IE = HostEntriesBeginToTransTable->end();
!TM && II != IE; ++II) {
// get the translation table (which contains all the good info).
TranslationTable *TransTable = &II->second;
// iterate over all the host table entries to see if we can locate the
// host_ptr.
__tgt_offload_entry *Begin = TransTable->HostTable.EntriesBegin;
__tgt_offload_entry *End = TransTable->HostTable.EntriesEnd;
__tgt_offload_entry *Cur = Begin;
for (uint32_t I = 0; Cur < End; ++Cur, ++I) {
if (Cur->addr != HostPtr)
continue;
// we got a match, now fill the HostPtrToTableMap so that we
// may avoid this search next time.
TM = &(*HostPtrToTableMap)[HostPtr];
TM->Table = TransTable;
TM->Index = I;
break;
}
}
} else {
TM = &TableMapIt->second;
}
TrlTblMtx->unlock();
} else {
TM = &TableMapIt->second;
}
TblMapMtx->unlock();

// No map for this host pointer found!
if (!TM) {
Expand All @@ -763,11 +763,13 @@ int target(int64_t DeviceId, void *HostPtr, int32_t ArgNum, void **ArgBases,
}

// get target table.
TrlTblMtx->lock();
assert(TM->Table->TargetsTable.size() > (size_t)DeviceId &&
"Not expecting a device ID outside the table's bounds!");
__tgt_target_table *TargetTable = TM->Table->TargetsTable[DeviceId];
TrlTblMtx->unlock();
__tgt_target_table *TargetTable = nullptr;
{
std::lock_guard<std::mutex> TrlTblLock(*TrlTblMtx);
assert(TM->Table->TargetsTable.size() > (size_t)DeviceId &&
"Not expecting a device ID outside the table's bounds!");
TargetTable = TM->Table->TargetsTable[DeviceId];
}
assert(TargetTable && "Global data has not been mapped\n");

__tgt_async_info AsyncInfo;
Expand Down Expand Up @@ -899,14 +901,15 @@ int target(int64_t DeviceId, void *HostPtr, int32_t ArgNum, void **ArgBases,

// Pop loop trip count
uint64_t LoopTripCount = 0;
TblMapMtx->lock();
auto I = Device.LoopTripCnt.find(__kmpc_global_thread_num(NULL));
if (I != Device.LoopTripCnt.end()) {
LoopTripCount = I->second;
Device.LoopTripCnt.erase(I);
DP("loop trip count is %lu.\n", LoopTripCount);
{
std::lock_guard<std::mutex> TblMapLock(*TblMapMtx);
auto I = Device.LoopTripCnt.find(__kmpc_global_thread_num(NULL));
if (I != Device.LoopTripCnt.end()) {
LoopTripCount = I->second;
Device.LoopTripCnt.erase(I);
DP("loop trip count is %lu.\n", LoopTripCount);
}
}
TblMapMtx->unlock();

// Launch device execution.
DP("Launching target execution %s with pointer " DPxMOD " (index=%d).\n",
Expand Down

0 comments on commit 3044092

Please sign in to comment.