diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h index d8d71e3e65a4a0..83f6e8d76fec75 100644 --- a/offload/plugins-nextgen/common/include/PluginInterface.h +++ b/offload/plugins-nextgen/common/include/PluginInterface.h @@ -54,6 +54,7 @@ namespace plugin { struct GenericPluginTy; struct GenericKernelTy; struct GenericDeviceTy; +struct RecordReplayTy; /// Class that wraps the __tgt_async_info to simply its usage. In case the /// object is constructed without a valid __tgt_async_info, the object will use @@ -958,7 +959,8 @@ struct GenericPluginTy { /// Construct a plugin instance. GenericPluginTy(Triple::ArchType TA) - : GlobalHandler(nullptr), JIT(TA), RPCServer(nullptr) {} + : GlobalHandler(nullptr), JIT(TA), RPCServer(nullptr), + RecordReplay(nullptr) {} virtual ~GenericPluginTy() {} @@ -1027,6 +1029,12 @@ struct GenericPluginTy { return *RPCServer; } + /// Get a reference to the record and replay interface for the plugin. + RecordReplayTy &getRecordReplay() { + assert(RecordReplay && "RR interface not initialized"); + return *RecordReplay; + } + /// Initialize a device within the plugin. Error initDevice(int32_t DeviceId); @@ -1204,6 +1212,9 @@ struct GenericPluginTy { /// The interface between the plugin and the GPU for host services. RPCServerTy *RPCServer; + + /// The interface between the plugin and the GPU for host services. + RecordReplayTy *RecordReplay; }; namespace Plugin { diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index 253acacc3a9dc8..550ebc9c28b250 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -39,6 +39,7 @@ using namespace target; using namespace plugin; // TODO: Fix any thread safety issues for multi-threaded kernel recording. +namespace llvm::omp::target::plugin { struct RecordReplayTy { // Describes the state of the record replay mechanism. @@ -358,8 +359,7 @@ struct RecordReplayTy { } } }; - -static RecordReplayTy RecordReplay; +} // namespace llvm::omp::target::plugin // Extract the mapping of host function pointers to device function pointers // from the entry table. Functions marked as 'indirect' in OpenMP will have @@ -470,7 +470,7 @@ GenericKernelTy::getKernelLaunchEnvironment( // Ctor/Dtor have no arguments, replaying uses the original kernel launch // environment. Older versions of the compiler do not generate a kernel // launch environment. - if (RecordReplay.isReplaying() || + if (GenericDevice.Plugin.getRecordReplay().isReplaying() || Version < OMP_KERNEL_ARG_MIN_VERSION_WITH_DYN_PTR) return nullptr; @@ -559,6 +559,7 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs, // Record the kernel description after we modified the argument count and num // blocks/threads. + RecordReplayTy &RecordReplay = GenericDevice.Plugin.getRecordReplay(); if (RecordReplay.isRecording()) { RecordReplay.saveImage(getName(), getImage()); RecordReplay.saveKernelInput(getName(), getImage()); @@ -833,6 +834,7 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) { delete MemoryManager; MemoryManager = nullptr; + RecordReplayTy &RecordReplay = Plugin.getRecordReplay(); if (RecordReplay.isRecordingOrReplaying()) RecordReplay.deinit(); @@ -886,7 +888,8 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin, return std::move(Err); // Setup the global device memory pool if needed. - if (!RecordReplay.isReplaying() && shouldSetupDeviceMemoryPool()) { + if (!Plugin.getRecordReplay().isReplaying() && + shouldSetupDeviceMemoryPool()) { uint64_t HeapSize; auto SizeOrErr = getDeviceHeapSize(HeapSize); if (SizeOrErr) { @@ -1301,8 +1304,8 @@ Expected GenericDeviceTy::dataAlloc(int64_t Size, void *HostPtr, TargetAllocTy Kind) { void *Alloc = nullptr; - if (RecordReplay.isRecordingOrReplaying()) - return RecordReplay.alloc(Size); + if (Plugin.getRecordReplay().isRecordingOrReplaying()) + return Plugin.getRecordReplay().alloc(Size); switch (Kind) { case TARGET_ALLOC_DEFAULT: @@ -1338,7 +1341,7 @@ Expected GenericDeviceTy::dataAlloc(int64_t Size, void *HostPtr, Error GenericDeviceTy::dataDelete(void *TgtPtr, TargetAllocTy Kind) { // Free is a noop when recording or replaying. - if (RecordReplay.isRecordingOrReplaying()) + if (Plugin.getRecordReplay().isRecordingOrReplaying()) return Plugin::success(); int Res; @@ -1405,7 +1408,8 @@ Error GenericDeviceTy::launchKernel(void *EntryPtr, void **ArgPtrs, KernelArgsTy &KernelArgs, __tgt_async_info *AsyncInfo) { AsyncInfoWrapperTy AsyncInfoWrapper( - *this, RecordReplay.isRecordingOrReplaying() ? nullptr : AsyncInfo); + *this, + Plugin.getRecordReplay().isRecordingOrReplaying() ? nullptr : AsyncInfo); GenericKernelTy &GenericKernel = *reinterpret_cast(EntryPtr); @@ -1416,6 +1420,7 @@ Error GenericDeviceTy::launchKernel(void *EntryPtr, void **ArgPtrs, // 'finalize' here to guarantee next record-replay actions are in-sync AsyncInfoWrapper.finalize(Err); + RecordReplayTy &RecordReplay = Plugin.getRecordReplay(); if (RecordReplay.isRecordingOrReplaying() && RecordReplay.isSaveOutputEnabled()) RecordReplay.saveKernelOutputInfo(GenericKernel.getName()); @@ -1503,6 +1508,9 @@ Error GenericPluginTy::init() { RPCServer = new RPCServerTy(*this); assert(RPCServer && "Invalid RPC server"); + RecordReplay = new RecordReplayTy(); + assert(RecordReplay && "Invalid RR interface"); + return Plugin::success(); } @@ -1523,6 +1531,9 @@ Error GenericPluginTy::deinit() { if (RPCServer) delete RPCServer; + if (RecordReplay) + delete RecordReplay; + // Perform last deinitializations on the plugin. return deinitImpl(); } @@ -1633,12 +1644,12 @@ int32_t GenericPluginTy::initialize_record_replay(int32_t DeviceId, isRecord ? RecordReplayTy::RRStatusTy::RRRecording : RecordReplayTy::RRStatusTy::RRReplaying; - if (auto Err = RecordReplay.init(&Device, MemorySize, VAddr, Status, - SaveOutput, ReqPtrArgOffset)) { + if (auto Err = RecordReplay->init(&Device, MemorySize, VAddr, Status, + SaveOutput, ReqPtrArgOffset)) { REPORT("WARNING RR did not intialize RR-properly with %lu bytes" "(Error: %s)\n", MemorySize, toString(std::move(Err)).data()); - RecordReplay.setStatus(RecordReplayTy::RRStatusTy::RRDeactivated); + RecordReplay->setStatus(RecordReplayTy::RRStatusTy::RRDeactivated); if (!isRecord) { return OFFLOAD_FAIL; @@ -1982,6 +1993,7 @@ int32_t GenericPluginTy::get_global(__tgt_device_binary Binary, uint64_t Size, assert(DevicePtr && "Invalid device global's address"); // Save the loaded globals if we are recording. + RecordReplayTy &RecordReplay = Device.Plugin.getRecordReplay(); if (RecordReplay.isRecording()) RecordReplay.addEntry(Name, Size, *DevicePtr);