Skip to content

Commit

Permalink
[OpenMP][NFC] Encapsulate Devices.size() (#74010)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdoerfert committed Dec 1, 2023
1 parent b6cad75 commit 1035cc7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
5 changes: 5 additions & 0 deletions openmp/libomptarget/include/PluginManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ struct PluginManager {
DelayedBinDesc.clear();
}

int getNumDevices() {
std::lock_guard<decltype(RTLsMtx)> Lock(RTLsMtx);
return Devices.size();
}

private:
bool RTLsLoaded = false;
llvm::SmallVector<__tgt_bin_desc *> DelayedBinDesc;
Expand Down
24 changes: 8 additions & 16 deletions openmp/libomptarget/src/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,11 @@

EXTERN int omp_get_num_devices(void) {
TIMESCOPE();
PM->RTLsMtx.lock();
size_t DevicesSize = PM->Devices.size();
PM->RTLsMtx.unlock();
size_t NumDevices = PM->getNumDevices();

DP("Call to omp_get_num_devices returning %zd\n", DevicesSize);
DP("Call to omp_get_num_devices returning %zd\n", NumDevices);

return DevicesSize;
return NumDevices;
}

EXTERN int omp_get_device_num(void) {
Expand Down Expand Up @@ -112,10 +110,8 @@ EXTERN int omp_target_is_present(const void *Ptr, int DeviceNum) {
return true;
}

PM->RTLsMtx.lock();
size_t DevicesSize = PM->Devices.size();
PM->RTLsMtx.unlock();
if (DevicesSize <= (size_t)DeviceNum) {
size_t NumDevices = PM->getNumDevices();
if (NumDevices <= (size_t)DeviceNum) {
DP("Call to omp_target_is_present with invalid device ID, returning "
"false\n");
return false;
Expand Down Expand Up @@ -562,18 +558,14 @@ EXTERN void *omp_get_mapped_ptr(const void *Ptr, int DeviceNum) {
return nullptr;
}

if (DeviceNum == omp_get_initial_device()) {
size_t NumDevices = omp_get_initial_device();
if (DeviceNum == NumDevices) {
REPORT("Device %d is initial device, returning Ptr " DPxMOD ".\n",
DeviceNum, DPxPTR(Ptr));
return const_cast<void *>(Ptr);
}

int DevicesSize = omp_get_initial_device();
{
std::lock_guard<std::mutex> LG(PM->RTLsMtx);
DevicesSize = PM->Devices.size();
}
if (DevicesSize <= DeviceNum) {
if (NumDevices <= DeviceNum) {
DP("DeviceNum %d is invalid, returning nullptr.\n", DeviceNum);
return nullptr;
}
Expand Down

0 comments on commit 1035cc7

Please sign in to comment.