diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index 74bcd0c14827f9..2b42b50d5b12ea 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -2879,7 +2879,8 @@ struct BitPart { /// does not invalidate internal references (std::map instead of DenseMap). static const Optional & collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, - std::map> &BPS, int Depth) { + std::map> &BPS, int Depth, + bool &FoundRoot) { auto I = BPS.find(V); if (I != BPS.end()) return I->second; @@ -2904,13 +2905,13 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, // If this is an or instruction, it may be an inner node of the bswap. if (match(V, m_Or(m_Value(X), m_Value(Y)))) { // Check we have both sources and they are from the same provider. - const auto &A = - collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1); + const auto &A = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, + Depth + 1, FoundRoot); if (!A || !A->Provider) return Result; - const auto &B = - collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, Depth + 1); + const auto &B = collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, + Depth + 1, FoundRoot); if (!B || A->Provider != B->Provider) return Result; @@ -2943,8 +2944,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, if (!MatchBitReversals && (BitShift.getZExtValue() % 8) != 0) return Result; - const auto &Res = - collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1); + const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, + Depth + 1, FoundRoot); if (!Res) return Result; Result = Res; @@ -2973,8 +2974,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, if (!MatchBitReversals && (NumMaskedBits % 8) != 0) return Result; - const auto &Res = - collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1); + const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, + Depth + 1, FoundRoot); if (!Res) return Result; Result = Res; @@ -2988,8 +2989,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, // If this is a zext instruction zero extend the result. if (match(V, m_ZExt(m_Value(X)))) { - const auto &Res = - collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1); + const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, + Depth + 1, FoundRoot); if (!Res) return Result; @@ -3004,8 +3005,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, // If this is a truncate instruction, extract the lower bits. if (match(V, m_Trunc(m_Value(X)))) { - const auto &Res = - collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1); + const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, + Depth + 1, FoundRoot); if (!Res) return Result; @@ -3018,8 +3019,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, // BITREVERSE - most likely due to us previous matching a partial // bitreverse. if (match(V, m_BitReverse(m_Value(X)))) { - const auto &Res = - collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1); + const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, + Depth + 1, FoundRoot); if (!Res) return Result; @@ -3031,8 +3032,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, // BSWAP - most likely due to us previous matching a partial bswap. if (match(V, m_BSwap(m_Value(X)))) { - const auto &Res = - collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1); + const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, + Depth + 1, FoundRoot); if (!Res) return Result; @@ -3063,13 +3064,13 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, return Result; // Check we have both sources and they are from the same provider. - const auto &LHS = - collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1); + const auto &LHS = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, + Depth + 1, FoundRoot); if (!LHS || !LHS->Provider) - return Result; + return Result; - const auto &RHS = - collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, Depth + 1); + const auto &RHS = collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, + Depth + 1, FoundRoot); if (!RHS || LHS->Provider != RHS->Provider) return Result; @@ -3083,8 +3084,14 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, } } - // Okay, we got to something that isn't a shift, 'or' or 'and'. This must be - // the input value to the bswap/bitreverse. + // If we've already found a root input value then we're never going to merge + // these back together. + if (FoundRoot) + return Result; + + // Okay, we got to something that isn't a shift, 'or', 'and', etc. This must + // be the root input value to the bswap/bitreverse. + FoundRoot = true; Result = BitPart(V, BitWidth); for (unsigned BitIdx = 0; BitIdx < BitWidth; ++BitIdx) Result->Provenance[BitIdx] = BitIdx; @@ -3126,8 +3133,10 @@ bool llvm::recognizeBSwapOrBitReverseIdiom( DemandedTy = Trunc->getType(); // Try to find all the pieces corresponding to the bswap. + bool FoundRoot = false; std::map> BPS; - auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS, 0); + const auto &Res = + collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS, 0, FoundRoot); if (!Res) return false; ArrayRef BitProvenance = Res->Provenance;