Skip to content

Commit

Permalink
[DAGCombiner][AArch64] Enhance to fold CSNEG into CSINC instruction
Browse files Browse the repository at this point in the history
Perform the scalar expression combine in the form of:
  CSNEG(1, c, cc) + b  =>  cc  ? b+1 : b-c => CSINC(b-c, b, !cc)
  CSNEG(c, -1, cc) + b =>  cc  ? b+c : b+1 => CSINC(b+c, b, cc)

Fix #53071

Reviewed By: dmgreen

Differential Revision: https://reviews.llvm.org/D119105
  • Loading branch information
vfdff authored and guopeilin committed Feb 16, 2022
1 parent 988a3ba commit 064b2a6
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 27 deletions.
54 changes: 38 additions & 16 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -14751,54 +14751,76 @@ static SDValue performAddUADDVCombine(SDNode *N, SelectionDAG &DAG) {
}

/// Perform the scalar expression combine in the form of:
/// CSEL (c, 1, cc) + b => CSINC(b+c, b, cc)
/// CSEL(c, 1, cc) + b => CSINC(b+c, b, cc)
/// CSNEG(c, -1, cc) + b => CSINC(b+c, b, cc)
static SDValue performAddCSelIntoCSinc(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
if (!VT.isScalarInteger() || N->getOpcode() != ISD::ADD)
return SDValue();

SDValue CSel = N->getOperand(0);
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);

// Handle commutivity.
if (CSel.getOpcode() != AArch64ISD::CSEL) {
std::swap(CSel, RHS);
if (CSel.getOpcode() != AArch64ISD::CSEL) {
if (LHS.getOpcode() != AArch64ISD::CSEL &&
LHS.getOpcode() != AArch64ISD::CSNEG) {
std::swap(LHS, RHS);
if (LHS.getOpcode() != AArch64ISD::CSEL &&
LHS.getOpcode() != AArch64ISD::CSNEG) {
return SDValue();
}
}

if (!CSel.hasOneUse())
if (!LHS.hasOneUse())
return SDValue();

AArch64CC::CondCode AArch64CC =
static_cast<AArch64CC::CondCode>(CSel.getConstantOperandVal(2));
static_cast<AArch64CC::CondCode>(LHS.getConstantOperandVal(2));

// The CSEL should include a const one operand.
ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(CSel.getOperand(0));
ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(CSel.getOperand(1));
if (!CTVal || !CFVal || (!CTVal->isOne() && !CFVal->isOne()))
// The CSEL should include a const one operand, and the CSNEG should include
// One or NegOne operand.
ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(LHS.getOperand(0));
ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(LHS.getOperand(1));
if (!CTVal || !CFVal)
return SDValue();

// switch CSEL (1, c, cc) to CSEL (c, 1, !cc)
if (CTVal->isOne() && !CFVal->isOne()) {
if (!(LHS.getOpcode() == AArch64ISD::CSEL &&
(CTVal->isOne() || CFVal->isOne())) &&
!(LHS.getOpcode() == AArch64ISD::CSNEG &&
(CTVal->isOne() || CFVal->isAllOnes())))
return SDValue();

// Switch CSEL(1, c, cc) to CSEL(c, 1, !cc)
if (LHS.getOpcode() == AArch64ISD::CSEL && CTVal->isOne() &&
!CFVal->isOne()) {
std::swap(CTVal, CFVal);
AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC);
}

SDLoc DL(N);
// Switch CSNEG(1, c, cc) to CSNEG(-c, -1, !cc)
if (LHS.getOpcode() == AArch64ISD::CSNEG && CTVal->isOne() &&
!CFVal->isAllOnes()) {
APInt C = -1 * CFVal->getAPIntValue();
CTVal = cast<ConstantSDNode>(DAG.getConstant(C, DL, VT));
CFVal = cast<ConstantSDNode>(DAG.getAllOnesConstant(DL, VT));
AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC);
}

// It might be neutral for larger constants, as the immediate need to be
// materialized in a register.
APInt ADDC = CTVal->getAPIntValue();
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.isLegalAddImmediate(ADDC.getSExtValue()))
return SDValue();

