Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenMP] Ensure Devices is accessed exlusively #74374

Merged
merged 1 commit into from
Dec 5, 2023

Conversation

jdoerfert
Copy link
Member

We accessed the Devices container most of the time while holding the RTLsMtx, but not always. Sometimes we used the mutex for the size query, but then accessed Devices again unguarded. From now we properly encapsulate the container in a ProtectedObj which ensures exclusive accesses. We also hide the "isReady" part in the getDevice accessor and use an llvm::Expected to allow to return errors.

@jdoerfert jdoerfert added openmp openmp:libomptarget OpenMP offload runtime labels Dec 4, 2023
@llvmbot
Copy link

llvmbot commented Dec 4, 2023

@llvm/pr-subscribers-openmp

Author: Johannes Doerfert (jdoerfert)

Changes

We accessed the Devices container most of the time while holding the RTLsMtx, but not always. Sometimes we used the mutex for the size query, but then accessed Devices again unguarded. From now we properly encapsulate the container in a ProtectedObj which ensures exclusive accesses. We also hide the "isReady" part in the getDevice accessor and use an llvm::Expected to allow to return errors.


Patch is 35.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74374.diff

9 Files Affected:

  • (modified) openmp/libomptarget/include/PluginManager.h (+30-7)
  • (modified) openmp/libomptarget/include/Shared/Debug.h (+4-3)
  • (modified) openmp/libomptarget/include/device.h (+2-5)
  • (modified) openmp/libomptarget/src/OpenMP/InteropAPI.cpp (+22-7)
  • (modified) openmp/libomptarget/src/PluginManager.cpp (+45-7)
  • (modified) openmp/libomptarget/src/api.cpp (+52-56)
  • (modified) openmp/libomptarget/src/device.cpp (+2-35)
  • (modified) openmp/libomptarget/src/interface.cpp (+28-24)
  • (modified) openmp/libomptarget/src/omptarget.cpp (+52-82)
diff --git a/openmp/libomptarget/include/PluginManager.h b/openmp/libomptarget/include/PluginManager.h
index 94ecce01ca74c..bc71e5d70474b 100644
--- a/openmp/libomptarget/include/PluginManager.h
+++ b/openmp/libomptarget/include/PluginManager.h
@@ -14,6 +14,7 @@
 #define OMPTARGET_PLUGIN_MANAGER_H
 
 #include "DeviceImage.h"
+#include "ExclusiveAccess.h"
 #include "Shared/APITypes.h"
 #include "Shared/PluginAPI.h"
 #include "Shared/Requirements.h"
@@ -25,6 +26,7 @@
 #include "llvm/ADT/iterator.h"
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/Support/DynamicLibrary.h"
+#include "llvm/Support/Error.h"
 
 #include <cstdint>
 #include <list>
