Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][SelectionDAG] Lower multiplication by a constant to shl+add+shl+add #89532

Merged
merged 2 commits into from
Apr 25, 2024

Conversation

vfdff
Copy link
Contributor

@vfdff vfdff commented Apr 21, 2024

Change the costmodel to lower a = b * C where C = (1 + 2^m) * 2^n + 1 to
add w8, w0, w0, lsl #m
add w0, w0, w8, lsl #n
Note: The latency can vary depending on the shirt amount
Fix part of #89430

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 21, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Allen (vfdff)

Changes

…+shl+add

Change the costmodel to lower a = b * C where C = (1 + 2^m) * 2^n + 1 to
add w8, w0, w0, lsl #m
add w0, w0, w8, lsl #n
Note: The latency can vary depending on the shirt amount
Fix part of #89430


Full diff: https://github.com/llvm/llvm-project/pull/89532.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+23)
  • (modified) llvm/test/CodeGen/AArch64/mul_pow2.ll (+19-2)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 3d1453e3beb9a1..e4d552dcf4f0f1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17602,12 +17602,31 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
     return false;
   };
 
+  // Can the const C be decomposed into (2^M + 1) * 2^N + 1), eg:
+  // C = 11 is equal to (1+4)*2+1, we don't decompose it into (1+2)*4-1 as
+  // the (2^N - 1) can't be execused via a single instruction.
+  auto isPowPlusPlusOneConst = [](APInt C, APInt &M, APInt &N) {
+    APInt CVMinus1 = C - 1;
+    if (CVMinus1.isNegative())
+      return false;
+    unsigned TrailingZeroes = CVMinus1.countr_zero();
+    APInt SCVMinus1 = CVMinus1.ashr(TrailingZeroes) - 1;
+    if (SCVMinus1.isPowerOf2()) {
+      M = SCVMinus1.logBase2();
+      N = TrailingZeroes;
+      return true;
+    }
+    return false;
+  };
+
   if (ConstValue.isNonNegative()) {
     // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
     // (mul x, 2^N - 1) => (sub (shl x, N), x)
     // (mul x, (2^(N-M) - 1) * 2^M) => (sub (shl x, N), (shl x, M))
     // (mul x, (2^M + 1) * (2^N + 1))
     //     => MV = (add (shl x, M), x); (add (shl MV, N), MV)
+    // (mul x, (2^M + 1) * 2^N + 1))
+    //     =>  MV = add (shl x, M), x); add (shl MV, N), x)
     APInt SCVMinus1 = ShiftedConstValue - 1;
     APInt SCVPlus1 = ShiftedConstValue + 1;
     APInt CVPlus1 = ConstValue + 1;
@@ -17632,6 +17651,10 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
         SDValue MVal = Add(Shl(N0, ShiftM1), N0);
         return Add(Shl(MVal, ShiftN1), MVal);
       }
