diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index db27711f29b178..4b7606e6021b7c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2297,7 +2297,8 @@ static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, /// funnel shift intrinsic. Example: /// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b))) /// --> call llvm.fshl.i32(a, a, b) -static Instruction *foldSelectRotate(SelectInst &Sel) { +static Instruction *foldSelectRotate(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { // The false value of the select must be a rotate of the true value. Value *Or0, *Or1; if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) @@ -2305,8 +2306,10 @@ static Instruction *foldSelectRotate(SelectInst &Sel) { Value *TVal = Sel.getTrueValue(); Value *SA0, *SA1; - if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA0)))) || - !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA1))))) + if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), + m_ZExtOrSelf(m_Value(SA0))))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), + m_ZExtOrSelf(m_Value(SA1)))))) return nullptr; auto ShiftOpcode0 = cast(Or0)->getOpcode(); @@ -2344,6 +2347,7 @@ static Instruction *foldSelectRotate(SelectInst &Sel) { (ShAmt == SA1 && ShiftOpcode1 == BinaryOperator::Shl); Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; Function *F = Intrinsic::getDeclaration(Sel.getModule(), IID, Sel.getType()); + ShAmt = Builder.CreateZExt(ShAmt, Sel.getType()); return IntrinsicInst::Create(F, { TVal, TVal, ShAmt }); } @@ -2960,7 +2964,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI, *this)) return Select; - if (Instruction *Rot = foldSelectRotate(SI)) + if (Instruction *Rot = foldSelectRotate(SI, Builder)) return Rot; if (Instruction *Copysign = foldSelectToCopysign(SI, Builder)) diff --git a/llvm/test/Transforms/InstCombine/rotate.ll b/llvm/test/Transforms/InstCombine/rotate.ll index 3dc65b01bc21bb..2b8d9fc1e3d15f 100644 --- a/llvm/test/Transforms/InstCombine/rotate.ll +++ b/llvm/test/Transforms/InstCombine/rotate.ll @@ -691,15 +691,8 @@ define i24 @rotl_select_weird_type(i24 %x, i24 %shamt) { define i32 @rotl_select_zext_shamt(i32 %x, i8 %y) { ; CHECK-LABEL: @rotl_select_zext_shamt( -; CHECK-NEXT: [[REM:%.*]] = and i8 [[Y:%.*]], 31 -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[REM]], 0 -; CHECK-NEXT: [[SH_PROM:%.*]] = zext i8 [[REM]] to i32 -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i8 32, [[REM]] -; CHECK-NEXT: [[SH_PROM1:%.*]] = zext i8 [[SUB]] to i32 -; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[X:%.*]], [[SH_PROM1]] -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[X]], [[SH_PROM]] -; CHECK-NEXT: [[OR:%.*]] = or i32 [[SHL]], [[SHR]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i32 [[X]], i32 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = zext i8 [[Y:%.*]] to i32 +; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.fshl.i32(i32 [[X:%.*]], i32 [[X]], i32 [[TMP1]]) ; CHECK-NEXT: ret i32 [[R]] ; %rem = and i8 %y, 31 @@ -716,15 +709,8 @@ define i32 @rotl_select_zext_shamt(i32 %x, i8 %y) { define i64 @rotr_select_zext_shamt(i64 %x, i32 %y) { ; CHECK-LABEL: @rotr_select_zext_shamt( -; CHECK-NEXT: [[REM:%.*]] = and i32 [[Y:%.*]], 63 -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[REM]], 0 -; CHECK-NEXT: [[SH_PROM:%.*]] = zext i32 [[REM]] to i64 -; CHECK-NEXT: [[SHR:%.*]] = lshr i64 [[X:%.*]], [[SH_PROM]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 64, [[REM]] -; CHECK-NEXT: [[SH_PROM1:%.*]] = zext i32 [[SUB]] to i64 -; CHECK-NEXT: [[SHL:%.*]] = shl i64 [[X]], [[SH_PROM1]] -; CHECK-NEXT: [[OR:%.*]] = or i64 [[SHL]], [[SHR]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i64 [[X]], i64 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[Y:%.*]] to i64 +; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.fshr.i64(i64 [[X:%.*]], i64 [[X]], i64 [[TMP1]]) ; CHECK-NEXT: ret i64 [[R]] ; %rem = and i32 %y, 63