Skip to content

Commit

Permalink
[CodeGenPrepare] Limit recursion depth for collectBitParts
Browse files Browse the repository at this point in the history
Summary:
Seeing some issues for windows debug pathological cases with collectBitParts
recursion (1525 levels of recursion!)
Setting the limit to 64 as this should be sufficient - passes all lit cases

Subscribers: llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D61728

Change-Id: I7f44cdc6c1badf1c2ccbf1b0c4b6afe27ecb39a1
llvm-svn: 360347
  • Loading branch information
dstutt committed May 9, 2019
1 parent f58a5c8 commit 411488b
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions llvm/lib/Transforms/Utils/Local.cpp
Expand Up @@ -91,6 +91,10 @@ using namespace llvm::PatternMatch;

STATISTIC(NumRemoved, "Number of unreachable basic blocks removed");

// Max recursion depth for collectBitParts used when detecting bswap and
// bitreverse idioms
static const unsigned BitPartRecursionMaxDepth = 64;

//===----------------------------------------------------------------------===//
// Local constant propagation.
//
Expand Down Expand Up @@ -2619,21 +2623,27 @@ struct BitPart {
/// does not invalidate internal references (std::map instead of DenseMap).
static const Optional<BitPart> &
collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
std::map<Value *, Optional<BitPart>> &BPS) {
std::map<Value *, Optional<BitPart>> &BPS, int Depth) {
auto I = BPS.find(V);
if (I != BPS.end())
return I->second;

auto &Result = BPS[V] = None;
auto BitWidth = cast<IntegerType>(V->getType())->getBitWidth();

// Prevent stack overflow by limiting the recursion depth
if (Depth == BitPartRecursionMaxDepth) {
LLVM_DEBUG(dbgs() << "collectBitParts max recursion depth reached.\n");
return Result;
}

if (Instruction *I = dyn_cast<Instruction>(V)) {
// If this is an or instruction, it may be an inner node of the bswap.
if (I->getOpcode() == Instruction::Or) {
auto &A = collectBitParts(I->getOperand(0), MatchBSwaps,
MatchBitReversals, BPS);
MatchBitReversals, BPS, Depth + 1);
auto &B = collectBitParts(I->getOperand(1), MatchBSwaps,
MatchBitReversals, BPS);
MatchBitReversals, BPS, Depth + 1);
if (!A || !B)
return Result;

Expand Down Expand Up @@ -2666,7 +2676,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
return Result;

auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
MatchBitReversals, BPS);
MatchBitReversals, BPS, Depth + 1);
if (!Res)
return Result;
Result = Res;
Expand Down Expand Up @@ -2698,7 +2708,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
return Result;

auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
MatchBitReversals, BPS);
MatchBitReversals, BPS, Depth + 1);
if (!Res)
return Result;
Result = Res;
Expand All @@ -2713,7 +2723,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// If this is a zext instruction zero extend the result.
if (I->getOpcode() == Instruction::ZExt) {
auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
MatchBitReversals, BPS);
MatchBitReversals, BPS, Depth + 1);
if (!Res)
return Result;

Expand Down Expand Up @@ -2775,7 +2785,7 @@ bool llvm::recognizeBSwapOrBitReverseIdiom(

// Try to find all the pieces corresponding to the bswap.
std::map<Value *, Optional<BitPart>> BPS;
auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS);
auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS, 0);
if (!Res)
return false;
auto &BitProvenance = Res->Provenance;
Expand Down

0 comments on commit 411488b

Please sign in to comment.