Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 45 additions & 44 deletions offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,28 @@ using namespace llvm::omp::target;
using namespace llvm::omp::target::plugin;
using namespace error;

struct ol_platform_impl_t {
ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
ol_platform_backend_t BackendType)
: Plugin(std::move(Plugin)), BackendType(BackendType) {}
std::unique_ptr<GenericPluginTy> Plugin;
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
ol_platform_backend_t BackendType;

/// Complete all pending work for this platform and perform any needed
/// cleanup.
///
/// After calling this function, no liboffload functions should be called with
/// this platform handle.
llvm::Error destroy();
};

// Handle type definitions. Ideally these would be 1:1 with the plugins, but
// we add some additional data here for now to avoid churn in the plugin
// interface.
struct ol_device_impl_t {
ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device,
ol_platform_handle_t Platform, InfoTreeNode &&DevInfo)
ol_platform_impl_t &Platform, InfoTreeNode &&DevInfo)
: DeviceNum(DeviceNum), Device(Device), Platform(Platform),
Info(std::forward<InfoTreeNode>(DevInfo)) {}

Expand All @@ -55,7 +71,7 @@ struct ol_device_impl_t {

int DeviceNum;
GenericDeviceTy *Device;
ol_platform_handle_t Platform;
ol_platform_impl_t &Platform;
InfoTreeNode Info;

llvm::SmallVector<__tgt_async_info *> OutstandingQueues;
Expand Down Expand Up @@ -102,31 +118,17 @@ struct ol_device_impl_t {
}
};

struct ol_platform_impl_t {
ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
ol_platform_backend_t BackendType)
: Plugin(std::move(Plugin)), BackendType(BackendType) {}
std::unique_ptr<GenericPluginTy> Plugin;
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
ol_platform_backend_t BackendType;

/// Complete all pending work for this platform and perform any needed
/// cleanup.
///
/// After calling this function, no liboffload functions should be called with
/// this platform handle.
llvm::Error destroy() {
llvm::Error Result = Plugin::success();
for (auto &D : Devices)
if (auto Err = D->destroy())
Result = llvm::joinErrors(std::move(Result), std::move(Err));
llvm::Error ol_platform_impl_t::destroy() {
llvm::Error Result = Plugin::success();
for (auto &D : Devices)
if (auto Err = D->destroy())
Result = llvm::joinErrors(std::move(Result), std::move(Err));

if (auto Res = Plugin->deinit())
Result = llvm::joinErrors(std::move(Result), std::move(Res));
if (auto Res = Plugin->deinit())
Result = llvm::joinErrors(std::move(Result), std::move(Res));

return Result;
}
};
return Result;
}

struct ol_queue_impl_t {
ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
Expand Down Expand Up @@ -206,12 +208,12 @@ struct OffloadContext {
// Partitioned list of memory base addresses. Each element in this list is a
// key in AllocInfoMap
llvm::SmallVector<void *> AllocBases{};
SmallVector<ol_platform_impl_t, 4> Platforms{};
SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{};
size_t RefCount;

ol_device_handle_t HostDevice() {
// The host platform is always inserted last
return Platforms.back().Devices[0].get();
return Platforms.back()->Devices[0].get();
}

static OffloadContext &get() {
Expand Down Expand Up @@ -251,35 +253,34 @@ Error initPlugins(OffloadContext &Context) {
#define PLUGIN_TARGET(Name) \
do { \
if (StringRef(#Name) != "host") \
Context.Platforms.emplace_back(ol_platform_impl_t{ \
Context.Platforms.emplace_back(std::make_unique<ol_platform_impl_t>( \
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
pluginNameToBackend(#Name)}); \
pluginNameToBackend(#Name))); \
} while (false);
#include "Shared/Targets.def"

// Preemptively initialize all devices in the plugin
for (auto &Platform : Context.Platforms) {
auto Err = Platform.Plugin->init();
auto Err = Platform->Plugin->init();
[[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices();
for (auto DevNum = 0; DevNum < Platform->Plugin->number_of_devices();
DevNum++) {
if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
auto Device = &Platform.Plugin->getDevice(DevNum);
if (Platform->Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
auto Device = &Platform->Plugin->getDevice(DevNum);
auto Info = Device->obtainInfoImpl();
if (auto Err = Info.takeError())
return Err;
Platform.Devices.emplace_back(std::make_unique<ol_device_impl_t>(
DevNum, Device, &Platform, std::move(*Info)));
Platform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
DevNum, Device, *Platform, std::move(*Info)));
}
}
}

// Add the special host device
auto &HostPlatform = Context.Platforms.emplace_back(
ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST});
HostPlatform.Devices.emplace_back(
std::make_unique<ol_device_impl_t>(-1, nullptr, nullptr, InfoTreeNode{}));
Context.HostDevice()->Platform = &HostPlatform;
std::make_unique<ol_platform_impl_t>(nullptr, OL_PLATFORM_BACKEND_HOST));
HostPlatform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
-1, nullptr, *HostPlatform, InfoTreeNode{}));

Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
Expand Down Expand Up @@ -316,10 +317,10 @@ Error olShutDown_impl() {

for (auto &P : OldContext->Platforms) {
// Host plugin is nullptr and has no deinit
if (!P.Plugin || !P.Plugin->is_initialized())
if (!P->Plugin || !P->Plugin->is_initialized())
continue;

if (auto Res = P.destroy())
if (auto Res = P->destroy())
Result = llvm::joinErrors(std::move(Result), std::move(Res));
}

Expand Down Expand Up @@ -384,7 +385,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
// These are not implemented by the plugin interface
switch (PropName) {
case OL_DEVICE_INFO_PLATFORM:
return Info.write<void *>(Device->Platform);
return Info.write<void *>(&Device->Platform);

case OL_DEVICE_INFO_TYPE:
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
Expand Down Expand Up @@ -517,7 +518,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,

switch (PropName) {
case OL_DEVICE_INFO_PLATFORM:
return Info.write<void *>(Device->Platform);
return Info.write<void *>(&Device->Platform);
case OL_DEVICE_INFO_TYPE:
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST);
case OL_DEVICE_INFO_NAME:
Expand Down Expand Up @@ -595,7 +596,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,

Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
for (auto &Platform : OffloadContext::get().Platforms) {
for (auto &Device : Platform.Devices) {
for (auto &Device : Platform->Devices) {
if (!Callback(Device.get(), UserData)) {
break;
}
Expand Down
Loading