Skip to content

Commit

Permalink
[OpenMP] Ensure Devices is accessed exlusively (#74374)
Browse files Browse the repository at this point in the history
We accessed the `Devices` container most of the time while holding the
RTLsMtx, but not always. Sometimes we used the mutex for the size query,
but then accessed Devices again unguarded. From now we properly
encapsulate the container in a ProtectedObj which ensures exclusive
accesses. We also hide the "isReady" part in the `getDevice` accessor
and use an `llvm::Expected` to allow to return errors.
  • Loading branch information
jdoerfert committed Dec 5, 2023
1 parent d6f4d52 commit 66784dc
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 226 deletions.
37 changes: 30 additions & 7 deletions openmp/libomptarget/include/PluginManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define OMPTARGET_PLUGIN_MANAGER_H

#include "DeviceImage.h"
#include "ExclusiveAccess.h"
#include "Shared/APITypes.h"
#include "Shared/PluginAPI.h"
#include "Shared/Requirements.h"
Expand All @@ -25,6 +26,7 @@
#include "llvm/ADT/iterator.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/DynamicLibrary.h"
#include "llvm/Support/Error.h"

#include <cstdint>
#include <list>
Expand Down Expand Up @@ -75,6 +77,13 @@ struct PluginAdaptorTy {

/// Struct for the data required to handle plugins
struct PluginManager {
/// Type of the devices container. We hand out DeviceTy& to queries which are
/// stable addresses regardless if the container changes.
using DeviceContainerTy = llvm::SmallVector<std::unique_ptr<DeviceTy>>;

/// Exclusive accessor type for the device container.
using ExclusiveDevicesAccessorTy = Accessor<DeviceContainerTy>;

PluginManager() {}

void init();
Expand All @@ -89,13 +98,19 @@ struct PluginManager {
DeviceImages.emplace_back(std::make_unique<DeviceImageTy>(TgtBinDesc, TgtDeviceImage));
}

/// Return the device presented to the user as device \p DeviceNo if it is
/// initialized and ready. Otherwise return an error explaining the problem.
llvm::Expected<DeviceTy &> getDevice(uint32_t DeviceNo);

/// Iterate over all initialized and ready devices registered with this
/// plugin.
auto devices(ExclusiveDevicesAccessorTy &DevicesAccessor) {
return llvm::make_pointee_range(*DevicesAccessor);
}

/// Iterate over all device images registered with this plugin.
auto deviceImages() { return llvm::make_pointee_range(DeviceImages); }

/// Devices associated with RTLs
llvm::SmallVector<std::unique_ptr<DeviceTy>> Devices;
std::mutex RTLsMtx; ///< For RTLs and Devices

/// Translation table retreived from the binary
HostEntriesBeginToTransTableTy HostEntriesBeginToTransTable;
std::mutex TrlTblMtx; ///< For Translation Table
Expand Down Expand Up @@ -124,9 +139,12 @@ struct PluginManager {
DelayedBinDesc.clear();
}

int getNumDevices() {
std::lock_guard<decltype(RTLsMtx)> Lock(RTLsMtx);
return Devices.size();
/// Return the number of usable devices.
int getNumDevices() { return getExclusiveDevicesAccessor()->size(); }

/// Return an exclusive handle to access the devices container.
ExclusiveDevicesAccessorTy getExclusiveDevicesAccessor() {
return Devices.getExclusiveAccessor();
}

int getNumUsedPlugins() const {
Expand Down Expand Up @@ -166,6 +184,11 @@ struct PluginManager {

/// The user provided requirements.
RequirementCollection Requirements;

std::mutex RTLsMtx; ///< For RTLs

/// Devices associated with plugins, accesses to the container are exclusive.
ProtectedObj<DeviceContainerTy> Devices;
};

extern PluginManager *PM;
Expand Down
7 changes: 4 additions & 3 deletions openmp/libomptarget/include/Shared/Debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,16 @@ inline uint32_t getDebugLevel() {
/// Print fatal error message with an error string and error identifier
#define FATAL_MESSAGE0(_num, _str) \
do { \
fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: %s\n", _num, _str); \
fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: %s\n", (int)_num, \
_str); \
abort(); \
} while (0)

/// Print fatal error message with a printf string and error identifier
#define FATAL_MESSAGE(_num, _str, ...) \
do { \
fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: " _str "\n", _num, \
__VA_ARGS__); \
fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: " _str "\n", \
(int)_num, __VA_ARGS__); \
abort(); \
} while (0)

Expand Down
7 changes: 2 additions & 5 deletions openmp/libomptarget/include/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,8 @@ struct DeviceTy {
/// completed and AsyncInfo.isDone() returns true.
int32_t queryAsync(AsyncInfoTy &AsyncInfo);

/// Calls the corresponding print in the \p RTLDEVID
/// device RTL to obtain the information of the specific device.
bool printDeviceInfo(int32_t RTLDevID);
/// Calls the corresponding print device info function in the plugin.
bool printDeviceInfo();

/// Event related interfaces.
/// {
Expand Down Expand Up @@ -245,6 +244,4 @@ struct DeviceTy {
llvm::DenseMap<llvm::StringRef, OffloadEntryTy *> DeviceOffloadEntries;
};

extern bool deviceIsReady(int DeviceNum);

#endif
29 changes: 22 additions & 7 deletions openmp/libomptarget/src/OpenMP/InteropAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include "PluginManager.h"
#include "device.h"
#include "omptarget.h"
#include "llvm/Support/Error.h"
#include <cstdlib>
#include <cstring>

extern "C" {

Expand Down Expand Up @@ -190,6 +193,14 @@ __OMP_GET_INTEROP_TY3(const char *, type_desc)
__OMP_GET_INTEROP_TY3(const char *, rc_desc)
#undef __OMP_GET_INTEROP_TY3

static const char *copyErrorString(llvm::Error &&Err) {
// TODO: Use the error string while avoiding leaks.
std::string ErrMsg = llvm::toString(std::move(Err));
char *UsrMsg = reinterpret_cast<char *>(malloc(ErrMsg.size() + 1));
strcpy(UsrMsg, ErrMsg.c_str());
return UsrMsg;
};

extern "C" {

void __tgt_interop_init(ident_t *LocRef, int32_t Gtid,
Expand All @@ -211,12 +222,14 @@ void __tgt_interop_init(ident_t *LocRef, int32_t Gtid,
}

InteropPtr = new omp_interop_val_t(DeviceId, InteropType);
if (!deviceIsReady(DeviceId)) {
InteropPtr->err_str = "Device not ready!";

auto DeviceOrErr = PM->getDevice(DeviceId);
if (!DeviceOrErr) {
InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
return;
}

DeviceTy &Device = *PM->Devices[DeviceId];
DeviceTy &Device = *DeviceOrErr;
if (!Device.RTL || !Device.RTL->init_device_info ||
Device.RTL->init_device_info(DeviceId, &(InteropPtr)->device_info,
&(InteropPtr)->err_str)) {
Expand Down Expand Up @@ -248,8 +261,9 @@ void __tgt_interop_use(ident_t *LocRef, int32_t Gtid,
assert((DeviceId == -1 || InteropVal->device_id == DeviceId) &&
"Inconsistent device-id usage!");

if (!deviceIsReady(DeviceId)) {
InteropPtr->err_str = "Device not ready!";
auto DeviceOrErr = PM->getDevice(DeviceId);
if (!DeviceOrErr) {
InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
return;
}

Expand Down Expand Up @@ -277,8 +291,9 @@ void __tgt_interop_destroy(ident_t *LocRef, int32_t Gtid,

assert((DeviceId == -1 || InteropVal->device_id == DeviceId) &&
"Inconsistent device-id usage!");
if (!deviceIsReady(DeviceId)) {
InteropPtr->err_str = "Device not ready!";
auto DeviceOrErr = PM->getDevice(DeviceId);
if (!DeviceOrErr) {
InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
return;
}

Expand Down
52 changes: 45 additions & 7 deletions openmp/libomptarget/src/PluginManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
//===----------------------------------------------------------------------===//

#include "PluginManager.h"
#include "Shared/Debug.h"

#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"

using namespace llvm;
using namespace llvm::sys;
Expand Down Expand Up @@ -71,7 +75,12 @@ PluginAdaptorTy::PluginAdaptorTy(const std::string &Name) : Name(Name) {

void PluginAdaptorTy::addOffloadEntries(DeviceImageTy &DI) {
for (int32_t I = 0; I < NumberOfDevices; ++I) {
DeviceTy &Device = *PM->Devices[DeviceOffset + I];
auto DeviceOrErr = PM->getDevice(DeviceOffset + I);
if (!DeviceOrErr)
FATAL_MESSAGE(DeviceOffset + I, "%s",
toString(DeviceOrErr.takeError()).c_str());

DeviceTy &Device = *DeviceOrErr;
for (OffloadEntryTy &Entry : DI.entries())
Device.addOffloadEntry(Entry);
}
Expand All @@ -97,14 +106,15 @@ void PluginManager::initPlugin(PluginAdaptorTy &Plugin) {
return;

// Initialize the device information for the RTL we are about to use.
const size_t Start = Devices.size();
Devices.reserve(Start + Plugin.NumberOfDevices);
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
const size_t Start = ExclusiveDevicesAccessor->size();
ExclusiveDevicesAccessor->reserve(Start + Plugin.NumberOfDevices);
for (int32_t DeviceId = 0; DeviceId < Plugin.NumberOfDevices; DeviceId++) {
Devices.push_back(std::make_unique<DeviceTy>(&Plugin));
ExclusiveDevicesAccessor->push_back(std::make_unique<DeviceTy>(&Plugin));
// global device ID
Devices[Start + DeviceId]->DeviceID = Start + DeviceId;
(*ExclusiveDevicesAccessor)[Start + DeviceId]->DeviceID = Start + DeviceId;
// RTL local device ID
Devices[Start + DeviceId]->RTLDeviceID = DeviceId;
(*ExclusiveDevicesAccessor)[Start + DeviceId]->RTLDeviceID = DeviceId;
}

// Initialize the index of this RTL and save it in the used RTLs.
Expand Down Expand Up @@ -254,7 +264,12 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
// Execute dtors for static objects if the device has been used, i.e.
// if its PendingCtors list has been emptied.
for (int32_t I = 0; I < FoundRTL->NumberOfDevices; ++I) {
DeviceTy &Device = *PM->Devices[FoundRTL->DeviceOffset + I];
auto DeviceOrErr = PM->getDevice(FoundRTL->DeviceOffset + I);
if (!DeviceOrErr)
FATAL_MESSAGE(FoundRTL->DeviceOffset + I, "%s",
toString(DeviceOrErr.takeError()).c_str());

DeviceTy &Device = *DeviceOrErr;
Device.PendingGlobalsMtx.lock();
if (Device.PendingCtorsDtors[Desc].PendingCtors.empty()) {
AsyncInfoTy AsyncInfo(Device);
Expand Down Expand Up @@ -313,3 +328,26 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {

DP("Done unregistering library!\n");
}

Expected<DeviceTy &> PluginManager::getDevice(uint32_t DeviceNo) {
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
if (DeviceNo >= ExclusiveDevicesAccessor->size())
return createStringError(
inconvertibleErrorCode(),
"Device number '%i' out of range, only %i devices available", DeviceNo,
ExclusiveDevicesAccessor->size());

DeviceTy &Device = *(*ExclusiveDevicesAccessor)[DeviceNo];

DP("Is the device %d (local ID %d) initialized? %d\n", DeviceNo,
Device.RTLDeviceID, Device.IsInit);

// Init the device if not done before
if (!Device.IsInit && Device.initOnce() != OFFLOAD_SUCCESS) {
return createStringError(inconvertibleErrorCode(),
"Failed to init device %d\n", DeviceNo);
}

DP("Device %d is ready to use.\n", DeviceNo);
return Device;
}
Loading

0 comments on commit 66784dc

Please sign in to comment.