diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 55266ff7c7ebf4..b8d4dddbfd15b9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2337,8 +2337,8 @@ static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, static Instruction *foldSelectRotate(SelectInst &Sel, InstCombiner::BuilderTy &Builder) { // The false value of the select must be a rotate of the true value. - Value *Or0, *Or1; - if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) + BinaryOperator *Or0, *Or1; + if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1))))) return nullptr; Value *TVal = Sel.getTrueValue(); @@ -2346,16 +2346,20 @@ static Instruction *foldSelectRotate(SelectInst &Sel, if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), m_ZExtOrSelf(m_Value(SA0))))) || !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), - m_ZExtOrSelf(m_Value(SA1)))))) + m_ZExtOrSelf(m_Value(SA1))))) || + Or0->getOpcode() == Or1->getOpcode()) return nullptr; - auto ShiftOpcode0 = cast(Or0)->getOpcode(); - auto ShiftOpcode1 = cast(Or1)->getOpcode(); - if (ShiftOpcode0 == ShiftOpcode1) - return nullptr; + // Canonicalize to or(shl(TVal, SA0), lshr(TVal, SA1)). + if (Or0->getOpcode() == BinaryOperator::LShr) { + std::swap(Or0, Or1); + std::swap(SA0, SA1); + } + assert(Or0->getOpcode() == BinaryOperator::Shl && + Or1->getOpcode() == BinaryOperator::LShr && + "Illegal or(shift,shift) pair"); - // We have one of these patterns so far: - // select ?, TVal, (or (lshr TVal, SA0), (shl TVal, SA1)) + // We should now have this pattern: // select ?, TVal, (or (shl TVal, SA0), (lshr TVal, SA1)) // This must be a power-of-2 rotate for a bitmasking transform to be valid. unsigned Width = Sel.getType()->getScalarSizeInBits(); @@ -2380,8 +2384,7 @@ static Instruction *foldSelectRotate(SelectInst &Sel, // This is a rotate that avoids shift-by-bitwidth UB in a suboptimal way. // Convert to funnel shift intrinsic. - bool IsFshl = (ShAmt == SA0 && ShiftOpcode0 == BinaryOperator::Shl) || - (ShAmt == SA1 && ShiftOpcode1 == BinaryOperator::Shl); + bool IsFshl = (ShAmt == SA0); Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; Function *F = Intrinsic::getDeclaration(Sel.getModule(), IID, Sel.getType()); ShAmt = Builder.CreateZExt(ShAmt, Sel.getType());