Skip to content

Commit

Permalink
[InstCombine] Fold select of srem and conditional add
Browse files Browse the repository at this point in the history
Simplify a pattern that may show up when computing
the remainder of euclidean division. Particularly,
when the divisor is a power of two and never negative,
the signed remainder can be folded with a bitwise and.

Fixes 64305.

Proofs: https://alive2.llvm.org/ce/z/9_KG6c

Differential Revision: https://reviews.llvm.org/D156811
  • Loading branch information
antoniofrighetto committed Aug 8, 2023
1 parent f5cb626 commit 2116921
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 21 deletions.
45 changes: 45 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2584,6 +2584,48 @@ static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT,
return nullptr;
}

/// Tries to reduce a pattern that arises when calculating the remainder of the
/// Euclidean division. When the divisor is a power of two and is guaranteed not
/// to be negative, a signed remainder can be folded with a bitwise and.
///
/// (x % n) < 0 ? (x % n) + n : (x % n)
/// -> x & (n - 1)
static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC,
IRBuilderBase &Builder) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
Value *FalseVal = SI.getFalseValue();

ICmpInst::Predicate Pred;
Value *Op, *RemRes, *Remainder;
const APInt *C;
bool TrueIfSigned = false;

if (!(match(CondVal, m_ICmp(Pred, m_Value(RemRes), m_APInt(C))) &&
IC.isSignBitCheck(Pred, *C, TrueIfSigned)))
return nullptr;

// If the sign bit is not set, we have a SGE/SGT comparison, and the operands
// of the select are inverted.
if (!TrueIfSigned)
std::swap(TrueVal, FalseVal);

// We are matching a quite specific pattern here:
// %rem = srem i32 %x, %n
// %cnd = icmp slt i32 %rem, 0
// %add = add i32 %rem, %n
// %sel = select i1 %cnd, i32 %add, i32 %rem
if (!(match(TrueVal, m_Add(m_Value(RemRes), m_Value(Remainder))) &&
match(RemRes, m_SRem(m_Value(Op), m_Specific(Remainder))) &&
IC.isKnownToBeAPowerOfTwo(Remainder, /*OrZero*/ true) &&
FalseVal == RemRes))
return nullptr;

Value *Add = Builder.CreateAdd(Remainder,
Constant::getAllOnesValue(RemRes->getType()));
return BinaryOperator::CreateAnd(Op, Add);
}

static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) {
FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition());
if (!FI)
Expand Down Expand Up @@ -3430,6 +3472,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Instruction *I = foldSelectExtConst(SI))
return I;

if (Instruction *I = foldSelectWithSRem(SI, *this, Builder))
return I;

// Fold (select C, (gep Ptr, Idx), Ptr) -> (gep Ptr, (select C, Idx, 0))
// Fold (select C, Ptr, (gep Ptr, Idx)) -> (gep Ptr, (select C, 0, Idx))
auto SelectGepWithBase = [&](GetElementPtrInst *Gep, Value *Base,
Expand Down
28 changes: 7 additions & 21 deletions llvm/test/Transforms/InstCombine/select-divrem.ll
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,7 @@ define i5 @urem_common_dividend_defined_cond(i1 noundef %b, i5 %x, i5 %y, i5 %z)

define i32 @rem_euclid_1(i32 %0) {
; CHECK-LABEL: @rem_euclid_1(
; CHECK-NEXT: [[REM:%.*]] = srem i32 [[TMP0:%.*]], 8
; CHECK-NEXT: [[COND:%.*]] = icmp slt i32 [[REM]], 0
; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[REM]], 8
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], i32 [[ADD]], i32 [[REM]]
; CHECK-NEXT: [[SEL:%.*]] = and i32 [[TMP0:%.*]], 7
; CHECK-NEXT: ret i32 [[SEL]]
;
%rem = srem i32 %0, 8
Expand All @@ -231,10 +228,7 @@ define i32 @rem_euclid_1(i32 %0) {

define i32 @rem_euclid_2(i32 %0) {
; CHECK-LABEL: @rem_euclid_2(
; CHECK-NEXT: [[REM:%.*]] = srem i32 [[TMP0:%.*]], 8
; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[REM]], 8
; CHECK-NEXT: [[COND1:%.*]] = icmp slt i32 [[REM]], 0
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND1]], i32 [[ADD]], i32 [[REM]]
; CHECK-NEXT: [[SEL:%.*]] = and i32 [[TMP0:%.*]], 7
; CHECK-NEXT: ret i32 [[SEL]]
;
%rem = srem i32 %0, 8
Expand Down Expand Up @@ -291,10 +285,7 @@ define i32 @rem_euclid_wrong_operands_select(i32 %0) {

define <2 x i32> @rem_euclid_vec(<2 x i32> %0) {
; CHECK-LABEL: @rem_euclid_vec(
; CHECK-NEXT: [[REM:%.*]] = srem <2 x i32> [[TMP0:%.*]], <i32 8, i32 8>
; CHECK-NEXT: [[COND:%.*]] = icmp slt <2 x i32> [[REM]], zeroinitializer
; CHECK-NEXT: [[ADD:%.*]] = add nsw <2 x i32> [[REM]], <i32 8, i32 8>
; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x i32> [[ADD]], <2 x i32> [[REM]]
; CHECK-NEXT: [[SEL:%.*]] = and <2 x i32> [[TMP0:%.*]], <i32 7, i32 7>
; CHECK-NEXT: ret <2 x i32> [[SEL]]
;
%rem = srem <2 x i32> %0, <i32 8, i32 8>
Expand All @@ -306,10 +297,7 @@ define <2 x i32> @rem_euclid_vec(<2 x i32> %0) {

define i128 @rem_euclid_i128(i128 %0) {
; CHECK-LABEL: @rem_euclid_i128(
; CHECK-NEXT: [[REM:%.*]] = srem i128 [[TMP0:%.*]], 8
; CHECK-NEXT: [[COND:%.*]] = icmp slt i128 [[REM]], 0
; CHECK-NEXT: [[ADD:%.*]] = add nsw i128 [[REM]], 8
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], i128 [[ADD]], i128 [[REM]]
; CHECK-NEXT: [[SEL:%.*]] = and i128 [[TMP0:%.*]], 7
; CHECK-NEXT: ret i128 [[SEL]]
;
%rem = srem i128 %0, 8
Expand All @@ -321,11 +309,9 @@ define i128 @rem_euclid_i128(i128 %0) {

define i8 @rem_euclid_non_const_pow2(i8 %0, i8 %1) {
; CHECK-LABEL: @rem_euclid_non_const_pow2(
; CHECK-NEXT: [[POW2:%.*]] = shl nuw i8 1, [[TMP0:%.*]]
; CHECK-NEXT: [[REM:%.*]] = srem i8 [[TMP1:%.*]], [[POW2]]
; CHECK-NEXT: [[COND:%.*]] = icmp slt i8 [[REM]], 0
; CHECK-NEXT: [[ADD:%.*]] = select i1 [[COND]], i8 [[POW2]], i8 0
; CHECK-NEXT: [[SEL:%.*]] = add i8 [[REM]], [[ADD]]
; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i8 -1, [[TMP0:%.*]]
; CHECK-NEXT: [[TMP3:%.*]] = xor i8 [[NOTMASK]], -1
; CHECK-NEXT: [[SEL:%.*]] = and i8 [[TMP3]], [[TMP1:%.*]]
; CHECK-NEXT: ret i8 [[SEL]]
;
%pow2 = shl i8 1, %0
Expand Down

0 comments on commit 2116921

Please sign in to comment.