-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[RISCV] Incorporate scalar addends to extend vector multiply accumulate chains #168660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Namely, tests case such as the following: %mul1 = mul %m00, %m01 %mul0 = mul %m10, %m11 %add0 = add %mul0, %constant %add1 = add %add0, %mul1
Previously, the following: %mul0 = mul nsw <8 x i32> %m00, %m01 %mul1 = mul nsw <8 x i32> %m10, %m11 %add0 = add <8 x i32> %mul0, splat (i32 32) %add1 = add <8 x i32> %add0, %mul1 lowered to: vsetivli zero, 8, e32, m2, ta, ma vmul.vv v8, v8, v9 vmacc.vv v8, v11, v10 li a0, 32 vadd.vx v8, v8, a0 After this patch, now lowers to: li a0, 32 vsetivli zero, 8, e32, m2, ta, ma vmv.v.x v12, a0 vmadd.vv v8, v9, v12 vmacc.vv v8, v11, v10
|
@llvm/pr-subscribers-backend-risc-v Author: Ryan Buchner (bababuck) ChangesModeled on 0cc981e from the AArch64 backend. C-code for the example case ( Thanks for the guidance on this @preames. Full diff: https://github.com/llvm/llvm-project/pull/168660.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 921d12757d672..809abbc69ce90 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -25655,3 +25655,17 @@ bool RISCVTargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const {
return VT.getSizeInBits() <= Subtarget.getXLen();
}
+
+bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+ SDValue N1) const {
+ if (!N0.hasOneUse())
+ return false;
+
+ // Avoid reassociating expressions that can be lowered to vector
+ // multiply accumulate (i.e. add (mul x, y), z)
+ if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL &&
+ (N0.getValueType().isVector() && Subtarget.hasStdExtV()))
+ return false;
+
+ return true;
+}
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 5cc427c867cfd..f4b3faefb1e95 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -470,6 +470,11 @@ class RISCVTargetLowering : public TargetLowering {
bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override;
+ /// Control the following reassociation of operands: (op (op x, c1), y) -> (op
+ /// (op x, y), c1) where N0 is (op x, c1) and N1 is y.
+ bool isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+ SDValue N1) const override;
+
/// Match a mask which "spreads" the leading elements of a vector evenly
/// across the result. Factor is the spread amount, and Index is the
/// offset applied.
diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
new file mode 100644
index 0000000000000..d7618d1d2bcf7
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
@@ -0,0 +1,146 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=riscv64 -mattr=+m,+v < %s | FileCheck %s
+
+define i32 @madd_scalar(i32 %m00, i32 %m01, i32 %m10, i32 %m11) nounwind {
+; CHECK-LABEL: madd_scalar:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: mul a0, a0, a1
+; CHECK-NEXT: mul a1, a2, a3
+; CHECK-NEXT: add a0, a0, a1
+; CHECK-NEXT: addiw a0, a0, 32
+; CHECK-NEXT: ret
+entry:
+ %mul0 = mul nsw i32 %m00, %m01
+ %mul1 = mul nsw i32 %m10, %m11
+ %add0 = add i32 %mul0, 32
+ %add1 = add i32 %add0, %mul1
+ ret i32 %add1
+}
+
+define <8 x i32> @vmadd_non_constant(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, <8 x i32> %addend) {
+; CHECK-LABEL: vmadd_non_constant:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmadd.vv v8, v10, v16
+; CHECK-NEXT: vmacc.vv v8, v14, v12
+; CHECK-NEXT: ret
+entry:
+ %mul0 = mul nsw <8 x i32> %m00, %m01
+ %mul1 = mul nsw <8 x i32> %m10, %m11
+ %add0 = add <8 x i32> %mul0, %addend
+ %add1 = add <8 x i32> %add0, %mul1
+ ret <8 x i32> %add1
+}
+
+define <vscale x 1 x i32> @vmadd_vscale_no_chain(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01) {
+; CHECK-LABEL: vmadd_vscale_no_chain:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vmv.v.x v10, a0
+; CHECK-NEXT: vmadd.vv v8, v9, v10
+; CHECK-NEXT: ret
+entry:
+ %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
+ %mul = mul nsw <vscale x 1 x i32> %m00, %m01
+ %add = add <vscale x 1 x i32> %mul, splat (i32 32)
+ ret <vscale x 1 x i32> %add
+}
+
+define <8 x i32> @vmadd_fixed_no_chain(<8 x i32> %m00, <8 x i32> %m01) {
+; CHECK-LABEL: vmadd_fixed_no_chain:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmv.v.x v12, a0
+; CHECK-NEXT: vmadd.vv v8, v10, v12
+; CHECK-NEXT: ret
+entry:
+ %mul = mul nsw <8 x i32> %m00, %m01
+ %add = add <8 x i32> %mul, splat (i32 32)
+ ret <8 x i32> %add
+}
+
+define <vscale x 1 x i32> @vmadd_vscale(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11) {
+; CHECK-LABEL: vmadd_vscale:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vmv.v.x v12, a0
+; CHECK-NEXT: vmadd.vv v8, v9, v12
+; CHECK-NEXT: vmacc.vv v8, v11, v10
+; CHECK-NEXT: ret
+entry:
+ %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
+ %mul0 = mul nsw <vscale x 1 x i32> %m00, %m01
+ %mul1 = mul nsw <vscale x 1 x i32> %m10, %m11
+ %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
+ %add1 = add <vscale x 1 x i32> %add0, %mul1
+ ret <vscale x 1 x i32> %add1
+}
+
+define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) {
+; CHECK-LABEL: vmadd_fixed:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmv.v.x v16, a0
+; CHECK-NEXT: vmadd.vv v8, v10, v16
+; CHECK-NEXT: vmacc.vv v8, v14, v12
+; CHECK-NEXT: ret
+entry:
+ %mul0 = mul nsw <8 x i32> %m00, %m01
+ %mul1 = mul nsw <8 x i32> %m10, %m11
+ %add0 = add <8 x i32> %mul0, splat (i32 32)
+ %add1 = add <8 x i32> %add0, %mul1
+ ret <8 x i32> %add1
+}
+
+define <vscale x 1 x i32> @vmadd_vscale_long(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11,
+; CHECK-LABEL: vmadd_vscale_long:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vmv.v.x v16, a0
+; CHECK-NEXT: vmadd.vv v8, v9, v16
+; CHECK-NEXT: vmacc.vv v8, v11, v10
+; CHECK-NEXT: vmacc.vv v8, v13, v12
+; CHECK-NEXT: vmacc.vv v8, v15, v14
+; CHECK-NEXT: ret
+ <vscale x 1 x i32> %m20, <vscale x 1 x i32> %m21, <vscale x 1 x i32> %m30, <vscale x 1 x i32> %m31) {
+entry:
+ %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
+ %mul0 = mul nsw <vscale x 1 x i32> %m00, %m01
+ %mul1 = mul nsw <vscale x 1 x i32> %m10, %m11
+ %mul2 = mul nsw <vscale x 1 x i32> %m20, %m21
+ %mul3 = mul nsw <vscale x 1 x i32> %m30, %m31
+ %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
+ %add1 = add <vscale x 1 x i32> %add0, %mul1
+ %add2 = add <vscale x 1 x i32> %add1, %mul2
+ %add3 = add <vscale x 1 x i32> %add2, %mul3
+ ret <vscale x 1 x i32> %add3
+}
+
+define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11,
+; CHECK-LABEL: vmadd_fixed_long:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmv.v.x v24, a0
+; CHECK-NEXT: vmadd.vv v8, v10, v24
+; CHECK-NEXT: vmacc.vv v8, v14, v12
+; CHECK-NEXT: vmacc.vv v8, v18, v16
+; CHECK-NEXT: vmacc.vv v8, v22, v20
+; CHECK-NEXT: ret
+ <8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) {
+entry:
+ %mul0 = mul nsw <8 x i32> %m00, %m01
+ %mul1 = mul nsw <8 x i32> %m10, %m11
+ %mul2 = mul nsw <8 x i32> %m20, %m21
+ %mul3 = mul nsw <8 x i32> %m30, %m31
+ %add0 = add <8 x i32> %mul0, splat (i32 32)
+ %add1 = add <8 x i32> %add0, %mul1
+ %add2 = add <8 x i32> %add1, %mul2
+ %add3 = add <8 x i32> %add2, %mul3
+ ret <8 x i32> %add3
+}
|
🐧 Linux x64 Test Results
|
|
Updated:
|
|
Can we use something other than "Properly" in the title? Upon first reading, I thought this was fixing a miscompile. |
topperc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Modeled on 0cc981e from the AArch64 backend.
C-code for the example case (
clang -O3 -S -mcpu=sifive-x280):Thanks for the guidance on this @preames.