Skip to content

Commit

Permalink
[AArch64] Adjust operand sequence for Add+Sub to combine more inline …
Browse files Browse the repository at this point in the history
…shift

((X >> C) - Y) + Z --> (Z - Y) + (X >> C)

Fix AArch part: #55714

Reviewed By: dmgreen

Differential Revision: https://reviews.llvm.org/D136158
  • Loading branch information
bcl5980 committed Oct 31, 2022
1 parent 72e9447 commit 325a308
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 9 deletions.
43 changes: 43 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -16854,6 +16854,44 @@ static SDValue performBuildVectorCombine(SDNode *N,
return SDValue();
}

// ((X >> C) - Y) + Z --> (Z - Y) + (X >> C)
static SDValue performAddCombineSubShift(SDNode *N, SDValue SUB, SDValue Z,
SelectionDAG &DAG) {
auto IsOneUseShiftC = [&](SDValue Shift) {
if (!Shift.hasOneUse())
return false;

// TODO: support SRL and SRA also
if (Shift.getOpcode() != ISD::SHL)
return false;

if (!isa<ConstantSDNode>(Shift.getOperand(1)))
return false;
return true;
};

// DAGCombiner will revert the combination when Z is constant cause
// dead loop. So don't enable the combination when Z is constant.
// If Z is one use shift C, we also can't do the optimization.
// It will falling to self infinite loop.
if (isa<ConstantSDNode>(Z) || IsOneUseShiftC(Z))
return SDValue();

if (SUB.getOpcode() != ISD::SUB || !SUB.hasOneUse())
return SDValue();

SDValue Shift = SUB.getOperand(0);
if (!IsOneUseShiftC(Shift))
return SDValue();

SDLoc DL(N);
EVT VT = N->getValueType(0);

SDValue Y = SUB.getOperand(1);
SDValue NewSub = DAG.getNode(ISD::SUB, DL, VT, Z, Y);
return DAG.getNode(ISD::ADD, DL, VT, NewSub, Shift);
}

static SDValue performAddCombineForShiftedOperands(SDNode *N,
SelectionDAG &DAG) {
// NOTE: Swapping LHS and RHS is not done for SUB, since SUB is not
Expand All @@ -16871,6 +16909,11 @@ static SDValue performAddCombineForShiftedOperands(SDNode *N,
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);

if (SDValue Val = performAddCombineSubShift(N, LHS, RHS, DAG))
return Val;
if (SDValue Val = performAddCombineSubShift(N, RHS, LHS, DAG))
return Val;

uint64_t LHSImm = 0, RHSImm = 0;
// If both operand are shifted by imm and shift amount is not greater than 4
// for one operand, swap LHS and RHS to put operand with smaller shift amount
Expand Down
93 changes: 84 additions & 9 deletions llvm/test/CodeGen/AArch64/addsub.ll
Expand Up @@ -694,40 +694,115 @@ if.end: ; preds = %if.then, %lor.lhs.f
ret i32 undef
}

; ((X >> C) - Y) + Z --> (Z - Y) + (X >> C)
define i32 @commute_subop0(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: commute_subop0:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl w8, w0, #3
; CHECK-NEXT: sub w8, w8, w1
; CHECK-NEXT: add w0, w8, w2
; CHECK-NEXT: sub w8, w2, w1
; CHECK-NEXT: add w0, w8, w0, lsl #3
; CHECK-NEXT: ret
%shl = shl i32 %x, 3
%sub = sub i32 %shl, %y
%add = add i32 %sub, %z
ret i32 %add
}

; ((X << C) - Y) + Z --> (Z - Y) + (X << C)
define i32 @commute_subop0_lshr(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: commute_subop0_lshr:
; CHECK: // %bb.0:
; CHECK-NEXT: lsr w8, w0, #3
; CHECK-NEXT: sub w8, w8, w1
; CHECK-NEXT: add w0, w8, w2
; CHECK-NEXT: ret
%lshr = lshr i32 %x, 3
%sub = sub i32 %lshr, %y
%add = add i32 %sub, %z
ret i32 %add
}

; ((X << C) - Y) + Z --> (Z - Y) + (X << C)
define i32 @commute_subop0_ashr(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: commute_subop0_ashr:
; CHECK: // %bb.0:
; CHECK-NEXT: asr w8, w0, #3
; CHECK-NEXT: sub w8, w8, w1
; CHECK-NEXT: add w0, w8, w2
; CHECK-NEXT: ret
%ashr = ashr i32 %x, 3
%sub = sub i32 %ashr, %y
%add = add i32 %sub, %z
ret i32 %add
}

; Z + ((X >> C) - Y) --> (Z - Y) + (X >> C)
define i32 @commute_subop0_cadd(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: commute_subop0_cadd:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl w8, w0, #3
; CHECK-NEXT: sub w8, w8, w1
; CHECK-NEXT: add w0, w2, w8
; CHECK-NEXT: sub w8, w2, w1
; CHECK-NEXT: add w0, w8, w0, lsl #3
; CHECK-NEXT: ret
%shl = shl i32 %x, 3
%sub = sub i32 %shl, %y
%add = add i32 %z, %sub
ret i32 %add
}

; Y + ((X >> C) - X) --> (Y - X) + (X >> C)
define i32 @commute_subop0_mul(i32 %x, i32 %y) {
; CHECK-LABEL: commute_subop0_mul:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl w8, w0, #3
; CHECK-NEXT: sub w8, w8, w0
; CHECK-NEXT: add w0, w8, w1
; CHECK-NEXT: sub w8, w1, w0
; CHECK-NEXT: add w0, w8, w0, lsl #3
; CHECK-NEXT: ret
%mul = mul i32 %x, 7
%add = add i32 %mul, %y
ret i32 %add
}

; negative case for ((X >> C) - Y) + Z --> (Z - Y) + (X >> C)
; Y can't be constant to avoid dead loop
define i32 @commute_subop0_zconst(i32 %x, i32 %y) {
; CHECK-LABEL: commute_subop0_zconst:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl w8, w0, #3
; CHECK-NEXT: sub w8, w8, w1
; CHECK-NEXT: add w0, w8, #1
; CHECK-NEXT: ret
%shl = shl i32 %x, 3
%sub = sub i32 %shl, %y
%add = add i32 %sub, 1
ret i32 %add
}

; negative case for ((X >> C) - Y) + Z --> (Z - Y) + (X >> C)
; Y can't be shift C also to avoid dead loop
define i32 @commute_subop0_zshiftc_oneuse(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: commute_subop0_zshiftc_oneuse:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl w8, w0, #3
; CHECK-NEXT: sub w8, w8, w1
; CHECK-NEXT: add w0, w8, w2, lsl #2
; CHECK-NEXT: ret
%xshl = shl i32 %x, 3
%sub = sub i32 %xshl, %y
%zshl = shl i32 %z, 2
%add = add i32 %sub, %zshl
ret i32 %add
}

define i32 @commute_subop0_zshiftc(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: commute_subop0_zshiftc:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl w8, w2, #2
; CHECK-NEXT: sub w9, w8, w1
; CHECK-NEXT: add w9, w9, w0, lsl #3
; CHECK-NEXT: eor w0, w8, w9
; CHECK-NEXT: ret
%xshl = shl i32 %x, 3
%sub = sub i32 %xshl, %y
%zshl = shl i32 %z, 2
%add = add i32 %sub, %zshl
%r = xor i32 %zshl, %add
ret i32 %r
}

0 comments on commit 325a308

Please sign in to comment.