From cf98a3c6fed0496b7732e15de087061357b0a709 Mon Sep 17 00:00:00 2001 From: bababuck Date: Tue, 18 Nov 2025 13:10:54 -0800 Subject: [PATCH 1/5] [RISCV] Add test for lowering vector multiply add chains Namely, tests case such as the following: %mul1 = mul %m00, %m01 %mul0 = mul %m10, %m11 %add0 = add %mul0, %constant %add1 = add %add0, %mul1 --- llvm/test/CodeGen/RISCV/vmadd-reassociate.ll | 146 +++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 llvm/test/CodeGen/RISCV/vmadd-reassociate.ll diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll new file mode 100644 index 0000000000000..d161c60c6c7bc --- /dev/null +++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll @@ -0,0 +1,146 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mtriple=riscv64 -mattr=+m,+v < %s | FileCheck %s + +define i32 @madd_scalar(i32 %m00, i32 %m01, i32 %m10, i32 %m11) nounwind { +; CHECK-LABEL: madd_scalar: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: mul a0, a0, a1 +; CHECK-NEXT: mul a1, a2, a3 +; CHECK-NEXT: add a0, a0, a1 +; CHECK-NEXT: addiw a0, a0, 32 +; CHECK-NEXT: ret +entry: + %mul0 = mul nsw i32 %m00, %m01 + %mul1 = mul nsw i32 %m10, %m11 + %add0 = add i32 %mul0, 32 + %add1 = add i32 %add0, %mul1 + ret i32 %add1 +} + +define <8 x i32> @vmadd_non_constant(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, <8 x i32> %addend) { +; CHECK-LABEL: vmadd_non_constant: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmadd.vv v8, v10, v16 +; CHECK-NEXT: vmacc.vv v8, v14, v12 +; CHECK-NEXT: ret +entry: + %mul0 = mul nsw <8 x i32> %m00, %m01 + %mul1 = mul nsw <8 x i32> %m10, %m11 + %add0 = add <8 x i32> %mul0, %addend + %add1 = add <8 x i32> %add0, %mul1 + ret <8 x i32> %add1 +} + +define @vmadd_vscale_no_chain( %m00, %m01) { +; CHECK-LABEL: vmadd_vscale_no_chain: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma +; CHECK-NEXT: vmv.v.x v10, a0 +; CHECK-NEXT: vmadd.vv v8, v9, v10 +; CHECK-NEXT: ret +entry: + %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true) + %mul = mul nsw %m00, %m01 + %add = add %mul, splat (i32 32) + ret %add +} + +define <8 x i32> @vmadd_fixed_no_chain(<8 x i32> %m00, <8 x i32> %m01) { +; CHECK-LABEL: vmadd_fixed_no_chain: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmv.v.x v12, a0 +; CHECK-NEXT: vmadd.vv v8, v10, v12 +; CHECK-NEXT: ret +entry: + %mul = mul nsw <8 x i32> %m00, %m01 + %add = add <8 x i32> %mul, splat (i32 32) + ret <8 x i32> %add +} + +define @vmadd_vscale( %m00, %m01, %m10, %m11) { +; CHECK-LABEL: vmadd_vscale: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma +; CHECK-NEXT: vmul.vv v8, v8, v9 +; CHECK-NEXT: vmacc.vv v8, v11, v10 +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vadd.vx v8, v8, a0 +; CHECK-NEXT: ret +entry: + %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true) + %mul0 = mul nsw %m00, %m01 + %mul1 = mul nsw %m10, %m11 + %add0 = add %mul0, splat (i32 32) + %add1 = add %add0, %mul1 + ret %add1 +} + +define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) { +; CHECK-LABEL: vmadd_fixed: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmul.vv v8, v8, v10 +; CHECK-NEXT: vmacc.vv v8, v14, v12 +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vadd.vx v8, v8, a0 +; CHECK-NEXT: ret +entry: + %mul0 = mul nsw <8 x i32> %m00, %m01 + %mul1 = mul nsw <8 x i32> %m10, %m11 + %add0 = add <8 x i32> %mul0, splat (i32 32) + %add1 = add <8 x i32> %add0, %mul1 + ret <8 x i32> %add1 +} + +define @vmadd_vscale_long( %m00, %m01, %m10, %m11, +; CHECK-LABEL: vmadd_vscale_long: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma +; CHECK-NEXT: vmul.vv v8, v8, v9 +; CHECK-NEXT: vmacc.vv v8, v11, v10 +; CHECK-NEXT: vmacc.vv v8, v13, v12 +; CHECK-NEXT: vmacc.vv v8, v15, v14 +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vadd.vx v8, v8, a0 +; CHECK-NEXT: ret + %m20, %m21, %m30, %m31) { +entry: + %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true) + %mul0 = mul nsw %m00, %m01 + %mul1 = mul nsw %m10, %m11 + %mul2 = mul nsw %m20, %m21 + %mul3 = mul nsw %m30, %m31 + %add0 = add %mul0, splat (i32 32) + %add1 = add %add0, %mul1 + %add2 = add %add1, %mul2 + %add3 = add %add2, %mul3 + ret %add3 +} + +define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, +; CHECK-LABEL: vmadd_fixed_long: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmul.vv v8, v8, v10 +; CHECK-NEXT: vmacc.vv v8, v14, v12 +; CHECK-NEXT: vmacc.vv v8, v18, v16 +; CHECK-NEXT: vmacc.vv v8, v22, v20 +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vadd.vx v8, v8, a0 +; CHECK-NEXT: ret + <8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) { +entry: + %mul0 = mul nsw <8 x i32> %m00, %m01 + %mul1 = mul nsw <8 x i32> %m10, %m11 + %mul2 = mul nsw <8 x i32> %m20, %m21 + %mul3 = mul nsw <8 x i32> %m30, %m31 + %add0 = add <8 x i32> %mul0, splat (i32 32) + %add1 = add <8 x i32> %add0, %mul1 + %add2 = add <8 x i32> %add1, %mul2 + %add3 = add <8 x i32> %add2, %mul3 + ret <8 x i32> %add3 +} From 4d5aa5dd815a98d6fe2ec503b64c4e14d3d771bb Mon Sep 17 00:00:00 2001 From: bababuck Date: Tue, 18 Nov 2025 10:32:48 -0800 Subject: [PATCH 2/5] [RISCV] Properly lower multiply-accumulate chains containing a constant Previously, the following: %mul0 = mul nsw <8 x i32> %m00, %m01 %mul1 = mul nsw <8 x i32> %m10, %m11 %add0 = add <8 x i32> %mul0, splat (i32 32) %add1 = add <8 x i32> %add0, %mul1 lowered to: vsetivli zero, 8, e32, m2, ta, ma vmul.vv v8, v8, v9 vmacc.vv v8, v11, v10 li a0, 32 vadd.vx v8, v8, a0 After this patch, now lowers to: li a0, 32 vsetivli zero, 8, e32, m2, ta, ma vmv.v.x v12, a0 vmadd.vv v8, v9, v12 vmacc.vv v8, v11, v10 --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 14 ++++++++++ llvm/lib/Target/RISCV/RISCVISelLowering.h | 5 ++++ llvm/test/CodeGen/RISCV/vmadd-reassociate.ll | 28 ++++++++++---------- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 921d12757d672..809abbc69ce90 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -25655,3 +25655,17 @@ bool RISCVTargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const { return VT.getSizeInBits() <= Subtarget.getXLen(); } + +bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0, + SDValue N1) const { + if (!N0.hasOneUse()) + return false; + + // Avoid reassociating expressions that can be lowered to vector + // multiply accumulate (i.e. add (mul x, y), z) + if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL && + (N0.getValueType().isVector() && Subtarget.hasStdExtV())) + return false; + + return true; +} diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 5cc427c867cfd..f4b3faefb1e95 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -470,6 +470,11 @@ class RISCVTargetLowering : public TargetLowering { bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override; + /// Control the following reassociation of operands: (op (op x, c1), y) -> (op + /// (op x, y), c1) where N0 is (op x, c1) and N1 is y. + bool isReassocProfitable(SelectionDAG &DAG, SDValue N0, + SDValue N1) const override; + /// Match a mask which "spreads" the leading elements of a vector evenly /// across the result. Factor is the spread amount, and Index is the /// offset applied. diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll index d161c60c6c7bc..d7618d1d2bcf7 100644 --- a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll +++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll @@ -64,11 +64,11 @@ entry: define @vmadd_vscale( %m00, %m01, %m10, %m11) { ; CHECK-LABEL: vmadd_vscale: ; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma -; CHECK-NEXT: vmul.vv v8, v8, v9 -; CHECK-NEXT: vmacc.vv v8, v11, v10 ; CHECK-NEXT: li a0, 32 -; CHECK-NEXT: vadd.vx v8, v8, a0 +; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma +; CHECK-NEXT: vmv.v.x v12, a0 +; CHECK-NEXT: vmadd.vv v8, v9, v12 +; CHECK-NEXT: vmacc.vv v8, v11, v10 ; CHECK-NEXT: ret entry: %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true) @@ -82,11 +82,11 @@ entry: define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) { ; CHECK-LABEL: vmadd_fixed: ; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a0, 32 ; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma -; CHECK-NEXT: vmul.vv v8, v8, v10 +; CHECK-NEXT: vmv.v.x v16, a0 +; CHECK-NEXT: vmadd.vv v8, v10, v16 ; CHECK-NEXT: vmacc.vv v8, v14, v12 -; CHECK-NEXT: li a0, 32 -; CHECK-NEXT: vadd.vx v8, v8, a0 ; CHECK-NEXT: ret entry: %mul0 = mul nsw <8 x i32> %m00, %m01 @@ -99,13 +99,13 @@ entry: define @vmadd_vscale_long( %m00, %m01, %m10, %m11, ; CHECK-LABEL: vmadd_vscale_long: ; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma -; CHECK-NEXT: vmul.vv v8, v8, v9 +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma +; CHECK-NEXT: vmv.v.x v16, a0 +; CHECK-NEXT: vmadd.vv v8, v9, v16 ; CHECK-NEXT: vmacc.vv v8, v11, v10 ; CHECK-NEXT: vmacc.vv v8, v13, v12 ; CHECK-NEXT: vmacc.vv v8, v15, v14 -; CHECK-NEXT: li a0, 32 -; CHECK-NEXT: vadd.vx v8, v8, a0 ; CHECK-NEXT: ret %m20, %m21, %m30, %m31) { entry: @@ -124,13 +124,13 @@ entry: define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, ; CHECK-LABEL: vmadd_fixed_long: ; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a0, 32 ; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma -; CHECK-NEXT: vmul.vv v8, v8, v10 +; CHECK-NEXT: vmv.v.x v24, a0 +; CHECK-NEXT: vmadd.vv v8, v10, v24 ; CHECK-NEXT: vmacc.vv v8, v14, v12 ; CHECK-NEXT: vmacc.vv v8, v18, v16 ; CHECK-NEXT: vmacc.vv v8, v22, v20 -; CHECK-NEXT: li a0, 32 -; CHECK-NEXT: vadd.vx v8, v8, a0 ; CHECK-NEXT: ret <8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) { entry: From e6218f2d5d00c1e70901d00d5e7e8ea4ea8b1d0c Mon Sep 17 00:00:00 2001 From: bababuck Date: Tue, 18 Nov 2025 22:37:36 -0800 Subject: [PATCH 3/5] Check hasVInstructions() rather than hasStdExtV --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 809abbc69ce90..e3f9e24555dae 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -25664,7 +25664,7 @@ bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0, // Avoid reassociating expressions that can be lowered to vector // multiply accumulate (i.e. add (mul x, y), z) if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL && - (N0.getValueType().isVector() && Subtarget.hasStdExtV())) + (N0.getValueType().isVector() && Subtarget.hasVInstructions())) return false; return true; From 5f3502f325861da4114bcf4abd23e2d5a19429e2 Mon Sep 17 00:00:00 2001 From: bababuck Date: Tue, 18 Nov 2025 22:43:53 -0800 Subject: [PATCH 4/5] Remove dead instructions from test --- llvm/test/CodeGen/RISCV/vmadd-reassociate.ll | 3 --- 1 file changed, 3 deletions(-) diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll index d7618d1d2bcf7..e2bcd5c08efd2 100644 --- a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll +++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll @@ -41,7 +41,6 @@ define @vmadd_vscale_no_chain( %m00, %m00, %m01 %add = add %mul, splat (i32 32) ret %add @@ -71,7 +70,6 @@ define @vmadd_vscale( %m00, %m00, %m01 %mul1 = mul nsw %m10, %m11 %add0 = add %mul0, splat (i32 32) @@ -109,7 +107,6 @@ define @vmadd_vscale_long( %m00, %m20, %m21, %m30, %m31) { entry: - %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true) %mul0 = mul nsw %m00, %m01 %mul1 = mul nsw %m10, %m11 %mul2 = mul nsw %m20, %m21 From 7e9bc66ab1eefa426240d2ae443bf3f521d2b18c Mon Sep 17 00:00:00 2001 From: bababuck Date: Thu, 20 Nov 2025 09:50:00 -0800 Subject: [PATCH 5/5] Remove un-needed NSW flags from multiplies in test --- llvm/test/CodeGen/RISCV/vmadd-reassociate.ll | 36 ++++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll index e2bcd5c08efd2..9fa0cec0ea339 100644 --- a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll +++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll @@ -10,8 +10,8 @@ define i32 @madd_scalar(i32 %m00, i32 %m01, i32 %m10, i32 %m11) nounwind { ; CHECK-NEXT: addiw a0, a0, 32 ; CHECK-NEXT: ret entry: - %mul0 = mul nsw i32 %m00, %m01 - %mul1 = mul nsw i32 %m10, %m11 + %mul0 = mul i32 %m00, %m01 + %mul1 = mul i32 %m10, %m11 %add0 = add i32 %mul0, 32 %add1 = add i32 %add0, %mul1 ret i32 %add1 @@ -25,8 +25,8 @@ define <8 x i32> @vmadd_non_constant(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> % ; CHECK-NEXT: vmacc.vv v8, v14, v12 ; CHECK-NEXT: ret entry: - %mul0 = mul nsw <8 x i32> %m00, %m01 - %mul1 = mul nsw <8 x i32> %m10, %m11 + %mul0 = mul <8 x i32> %m00, %m01 + %mul1 = mul <8 x i32> %m10, %m11 %add0 = add <8 x i32> %mul0, %addend %add1 = add <8 x i32> %add0, %mul1 ret <8 x i32> %add1 @@ -41,7 +41,7 @@ define @vmadd_vscale_no_chain( %m00, %m00, %m01 + %mul = mul %m00, %m01 %add = add %mul, splat (i32 32) ret %add } @@ -55,7 +55,7 @@ define <8 x i32> @vmadd_fixed_no_chain(<8 x i32> %m00, <8 x i32> %m01) { ; CHECK-NEXT: vmadd.vv v8, v10, v12 ; CHECK-NEXT: ret entry: - %mul = mul nsw <8 x i32> %m00, %m01 + %mul = mul <8 x i32> %m00, %m01 %add = add <8 x i32> %mul, splat (i32 32) ret <8 x i32> %add } @@ -70,8 +70,8 @@ define @vmadd_vscale( %m00, %m00, %m01 - %mul1 = mul nsw %m10, %m11 + %mul0 = mul %m00, %m01 + %mul1 = mul %m10, %m11 %add0 = add %mul0, splat (i32 32) %add1 = add %add0, %mul1 ret %add1 @@ -87,8 +87,8 @@ define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 ; CHECK-NEXT: vmacc.vv v8, v14, v12 ; CHECK-NEXT: ret entry: - %mul0 = mul nsw <8 x i32> %m00, %m01 - %mul1 = mul nsw <8 x i32> %m10, %m11 + %mul0 = mul <8 x i32> %m00, %m01 + %mul1 = mul <8 x i32> %m10, %m11 %add0 = add <8 x i32> %mul0, splat (i32 32) %add1 = add <8 x i32> %add0, %mul1 ret <8 x i32> %add1 @@ -107,10 +107,10 @@ define @vmadd_vscale_long( %m00, %m20, %m21, %m30, %m31) { entry: - %mul0 = mul nsw %m00, %m01 - %mul1 = mul nsw %m10, %m11 - %mul2 = mul nsw %m20, %m21 - %mul3 = mul nsw %m30, %m31 + %mul0 = mul %m00, %m01 + %mul1 = mul %m10, %m11 + %mul2 = mul %m20, %m21 + %mul3 = mul %m30, %m31 %add0 = add %mul0, splat (i32 32) %add1 = add %add0, %mul1 %add2 = add %add1, %mul2 @@ -131,10 +131,10 @@ define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m1 ; CHECK-NEXT: ret <8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) { entry: - %mul0 = mul nsw <8 x i32> %m00, %m01 - %mul1 = mul nsw <8 x i32> %m10, %m11 - %mul2 = mul nsw <8 x i32> %m20, %m21 - %mul3 = mul nsw <8 x i32> %m30, %m31 + %mul0 = mul <8 x i32> %m00, %m01 + %mul1 = mul <8 x i32> %m10, %m11 + %mul2 = mul <8 x i32> %m20, %m21 + %mul3 = mul <8 x i32> %m30, %m31 %add0 = add <8 x i32> %mul0, splat (i32 32) %add1 = add <8 x i32> %add0, %mul1 %add2 = add <8 x i32> %add1, %mul2