Skip to content

Commit

Permalink
[OpenMPOpt][FIX] Ensure to propagate information about parallel regions
Browse files Browse the repository at this point in the history
Before, we checked the parallel region only once, and ignored updates in
the KernelInfo for the parallel region that happened later. This caused
us to think nested parallel sections are not present even if they are,
among other things.
  • Loading branch information
jdoerfert committed Aug 25, 2023
1 parent 01a92f0 commit a013981
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 28 deletions.
74 changes: 46 additions & 28 deletions llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
Expand Down Expand Up @@ -670,7 +671,7 @@ struct KernelInfoState : AbstractState {

/// The parallel regions (identified by the outlined parallel functions) that
/// can be reached from the associated function.
BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
ReachedKnownParallelRegions;

/// State to track what parallel region we might reach.
Expand Down Expand Up @@ -4455,11 +4456,15 @@ struct AAKernelInfoFunction : AAKernelInfo {
Value *ZeroArg =
Constant::getNullValue(ParallelRegionFnTy->getParamType(0));

const unsigned int WrapperFunctionArgNo = 6;

// Now that we have most of the CFG skeleton it is time for the if-cascade
// that checks the function pointer we got from the runtime against the
// parallel regions we expect, if there are any.
for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
auto *ParallelRegion = ReachedKnownParallelRegions[I];
auto *CB = ReachedKnownParallelRegions[I];
auto *ParallelRegion = dyn_cast<Function>(
CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
BasicBlock *PRExecuteBB = BasicBlock::Create(
Ctx, "worker_state_machine.parallel_region.execute", Kernel,
StateMachineEndParallelBB);
Expand Down Expand Up @@ -4822,8 +4827,6 @@ struct AAKernelInfoCallSite : AAKernelInfo {
return;
}

const unsigned int NonWrapperFunctionArgNo = 5;
const unsigned int WrapperFunctionArgNo = 6;
RuntimeFunction RF = It->getSecond();
switch (RF) {
// All the functions we know are compatible with SPMD mode.
Expand Down Expand Up @@ -4902,28 +4905,10 @@ struct AAKernelInfoCallSite : AAKernelInfo {
case OMPRTL___kmpc_target_deinit:
KernelDeinitCB = &CB;
break;
case OMPRTL___kmpc_parallel_51: {
auto *ParallelRegionOp =
CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts();
if (isa<ConstantPointerNull>(ParallelRegionOp))
ParallelRegionOp =
CB.getArgOperand(NonWrapperFunctionArgNo)->stripPointerCasts();
if (auto *ParallelRegion = dyn_cast<Function>(ParallelRegionOp)) {
ReachedKnownParallelRegions.insert(ParallelRegion);
/// Check nested parallelism
auto *FnAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
!FnAA->ReachedKnownParallelRegions.empty() ||
!FnAA->ReachedUnknownParallelRegions.empty();
break;
}
// The condition above should usually get the parallel region function
// pointer and record it. In the off chance it doesn't we assume the
// worst.
ReachedUnknownParallelRegions.insert(&CB);
break;
}
case OMPRTL___kmpc_parallel_51:
if (!handleParallel51(A, CB))
indicatePessimisticFixpoint();
return;
case OMPRTL___kmpc_omp_task:
// We do not look into tasks right now, just give up.
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
Expand Down Expand Up @@ -4969,14 +4954,21 @@ struct AAKernelInfoCallSite : AAKernelInfo {
return ChangeStatus::CHANGED;
}

KernelInfoState StateBefore = getState();
CallBase &CB = cast<CallBase>(getAssociatedValue());
if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
if (!handleParallel51(A, CB))
return indicatePessimisticFixpoint();
return StateBefore == getState() ? ChangeStatus::UNCHANGED
: ChangeStatus::CHANGED;
}

// F is a runtime function that allocates or frees memory, check
// AAHeapToStack and AAHeapToShared.
KernelInfoState StateBefore = getState();
assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
It->getSecond() == OMPRTL___kmpc_free_shared) &&
"Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");

CallBase &CB = cast<CallBase>(getAssociatedValue());

auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
*this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
Expand Down Expand Up @@ -5008,6 +5000,32 @@ struct AAKernelInfoCallSite : AAKernelInfo {
return StateBefore == getState() ? ChangeStatus::UNCHANGED
: ChangeStatus::CHANGED;
}

/// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
/// handled, if a problem occurred, false is returned.
bool handleParallel51(Attributor &A, CallBase &CB) {
const unsigned int NonWrapperFunctionArgNo = 5;
const unsigned int WrapperFunctionArgNo = 6;
auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
? NonWrapperFunctionArgNo
: WrapperFunctionArgNo;

auto *ParallelRegion = dyn_cast<Function>(
CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
if (!ParallelRegion)
return false;

ReachedKnownParallelRegions.insert(&CB);
/// Check nested parallelism
auto *FnAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
!FnAA->ReachedKnownParallelRegions.empty() ||
!FnAA->ReachedKnownParallelRegions.isValidState() ||
!FnAA->ReachedUnknownParallelRegions.isValidState() ||
!FnAA->ReachedUnknownParallelRegions.empty();
return true;
}
};

struct AAFoldRuntimeCall
Expand Down
57 changes: 57 additions & 0 deletions openmp/libomptarget/test/offloading/bug64959.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// RUN: %libomptarget-compilexx-run-and-check-generic
// RUN: %libomptarget-compileoptxx-run-and-check-generic

// TODO: This requires malloc support for the threads states.
// UNSUPPORTED: amdgcn-amd-amdhsa

#include <omp.h>
#include <stdio.h>
#define N 10

int isCPU() { return 1; }

#pragma omp begin declare variant match(device = {kind(gpu)})
int isCPU() { return 0; }
#pragma omp end declare variant

int main(void) {
long int aa = 0;
int res = 0;

int ng = 12;
int cmom = 14;
int nxyz;

#pragma omp target map(from : nxyz, ng, cmom)
{
nxyz = isCPU() ? 2 : 5000;
ng = isCPU() ? 2 : 12;
cmom = isCPU() ? 2 : 14;
}

#pragma omp target teams distribute num_teams(nxyz) \
thread_limit(ng *(cmom - 1)) map(tofrom : aa)
for (int gid = 0; gid < nxyz; gid++) {
#pragma omp parallel for collapse(2)
for (unsigned int g = 0; g < ng; g++) {
for (unsigned int l = 0; l < cmom - 1; l++) {
int a = 0;
#pragma omp parallel for reduction(+ : a)
for (int i = 0; i < N; i++) {
a += i;
}
#pragma omp atomic
aa += a;
}
}
}
long exp = (long)ng * (cmom - 1) * nxyz * (N * (N - 1) / 2);
printf("The result is = %ld exp:%ld!\n", aa, exp);
if (aa != exp) {
printf("Failed %ld\n", aa);
return 1;
}
// CHECK: Success
printf("Success\n");
return 0;
}

0 comments on commit a013981

Please sign in to comment.