assert(CFVal->isOne() && "Unexpected constant value");
assert(((LHS.getOpcode() == AArch64ISD::CSEL && CFVal->isOne()) ||
(LHS.getOpcode() == AArch64ISD::CSNEG && CFVal->isAllOnes())) &&
"Unexpected constant value");

SDLoc DL(N);
SDValue NewNode = DAG.getNode(ISD::ADD, DL, VT, RHS, SDValue(CTVal, 0));
SDValue CCVal = DAG.getConstant(AArch64CC, DL, MVT::i32);
SDValue Cmp = CSel.getOperand(3);
SDValue Cmp = LHS.getOperand(3);

return DAG.getNode(AArch64ISD::CSINC, DL, VT, NewNode, RHS, CCVal, Cmp);
}
Expand Down
38 changes: 34 additions & 4 deletions llvm/test/CodeGen/AArch64/aarch64-isel-csinc-type.ll
Expand Up @@ -7,7 +7,7 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
target triple = "aarch64-unknown-linux-gnu"

; char csinc1 (char a, char b) { return !a ? b+1 : b+3; }
define dso_local i8 @csinc1(i8 %a, i8 %b) local_unnamed_addr #0 {
define i8 @csinc1(i8 %a, i8 %b) local_unnamed_addr #0 {
; CHECK-LABEL: csinc1:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: tst w0, #0xff
Expand All @@ -22,7 +22,7 @@ entry:
}

; short csinc2 (short a, short b) { return !a ? b+1 : b+3; }
define dso_local i16 @csinc2(i16 %a, i16 %b) local_unnamed_addr #0 {
define i16 @csinc2(i16 %a, i16 %b) local_unnamed_addr #0 {
; CHECK-LABEL: csinc2:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: tst w0, #0xffff
Expand All @@ -37,7 +37,7 @@ entry:
}

; int csinc3 (int a, int b) { return !a ? b+1 : b+3; }
define dso_local i32 @csinc3(i32 %a, i32 %b) local_unnamed_addr #0 {
define i32 @csinc3(i32 %a, i32 %b) local_unnamed_addr #0 {
; CHECK-LABEL: csinc3:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: cmp w0, #0
Expand All @@ -52,7 +52,7 @@ entry:
}

; long long csinc4 (long long a, long long b) { return !a ? b+1 : b+3; }
define dso_local i64 @csinc4(i64 %a, i64 %b) local_unnamed_addr #0 {
define i64 @csinc4(i64 %a, i64 %b) local_unnamed_addr #0 {
; CHECK-LABEL: csinc4:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: cmp x0, #0
Expand All @@ -65,3 +65,33 @@ entry:
%cond = add nsw i64 %cond.v, %b
ret i64 %cond
}

; long long csinc8 (long long a, long long b) { return a ? b-1 : b+1; }
define i64 @csinc8(i64 %a, i64 %b) {
; CHECK-LABEL: csinc8:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sub x8, x1, #1
; CHECK-NEXT: cmp x0, #0
; CHECK-NEXT: csinc x0, x8, x1, ne
; CHECK-NEXT: ret
entry:
%tobool.not = icmp eq i64 %a, 0
%cond.v = select i1 %tobool.not, i64 1, i64 -1
%cond = add nsw i64 %cond.v, %b
ret i64 %cond
}

; long long csinc9 (long long a, long long b) { return a ? b+1 : b-1; }
define i64 @csinc9(i64 %a, i64 %b) {
; CHECK-LABEL: csinc9:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sub x8, x1, #1
; CHECK-NEXT: cmp x0, #0
; CHECK-NEXT: csinc x0, x8, x1, eq
; CHECK-NEXT: ret
entry:
%tobool.not = icmp eq i64 %a, 0
%cond.v = select i1 %tobool.not, i64 -1, i64 1
%cond = add nsw i64 %cond.v, %b
ret i64 %cond
}
44 changes: 37 additions & 7 deletions llvm/test/CodeGen/AArch64/aarch64-isel-csinc.ll
Expand Up @@ -7,7 +7,7 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
target triple = "aarch64-unknown-linux-gnu"

