Skip to content

Commit

Permalink
[InstCombine] fold more icmp + select patterns by distributive laws
Browse files Browse the repository at this point in the history
follow up D139076, add icmp with not only eq/ne, but also gt/lt/ge/le.

Reviewed By: spatel

Differential Revision: https://reviews.llvm.org/D139253
  • Loading branch information
bcl5980 committed Dec 7, 2022
1 parent 10c3df7 commit b4c8cfc
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 64 deletions.
59 changes: 36 additions & 23 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Expand Up @@ -316,33 +316,42 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,

Value *OtherOpT, *OtherOpF;
bool MatchIsOpZero;
auto getCommonOp = [&](Instruction *TI, Instruction *FI,
bool Commute) -> Value * {
Value *CommonOp = nullptr;
if (TI->getOperand(0) == FI->getOperand(0)) {
CommonOp = TI->getOperand(0);
OtherOpT = TI->getOperand(1);
OtherOpF = FI->getOperand(1);
MatchIsOpZero = true;
} else if (TI->getOperand(1) == FI->getOperand(1)) {
CommonOp = TI->getOperand(1);
OtherOpT = TI->getOperand(0);
OtherOpF = FI->getOperand(0);
MatchIsOpZero = false;
} else if (!Commute) {
auto getCommonOp = [&](Instruction *TI, Instruction *FI, bool Commute,
bool Swapped = false) -> Value * {
assert(!(Commute && Swapped) &&
"Commute and Swapped can't set at the same time");
if (!Swapped) {
if (TI->getOperand(0) == FI->getOperand(0)) {
OtherOpT = TI->getOperand(1);
OtherOpF = FI->getOperand(1);
MatchIsOpZero = true;
return TI->getOperand(0);
} else if (TI->getOperand(1) == FI->getOperand(1)) {
OtherOpT = TI->getOperand(0);
OtherOpF = FI->getOperand(0);
MatchIsOpZero = false;
return TI->getOperand(1);
}
}

if (!Commute && !Swapped)
return nullptr;
} else if (TI->getOperand(0) == FI->getOperand(1)) {
CommonOp = TI->getOperand(0);

// If we are allowing commute or swap of operands, then
// allow a cross-operand match. In that case, MatchIsOpZero
// means that TI's operand 0 (FI's operand 1) is the common op.
if (TI->getOperand(0) == FI->getOperand(1)) {
OtherOpT = TI->getOperand(1);
OtherOpF = FI->getOperand(0);
MatchIsOpZero = true;
return TI->getOperand(0);
} else if (TI->getOperand(1) == FI->getOperand(0)) {
CommonOp = TI->getOperand(1);
OtherOpT = TI->getOperand(0);
OtherOpF = FI->getOperand(1);
MatchIsOpZero = true;
MatchIsOpZero = false;
return TI->getOperand(1);
}
return CommonOp;
return nullptr;
};

if (TI->hasOneUse() || FI->hasOneUse()) {
Expand Down Expand Up @@ -379,16 +388,20 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
}
}

// icmp eq/ne with a common operand also can have the common operand
// icmp with a common operand also can have the common operand
// pulled after the select.
ICmpInst::Predicate TPred, FPred;
if (match(TI, m_ICmp(TPred, m_Value(), m_Value())) &&
match(FI, m_ICmp(FPred, m_Value(), m_Value()))) {
if (TPred == FPred && ICmpInst::isEquality(TPred)) {
if (Value *MatchOp = getCommonOp(TI, FI, true)) {
if (TPred == FPred || TPred == CmpInst::getSwappedPredicate(FPred)) {
bool Swapped = TPred != FPred;
if (Value *MatchOp =
getCommonOp(TI, FI, ICmpInst::isEquality(TPred), Swapped)) {
Value *NewSel = Builder.CreateSelect(Cond, OtherOpT, OtherOpF,
SI.getName() + ".v", &SI);
return new ICmpInst(TPred, NewSel, MatchOp);
return new ICmpInst(
MatchIsOpZero ? TPred : CmpInst::getSwappedPredicate(TPred),
MatchOp, NewSel);
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions llvm/test/Transforms/InstCombine/select-bitext.ll
Expand Up @@ -485,10 +485,9 @@ define i32 @sel_zext_const_uses(i8 %a, i8 %x) {

define i32 @test_op_op(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: @test_op_op(
; CHECK-NEXT: [[CCA:%.*]] = icmp sgt i32 [[A:%.*]], 0
; CHECK-NEXT: [[CCB:%.*]] = icmp sgt i32 [[B:%.*]], 0
; CHECK-NEXT: [[CCC:%.*]] = icmp sgt i32 [[C:%.*]], 0
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[CCC]], i1 [[CCA]], i1 [[CCB]]
; CHECK-NEXT: [[R_V_V:%.*]] = select i1 [[CCC]], i32 [[A:%.*]], i32 [[B:%.*]]
; CHECK-NEXT: [[R_V:%.*]] = icmp sgt i32 [[R_V_V]], 0
; CHECK-NEXT: [[R:%.*]] = sext i1 [[R_V]] to i32
; CHECK-NEXT: ret i32 [[R]]
;
Expand Down
64 changes: 26 additions & 38 deletions llvm/test/Transforms/InstCombine/select-cmp.ll
Expand Up @@ -124,9 +124,8 @@ define i1 @icmp_common_one_use_1(i1 %c, i8 %x, i8 %y, i8 %z) {

define i1 @icmp_slt_common(i1 %c, i6 %x, i6 %y, i6 %z) {
; CHECK-LABEL: @icmp_slt_common(
; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i6 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i6 [[X]], [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = icmp sgt i6 [[R_V]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%cmp1 = icmp slt i6 %x, %y
Expand All @@ -137,9 +136,8 @@ define i1 @icmp_slt_common(i1 %c, i6 %x, i6 %y, i6 %z) {

define i1 @icmp_sgt_common(i1 %c, i6 %x, i6 %y, i6 %z) {
; CHECK-LABEL: @icmp_sgt_common(
; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i6 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i6 [[X]], [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = icmp slt i6 [[R_V]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%cmp1 = icmp sgt i6 %x, %y
Expand All @@ -150,9 +148,8 @@ define i1 @icmp_sgt_common(i1 %c, i6 %x, i6 %y, i6 %z) {

define i1 @icmp_sle_common(i1 %c, i6 %x, i6 %y, i6 %z) {
; CHECK-LABEL: @icmp_sle_common(
; CHECK-NEXT: [[CMP1:%.*]] = icmp sle i6 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp sle i6 [[Z:%.*]], [[X]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = icmp sle i6 [[R_V]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%cmp1 = icmp sle i6 %y, %x
Expand All @@ -163,9 +160,8 @@ define i1 @icmp_sle_common(i1 %c, i6 %x, i6 %y, i6 %z) {

define i1 @icmp_sge_common(i1 %c, i6 %x, i6 %y, i6 %z) {
; CHECK-LABEL: @icmp_sge_common(
; CHECK-NEXT: [[CMP1:%.*]] = icmp sge i6 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp sge i6 [[Z:%.*]], [[X]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = icmp sge i6 [[R_V]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%cmp1 = icmp sge i6 %y, %x
Expand All @@ -176,9 +172,8 @@ define i1 @icmp_sge_common(i1 %c, i6 %x, i6 %y, i6 %z) {

define i1 @icmp_slt_sgt_common(i1 %c, i6 %x, i6 %y, i6 %z) {
; CHECK-LABEL: @icmp_slt_sgt_common(
; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i6 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i6 [[Z:%.*]], [[X]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = icmp sgt i6 [[R_V]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%cmp1 = icmp slt i6 %x, %y
Expand All @@ -189,9 +184,8 @@ define i1 @icmp_slt_sgt_common(i1 %c, i6 %x, i6 %y, i6 %z) {

define i1 @icmp_sle_sge_common(i1 %c, i6 %x, i6 %y, i6 %z) {
; CHECK-LABEL: @icmp_sle_sge_common(
; CHECK-NEXT: [[CMP1:%.*]] = icmp sle i6 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp sge i6 [[X]], [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = icmp sle i6 [[R_V]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%cmp1 = icmp sle i6 %y, %x
Expand All @@ -202,9 +196,8 @@ define i1 @icmp_sle_sge_common(i1 %c, i6 %x, i6 %y, i6 %z) {

define i1 @icmp_ult_common(i1 %c, i6 %x, i6 %y, i6 %z) {
; CHECK-LABEL: @icmp_ult_common(
; CHECK-NEXT: [[CMP1:%.*]] = icmp ult i6 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp ult i6 [[X]], [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = icmp ugt i6 [[R_V]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%cmp1 = icmp ult i6 %x, %y
Expand All @@ -215,9 +208,8 @@ define i1 @icmp_ult_common(i1 %c, i6 %x, i6 %y, i6 %z) {

define i1 @icmp_ule_common(i1 %c, i6 %x, i6 %y, i6 %z) {
; CHECK-LABEL: @icmp_ule_common(
; CHECK-NEXT: [[CMP1:%.*]] = icmp ule i6 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp ule i6 [[Z:%.*]], [[X]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = icmp ule i6 [[R_V]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%cmp1 = icmp ule i6 %y, %x
Expand All @@ -228,9 +220,8 @@ define i1 @icmp_ule_common(i1 %c, i6 %x, i6 %y, i6 %z) {

define i1 @icmp_ugt_common(i1 %c, i8 %x, i8 %y, i8 %z) {
; CHECK-LABEL: @icmp_ugt_common(
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[Z:%.*]], [[X]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[C:%.*]], i8 [[Y:%.*]], i8 [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = icmp ugt i8 [[R_V]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%cmp1 = icmp ugt i8 %y, %x
Expand All @@ -241,9 +232,8 @@ define i1 @icmp_ugt_common(i1 %c, i8 %x, i8 %y, i8 %z) {

define i1 @icmp_uge_common(i1 %c, i6 %x, i6 %y, i6 %z) {
; CHECK-LABEL: @icmp_uge_common(
; CHECK-NEXT: [[CMP1:%.*]] = icmp uge i6 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp uge i6 [[Z:%.*]], [[X]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = icmp uge i6 [[R_V]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%cmp1 = icmp uge i6 %y, %x
Expand All @@ -254,9 +244,8 @@ define i1 @icmp_uge_common(i1 %c, i6 %x, i6 %y, i6 %z) {

define i1 @icmp_ult_ugt_common(i1 %c, i6 %x, i6 %y, i6 %z) {
; CHECK-LABEL: @icmp_ult_ugt_common(
; CHECK-NEXT: [[CMP1:%.*]] = icmp ult i6 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i6 [[Z:%.*]], [[X]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = icmp ugt i6 [[R_V]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%cmp1 = icmp ult i6 %x, %y
Expand All @@ -267,9 +256,8 @@ define i1 @icmp_ult_ugt_common(i1 %c, i6 %x, i6 %y, i6 %z) {

define i1 @icmp_ule_uge_common(i1 %c, i6 %x, i6 %y, i6 %z) {
; CHECK-LABEL: @icmp_ule_uge_common(
; CHECK-NEXT: [[CMP1:%.*]] = icmp ule i6 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp uge i6 [[X]], [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
; CHECK-NEXT: [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
; CHECK-NEXT: [[R:%.*]] = icmp ule i6 [[R_V]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%cmp1 = icmp ule i6 %y, %x
Expand All @@ -293,7 +281,7 @@ define i1 @icmp_common_pred_different(i1 %c, i8 %x, i8 %y, i8 %z) {
ret i1 %r
}

; negative test: two pred is not swap
; negative test for non-equality: two pred is not swap

define i1 @icmp_common_pred_not_swap(i1 %c, i8 %x, i8 %y, i8 %z) {
; CHECK-LABEL: @icmp_common_pred_not_swap(
Expand All @@ -308,7 +296,7 @@ define i1 @icmp_common_pred_not_swap(i1 %c, i8 %x, i8 %y, i8 %z) {
ret i1 %r
}

; negative test: not commute pred
; negative test for non-equality: not commute pred

define i1 @icmp_common_pred_not_commute_pred(i1 %c, i8 %x, i8 %y, i8 %z) {
; CHECK-LABEL: @icmp_common_pred_not_commute_pred(
Expand Down

0 comments on commit b4c8cfc

Please sign in to comment.