diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 02881109f17d2..fd4b416ec8792 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2706,9 +2706,8 @@ Instruction *InstCombinerImpl::matchBSwapOrBitReverse(Instruction &I, return LastInst; } -/// Match UB-safe variants of the funnel shift intrinsic. -static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, - const DominatorTree &DT) { +std::optional>> +InstCombinerImpl::convertShlOrLShrToFShlOrFShr(Instruction &Or) { // TODO: Can we reduce the code duplication between this and the related // rotate matching code under visitSelect and visitTrunc? unsigned Width = Or.getType()->getScalarSizeInBits(); @@ -2716,7 +2715,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, Instruction *Or0, *Or1; if (!match(Or.getOperand(0), m_Instruction(Or0)) || !match(Or.getOperand(1), m_Instruction(Or1))) - return nullptr; + return std::nullopt; bool IsFshl = true; // Sub on LSHR. SmallVector FShiftArgs; @@ -2730,7 +2729,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) || Or0->getOpcode() == Or1->getOpcode()) - return nullptr; + return std::nullopt; // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). if (Or0->getOpcode() == BinaryOperator::LShr) { @@ -2766,7 +2765,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, // might remove it after this fold). This still doesn't guarantee that the // final codegen will match this original pattern. if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) { - KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or); + KnownBits KnownL = computeKnownBits(L, /*Depth*/ 0, &Or); return KnownL.getMaxValue().ult(Width) ? L : nullptr; } @@ -2810,7 +2809,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, IsFshl = false; // Sub on SHL. } if (!ShAmt) - return nullptr; + return std::nullopt; FShiftArgs = {ShVal0, ShVal1, ShAmt}; } else if (isa(Or0) || isa(Or1)) { @@ -2832,18 +2831,18 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, const APInt *ZextHighShlAmt; if (!match(Or0, m_OneUse(m_Shl(m_Value(ZextHigh), m_APInt(ZextHighShlAmt))))) - return nullptr; + return std::nullopt; if (!match(Or1, m_ZExt(m_Value(Low))) || !match(ZextHigh, m_ZExt(m_Value(High)))) - return nullptr; + return std::nullopt; unsigned HighSize = High->getType()->getScalarSizeInBits(); unsigned LowSize = Low->getType()->getScalarSizeInBits(); // Make sure High does not overlap with Low and most significant bits of // High aren't shifted out. if (ZextHighShlAmt->ult(LowSize) || ZextHighShlAmt->ugt(Width - HighSize)) - return nullptr; + return std::nullopt; for (User *U : ZextHigh->users()) { Value *X, *Y; @@ -2874,11 +2873,22 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, } if (FShiftArgs.empty()) - return nullptr; + return std::nullopt; Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; - Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType()); - return CallInst::Create(F, FShiftArgs); + return std::make_tuple(IID, FShiftArgs); +} + +/// Match UB-safe variants of the funnel shift intrinsic. +static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, + const DominatorTree &DT) { + if (auto Opt = IC.convertShlOrLShrToFShlOrFShr(Or)) { + auto [IID, FShiftArgs] = *Opt; + Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType()); + return CallInst::Create(F, FShiftArgs); + } + + return nullptr; } /// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 0bbb22be71569..303d02cc24fc9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -236,6 +236,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final return getLosslessTrunc(C, TruncTy, Instruction::SExt); } + std::optional>> + convertShlOrLShrToFShlOrFShr(Instruction &Or); + private: bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI); bool isDesirableIntType(unsigned BitWidth) const;