Skip to content

Commit

Permalink
[OpenMP][FIX] Ensure recording works properly w/ late allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
jdoerfert committed Nov 21, 2023
1 parent 6663df3 commit 41566fb
Showing 1 changed file with 28 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ struct RecordReplayTy {
OS.close();
}

void saveKernelDescr(const char *Name, void **ArgPtrs, ptrdiff_t *ArgOffsets,
int32_t NumArgs, uint64_t NumTeamsClause,
uint32_t ThreadLimitClause, uint64_t LoopTripCount) {
void saveKernelDescr(const char *Name, void **ArgPtrs, int32_t NumArgs,
uint64_t NumTeamsClause, uint32_t ThreadLimitClause,
uint64_t LoopTripCount) {
json::Object JsonKernelInfo;
JsonKernelInfo["Name"] = Name;
JsonKernelInfo["NumArgs"] = NumArgs;
Expand All @@ -251,7 +251,7 @@ struct RecordReplayTy {

json::Array JsonArgOffsets;
for (int I = 0; I < NumArgs; ++I)
JsonArgOffsets.push_back(ArgOffsets[I]);
JsonArgOffsets.push_back(0);
JsonKernelInfo["ArgOffsets"] = json::Value(std::move(JsonArgOffsets));

SmallString<128> JsonFilename = {Name, ".json"};
Expand Down Expand Up @@ -427,6 +427,11 @@ Expected<KernelLaunchEnvironmentTy *>
GenericKernelTy::getKernelLaunchEnvironment(
GenericDeviceTy &GenericDevice,
AsyncInfoWrapperTy &AsyncInfoWrapper) const {
// Ctor/Dtor have no arguments, replaying uses the original kernel launch
// environment.
if (isCtorOrDtor() || RecordReplay.isReplaying())
return nullptr;

// TODO: Check if the kernel needs a launch environment.
auto AllocOrErr = GenericDevice.dataAlloc(sizeof(KernelLaunchEnvironmentTy),
/*HostPtr=*/nullptr,
Expand Down Expand Up @@ -501,6 +506,15 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
getNumBlocks(GenericDevice, KernelArgs.NumTeams, KernelArgs.Tripcount,
NumThreads, KernelArgs.ThreadLimit[0] > 0);

// Record the kernel description after we modified the argument count and num
// blocks/threads.
if (RecordReplay.isRecording()) {
RecordReplay.saveImage(getName(), getImage());
RecordReplay.saveKernelInput(getName(), getImage());
RecordReplay.saveKernelDescr(getName(), Ptrs.data(), KernelArgs.NumArgs,
NumBlocks, NumThreads, KernelArgs.Tripcount);
}

if (auto Err =
printLaunchInfo(GenericDevice, KernelArgs, NumThreads, NumBlocks))
return Err;
Expand All @@ -517,16 +531,20 @@ void *GenericKernelTy::prepareArgs(
if (isCtorOrDtor())
return nullptr;

NumArgs += 1;
uint32_t KLEOffset = !!KernelLaunchEnvironment;
NumArgs += KLEOffset;

Args.resize(NumArgs);
Ptrs.resize(NumArgs);

Ptrs[0] = KernelLaunchEnvironment;
Args[0] = &Ptrs[0];
if (KernelLaunchEnvironment) {
Ptrs[0] = KernelLaunchEnvironment;
Args[0] = &Ptrs[0];
}

for (int I = 1; I < NumArgs; ++I) {
Ptrs[I] = (void *)((intptr_t)ArgPtrs[I - 1] + ArgOffsets[I - 1]);
for (int I = KLEOffset; I < NumArgs; ++I) {
Ptrs[I] =
(void *)((intptr_t)ArgPtrs[I - KLEOffset] + ArgOffsets[I - KLEOffset]);
Args[I] = &Ptrs[I];
}
return &Args[0];
Expand Down Expand Up @@ -808,7 +826,7 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
return std::move(Err);

// Setup the global device memory pool if needed.
if (shouldSetupDeviceMemoryPool()) {
if (!RecordReplay.isReplaying() && shouldSetupDeviceMemoryPool()) {
uint64_t HeapSize;
auto SizeOrErr = getDeviceHeapSize(HeapSize);
if (SizeOrErr) {
Expand Down Expand Up @@ -1413,21 +1431,9 @@ Error GenericDeviceTy::launchKernel(void *EntryPtr, void **ArgPtrs,
GenericKernelTy &GenericKernel =
*reinterpret_cast<GenericKernelTy *>(EntryPtr);

if (RecordReplay.isRecording()) {
RecordReplay.saveImage(GenericKernel.getName(), GenericKernel.getImage());
RecordReplay.saveKernelInput(GenericKernel.getName(),
GenericKernel.getImage());
}

auto Err = GenericKernel.launch(*this, ArgPtrs, ArgOffsets, KernelArgs,
AsyncInfoWrapper);

if (RecordReplay.isRecording())
RecordReplay.saveKernelDescr(GenericKernel.getName(), ArgPtrs, ArgOffsets,
KernelArgs.NumArgs, KernelArgs.NumTeams[0],
KernelArgs.ThreadLimit[0],
KernelArgs.Tripcount);

// 'finalize' here to guarantee next record-replay actions are in-sync
AsyncInfoWrapper.finalize(Err);

Expand Down

0 comments on commit 41566fb

Please sign in to comment.