; int csinc1 (int a, int b) { return !a ? b+3 : b+1; }
define dso_local i32 @csinc1(i32 %a, i32 %b) {
define i32 @csinc1(i32 %a, i32 %b) {
; CHECK-LABEL: csinc1:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: cmp w0, #0
Expand All @@ -22,7 +22,7 @@ entry:
}

; int csinc2 (int a, int b) { return a ? b+3 : b+1; }
define dso_local i32 @csinc2(i32 %a, i32 %b) {
define i32 @csinc2(i32 %a, i32 %b) {
; CHECK-LABEL: csinc2:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: cmp w0, #0
Expand All @@ -37,7 +37,7 @@ entry:
}

; int csinc3 (int a, int b) { return !a ? b+1 : b-3; }
define dso_local i32 @csinc3(i32 %a, i32 %b) {
define i32 @csinc3(i32 %a, i32 %b) {
; CHECK-LABEL: csinc3:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sub w8, w1, #3
Expand All @@ -52,7 +52,7 @@ entry:
}

; int csinc4 (int a, int b) { return a ? b+1 : b-3; }
define dso_local i32 @csinc4(i32 %a, i32 %b) {
define i32 @csinc4(i32 %a, i32 %b) {
; CHECK-LABEL: csinc4:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sub w8, w1, #3
Expand All @@ -67,7 +67,7 @@ entry:
}

; int csinc5 (int a, int b) { return a ? b+1 : b-4095; }
define dso_local i32 @csinc5(i32 %a, i32 %b) {
define i32 @csinc5(i32 %a, i32 %b) {
; CHECK-LABEL: csinc5:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sub w8, w1, #4095
Expand All @@ -82,7 +82,7 @@ entry:
}

; int csinc6 (int a, int b) { return a ? b+1 : b-4096; }
define dso_local i32 @csinc6(i32 %a, i32 %b) {
define i32 @csinc6(i32 %a, i32 %b) {
; CHECK-LABEL: csinc6:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sub w8, w1, #1, lsl #12 // =4096
Expand All @@ -98,7 +98,7 @@ entry:

; prevent larger constants (the add laid after csinc)
; int csinc7 (int a, int b) { return a ? b+1 : b-4097; }
define dso_local i32 @csinc7(i32 %a, i32 %b) {
define i32 @csinc7(i32 %a, i32 %b) {
; CHECK-LABEL: csinc7:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: cmp w0, #0
Expand All @@ -112,3 +112,33 @@ entry:
%cond = add nsw i32 %cond.v, %b
ret i32 %cond
}

; int csinc8 (int a, int b) { return a ? b-1 : b+1; }
define i32 @csinc8(i32 %a, i32 %b) {
; CHECK-LABEL: csinc8:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sub w8, w1, #1
; CHECK-NEXT: cmp w0, #0
; CHECK-NEXT: csinc w0, w8, w1, ne
; CHECK-NEXT: ret
entry:
%tobool.not = icmp eq i32 %a, 0
%cond.v = select i1 %tobool.not, i32 1, i32 -1
%cond = add nsw i32 %cond.v, %b
ret i32 %cond
}

; int csinc9 (int a, int b) { return a ? b+1 : b-1; }
define i32 @csinc9(i32 %a, i32 %b) {
; CHECK-LABEL: csinc9:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sub w8, w1, #1
; CHECK-NEXT: cmp w0, #0
; CHECK-NEXT: csinc w0, w8, w1, eq
; CHECK-NEXT: ret
entry:
%tobool.not = icmp eq i32 %a, 0
%cond.v = select i1 %tobool.not, i32 -1, i32 1
%cond = add nsw i32 %cond.v, %b
ret i32 %cond
}

0 comments on commit 064b2a6

Please sign in to comment.