@@ -75,6 +77,13 @@ struct PluginAdaptorTy {
 
 /// Struct for the data required to handle plugins
 struct PluginManager {
+  /// Type of the devices container. We hand out DeviceTy& to queries which are
+  /// stable addresses regardless if the container changes.
+  using DeviceContainerTy = llvm::SmallVector<std::unique_ptr<DeviceTy>>;
+
+  /// Exclusive accessor type for the device container.
+  using ExclusiveDevicesAccessorTy = Accessor<DeviceContainerTy>;
+
   PluginManager() {}
 
   void init();
@@ -89,13 +98,19 @@ struct PluginManager {
     DeviceImages.emplace_back(std::make_unique<DeviceImageTy>(TgtBinDesc, TgtDeviceImage));
   }
 
+  /// Return the device presented to the user as device \p DeviceNo if it is
+  /// initialized and ready. Otherwise return an error explaining the problem.
+  llvm::Expected<DeviceTy &> getDevice(uint32_t DeviceNo);
+
+  /// Iterate over all initialized and ready devices registered with this
+  /// plugin.
+  auto devices(ExclusiveDevicesAccessorTy &DevicesAccessor) {
+    return llvm::make_pointee_range(*DevicesAccessor);
+  }
+
   /// Iterate over all device images registered with this plugin.
   auto deviceImages() { return llvm::make_pointee_range(DeviceImages); }
 
-  /// Devices associated with RTLs
-  llvm::SmallVector<std::unique_ptr<DeviceTy>> Devices;
-  std::mutex RTLsMtx; ///< For RTLs and Devices
-
   /// Translation table retreived from the binary
   HostEntriesBeginToTransTableTy HostEntriesBeginToTransTable;
   std::mutex TrlTblMtx; ///< For Translation Table
@@ -124,9 +139,12 @@ struct PluginManager {
     DelayedBinDesc.clear();
   }
 
-  int getNumDevices() {
-    std::lock_guard<decltype(RTLsMtx)> Lock(RTLsMtx);
-    return Devices.size();
+  /// Return the number of usable devices.
+  int getNumDevices() { return getExclusiveDevicesAccessor()->size(); }
+
+  /// Return an exclusive handle to access the devices container.
+  ExclusiveDevicesAccessorTy getExclusiveDevicesAccessor() {
+    return Devices.getExclusiveAccessor();
   }
 
   int getNumUsedPlugins() const {
@@ -166,6 +184,11 @@ struct PluginManager {
 
   /// The user provided requirements.
   RequirementCollection Requirements;
+
+  std::mutex RTLsMtx; ///< For RTLs
+
+  /// Devices associated with plugins, accesses to the container are exclusive.
+  ProtectedObj<DeviceContainerTy> Devices;
 };
 
 extern PluginManager *PM;
diff --git a/openmp/libomptarget/include/Shared/Debug.h b/openmp/libomptarget/include/Shared/Debug.h
index 9f8818429c779..a39626d15386b 100644
--- a/openmp/libomptarget/include/Shared/Debug.h
+++ b/openmp/libomptarget/include/Shared/Debug.h
@@ -115,15 +115,16 @@ inline uint32_t getDebugLevel() {
 /// Print fatal error message with an error string and error identifier
 #define FATAL_MESSAGE0(_num, _str)                                             \
   do {                                                                         \
-    fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: %s\n", _num, _str); \
+    fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: %s\n", (int)_num,   \
+            _str);                                                             \
     abort();                                                                   \
   } while (0)
 
 /// Print fatal error message with a printf string and error identifier
 #define FATAL_MESSAGE(_num, _str, ...)                                         \
   do {                                                                         \
-    fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: " _str "\n", _num,  \
-            __VA_ARGS__);                                                      \
+    fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: " _str "\n",        \
+            (int)_num, __VA_ARGS__);                                           \
     abort();                                                                   \
   } while (0)
 
diff --git a/openmp/libomptarget/include/device.h b/openmp/libomptarget/include/device.h
index 05ed6546557a4..5146fc1444b44 100644
--- a/openmp/libomptarget/include/device.h
+++ b/openmp/libomptarget/include/device.h
@@ -202,9 +202,8 @@ struct DeviceTy {
   /// completed and AsyncInfo.isDone() returns true.
   int32_t queryAsync(AsyncInfoTy &AsyncInfo);
 
-  /// Calls the corresponding print in the \p RTLDEVID
-  /// device RTL to obtain the information of the specific device.
-  bool printDeviceInfo(int32_t RTLDevID);
+  /// Calls the corresponding print device info function in the plugin.
+  bool printDeviceInfo();
 
   /// Event related interfaces.
   /// {
@@ -245,6 +244,4 @@ struct DeviceTy {
   llvm::DenseMap<llvm::StringRef, OffloadEntryTy *> DeviceOffloadEntries;
 };
 
-extern bool deviceIsReady(int DeviceNum);
-
 #endif
diff --git a/openmp/libomptarget/src/OpenMP/InteropAPI.cpp b/openmp/libomptarget/src/OpenMP/InteropAPI.cpp
index 6a40dbca87afd..c96ce2ce60b75 100644
--- a/openmp/libomptarget/src/OpenMP/InteropAPI.cpp
+++ b/openmp/libomptarget/src/OpenMP/InteropAPI.cpp
@@ -13,6 +13,9 @@
 #include "PluginManager.h"
 #include "device.h"
 #include "omptarget.h"
+#include "llvm/Support/Error.h"
+#include <cstdlib>
+#include <cstring>
 
 extern "C" {
 
@@ -190,6 +193,14 @@ __OMP_GET_INTEROP_TY3(const char *, type_desc)
 __OMP_GET_INTEROP_TY3(const char *, rc_desc)
 #undef __OMP_GET_INTEROP_TY3
 
+static const char *copyErrorString(llvm::Error &&Err) {
+  // TODO: Use the error string while avoiding leaks.
+  std::string ErrMsg = llvm::toString(std::move(Err));
+  char *UsrMsg = reinterpret_cast<char *>(malloc(ErrMsg.size() + 1));
+  strcpy(UsrMsg, ErrMsg.c_str());
+  return UsrMsg;
+};
+
 extern "C" {
 
 void __tgt_interop_init(ident_t *LocRef, int32_t Gtid,
@@ -211,12 +222,14 @@ void __tgt_interop_init(ident_t *LocRef, int32_t Gtid,
   }
 
   InteropPtr = new omp_interop_val_t(DeviceId, InteropType);
-  if (!deviceIsReady(DeviceId)) {
-    InteropPtr->err_str = "Device not ready!";
+
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr) {
+    InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
     return;
   }
 
-  DeviceTy &Device = *PM->Devices[DeviceId];
+  DeviceTy &Device = *DeviceOrErr;
   if (!Device.RTL || !Device.RTL->init_device_info ||
       Device.RTL->init_device_info(DeviceId, &(InteropPtr)->device_info,
                                    &(InteropPtr)->err_str)) {
@@ -248,8 +261,9 @@ void __tgt_interop_use(ident_t *LocRef, int32_t Gtid,
   assert((DeviceId == -1 || InteropVal->device_id == DeviceId) &&
          "Inconsistent device-id usage!");
 
-  if (!deviceIsReady(DeviceId)) {
-    InteropPtr->err_str = "Device not ready!";
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr) {
+    InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
     return;
   }
 
@@ -277,8 +291,9 @@ void __tgt_interop_destroy(ident_t *LocRef, int32_t Gtid,
 
   assert((DeviceId == -1 || InteropVal->device_id == DeviceId) &&
          "Inconsistent device-id usage!");
-  if (!deviceIsReady(DeviceId)) {
-    InteropPtr->err_str = "Device not ready!";
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr) {
+    InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
     return;
   }
 
diff --git a/openmp/libomptarget/src/PluginManager.cpp b/openmp/libomptarget/src/PluginManager.cpp
index e6dedeb699b14..931143ad2347d 100644
--- a/openmp/libomptarget/src/PluginManager.cpp
+++ b/openmp/libomptarget/src/PluginManager.cpp
@@ -11,6 +11,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "PluginManager.h"
+#include "Shared/Debug.h"
+
+#include "llvm/Support/Error.h"
+#include "llvm/Support/ErrorHandling.h"
 
 using namespace llvm;
 using namespace llvm::sys;
@@ -71,7 +75,12 @@ PluginAdaptorTy::PluginAdaptorTy(const std::string &Name) : Name(Name) {
 
 void PluginAdaptorTy::addOffloadEntries(DeviceImageTy &DI) {
   for (int32_t I = 0; I < NumberOfDevices; ++I) {
-    DeviceTy &Device = *PM->Devices[DeviceOffset + I];
+    auto DeviceOrErr = PM->getDevice(DeviceOffset + I);
+    if (!DeviceOrErr)
+      FATAL_MESSAGE(DeviceOffset + I, "%s",
+                    toString(DeviceOrErr.takeError()).c_str());
+
+    DeviceTy &Device = *DeviceOrErr;
     for (OffloadEntryTy &Entry : DI.entries())
       Device.addOffloadEntry(Entry);
   }
@@ -97,14 +106,15 @@ void PluginManager::initPlugin(PluginAdaptorTy &Plugin) {
     return;
 
   // Initialize the device information for the RTL we are about to use.
-  const size_t Start = Devices.size();
-  Devices.reserve(Start + Plugin.NumberOfDevices);
+  auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
+  const size_t Start = ExclusiveDevicesAccessor->size();
+  ExclusiveDevicesAccessor->reserve(Start + Plugin.NumberOfDevices);
   for (int32_t DeviceId = 0; DeviceId < Plugin.NumberOfDevices; DeviceId++) {
-    Devices.push_back(std::make_unique<DeviceTy>(&Plugin));
+    ExclusiveDevicesAccessor->push_back(std::make_unique<DeviceTy>(&Plugin));
     // global device ID
-    Devices[Start + DeviceId]->DeviceID = Start + DeviceId;
+    (*ExclusiveDevicesAccessor)[Start + DeviceId]->DeviceID = Start + DeviceId;
     // RTL local device ID
-    Devices[Start + DeviceId]->RTLDeviceID = DeviceId;
+    (*ExclusiveDevicesAccessor)[Start + DeviceId]->RTLDeviceID = DeviceId;
   }
 
   // Initialize the index of this RTL and save it in the used RTLs.
@@ -254,7 +264,12 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
       // Execute dtors for static objects if the device has been used, i.e.
       // if its PendingCtors list has been emptied.
       for (int32_t I = 0; I < FoundRTL->NumberOfDevices; ++I) {
-        DeviceTy &Device = *PM->Devices[FoundRTL->DeviceOffset + I];
+        auto DeviceOrErr = PM->getDevice(FoundRTL->DeviceOffset + I);
+        if (!DeviceOrErr)
+          FATAL_MESSAGE(FoundRTL->DeviceOffset + I, "%s",
+                        toString(DeviceOrErr.takeError()).c_str());
+
+        DeviceTy &Device = *DeviceOrErr;
         Device.PendingGlobalsMtx.lock();
         if (Device.PendingCtorsDtors[Desc].PendingCtors.empty()) {
           AsyncInfoTy AsyncInfo(Device);
@@ -313,3 +328,26 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
 
   DP("Done unregistering library!\n");
 }
+
+Expected<DeviceTy &> PluginManager::getDevice(uint32_t DeviceNo) {
+  auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
+  if (DeviceNo >= ExclusiveDevicesAccessor->size())
+    return createStringError(
+        inconvertibleErrorCode(),
+        "Device number '%i' out of range, only %i devices available", DeviceNo,
+        ExclusiveDevicesAccessor->size());
+
+  DeviceTy &Device = *(*ExclusiveDevicesAccessor)[DeviceNo];
+
+  DP("Is the device %d (local ID %d) initialized? %d\n", DeviceNo,
+     Device.RTLDeviceID, Device.IsInit);
+
+  // Init the device if not done before
+  if (!Device.IsInit && Device.initOnce() != OFFLOAD_SUCCESS) {
+    return createStringError(inconvertibleErrorCode(),
+                             "Failed to init device %d\n", DeviceNo);
+  }
+
+  DP("Device %d is ready to use.\n", DeviceNo);
+  return Device;
+}
diff --git a/openmp/libomptarget/src/api.cpp b/openmp/libomptarget/src/api.cpp
index cc4cca286df51..0341e0c754649 100644
--- a/openmp/libomptarget/src/api.cpp
+++ b/openmp/libomptarget/src/api.cpp
@@ -110,21 +110,18 @@ EXTERN int omp_target_is_present(const void *Ptr, int DeviceNum) {
     return true;
   }
 
-  size_t NumDevices = PM->getNumDevices();
-  if (NumDevices <= (size_t)DeviceNum) {
-    DP("Call to omp_target_is_present with invalid device ID, returning "
-       "false\n");
-    return false;
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceNum);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-  DeviceTy &Device = *PM->Devices[DeviceNum];
   // omp_target_is_present tests whether a host pointer refers to storage that
   // is mapped to a given device. However, due to the lack of the storage size,
   // only check 1 byte. Cannot set size 0 which checks whether the pointer (zero
   // lengh array) is mapped instead of the referred storage.
-  TargetPointerResultTy TPR = Device.getTgtPtrBegin(const_cast<void *>(Ptr), 1,
-                                                    /*UpdateRefCount=*/false,
-                                                    /*UseHoldRefCount=*/false);
+  TargetPointerResultTy TPR =
+      DeviceOrErr->getTgtPtrBegin(const_cast<void *>(Ptr), 1,
+                                  /*UpdateRefCount=*/false,
+                                  /*UseHoldRefCount=*/false);
   int Rc = TPR.isPresent();
   DP("Call to omp_target_is_present returns %d\n", Rc);
   return Rc;
@@ -150,16 +147,6 @@ EXTERN int omp_target_memcpy(void *Dst, const void *Src, size_t Length,
     return OFFLOAD_FAIL;
   }
 
-  if (SrcDevice != omp_get_initial_device() && !deviceIsReady(SrcDevice)) {
-    REPORT("omp_target_memcpy returns OFFLOAD_FAIL\n");
-    return OFFLOAD_FAIL;
-  }
-
-  if (DstDevice != omp_get_initial_device() && !deviceIsReady(DstDevice)) {
-    REPORT("omp_target_memcpy returns OFFLOAD_FAIL\n");
-    return OFFLOAD_FAIL;
-  }
-
   int Rc = OFFLOAD_SUCCESS;
   void *SrcAddr = (char *)const_cast<void *>(Src) + SrcOffset;
   void *DstAddr = (char *)Dst + DstOffset;
@@ -172,35 +159,49 @@ EXTERN int omp_target_memcpy(void *Dst, const void *Src, size_t Length,
       Rc = OFFLOAD_FAIL;
   } else if (SrcDevice == omp_get_initial_device()) {
     DP("copy from host to device\n");
-    DeviceTy &DstDev = *PM->Devices[DstDevice];
-    AsyncInfoTy AsyncInfo(DstDev);
-    Rc = DstDev.submitData(DstAddr, SrcAddr, Length, AsyncInfo);
+    auto DstDeviceOrErr = PM->getDevice(DstDevice);
+    if (!DstDeviceOrErr)
+      FATAL_MESSAGE(DstDevice, "%s",
+                    toString(DstDeviceOrErr.takeError()).c_str());
+    AsyncInfoTy AsyncInfo(*DstDeviceOrErr);
+    Rc = DstDeviceOrErr->submitData(DstAddr, SrcAddr, Length, AsyncInfo);
   } else if (DstDevice == omp_get_initial_device()) {
     DP("copy from device to host\n");
-    DeviceTy &SrcDev = *PM->Devices[SrcDevice];
-    AsyncInfoTy AsyncInfo(SrcDev);
-    Rc = SrcDev.retrieveData(DstAddr, SrcAddr, Length, AsyncInfo);
+    auto SrcDeviceOrErr = PM->getDevice(SrcDevice);
+    if (!SrcDeviceOrErr)
+      FATAL_MESSAGE(SrcDevice, "%s",
+                    toString(SrcDeviceOrErr.takeError()).c_str());
+    AsyncInfoTy AsyncInfo(*SrcDeviceOrErr);
+    Rc = SrcDeviceOrErr->retrieveData(DstAddr, SrcAddr, Length, AsyncInfo);
   } else {
     DP("copy from device to device\n");
-    DeviceTy &SrcDev = *PM->Devices[SrcDevice];
-    DeviceTy &DstDev = *PM->Devices[DstDevice];
+    auto SrcDeviceOrErr = PM->getDevice(SrcDevice);
+    if (!SrcDeviceOrErr)
+      FATAL_MESSAGE(SrcDevice, "%s",
+                    toString(SrcDeviceOrErr.takeError()).c_str());
+    AsyncInfoTy AsyncInfo(*SrcDeviceOrErr);
+    auto DstDeviceOrErr = PM->getDevice(DstDevice);
+    if (!DstDeviceOrErr)
+      FATAL_MESSAGE(DstDevice, "%s",
+                    toString(DstDeviceOrErr.takeError()).c_str());
     // First try to use D2D memcpy which is more efficient. If fails, fall back
     // to unefficient way.
-    if (SrcDev.isDataExchangable(DstDev)) {
-      AsyncInfoTy AsyncInfo(SrcDev);
-      Rc = SrcDev.dataExchange(SrcAddr, DstDev, DstAddr, Length, AsyncInfo);
+    if (SrcDeviceOrErr->isDataExchangable(*DstDeviceOrErr)) {
+      AsyncInfoTy AsyncInfo(*SrcDeviceOrErr);
+      Rc = SrcDeviceOrErr->dataExchange(SrcAddr, *DstDeviceOrErr, DstAddr,
+                                        Length, AsyncInfo);
       if (Rc == OFFLOAD_SUCCESS)
         return OFFLOAD_SUCCESS;
     }
 
     void *Buffer = malloc(Length);
     {
-      AsyncInfoTy AsyncInfo(SrcDev);
-      Rc = SrcDev.retrieveData(Buffer, SrcAddr, Length, AsyncInfo);
+      AsyncInfoTy AsyncInfo(*SrcDeviceOrErr);
+      Rc = SrcDeviceOrErr->retrieveData(Buffer, SrcAddr, Length, AsyncInfo);
     }
     if (Rc == OFFLOAD_SUCCESS) {
-      AsyncInfoTy AsyncInfo(DstDev);
-      Rc = DstDev.submitData(DstAddr, Buffer, Length, AsyncInfo);
+      AsyncInfoTy AsyncInfo(*DstDeviceOrErr);
+      Rc = DstDeviceOrErr->submitData(DstAddr, Buffer, Length, AsyncInfo);
     }
     free(Buffer);
   }
@@ -507,15 +508,13 @@ EXTERN int omp_target_associate_ptr(const void *HostPtr, const void *DevicePtr,
     return OFFLOAD_FAIL;
   }
 
-  if (!deviceIsReady(DeviceNum)) {
-    REPORT("omp_target_associate_ptr returns OFFLOAD_FAIL\n");
-    return OFFLOAD_FAIL;
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceNum);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-  DeviceTy &Device = *PM->Devices[DeviceNum];
   void *DeviceAddr = (void *)((uint64_t)DevicePtr + (uint64_t)DeviceOffset);
-  int Rc = Device.associatePtr(const_cast<void *>(HostPtr),
-                               const_cast<void *>(DeviceAddr), Size);
+  int Rc = DeviceOrErr->associatePtr(const_cast<void *>(HostPtr),
+                                     const_cast<void *>(DeviceAddr), Size);
   DP("omp_target_associate_ptr returns %d\n", Rc);
   return Rc;
 }
@@ -537,13 +536,11 @@ EXTERN int omp_target_disassociate_ptr(const void *HostPtr, int DeviceNum) {
     return OFFLOAD_FAIL;
   }
 
-  if (!deviceIsReady(DeviceNum)) {
-    REPORT("omp_target_disassociate_ptr returns OFFLOAD_FAIL\n");
-    return OFFLOAD_FAIL;
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceNum);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-  DeviceTy &Device = *PM->Devices[DeviceNum];
-  int Rc = Device.disassociatePtr(const_cast<void *>(HostPtr));
+  int Rc = DeviceOrErr->disassociatePtr(const_cast<void *>(HostPtr));
   DP("omp_target_disassociate_ptr returns %d\n", Rc);
   return Rc;
 }
@@ -570,15 +567,14 @@ EXTERN void *omp_get_mapped_ptr(const void *Ptr, int DeviceNum) {
     return nullptr;
   }
 
-  if (!deviceIsReady(DeviceNum)) {
-    REPORT("Device %d is not ready, returning nullptr.\n", DeviceNum);
-    return nullptr;
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceNum);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-  auto &Device = *PM->Devices[DeviceNum];
-  TargetPointerResultTy TPR = Device.getTgtPtrBegin(const_cast<void *>(Ptr), 1,
-                                                    /*UpdateRefCount=*/false,
-                                                    /*UseHoldRefCount=*/false);
+  TargetPointerResultTy TPR =
+      DeviceOrErr->getTgtPtrBegin(const_cast<void *>(Ptr), 1,
+                                  /*UpdateRefCount=*/false,
+                                  /*UseHoldRefCount=*/false);
   if (!TPR.isPresent()) {
     DP("Ptr " DPxMOD "is not present on device %d, returning nullptr.\n",
        DPxPTR(Ptr), DeviceNum);
diff --git a/openmp/libomptarget/src/device.cpp b/openmp/libomptarget/src/device.cpp
index d3481d42af967..ad9563e04def4 100644
--- a/openmp/libomptarget/src/device.cpp
+++ b/openmp/libomptarget/src/device.cpp
@@ -711,10 +711,10 @@ int32_t DeviceTy::launchKernel(void *TgtEntryPtr, void **TgtVarsPtr,
 }
 
 // Run region on device
-bool DeviceTy::printDeviceInfo(int32_t RTLDevId) {
+bool DeviceTy::printDeviceInfo() {
   if (!RTL->print_device_info)
     return false;
-  RTL->print_device_info(RTLDevId);
+  RTL->print_device_info(RTLDeviceID);
   return true;
 }
 
@@ -778,39 +778,6 @@ int32_t DeviceTy::destroyEvent(void *Event) {
   return OFFLOAD_SUCCESS;
 }
 
-/// Check whether a device has an associated RTL and initialize it if it's not
-/// already initialized.
-bool deviceIsReady(int DeviceNum) {
-  DP("Checking whether device %d is ready.\n", DeviceNum);
-  // Devices.size() can only change whi...
[truncated]

We accessed the `Devices` container most of the time while holding the
RTLsMtx, but not always. Sometimes we used the mutex for the size query,
but then accessed Devices again unguarded. From now we properly
encapsulate the container in a ProtectedObj which ensures exclusive
accesses. We also hide the "isReady" part in the `getDevice` accessor
and use an `llvm::Expected` to allow to return errors.
@jdoerfert jdoerfert merged commit 66784dc into llvm:main Dec 5, 2023
3 checks passed
@jdoerfert jdoerfert deleted the offload_prep8 branch December 5, 2023 01:10
searlmc1 pushed a commit to ROCm/llvm-project that referenced this pull request Dec 6, 2023
Reverts: to work later
  [OpenMP] Ensure `Devices` is accessed exlusively (llvm#74374)
Change-Id: Ia8c9a666a385bebe11d4982a41b0538ddb98c3e1
searlmc1 pushed a commit to ROCm/llvm-project that referenced this pull request Dec 12, 2023
Restores:
  [OpenMP] Ensure `Devices` is accessed exlusively (llvm#74374)

Change-Id: I34e5814a76c61cba9deae2c129e3aae96116662e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
openmp:libomptarget OpenMP offload runtime openmp
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants