diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 97f129e200de7..dd1e8da2eb485 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1121,7 +1121,7 @@ static const unsigned MaxDepth = 6; // actual instructions, otherwise return a non-null dummy value. Return nullptr // on failure. static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, - bool DoFold) { + bool AssumeNonZero, bool DoFold) { auto IfFold = [DoFold](function_ref Fn) { if (!DoFold) return reinterpret_cast(-1); @@ -1147,14 +1147,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // FIXME: Require one use? Value *X, *Y; if (match(Op, m_ZExt(m_Value(X)))) - if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); }); // log2(X << Y) -> log2(X) + Y // FIXME: Require one use unless X is 1? - if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) - if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) - return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) { + auto *BO = cast(Op); + // nuw will be set if the `shl` is trivially non-zero. + if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap()) + if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) + return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + } // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y) // FIXME: missed optimization: if one of the hands of select is/contains @@ -1162,8 +1166,10 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // FIXME: can both hands contain undef? // FIXME: Require one use? if (SelectInst *SI = dyn_cast(Op)) - if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold)) - if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, + AssumeNonZero, DoFold)) + if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, + AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateSelect(SI->getOperand(0), LogX, LogY); }); @@ -1171,13 +1177,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // log2(umin(X, Y)) -> umin(log2(X), log2(Y)) // log2(umax(X, Y)) -> umax(log2(X), log2(Y)) auto *MinMax = dyn_cast(Op); - if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) - if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold)) - if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold)) + if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) { + // Use AssumeNonZero as false here. Otherwise we can hit case where + // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow). + if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, + /*AssumeNonZero*/ false, DoFold)) + if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, + /*AssumeNonZero*/ false, DoFold)) return IfFold([&]() { - return Builder.CreateBinaryIntrinsic( - MinMax->getIntrinsicID(), LogX, LogY); + return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX, + LogY); }); + } return nullptr; } @@ -1297,8 +1308,10 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { } // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away. - if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) { - Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true); + if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, + /*AssumeNonZero*/ true, /*DoFold*/ true); return replaceInstUsesWith( I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact())); }