diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 08a2e25b97d85..051882da7c6c7 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -39,12 +39,28 @@ using namespace llvm::omp::target; using namespace llvm::omp::target::plugin; using namespace error; +struct ol_platform_impl_t { + ol_platform_impl_t(std::unique_ptr Plugin, + ol_platform_backend_t BackendType) + : Plugin(std::move(Plugin)), BackendType(BackendType) {} + std::unique_ptr Plugin; + llvm::SmallVector> Devices; + ol_platform_backend_t BackendType; + + /// Complete all pending work for this platform and perform any needed + /// cleanup. + /// + /// After calling this function, no liboffload functions should be called with + /// this platform handle. + llvm::Error destroy(); +}; + // Handle type definitions. Ideally these would be 1:1 with the plugins, but // we add some additional data here for now to avoid churn in the plugin // interface. struct ol_device_impl_t { ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device, - ol_platform_handle_t Platform, InfoTreeNode &&DevInfo) + ol_platform_impl_t &Platform, InfoTreeNode &&DevInfo) : DeviceNum(DeviceNum), Device(Device), Platform(Platform), Info(std::forward(DevInfo)) {} @@ -55,7 +71,7 @@ struct ol_device_impl_t { int DeviceNum; GenericDeviceTy *Device; - ol_platform_handle_t Platform; + ol_platform_impl_t &Platform; InfoTreeNode Info; llvm::SmallVector<__tgt_async_info *> OutstandingQueues; @@ -102,31 +118,17 @@ struct ol_device_impl_t { } }; -struct ol_platform_impl_t { - ol_platform_impl_t(std::unique_ptr Plugin, - ol_platform_backend_t BackendType) - : Plugin(std::move(Plugin)), BackendType(BackendType) {} - std::unique_ptr Plugin; - llvm::SmallVector> Devices; - ol_platform_backend_t BackendType; - - /// Complete all pending work for this platform and perform any needed - /// cleanup. - /// - /// After calling this function, no liboffload functions should be called with - /// this platform handle. - llvm::Error destroy() { - llvm::Error Result = Plugin::success(); - for (auto &D : Devices) - if (auto Err = D->destroy()) - Result = llvm::joinErrors(std::move(Result), std::move(Err)); +llvm::Error ol_platform_impl_t::destroy() { + llvm::Error Result = Plugin::success(); + for (auto &D : Devices) + if (auto Err = D->destroy()) + Result = llvm::joinErrors(std::move(Result), std::move(Err)); - if (auto Res = Plugin->deinit()) - Result = llvm::joinErrors(std::move(Result), std::move(Res)); + if (auto Res = Plugin->deinit()) + Result = llvm::joinErrors(std::move(Result), std::move(Res)); - return Result; - } -}; + return Result; +} struct ol_queue_impl_t { ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device) @@ -206,12 +208,12 @@ struct OffloadContext { // Partitioned list of memory base addresses. Each element in this list is a // key in AllocInfoMap llvm::SmallVector AllocBases{}; - SmallVector Platforms{}; + SmallVector, 4> Platforms{}; size_t RefCount; ol_device_handle_t HostDevice() { // The host platform is always inserted last - return Platforms.back().Devices[0].get(); + return Platforms.back()->Devices[0].get(); } static OffloadContext &get() { @@ -251,35 +253,34 @@ Error initPlugins(OffloadContext &Context) { #define PLUGIN_TARGET(Name) \ do { \ if (StringRef(#Name) != "host") \ - Context.Platforms.emplace_back(ol_platform_impl_t{ \ + Context.Platforms.emplace_back(std::make_unique( \ std::unique_ptr(createPlugin_##Name()), \ - pluginNameToBackend(#Name)}); \ + pluginNameToBackend(#Name))); \ } while (false); #include "Shared/Targets.def" // Preemptively initialize all devices in the plugin for (auto &Platform : Context.Platforms) { - auto Err = Platform.Plugin->init(); + auto Err = Platform->Plugin->init(); [[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); - for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices(); + for (auto DevNum = 0; DevNum < Platform->Plugin->number_of_devices(); DevNum++) { - if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { - auto Device = &Platform.Plugin->getDevice(DevNum); + if (Platform->Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { + auto Device = &Platform->Plugin->getDevice(DevNum); auto Info = Device->obtainInfoImpl(); if (auto Err = Info.takeError()) return Err; - Platform.Devices.emplace_back(std::make_unique( - DevNum, Device, &Platform, std::move(*Info))); + Platform->Devices.emplace_back(std::make_unique( + DevNum, Device, *Platform, std::move(*Info))); } } } // Add the special host device auto &HostPlatform = Context.Platforms.emplace_back( - ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST}); - HostPlatform.Devices.emplace_back( - std::make_unique(-1, nullptr, nullptr, InfoTreeNode{})); - Context.HostDevice()->Platform = &HostPlatform; + std::make_unique(nullptr, OL_PLATFORM_BACKEND_HOST)); + HostPlatform->Devices.emplace_back(std::make_unique( + -1, nullptr, *HostPlatform, InfoTreeNode{})); Context.TracingEnabled = std::getenv("OFFLOAD_TRACE"); Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); @@ -316,10 +317,10 @@ Error olShutDown_impl() { for (auto &P : OldContext->Platforms) { // Host plugin is nullptr and has no deinit - if (!P.Plugin || !P.Plugin->is_initialized()) + if (!P->Plugin || !P->Plugin->is_initialized()) continue; - if (auto Res = P.destroy()) + if (auto Res = P->destroy()) Result = llvm::joinErrors(std::move(Result), std::move(Res)); } @@ -384,7 +385,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, // These are not implemented by the plugin interface switch (PropName) { case OL_DEVICE_INFO_PLATFORM: - return Info.write(Device->Platform); + return Info.write(&Device->Platform); case OL_DEVICE_INFO_TYPE: return Info.write(OL_DEVICE_TYPE_GPU); @@ -517,7 +518,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, switch (PropName) { case OL_DEVICE_INFO_PLATFORM: - return Info.write(Device->Platform); + return Info.write(&Device->Platform); case OL_DEVICE_INFO_TYPE: return Info.write(OL_DEVICE_TYPE_HOST); case OL_DEVICE_INFO_NAME: @@ -595,7 +596,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) { for (auto &Platform : OffloadContext::get().Platforms) { - for (auto &Device : Platform.Devices) { + for (auto &Device : Platform->Devices) { if (!Callback(Device.get(), UserData)) { break; }