diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 661c50062223c..3b7875dd761bc 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2584,6 +2584,48 @@ static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT, return nullptr; } +/// Tries to reduce a pattern that arises when calculating the remainder of the +/// Euclidean division. When the divisor is a power of two and is guaranteed not +/// to be negative, a signed remainder can be folded with a bitwise and. +/// +/// (x % n) < 0 ? (x % n) + n : (x % n) +/// -> x & (n - 1) +static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC, + IRBuilderBase &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + ICmpInst::Predicate Pred; + Value *Op, *RemRes, *Remainder; + const APInt *C; + bool TrueIfSigned = false; + + if (!(match(CondVal, m_ICmp(Pred, m_Value(RemRes), m_APInt(C))) && + IC.isSignBitCheck(Pred, *C, TrueIfSigned))) + return nullptr; + + // If the sign bit is not set, we have a SGE/SGT comparison, and the operands + // of the select are inverted. + if (!TrueIfSigned) + std::swap(TrueVal, FalseVal); + + // We are matching a quite specific pattern here: + // %rem = srem i32 %x, %n + // %cnd = icmp slt i32 %rem, 0 + // %add = add i32 %rem, %n + // %sel = select i1 %cnd, i32 %add, i32 %rem + if (!(match(TrueVal, m_Add(m_Value(RemRes), m_Value(Remainder))) && + match(RemRes, m_SRem(m_Value(Op), m_Specific(Remainder))) && + IC.isKnownToBeAPowerOfTwo(Remainder, /*OrZero*/ true) && + FalseVal == RemRes)) + return nullptr; + + Value *Add = Builder.CreateAdd(Remainder, + Constant::getAllOnesValue(RemRes->getType())); + return BinaryOperator::CreateAnd(Op, Add); +} + static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) { FreezeInst *FI = dyn_cast(Sel.getCondition()); if (!FI) @@ -3430,6 +3472,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Instruction *I = foldSelectExtConst(SI)) return I; + if (Instruction *I = foldSelectWithSRem(SI, *this, Builder)) + return I; + // Fold (select C, (gep Ptr, Idx), Ptr) -> (gep Ptr, (select C, Idx, 0)) // Fold (select C, Ptr, (gep Ptr, Idx)) -> (gep Ptr, (select C, 0, Idx)) auto SelectGepWithBase = [&](GetElementPtrInst *Gep, Value *Base, diff --git a/llvm/test/Transforms/InstCombine/select-divrem.ll b/llvm/test/Transforms/InstCombine/select-divrem.ll index 1343191e349d7..a5b56609d6062 100644 --- a/llvm/test/Transforms/InstCombine/select-divrem.ll +++ b/llvm/test/Transforms/InstCombine/select-divrem.ll @@ -216,10 +216,7 @@ define i5 @urem_common_dividend_defined_cond(i1 noundef %b, i5 %x, i5 %y, i5 %z) define i32 @rem_euclid_1(i32 %0) { ; CHECK-LABEL: @rem_euclid_1( -; CHECK-NEXT: [[REM:%.*]] = srem i32 [[TMP0:%.*]], 8 -; CHECK-NEXT: [[COND:%.*]] = icmp slt i32 [[REM]], 0 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[REM]], 8 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], i32 [[ADD]], i32 [[REM]] +; CHECK-NEXT: [[SEL:%.*]] = and i32 [[TMP0:%.*]], 7 ; CHECK-NEXT: ret i32 [[SEL]] ; %rem = srem i32 %0, 8 @@ -231,10 +228,7 @@ define i32 @rem_euclid_1(i32 %0) { define i32 @rem_euclid_2(i32 %0) { ; CHECK-LABEL: @rem_euclid_2( -; CHECK-NEXT: [[REM:%.*]] = srem i32 [[TMP0:%.*]], 8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[REM]], 8 -; CHECK-NEXT: [[COND1:%.*]] = icmp slt i32 [[REM]], 0 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND1]], i32 [[ADD]], i32 [[REM]] +; CHECK-NEXT: [[SEL:%.*]] = and i32 [[TMP0:%.*]], 7 ; CHECK-NEXT: ret i32 [[SEL]] ; %rem = srem i32 %0, 8 @@ -291,10 +285,7 @@ define i32 @rem_euclid_wrong_operands_select(i32 %0) { define <2 x i32> @rem_euclid_vec(<2 x i32> %0) { ; CHECK-LABEL: @rem_euclid_vec( -; CHECK-NEXT: [[REM:%.*]] = srem <2 x i32> [[TMP0:%.*]], -; CHECK-NEXT: [[COND:%.*]] = icmp slt <2 x i32> [[REM]], zeroinitializer -; CHECK-NEXT: [[ADD:%.*]] = add nsw <2 x i32> [[REM]], -; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x i32> [[ADD]], <2 x i32> [[REM]] +; CHECK-NEXT: [[SEL:%.*]] = and <2 x i32> [[TMP0:%.*]], ; CHECK-NEXT: ret <2 x i32> [[SEL]] ; %rem = srem <2 x i32> %0, @@ -306,10 +297,7 @@ define <2 x i32> @rem_euclid_vec(<2 x i32> %0) { define i128 @rem_euclid_i128(i128 %0) { ; CHECK-LABEL: @rem_euclid_i128( -; CHECK-NEXT: [[REM:%.*]] = srem i128 [[TMP0:%.*]], 8 -; CHECK-NEXT: [[COND:%.*]] = icmp slt i128 [[REM]], 0 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i128 [[REM]], 8 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], i128 [[ADD]], i128 [[REM]] +; CHECK-NEXT: [[SEL:%.*]] = and i128 [[TMP0:%.*]], 7 ; CHECK-NEXT: ret i128 [[SEL]] ; %rem = srem i128 %0, 8 @@ -321,11 +309,9 @@ define i128 @rem_euclid_i128(i128 %0) { define i8 @rem_euclid_non_const_pow2(i8 %0, i8 %1) { ; CHECK-LABEL: @rem_euclid_non_const_pow2( -; CHECK-NEXT: [[POW2:%.*]] = shl nuw i8 1, [[TMP0:%.*]] -; CHECK-NEXT: [[REM:%.*]] = srem i8 [[TMP1:%.*]], [[POW2]] -; CHECK-NEXT: [[COND:%.*]] = icmp slt i8 [[REM]], 0 -; CHECK-NEXT: [[ADD:%.*]] = select i1 [[COND]], i8 [[POW2]], i8 0 -; CHECK-NEXT: [[SEL:%.*]] = add i8 [[REM]], [[ADD]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i8 -1, [[TMP0:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = xor i8 [[NOTMASK]], -1 +; CHECK-NEXT: [[SEL:%.*]] = and i8 [[TMP3]], [[TMP1:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] ; %pow2 = shl i8 1, %0