+    } else if (Subtarget->hasALULSLFast() &&
+               isPowPlusPlusOneConst(ConstValue, CVM, CVN)) {
+      SDValue MVal = Add(Shl(N0, CVM.getZExtValue()), N0);
+      return Add(Shl(MVal, CVN.getZExtValue()), N0);
     }
   } else {
     // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
diff --git a/llvm/test/CodeGen/AArch64/mul_pow2.ll b/llvm/test/CodeGen/AArch64/mul_pow2.ll
index 90e560af4465a9..6f49f38bf41a5c 100644
--- a/llvm/test/CodeGen/AArch64/mul_pow2.ll
+++ b/llvm/test/CodeGen/AArch64/mul_pow2.ll
@@ -410,6 +410,23 @@ define i32 @test11(i32 %x) {
   ret i32 %mul
 }
 
+define i32 @test11_fast_shift(i32 %x) "target-features"="+alu-lsl-fast" {
+; CHECK-LABEL: test11_fast_shift:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add w8, w0, w0
+; CHECK-NEXT:    add w0, w0, w8, lsl #1
+; CHECK-NEXT:    ret
+;
+; GISEL-LABEL: test11_fast_shift:
+; GISEL:       // %bb.0:
+; GISEL-NEXT:    mov w8, #11 // =0xb
+; GISEL-NEXT:    mul w0, w0, w8
+; GISEL-NEXT:    ret
+
+  %mul = mul nsw i32 %x, 11
+  ret i32 %mul
+}
+
 define i32 @test12(i32 %x) {
 ; CHECK-LABEL: test12:
 ; CHECK:       // %bb.0:
@@ -858,9 +875,9 @@ define <4 x i32> @muladd_demand_commute(<4 x i32> %x, <4 x i32> %y) {
 ;
 ; GISEL-LABEL: muladd_demand_commute:
 ; GISEL:       // %bb.0:
-; GISEL-NEXT:    adrp x8, .LCPI49_0
+; GISEL-NEXT:    adrp x8, .LCPI50_0
 ; GISEL-NEXT:    movi v3.4s, #1, msl #16
-; GISEL-NEXT:    ldr q2, [x8, :lo12:.LCPI49_0]
+; GISEL-NEXT:    ldr q2, [x8, :lo12:.LCPI50_0]
 ; GISEL-NEXT:    mla v1.4s, v0.4s, v2.4s
 ; GISEL-NEXT:    and v0.16b, v1.16b, v3.16b
 ; GISEL-NEXT:    ret

@vfdff vfdff changed the title [AArch64][SelectionDAG] Lower multiplication by a constant to shl+add… [AArch64][SelectionDAG] Lower multiplication by a constant to shl+add+shl+add Apr 21, 2024
@@ -17632,6 +17651,10 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
SDValue MVal = Add(Shl(N0, ShiftM1), N0);
return Add(Shl(MVal, ShiftN1), MVal);
}
} else if (Subtarget->hasALULSLFast() &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hasALULSLFast() specifically refers to shift amounts of 4 or less.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks , add the checking for shift amounts

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the the limit for ALULSLFast supposed to be 4? Not sure where 3 came from; maybe some confusion with the old LSLFast for loads.

I'm a little concerned the structure of the nested if statements could miss some cases... is it possible for a number to pass both isPowPlusPlusConst and isPowPlusPlusOneConst? Maybe we can restructure the code to just be a series of if statements with early returns, instead of if/else if/etc.?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apply your comment, thanks

define i32 @test11_fast_shift(i32 %x) "target-features"="+alu-lsl-fast" {
; CHECK-LABEL: test11_fast_shift:
; CHECK: // %bb.0:
; CHECK-NEXT: add w8, w0, w0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing shift amount?

Could probably use a couple more tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for missing the shift amount, fixed.
Also add 2 negative cases whose shift amount are out of bound.

@@ -17632,6 +17651,10 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
SDValue MVal = Add(Shl(N0, ShiftM1), N0);
return Add(Shl(MVal, ShiftN1), MVal);
}
} else if (Subtarget->hasALULSLFast() &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the the limit for ALULSLFast supposed to be 4? Not sure where 3 came from; maybe some confusion with the old LSLFast for loads.

I'm a little concerned the structure of the nested if statements could miss some cases... is it possible for a number to pass both isPowPlusPlusConst and isPowPlusPlusOneConst? Maybe we can restructure the code to just be a series of if statements with early returns, instead of if/else if/etc.?

@@ -510,6 +527,24 @@ define i32 @test25_fast_shift(i32 %x) "target-features"="+alu-lsl-fast" {
ret i32 %mul
}

; Negative: 35 = (((1<<4) + 1) << 1) + 1, the shift number 4 is out of bound
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that it's really relevant for this specific case given my other comment, but you can decompose this to x*7*5.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the (2^N - 1) can't be execused via a single instruction, so we need 2 instrunction to support the 7 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the tests with shift amount bound 4, thanks

@davemgreen
Copy link
Collaborator

As I said in the ticket, do we have any evidence that these are better as add+shift? As far as I understand GCC optimized it this way because older cpus had slower mul and faster add+lsl, but that has changed in more recent cores and mul is now usually relatively quick.

@vfdff
Copy link
Contributor Author

vfdff commented Apr 24, 2024

Thank you for your reminder.
Yes, the transformation is not need for recent cores. But for some old cores, such as tsv110, it will be benifit from this conversion.
So it seems reasonable when this conversion associated with FeatureALULSLFast.

  • mul: the latency is 3~4, the throughout is 1 depend the register width
  • adds: the latency is 1, the throughout is 2 depend the register width

https://github.com/llvm/llvm-project/blob/b8e97f0768f2b537c45930f56f4027a4c0a07f24/llvm/test/tools/llvm-mca/AArch64/HiSilicon/tsv110-basic-instructions.s#L1830C10-L1830C71

@davemgreen
Copy link
Collaborator

Thanks for the info, that sounds good then. I think for newer Arm cores the cost model between the two is going to be pretty close. One or the other might be better in different places.

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

In terms of general cost modeling, these are clearly equivalent to existing patterns: two cheap ALU ops. If we wanted to disable all 2-instruction patterns for some targets, we could maybe consider that, but it would be a different discussion.

Accord D152827, when the shift amounts is 4 or less, they are
cheap as a move.
…+shl+add

Change the costmodel to lower a = b * C where C = (1 + 2^m) * 2^n + 1 to
          add   w8, w0, w0, lsl #m
          add   w0, w0, w8, lsl #n
Note: The latency of add can vary depending on the shirt amount
      They are cheap as a move when the shift amounts is 4 or less.
Fix part of llvm#89430
@vfdff vfdff merged commit a6bdd6d into llvm:main Apr 25, 2024
2 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants