Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Libomptarget] Factor functions out of 'Plugin' interface #86528

Merged
merged 1 commit into from
Mar 25, 2024

Conversation

jhuber6
Copy link
Contributor

@jhuber6 jhuber6 commented Mar 25, 2024

Summary:
This patch factors common functions out of the Plugin interface prior
to its removal in a future patch. This simply temporarily renames it to
PluginTy so that we could re-use Plugin::check internally as this
needs to be defined statically per plugin now. We can refactor this
later.

The future patch will delete PluginTy and PluginTy::get entirely.
This simply tries to minimize a few changes to make it easier to land.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 25, 2024

@llvm/pr-subscribers-backend-amdgpu

Author: Joseph Huber (jhuber6)

Changes

Summary:
This patch factors common functions out of the Plugin interface prior
to its removal in a future patch. This simply temporarily renames it to
PluginTy so that we could re-use Plugin::check internally as this
needs to be defined statically per plugin now. We can refactor this
later.

The future patch will delete PluginTy and PluginTy::get entirely.
This simply tries to minimize a few changes to make it easier to land.


Patch is 32.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86528.diff

6 Files Affected:

  • (modified) openmp/libomptarget/CMakeLists.txt (+1-2)
  • (modified) openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp (+20-19)
  • (modified) openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h (+37-33)
  • (modified) openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp (+39-39)
  • (modified) openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp (+14-12)
  • (modified) openmp/libomptarget/plugins-nextgen/host/src/rtl.cpp (+15-13)
diff --git a/openmp/libomptarget/CMakeLists.txt b/openmp/libomptarget/CMakeLists.txt
index b382137b70ee2a..531198fae01699 100644
--- a/openmp/libomptarget/CMakeLists.txt
+++ b/openmp/libomptarget/CMakeLists.txt
@@ -145,8 +145,7 @@ add_subdirectory(DeviceRTL)
 add_subdirectory(tools)
 
 # Build target agnostic offloading library.
-set(LIBOMPTARGET_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src)
-add_subdirectory(${LIBOMPTARGET_SRC_DIR})
+add_subdirectory(src)
 
 # Add tests.
 add_subdirectory(test)
diff --git a/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
index fce7454bf2800d..14461f14dd6706 100644
--- a/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -2088,7 +2088,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
   /// Allocate and construct an AMDGPU kernel.
   Expected<GenericKernelTy &> constructKernel(const char *Name) override {
     // Allocate and construct the AMDGPU kernel.
-    AMDGPUKernelTy *AMDGPUKernel = Plugin::get().allocate<AMDGPUKernelTy>();
+    AMDGPUKernelTy *AMDGPUKernel = PluginTy::get().allocate<AMDGPUKernelTy>();
     if (!AMDGPUKernel)
       return Plugin::error("Failed to allocate memory for AMDGPU kernel");
 
@@ -2139,7 +2139,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
                                            int32_t ImageId) override {
     // Allocate and initialize the image object.
     AMDGPUDeviceImageTy *AMDImage =
-        Plugin::get().allocate<AMDGPUDeviceImageTy>();
+        PluginTy::get().allocate<AMDGPUDeviceImageTy>();
     new (AMDImage) AMDGPUDeviceImageTy(ImageId, *this, TgtImage);
 
     // Load the HSA executable.
@@ -2697,7 +2697,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
   }
   Error setDeviceHeapSize(uint64_t Value) override {
     for (DeviceImageTy *Image : LoadedImages)
-      if (auto Err = setupDeviceMemoryPool(Plugin::get(), *Image, Value))
+      if (auto Err = setupDeviceMemoryPool(PluginTy::get(), *Image, Value))
         return Err;
     DeviceMemoryPoolSize = Value;
     return Plugin::success();
@@ -2737,7 +2737,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     return utils::iterateAgentMemoryPools(
         Agent, [&](hsa_amd_memory_pool_t HSAMemoryPool) {
           AMDGPUMemoryPoolTy *MemoryPool =
-              Plugin::get().allocate<AMDGPUMemoryPoolTy>();
+              PluginTy::get().allocate<AMDGPUMemoryPoolTy>();
           new (MemoryPool) AMDGPUMemoryPoolTy(HSAMemoryPool);
           AllMemoryPools.push_back(MemoryPool);
           return HSA_STATUS_SUCCESS;
@@ -3115,6 +3115,17 @@ struct AMDGPUPluginTy final : public GenericPluginTy {
     return Plugin::check(Status, "Error in hsa_shut_down: %s");
   }
 
+  /// Creates an AMDGPU device.
+  GenericDeviceTy *createDevice(int32_t DeviceId, int32_t NumDevices) override {
+    return new AMDGPUDeviceTy(DeviceId, NumDevices, getHostDevice(),
+                              getKernelAgent(DeviceId));
+  }
+
+  /// Creates an AMDGPU global handler.
+  GenericGlobalHandlerTy *createGlobalHandler() override {
+    return new AMDGPUGlobalHandlerTy();
+  }
+
   Triple::ArchType getTripleArch() const override { return Triple::amdgcn; }
 
   /// Get the ELF code for recognizing the compatible image binary.
@@ -3237,7 +3248,7 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
   // 56 bytes per allocation.
   uint32_t AllArgsSize = KernelArgsSize + ImplicitArgsSize;
 
-  AMDHostDeviceTy &HostDevice = Plugin::get<AMDGPUPluginTy>().getHostDevice();
+  AMDHostDeviceTy &HostDevice = PluginTy::get<AMDGPUPluginTy>().getHostDevice();
   AMDGPUMemoryManagerTy &ArgsMemoryManager = HostDevice.getArgsMemoryManager();
 
   void *AllArgs = nullptr;
@@ -3347,20 +3358,10 @@ Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
   return Plugin::success();
 }
 
-GenericPluginTy *Plugin::createPlugin() { return new AMDGPUPluginTy(); }
-
-GenericDeviceTy *Plugin::createDevice(int32_t DeviceId, int32_t NumDevices) {
-  AMDGPUPluginTy &Plugin = get<AMDGPUPluginTy &>();
-  return new AMDGPUDeviceTy(DeviceId, NumDevices, Plugin.getHostDevice(),
-                            Plugin.getKernelAgent(DeviceId));
-}
-
-GenericGlobalHandlerTy *Plugin::createGlobalHandler() {
-  return new AMDGPUGlobalHandlerTy();
-}
+GenericPluginTy *PluginTy::createPlugin() { return new AMDGPUPluginTy(); }
 
 template <typename... ArgsTy>
-Error Plugin::check(int32_t Code, const char *ErrFmt, ArgsTy... Args) {
+static Error Plugin::check(int32_t Code, const char *ErrFmt, ArgsTy... Args) {
   hsa_status_t ResultCode = static_cast<hsa_status_t>(Code);
   if (ResultCode == HSA_STATUS_SUCCESS || ResultCode == HSA_STATUS_INFO_BREAK)
     return Error::success();
@@ -3384,7 +3385,7 @@ void *AMDGPUMemoryManagerTy::allocate(size_t Size, void *HstPtr,
   }
   assert(Ptr && "Invalid pointer");
 
-  auto &KernelAgents = Plugin::get<AMDGPUPluginTy>().getKernelAgents();
+  auto &KernelAgents = PluginTy::get<AMDGPUPluginTy>().getKernelAgents();
 
   // Allow all kernel agents to access the allocation.
   if (auto Err = MemoryPool->enableAccess(Ptr, Size, KernelAgents)) {
@@ -3427,7 +3428,7 @@ void *AMDGPUDeviceTy::allocate(size_t Size, void *, TargetAllocTy Kind) {
   }
 
   if (Alloc) {
-    auto &KernelAgents = Plugin::get<AMDGPUPluginTy>().getKernelAgents();
+    auto &KernelAgents = PluginTy::get<AMDGPUPluginTy>().getKernelAgents();
     // Inherently necessary for host or shared allocations
     // Also enabled for device memory to allow device to device memcpy
 
diff --git a/openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h b/openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h
index b7be7b645ba33e..1440f8f678d820 100644
--- a/openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h
+++ b/openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h
@@ -976,6 +976,13 @@ struct GenericPluginTy {
   Error deinit();
   virtual Error deinitImpl() = 0;
 
+  /// Create a new device for the underlying plugin.
+  virtual GenericDeviceTy *createDevice(int32_t DeviceID,
+                                        int32_t NumDevices) = 0;
+
+  /// Create a new global handler for the underlying plugin.
+  virtual GenericGlobalHandlerTy *createGlobalHandler() = 0;
+
   /// Get the reference to the device with a certain device id.
   GenericDeviceTy &getDevice(int32_t DeviceId) {
     assert(isValidDeviceId(DeviceId) && "Invalid device id");
@@ -1085,29 +1092,53 @@ struct GenericPluginTy {
   RPCServerTy *RPCServer;
 };
 
+namespace Plugin {
+/// Create a success error. This is the same as calling Error::success(), but
+/// it is recommended to use this one for consistency with Plugin::error() and
+/// Plugin::check().
+static Error success() { return Error::success(); }
+
+/// Create a string error.
+template <typename... ArgsTy>
+static Error error(const char *ErrFmt, ArgsTy... Args) {
+  return createStringError(inconvertibleErrorCode(), ErrFmt, Args...);
+}
+
+/// Check the plugin-specific error code and return an error or success
+/// accordingly. In case of an error, create a string error with the error
+/// description. The ErrFmt should follow the format:
+///     "Error in <function name>[<optional info>]: %s"
+/// The last format specifier "%s" is mandatory and will be used to place the
+/// error code's description. Notice this function should be only called from
+/// the plugin-specific code.
+/// TODO: Refactor this, must be defined individually by each plugin.
+template <typename... ArgsTy>
+static Error check(int32_t ErrorCode, const char *ErrFmt, ArgsTy... Args);
+} // namespace Plugin
+
 /// Class for simplifying the getter operation of the plugin. Anywhere on the
 /// code, the current plugin can be retrieved by Plugin::get(). The class also
 /// declares functions to create plugin-specific object instances. The check(),
 /// createPlugin(), createDevice() and createGlobalHandler() functions should be
 /// defined by each plugin implementation.
-class Plugin {
+class PluginTy {
   // Reference to the plugin instance.
   static GenericPluginTy *SpecificPlugin;
 
-  Plugin() {
+  PluginTy() {
     if (auto Err = init())
       REPORT("Failed to initialize plugin: %s\n",
              toString(std::move(Err)).data());
   }
 
-  ~Plugin() {
+  ~PluginTy() {
     if (auto Err = deinit())
       REPORT("Failed to deinitialize plugin: %s\n",
              toString(std::move(Err)).data());
   }
 
-  Plugin(const Plugin &) = delete;
-  void operator=(const Plugin &) = delete;
+  PluginTy(const PluginTy &) = delete;
+  void operator=(const PluginTy &) = delete;
 
   /// Create and intialize the plugin instance.
   static Error init() {
@@ -1158,7 +1189,7 @@ class Plugin {
     // This static variable will initialize the underlying plugin instance in
     // case there was no previous explicit initialization. The initialization is
     // thread safe.
-    static Plugin Plugin;
+    static PluginTy Plugin;
 
     assert(SpecificPlugin && "Plugin is not active");
     return *SpecificPlugin;
@@ -1170,35 +1201,8 @@ class Plugin {
   /// Indicate whether the plugin is active.
   static bool isActive() { return SpecificPlugin != nullptr; }
 
-  /// Create a success error. This is the same as calling Error::success(), but
-  /// it is recommended to use this one for consistency with Plugin::error() and
-  /// Plugin::check().
-  static Error success() { return Error::success(); }
-
-  /// Create a string error.
-  template <typename... ArgsTy>
-  static Error error(const char *ErrFmt, ArgsTy... Args) {
-    return createStringError(inconvertibleErrorCode(), ErrFmt, Args...);
-  }
-
-  /// Check the plugin-specific error code and return an error or success
-  /// accordingly. In case of an error, create a string error with the error
-  /// description. The ErrFmt should follow the format:
-  ///     "Error in <function name>[<optional info>]: %s"
-  /// The last format specifier "%s" is mandatory and will be used to place the
-  /// error code's description. Notice this function should be only called from
-  /// the plugin-specific code.
-  template <typename... ArgsTy>
-  static Error check(int32_t ErrorCode, const char *ErrFmt, ArgsTy... Args);
-
   /// Create a plugin instance.
   static GenericPluginTy *createPlugin();
-
-  /// Create a plugin-specific device.
-  static GenericDeviceTy *createDevice(int32_t DeviceId, int32_t NumDevices);
-
-  /// Create a plugin-specific global handler.
-  static GenericGlobalHandlerTy *createGlobalHandler();
 };
 
 /// Auxiliary interface class for GenericDeviceResourceManagerTy. This class
diff --git a/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp
index f39f913d85eec2..fe937e984523fd 100644
--- a/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp
@@ -39,7 +39,7 @@ using namespace omp;
 using namespace target;
 using namespace plugin;
 
-GenericPluginTy *Plugin::SpecificPlugin = nullptr;
+GenericPluginTy *PluginTy::SpecificPlugin = nullptr;
 
 // TODO: Fix any thread safety issues for multi-threaded kernel recording.
 struct RecordReplayTy {
@@ -438,7 +438,7 @@ Error GenericKernelTy::init(GenericDeviceTy &GenericDevice,
   // Retrieve kernel environment object for the kernel.
   GlobalTy KernelEnv(std::string(Name) + "_kernel_environment",
                      sizeof(KernelEnvironment), &KernelEnvironment);
-  GenericGlobalHandlerTy &GHandler = Plugin::get().getGlobalHandler();
+  GenericGlobalHandlerTy &GHandler = PluginTy::get().getGlobalHandler();
   if (auto Err =
           GHandler.readGlobalFromImage(GenericDevice, *ImagePtr, KernelEnv)) {
     [[maybe_unused]] std::string ErrStr = toString(std::move(Err));
@@ -1488,7 +1488,7 @@ Error GenericPluginTy::init() {
   assert(Devices.size() == 0 && "Plugin already initialized");
   Devices.resize(NumDevices, nullptr);
 
-  GlobalHandler = Plugin::createGlobalHandler();
+  GlobalHandler = createGlobalHandler();
   assert(GlobalHandler && "Invalid global handler");
 
   RPCServer = new RPCServerTy(NumDevices);
@@ -1522,7 +1522,7 @@ Error GenericPluginTy::initDevice(int32_t DeviceId) {
   assert(!Devices[DeviceId] && "Device already initialized");
 
   // Create the device and save the reference.
-  GenericDeviceTy *Device = Plugin::createDevice(DeviceId, NumDevices);
+  GenericDeviceTy *Device = createDevice(DeviceId, NumDevices);
   assert(Device && "Invalid device");
 
   // Save the device reference into the list.
@@ -1581,7 +1581,7 @@ extern "C" {
 #endif
 
 int32_t __tgt_rtl_init_plugin() {
-  auto Err = Plugin::initIfNeeded();
+  auto Err = PluginTy::initIfNeeded();
   if (Err) {
     [[maybe_unused]] std::string ErrStr = toString(std::move(Err));
     DP("Failed to init plugin: %s", ErrStr.c_str());
@@ -1592,7 +1592,7 @@ int32_t __tgt_rtl_init_plugin() {
 }
 
 int32_t __tgt_rtl_is_valid_binary(__tgt_device_image *Image) {
-  if (!Plugin::isActive())
+  if (!PluginTy::isActive())
     return false;
 
   StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
@@ -1609,13 +1609,13 @@ int32_t __tgt_rtl_is_valid_binary(__tgt_device_image *Image) {
   case file_magic::elf_executable:
   case file_magic::elf_shared_object:
   case file_magic::elf_core: {
-    auto MatchOrErr = Plugin::get().checkELFImage(Buffer);
+    auto MatchOrErr = PluginTy::get().checkELFImage(Buffer);
     if (Error Err = MatchOrErr.takeError())
       return HandleError(std::move(Err));
     return *MatchOrErr;
   }
   case file_magic::bitcode: {
-    auto MatchOrErr = Plugin::get().getJIT().checkBitcodeImage(Buffer);
+    auto MatchOrErr = PluginTy::get().getJIT().checkBitcodeImage(Buffer);
     if (Error Err = MatchOrErr.takeError())
       return HandleError(std::move(Err));
     return *MatchOrErr;
@@ -1626,7 +1626,7 @@ int32_t __tgt_rtl_is_valid_binary(__tgt_device_image *Image) {
 }
 
 int32_t __tgt_rtl_init_device(int32_t DeviceId) {
-  auto Err = Plugin::get().initDevice(DeviceId);
+  auto Err = PluginTy::get().initDevice(DeviceId);
   if (Err) {
     REPORT("Failure to initialize device %d: %s\n", DeviceId,
            toString(std::move(Err)).data());
@@ -1636,23 +1636,23 @@ int32_t __tgt_rtl_init_device(int32_t DeviceId) {
   return OFFLOAD_SUCCESS;
 }
 
-int32_t __tgt_rtl_number_of_devices() { return Plugin::get().getNumDevices(); }
+int32_t __tgt_rtl_number_of_devices() { return PluginTy::get().getNumDevices(); }
 
 int64_t __tgt_rtl_init_requires(int64_t RequiresFlags) {
-  Plugin::get().setRequiresFlag(RequiresFlags);
+  PluginTy::get().setRequiresFlag(RequiresFlags);
   return OFFLOAD_SUCCESS;
 }
 
 int32_t __tgt_rtl_is_data_exchangable(int32_t SrcDeviceId,
                                       int32_t DstDeviceId) {
-  return Plugin::get().isDataExchangable(SrcDeviceId, DstDeviceId);
+  return PluginTy::get().isDataExchangable(SrcDeviceId, DstDeviceId);
 }
 
 int32_t __tgt_rtl_initialize_record_replay(int32_t DeviceId, int64_t MemorySize,
                                            void *VAddr, bool isRecord,
                                            bool SaveOutput,
                                            uint64_t &ReqPtrArgOffset) {
-  GenericPluginTy &Plugin = Plugin::get();
+  GenericPluginTy &Plugin = PluginTy::get();
   GenericDeviceTy &Device = Plugin.getDevice(DeviceId);
   RecordReplayTy::RRStatusTy Status =
       isRecord ? RecordReplayTy::RRStatusTy::RRRecording
@@ -1674,7 +1674,7 @@ int32_t __tgt_rtl_initialize_record_replay(int32_t DeviceId, int64_t MemorySize,
 
 int32_t __tgt_rtl_load_binary(int32_t DeviceId, __tgt_device_image *TgtImage,
                               __tgt_device_binary *Binary) {
-  GenericPluginTy &Plugin = Plugin::get();
+  GenericPluginTy &Plugin = PluginTy::get();
   GenericDeviceTy &Device = Plugin.getDevice(DeviceId);
 
   auto ImageOrErr = Device.loadBinary(Plugin, TgtImage);
@@ -1695,7 +1695,7 @@ int32_t __tgt_rtl_load_binary(int32_t DeviceId, __tgt_device_image *TgtImage,
 
 void *__tgt_rtl_data_alloc(int32_t DeviceId, int64_t Size, void *HostPtr,
                            int32_t Kind) {
-  auto AllocOrErr = Plugin::get().getDevice(DeviceId).dataAlloc(
+  auto AllocOrErr = PluginTy::get().getDevice(DeviceId).dataAlloc(
       Size, HostPtr, (TargetAllocTy)Kind);
   if (!AllocOrErr) {
     auto Err = AllocOrErr.takeError();
@@ -1710,7 +1710,7 @@ void *__tgt_rtl_data_alloc(int32_t DeviceId, int64_t Size, void *HostPtr,
 
 int32_t __tgt_rtl_data_delete(int32_t DeviceId, void *TgtPtr, int32_t Kind) {
   auto Err =
-      Plugin::get().getDevice(DeviceId).dataDelete(TgtPtr, (TargetAllocTy)Kind);
+      PluginTy::get().getDevice(DeviceId).dataDelete(TgtPtr, (TargetAllocTy)Kind);
   if (Err) {
     REPORT("Failure to deallocate device pointer %p: %s\n", TgtPtr,
            toString(std::move(Err)).data());
@@ -1722,7 +1722,7 @@ int32_t __tgt_rtl_data_delete(int32_t DeviceId, void *TgtPtr, int32_t Kind) {
 
 int32_t __tgt_rtl_data_lock(int32_t DeviceId, void *Ptr, int64_t Size,
                             void **LockedPtr) {
-  auto LockedPtrOrErr = Plugin::get().getDevice(DeviceId).dataLock(Ptr, Size);
+  auto LockedPtrOrErr = PluginTy::get().getDevice(DeviceId).dataLock(Ptr, Size);
   if (!LockedPtrOrErr) {
     auto Err = LockedPtrOrErr.takeError();
     REPORT("Failure to lock memory %p: %s\n", Ptr,
@@ -1740,7 +1740,7 @@ int32_t __tgt_rtl_data_lock(int32_t DeviceId, void *Ptr, int64_t Size,
 }
 
 int32_t __tgt_rtl_data_unlock(int32_t DeviceId, void *Ptr) {
-  auto Err = Plugin::get().getDevice(DeviceId).dataUnlock(Ptr);
+  auto Err = PluginTy::get().getDevice(DeviceId).dataUnlock(Ptr);
   if (Err) {
     REPORT("Failure to unlock memory %p: %s\n", Ptr,
            toString(std::move(Err)).data());
@@ -1752,7 +1752,7 @@ int32_t __tgt_rtl_data_unlock(int32_t DeviceId, void *Ptr) {
 
 int32_t __tgt_rtl_data_notify_mapped(int32_t DeviceId, void *HstPtr,
                                      int64_t Size) {
-  auto Err = Plugin::get().getDevice(DeviceId).notifyDataMapped(HstPtr, Size);
+  auto Err = PluginTy::get().getDevice(DeviceId).notifyDataMapped(HstPtr, Size);
   if (Err) {
     REPORT("Failure to notify data mapped %p: %s\n", HstPtr,
            toString(std::move(Err)).data());
@@ -1763,7 +1763,7 @@ int32_t __tgt_rtl_data_notify_mapped(int32_t DeviceId, void *HstPtr,
 }
 
 int32_t __tgt_rtl_data_notify_unmapped(int32_t DeviceId, void *HstPtr) {
-  auto Err = Plugin::get().getDevice(DeviceId).notifyDataUnmapped(HstPtr);
+  auto Err = PluginTy::get().getDevice(DeviceId).notifyDataUnmapped(HstPtr);
   if (Err) {
     REPORT("Failure to notify data unmapped %p: %s\n", HstPtr,
            toString(std::move(Err)).data());
@@ -1782,7 +1782,7 @@ int32_t __tgt_rtl_data_submit(int32_t DeviceId, void *TgtPtr, void *HstPtr,
 int32_t __tgt_rtl_data_submit_async(int32_t DeviceId, void *TgtPtr,
                                     void *HstPtr, int64_t Size,
                                     __tgt_async_info *AsyncInfoPtr) {
-  auto Err = Plugin::get().getDevice(DeviceId).dataSubmit(TgtPtr, HstPtr, Size,
+  auto Err = PluginTy::get().getDevice(DeviceId).dataSubmit(TgtPtr, HstPtr, Size,
                                                           AsyncInfoPtr);
   if (Err) {
     REPORT("Failure to copy data from host to device. Pointers: host "
@@ -1804,7 +1804,7 @@ int32_t __tgt_rtl_data_retrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr,
 int32_t __tgt_rtl_data_retrieve_async(int32_t DeviceId, void *HstPtr,
                                       void *TgtPtr, int64_t Size,
                                       __tgt_async_info *AsyncInfoPtr) {
-  auto Err = Plugin::get().getDevice(DeviceId).dataRetrieve(HstPtr, TgtPtr,
+  auto Err = PluginTy::get().getDevice(DeviceId).dataRetrieve(HstPtr, TgtPtr,
                                                             Size, AsyncInfoPtr);
   if (Err) {
     REPORT("Faliure to copy data from device to host. Pointers: host "
@@ -1829,8 +1829,8 @@ int32_t __tgt_rtl_data_exchange_async(int32_t SrcDeviceId, void *SrcPtr,
                                       int DstDeviceId, void *DstPtr,
                                       int64_t Size,
                  ...
[truncated]

Copy link

github-actions bot commented Mar 25, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link

✅ With the latest revision this PR passed the Python code formatter.

Summary:
This patch factors common functions out of the `Plugin` interface prior
to its removal in a future patch. This simply temporarily renames it to
`PluginTy` so that we could re-use `Plugin::check` internally as this
needs to be defined statically per plugin now. We can refactor this
later.

The future patch will delete `PluginTy` and `PluginTy::get` entirely.
This simply tries to minimize a few changes to make it easier to land.
Copy link
Member

@jdoerfert jdoerfert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

@jhuber6
Copy link
Contributor Author

jhuber6 commented Mar 25, 2024

@jhuber6 jhuber6 merged commit 2cad43c into llvm:main Mar 25, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants