From 0399473de886595d8ce3346f2cc99c94267496e5 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Sun, 19 Jun 2022 10:59:05 -0400 Subject: [PATCH] [InstCombine] add fold for (ShiftC >> X) isExact() && C.isZero()) return new ICmpInst(Pred, X, Cmp.getOperand(1)); - const APInt *ShiftVal; - if (Cmp.isEquality() && match(Shr->getOperand(0), m_APInt(ShiftVal))) - return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftVal); + bool IsAShr = Shr->getOpcode() == Instruction::AShr; + const APInt *ShiftValC; + if (match(Shr->getOperand(0), m_APInt(ShiftValC))) { + if (Cmp.isEquality()) + return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftValC); - const APInt *ShiftAmt; - if (!match(Shr->getOperand(1), m_APInt(ShiftAmt))) + // If the shifted constant is a power-of-2, test the shift amount directly: + // (ShiftValC >> X) >u C --> X isPowerOf2()) { + assert(ShiftValC->ugt(C) && "Expected simplify of compare"); + unsigned CmpLZ = C.countLeadingZeros(); + unsigned ShiftLZ = ShiftValC->countLeadingZeros(); + Constant *NewC = ConstantInt::get(Shr->getType(), CmpLZ - ShiftLZ); + return new ICmpInst(ICmpInst::ICMP_ULT, Shr->User::getOperand(1), NewC); + } + } + + const APInt *ShiftAmtC; + if (!match(Shr->getOperand(1), m_APInt(ShiftAmtC))) return nullptr; // Check that the shift amount is in range. If not, don't perform undefined // shifts. When the shift is visited it will be simplified. unsigned TypeBits = C.getBitWidth(); - unsigned ShAmtVal = ShiftAmt->getLimitedValue(TypeBits); + unsigned ShAmtVal = ShiftAmtC->getLimitedValue(TypeBits); if (ShAmtVal >= TypeBits || ShAmtVal == 0) return nullptr; - bool IsAShr = Shr->getOpcode() == Instruction::AShr; bool IsExact = Shr->isExact(); Type *ShrTy = Shr->getType(); // TODO: If we could guarantee that InstSimplify would handle all of the diff --git a/llvm/test/Transforms/InstCombine/icmp-shr.ll b/llvm/test/Transforms/InstCombine/icmp-shr.ll index 2d78916356ab73..f013e6bdcbe2c4 100644 --- a/llvm/test/Transforms/InstCombine/icmp-shr.ll +++ b/llvm/test/Transforms/InstCombine/icmp-shr.ll @@ -1008,8 +1008,7 @@ define i1 @ashr_exact_ne_0_multiuse(i8 %x) { define i1 @lshr_pow2_ugt(i8 %x) { ; CHECK-LABEL: @lshr_pow2_ugt( -; CHECK-NEXT: [[S:%.*]] = lshr i8 2, [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = icmp ugt i8 [[S]], 1 +; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[X:%.*]], 0 ; CHECK-NEXT: ret i1 [[R]] ; %s = lshr i8 2, %x @@ -1021,7 +1020,7 @@ define i1 @lshr_pow2_ugt_use(i8 %x) { ; CHECK-LABEL: @lshr_pow2_ugt_use( ; CHECK-NEXT: [[S:%.*]] = lshr i8 -128, [[X:%.*]] ; CHECK-NEXT: call void @use(i8 [[S]]) -; CHECK-NEXT: [[R:%.*]] = icmp ugt i8 [[S]], 5 +; CHECK-NEXT: [[R:%.*]] = icmp ult i8 [[X]], 5 ; CHECK-NEXT: ret i1 [[R]] ; %s = lshr i8 128, %x @@ -1032,8 +1031,7 @@ define i1 @lshr_pow2_ugt_use(i8 %x) { define <2 x i1> @lshr_pow2_ugt_vec(<2 x i8> %x) { ; CHECK-LABEL: @lshr_pow2_ugt_vec( -; CHECK-NEXT: [[S:%.*]] = lshr <2 x i8> , [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = icmp ugt <2 x i8> [[S]], +; CHECK-NEXT: [[R:%.*]] = icmp eq <2 x i8> [[X:%.*]], zeroinitializer ; CHECK-NEXT: ret <2 x i1> [[R]] ; %s = lshr <2 x i8> , %x @@ -1041,6 +1039,8 @@ define <2 x i1> @lshr_pow2_ugt_vec(<2 x i8> %x) { ret <2 x i1> %r } +; negative test - need power-of-2 + define i1 @lshr_not_pow2_ugt(i8 %x) { ; CHECK-LABEL: @lshr_not_pow2_ugt( ; CHECK-NEXT: [[S:%.*]] = lshr i8 3, [[X:%.*]] @@ -1054,8 +1054,7 @@ define i1 @lshr_not_pow2_ugt(i8 %x) { define i1 @lshr_pow2_ugt1(i8 %x) { ; CHECK-LABEL: @lshr_pow2_ugt1( -; CHECK-NEXT: [[S:%.*]] = lshr i8 -128, [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = icmp ugt i8 [[S]], 1 +; CHECK-NEXT: [[R:%.*]] = icmp ult i8 [[X:%.*]], 7 ; CHECK-NEXT: ret i1 [[R]] ; %s = lshr i8 128, %x @@ -1063,6 +1062,8 @@ define i1 @lshr_pow2_ugt1(i8 %x) { ret i1 %r } +; negative test - need logical shift + define i1 @ashr_pow2_ugt(i8 %x) { ; CHECK-LABEL: @ashr_pow2_ugt( ; CHECK-NEXT: [[S:%.*]] = ashr i8 -128, [[X:%.*]] @@ -1074,6 +1075,8 @@ define i1 @ashr_pow2_ugt(i8 %x) { ret i1 %r } +; negative test - need unsigned pred + define i1 @lshr_pow2_sgt(i8 %x) { ; CHECK-LABEL: @lshr_pow2_sgt( ; CHECK-NEXT: [[S:%.*]] = lshr i8 -128, [[X:%.*]]