Skip to content

Commit

Permalink
[Attributor][FIX] Avoid dangling stack references in map
Browse files Browse the repository at this point in the history
The old code did not account for new queries during an update, which
caused us to leave stack RQIs in the map. We are now explicit about
temporary vs non-temporary RQIs.

Fixes: #64959
  • Loading branch information
jdoerfert committed Aug 24, 2023
1 parent 3611300 commit d2c37fc
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 30 deletions.
64 changes: 34 additions & 30 deletions llvm/lib/Transforms/IPO/AttributorAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3541,24 +3541,24 @@ struct CachedReachabilityAA : public BaseTy {
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
ChangeStatus Changed = ChangeStatus::UNCHANGED;
InUpdate = true;
for (unsigned u = 0, e = QueryVector.size(); u < e; ++u) {
RQITy *RQI = QueryVector[u];
if (RQI->Result == RQITy::Reachable::No && isReachableImpl(A, *RQI))
if (RQI->Result == RQITy::Reachable::No &&
isReachableImpl(A, *RQI, /*IsTemporaryRQI=*/false))
Changed = ChangeStatus::CHANGED;
}
InUpdate = false;
return Changed;
}

virtual bool isReachableImpl(Attributor &A, RQITy &RQI) = 0;
virtual bool isReachableImpl(Attributor &A, RQITy &RQI,
bool IsTemporaryRQI) = 0;

bool rememberResult(Attributor &A, typename RQITy::Reachable Result,
RQITy &RQI, bool UsedExclusionSet) {
RQITy &RQI, bool UsedExclusionSet, bool IsTemporaryRQI) {
RQI.Result = Result;

// Remove the temporary RQI from the cache.
if (!InUpdate)
if (IsTemporaryRQI)
QueryCache.erase(&RQI);

// Insert a plain RQI (w/o exclusion set) if that makes sense. Two options:
Expand All @@ -3576,7 +3576,7 @@ struct CachedReachabilityAA : public BaseTy {
}

// Check if we need to insert a new permanent RQI with the exclusion set.
if (!InUpdate && Result != RQITy::Reachable::Yes && UsedExclusionSet) {
if (IsTemporaryRQI && Result != RQITy::Reachable::Yes && UsedExclusionSet) {
assert((!RQI.ExclusionSet || !RQI.ExclusionSet->empty()) &&
"Did not expect empty set!");
RQITy *RQIPtr = new (A.Allocator)
Expand All @@ -3588,7 +3588,7 @@ struct CachedReachabilityAA : public BaseTy {
QueryCache.insert(RQIPtr);
}

if (Result == RQITy::Reachable::No && !InUpdate)
if (Result == RQITy::Reachable::No && IsTemporaryRQI)
A.registerForUpdate(*this);
return Result == RQITy::Reachable::Yes;
}
Expand Down Expand Up @@ -3629,7 +3629,6 @@ struct CachedReachabilityAA : public BaseTy {
}

private:
bool InUpdate = false;
SmallVector<RQITy *> QueryVector;
DenseSet<RQITy *> QueryCache;
};
Expand All @@ -3653,7 +3652,8 @@ struct AAIntraFnReachabilityFunction final
RQITy StackRQI(A, From, To, ExclusionSet, false);
typename RQITy::Reachable Result;
if (!NonConstThis->checkQueryCache(A, StackRQI, Result))
return NonConstThis->isReachableImpl(A, StackRQI);
return NonConstThis->isReachableImpl(A, StackRQI,
/*IsTemporaryRQI=*/true);
return Result == RQITy::Reachable::Yes;
}

Expand All @@ -3678,7 +3678,8 @@ struct AAIntraFnReachabilityFunction final
return Base::updateImpl(A);
}

bool isReachableImpl(Attributor &A, RQITy &RQI) override {
bool isReachableImpl(Attributor &A, RQITy &RQI,
bool IsTemporaryRQI) override {
const Instruction *Origin = RQI.From;
bool UsedExclusionSet = false;

Expand All @@ -3704,12 +3705,14 @@ struct AAIntraFnReachabilityFunction final
// possible.
if (FromBB == ToBB &&
WillReachInBlock(*RQI.From, *RQI.To, RQI.ExclusionSet))
return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet);
return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet,
IsTemporaryRQI);

// Check if reaching the ToBB block is sufficient or if even that would not
// ensure reaching the target. In the latter case we are done.
if (!WillReachInBlock(ToBB->front(), *RQI.To, RQI.ExclusionSet))
return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet,
IsTemporaryRQI);

