diff --git a/openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h b/openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h index 79e8464bfda5c..7f05464f36c1f 100644 --- a/openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h +++ b/openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h @@ -45,6 +45,8 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" +struct RecordReplayTy; + namespace llvm { namespace omp { namespace target { @@ -1031,6 +1033,12 @@ struct GenericPluginTy { return *RPCServer; } + /// Get a reference to the R&R interface for this plugin. + RecordReplayTy &getRecordAndReplay() const { + assert(RecordReplay && "R&R not initialized"); + return *RecordReplay; + } + /// Get the OpenMP requires flags set for this plugin. int64_t getRequiresFlags() const { return RequiresFlags; } @@ -1220,6 +1228,9 @@ struct GenericPluginTy { /// The interface between the plugin and the GPU for host services. RPCServerTy *RPCServer; + + /// The interface into the record-and-replay functionality. + RecordReplayTy *RecordReplay; }; namespace Plugin { diff --git a/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp index b5f3c45c835fd..6df9798f12e3d 100644 --- a/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp +++ b/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp @@ -362,8 +362,6 @@ struct RecordReplayTy { } }; -static RecordReplayTy RecordReplay; - // Extract the mapping of host function pointers to device function pointers // from the entry table. Functions marked as 'indirect' in OpenMP will have // offloading entries generated for them which map the host's function pointer @@ -473,7 +471,8 @@ GenericKernelTy::getKernelLaunchEnvironment( // Ctor/Dtor have no arguments, replaying uses the original kernel launch // environment. Older versions of the compiler do not generate a kernel // launch environment. - if (isCtorOrDtor() || RecordReplay.isReplaying() || + if (isCtorOrDtor() || + GenericDevice.Plugin.getRecordAndReplay().isReplaying() || Version < OMP_KERNEL_ARG_MIN_VERSION_WITH_DYN_PTR) return nullptr; @@ -562,6 +561,7 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs, // Record the kernel description after we modified the argument count and num // blocks/threads. + RecordReplayTy &RecordReplay = GenericDevice.Plugin.getRecordAndReplay(); if (RecordReplay.isRecording()) { RecordReplay.saveImage(getName(), getImage()); RecordReplay.saveKernelInput(getName(), getImage()); @@ -839,9 +839,6 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) { delete MemoryManager; MemoryManager = nullptr; - if (RecordReplay.isRecordingOrReplaying()) - RecordReplay.deinit(); - if (RPCServer) if (auto Err = RPCServer->deinitDevice(*this)) return Err; @@ -858,6 +855,7 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) { return deinitImpl(); } + Expected GenericDeviceTy::loadBinary(GenericPluginTy &Plugin, const __tgt_device_image *InputTgtImage) { @@ -892,7 +890,8 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin, return std::move(Err); // Setup the global device memory pool if needed. - if (!RecordReplay.isReplaying() && shouldSetupDeviceMemoryPool()) { + if (!Plugin.getRecordAndReplay().isReplaying() && + shouldSetupDeviceMemoryPool()) { uint64_t HeapSize; auto SizeOrErr = getDeviceHeapSize(HeapSize); if (SizeOrErr) { @@ -1307,8 +1306,8 @@ Expected GenericDeviceTy::dataAlloc(int64_t Size, void *HostPtr, TargetAllocTy Kind) { void *Alloc = nullptr; - if (RecordReplay.isRecordingOrReplaying()) - return RecordReplay.alloc(Size); + if (Plugin.getRecordAndReplay().isRecordingOrReplaying()) + return Plugin.getRecordAndReplay().alloc(Size); switch (Kind) { case TARGET_ALLOC_DEFAULT: @@ -1344,7 +1343,7 @@ Expected GenericDeviceTy::dataAlloc(int64_t Size, void *HostPtr, Error GenericDeviceTy::dataDelete(void *TgtPtr, TargetAllocTy Kind) { // Free is a noop when recording or replaying. - if (RecordReplay.isRecordingOrReplaying()) + if (Plugin.getRecordAndReplay().isRecordingOrReplaying()) return Plugin::success(); int Res; @@ -1396,6 +1395,7 @@ Error GenericDeviceTy::launchKernel(void *EntryPtr, void **ArgPtrs, ptrdiff_t *ArgOffsets, KernelArgsTy &KernelArgs, __tgt_async_info *AsyncInfo) { + RecordReplayTy &RecordReplay = Plugin.getRecordAndReplay(); AsyncInfoWrapperTy AsyncInfoWrapper( *this, RecordReplay.isRecordingOrReplaying() ? nullptr : AsyncInfo); @@ -1495,6 +1495,9 @@ Error GenericPluginTy::init() { RPCServer = new RPCServerTy(*this); assert(RPCServer && "Invalid RPC server"); + RecordReplay = new RecordReplayTy(); + assert(RecordReplay && "Invalid Record and Replay handler"); + return Plugin::success(); } @@ -1508,6 +1511,9 @@ Error GenericPluginTy::deinit() { assert(!Devices[DeviceId] && "Device was not deinitialized"); } + if (RecordReplay && RecordReplay->isRecordingOrReplaying()) + RecordReplay->deinit(); + // There is no global handler if no device is available. if (GlobalHandler) delete GlobalHandler; @@ -1515,6 +1521,9 @@ Error GenericPluginTy::deinit() { if (RPCServer) delete RPCServer; + if (RecordReplay) + delete RecordReplay; + // Perform last deinitializations on the plugin. return deinitImpl(); } @@ -1630,12 +1639,12 @@ int32_t GenericPluginTy::initialize_record_replay(int32_t DeviceId, isRecord ? RecordReplayTy::RRStatusTy::RRRecording : RecordReplayTy::RRStatusTy::RRReplaying; - if (auto Err = RecordReplay.init(&Device, MemorySize, VAddr, Status, - SaveOutput, ReqPtrArgOffset)) { + if (auto Err = RecordReplay->init(&Device, MemorySize, VAddr, Status, + SaveOutput, ReqPtrArgOffset)) { REPORT("WARNING RR did not intialize RR-properly with %lu bytes" "(Error: %s)\n", MemorySize, toString(std::move(Err)).data()); - RecordReplay.setStatus(RecordReplayTy::RRStatusTy::RRDeactivated); + RecordReplay->setStatus(RecordReplayTy::RRStatusTy::RRDeactivated); if (!isRecord) { return OFFLOAD_FAIL; @@ -1984,8 +1993,8 @@ int32_t GenericPluginTy::get_global(__tgt_device_binary Binary, uint64_t Size, assert(DevicePtr && "Invalid device global's address"); // Save the loaded globals if we are recording. - if (RecordReplay.isRecording()) - RecordReplay.addEntry(Name, Size, *DevicePtr); + if (getRecordAndReplay().isRecording()) + getRecordAndReplay().addEntry(Name, Size, *DevicePtr); return OFFLOAD_SUCCESS; }