const Function *Fn = FromBB->getParent();
SmallPtrSet<const BasicBlock *, 16> ExclusionBlocks;
Expand All @@ -3722,13 +3725,14 @@ struct AAIntraFnReachabilityFunction final
if (ExclusionBlocks.count(FromBB) &&
!WillReachInBlock(*RQI.From, *FromBB->getTerminator(),
RQI.ExclusionSet))
return rememberResult(A, RQITy::Reachable::No, RQI, true);
return rememberResult(A, RQITy::Reachable::No, RQI, true, IsTemporaryRQI);

auto *LivenessAA =
A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
if (LivenessAA && LivenessAA->isAssumedDead(ToBB)) {
DeadBlocks.insert(ToBB);
return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet,
IsTemporaryRQI);
}

SmallPtrSet<const BasicBlock *, 16> Visited;
Expand All @@ -3747,11 +3751,11 @@ struct AAIntraFnReachabilityFunction final
}
// We checked before if we just need to reach the ToBB block.
if (SuccBB == ToBB)
return rememberResult(A, RQITy::Reachable::Yes, RQI,
UsedExclusionSet);
return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet,
IsTemporaryRQI);
if (DT && ExclusionBlocks.empty() && DT->dominates(BB, ToBB))
return rememberResult(A, RQITy::Reachable::Yes, RQI,
UsedExclusionSet);
return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet,
IsTemporaryRQI);

if (ExclusionBlocks.count(SuccBB)) {
UsedExclusionSet = true;
Expand All @@ -3762,7 +3766,8 @@ struct AAIntraFnReachabilityFunction final
}

DeadEdges.insert(LocalDeadEdges.begin(), LocalDeadEdges.end());
return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet,
IsTemporaryRQI);
}

/// See AbstractAttribute::trackStatistics()
Expand Down Expand Up @@ -10646,22 +10651,19 @@ struct AAInterFnReachabilityFunction
RQITy StackRQI(A, From, To, ExclusionSet, false);
typename RQITy::Reachable Result;
if (!NonConstThis->checkQueryCache(A, StackRQI, Result))
return NonConstThis->isReachableImpl(A, StackRQI);
return NonConstThis->isReachableImpl(A, StackRQI,
/*IsTemporaryRQI=*/true);
return Result == RQITy::Reachable::Yes;
}

bool isReachableImpl(Attributor &A, RQITy &RQI) override {
return isReachableImpl(A, RQI, nullptr);
}

bool isReachableImpl(Attributor &A, RQITy &RQI,
SmallPtrSet<const Function *, 16> *Visited) {

bool IsTemporaryRQI) override {
const Instruction *EntryI =
&RQI.From->getFunction()->getEntryBlock().front();
if (EntryI != RQI.From &&
!instructionCanReach(A, *EntryI, *RQI.To, nullptr))
return rememberResult(A, RQITy::Reachable::No, RQI, false);
return rememberResult(A, RQITy::Reachable::No, RQI, false,
IsTemporaryRQI);

auto CheckReachableCallBase = [&](CallBase *CB) {
auto *CBEdges = A.getAAFor<AACallEdges>(
Expand Down Expand Up @@ -10721,9 +10723,11 @@ struct AAInterFnReachabilityFunction
if (!A.checkForAllCallLikeInstructions(CheckCallBase, *this,
UsedAssumedInformation,
/* CheckBBLivenessOnly */ true))
return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet);
return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet,
IsTemporaryRQI);

return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet,
IsTemporaryRQI);
}

void trackStatistics() const override {}
Expand Down
40 changes: 40 additions & 0 deletions openmp/libomptarget/test/offloading/bug64959_compile_only.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: %libomptarget-compile-generic
// RUN: %libomptarget-compileopt-generic

#include <stdio.h>
#define N 10

int main(void) {
long int aa = 0;
int res = 0;

int ng = 12;
int cmom = 14;
int nxyz = 5000;

#pragma omp target teams distribute num_teams(nxyz) \
thread_limit(ng *(cmom - 1)) map(tofrom : aa)
for (int gid = 0; gid < nxyz; gid++) {
#pragma omp parallel for collapse(2)
for (unsigned int g = 0; g < ng; g++) {
for (unsigned int l = 0; l < cmom - 1; l++) {
int a = 0;
#pragma omp parallel for reduction(+ : a)
for (int i = 0; i < N; i++) {
a += i;
}
#pragma omp atomic
aa += a;
}
}
}
long exp = (long)ng * (cmom - 1) * nxyz * (N * (N - 1) / 2);
printf("The result is = %ld exp:%ld!\n", aa, exp);
if (aa != exp) {
printf("Failed %ld\n", aa);
return 1;
}
// CHECK: Success
printf("Success\n");
return 0;
}

0 comments on commit d2c37fc

Please sign in to comment.