Skip to content

Conversation

@bababuck
Copy link
Contributor

Previously, the following:
  %mul0 = mul nsw <8 x i32> %m00, %m01
  %mul1 = mul nsw <8 x i32> %m10, %m11
  %add0 = add <8 x i32> %mul0, splat (i32 32)
  %add1 = add <8 x i32> %add0, %mul1

lowered to:
  vsetivli zero, 8, e32, m2, ta, ma
  vmul.vv v8, v8, v9
  vmacc.vv v8, v11, v10
  li a0, 32
  vadd.vx v8, v8, a0

After this patch, now lowers to:
  li a0, 32
  vsetivli zero, 8, e32, m2, ta, ma
  vmv.v.x v12, a0
  vmadd.vv v8, v9, v12
  vmacc.vv v8, v11, v10

Modeled on 0cc981e from the AArch64 backend.

C-code for the example case (clang -O3 -S -mcpu=sifive-x280):

int madd_fail(int a, int b, int * restrict src, int * restrict dst, int loop_bound) {
  for (int i = 0; i < loop_bound; i += 2) {
    dst[i] = src[i] * a + src[i + 1] * b + 32;
  }
}

Thanks for the guidance on this @preames.

Namely, tests case such as the following:
%mul1 = mul %m00, %m01
%mul0 = mul %m10, %m11
%add0 = add %mul0, %constant
%add1 = add %add0, %mul1
Previously, the following:
  %mul0 = mul nsw <8 x i32> %m00, %m01
  %mul1 = mul nsw <8 x i32> %m10, %m11
  %add0 = add <8 x i32> %mul0, splat (i32 32)
  %add1 = add <8 x i32> %add0, %mul1

lowered to:
  vsetivli zero, 8, e32, m2, ta, ma
  vmul.vv v8, v8, v9
  vmacc.vv v8, v11, v10
  li a0, 32
  vadd.vx v8, v8, a0

After this patch, now lowers to:
  li a0, 32
  vsetivli zero, 8, e32, m2, ta, ma
  vmv.v.x v12, a0
  vmadd.vv v8, v9, v12
  vmacc.vv v8, v11, v10
@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Ryan Buchner (bababuck)

Changes
Previously, the following:
  %mul0 = mul nsw &lt;8 x i32&gt; %m00, %m01
  %mul1 = mul nsw &lt;8 x i32&gt; %m10, %m11
  %add0 = add &lt;8 x i32&gt; %mul0, splat (i32 32)
  %add1 = add &lt;8 x i32&gt; %add0, %mul1

lowered to:
  vsetivli zero, 8, e32, m2, ta, ma
  vmul.vv v8, v8, v9
  vmacc.vv v8, v11, v10
  li a0, 32
  vadd.vx v8, v8, a0

After this patch, now lowers to:
  li a0, 32
  vsetivli zero, 8, e32, m2, ta, ma
  vmv.v.x v12, a0
  vmadd.vv v8, v9, v12
  vmacc.vv v8, v11, v10

Modeled on 0cc981e from the AArch64 backend.

C-code for the example case (clang -O3 -S -mcpu=sifive-x280):

int madd_fail(int a, int b, int * restrict src, int * restrict dst, int loop_bound) {
  for (int i = 0; i &lt; loop_bound; i += 2) {
    dst[i] = src[i] * a + src[i + 1] * b + 32;
  }
}

Thanks for the guidance on this @preames.


Full diff: https://github.com/llvm/llvm-project/pull/168660.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+14)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+5)
  • (added) llvm/test/CodeGen/RISCV/vmadd-reassociate.ll (+146)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 921d12757d672..809abbc69ce90 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -25655,3 +25655,17 @@ bool RISCVTargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const {
 
   return VT.getSizeInBits() <= Subtarget.getXLen();
 }
+
+bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+                                              SDValue N1) const {
+  if (!N0.hasOneUse())
+    return false;
+
+  // Avoid reassociating expressions that can be lowered to vector
+  // multiply accumulate (i.e. add (mul x, y), z)
+  if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL &&
+      (N0.getValueType().isVector() && Subtarget.hasStdExtV()))
+    return false;
+
+  return true;
+}
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 5cc427c867cfd..f4b3faefb1e95 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -470,6 +470,11 @@ class RISCVTargetLowering : public TargetLowering {
 
   bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override;
 
+  /// Control the following reassociation of operands: (op (op x, c1), y) -> (op
+  /// (op x, y), c1) where N0 is (op x, c1) and N1 is y.
+  bool isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+                           SDValue N1) const override;
+
   /// Match a mask which "spreads" the leading elements of a vector evenly
   /// across the result.  Factor is the spread amount, and Index is the
   /// offset applied.
diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
new file mode 100644
index 0000000000000..d7618d1d2bcf7
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
@@ -0,0 +1,146 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=riscv64 -mattr=+m,+v < %s | FileCheck %s
+
+define i32 @madd_scalar(i32 %m00, i32 %m01, i32 %m10, i32 %m11) nounwind {
+; CHECK-LABEL: madd_scalar:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    mul a0, a0, a1
+; CHECK-NEXT:    mul a1, a2, a3
+; CHECK-NEXT:    add a0, a0, a1
+; CHECK-NEXT:    addiw a0, a0, 32
+; CHECK-NEXT:    ret
+entry:
+  %mul0 = mul nsw i32 %m00, %m01
+  %mul1 = mul nsw i32 %m10, %m11
+  %add0 = add i32 %mul0, 32
+  %add1 = add i32 %add0, %mul1
+  ret i32 %add1
+}
+
+define <8 x i32> @vmadd_non_constant(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, <8 x i32> %addend) {
+; CHECK-LABEL: vmadd_non_constant:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmadd.vv v8, v10, v16
+; CHECK-NEXT:    vmacc.vv v8, v14, v12
+; CHECK-NEXT:    ret
+entry:
+  %mul0 = mul nsw <8 x i32> %m00, %m01
+  %mul1 = mul nsw <8 x i32> %m10, %m11
+  %add0 = add <8 x i32> %mul0, %addend
+  %add1 = add <8 x i32> %add0, %mul1
+  ret <8 x i32> %add1
+}
+
+define <vscale x 1 x i32> @vmadd_vscale_no_chain(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01) {
+; CHECK-LABEL: vmadd_vscale_no_chain:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.x v10, a0
+; CHECK-NEXT:    vmadd.vv v8, v9, v10
+; CHECK-NEXT:    ret
+entry:
+  %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
+  %mul = mul nsw <vscale x 1 x i32> %m00, %m01
+  %add = add <vscale x 1 x i32> %mul, splat (i32 32)
+  ret <vscale x 1 x i32> %add
+}
+
+define <8 x i32> @vmadd_fixed_no_chain(<8 x i32> %m00, <8 x i32> %m01) {
+; CHECK-LABEL: vmadd_fixed_no_chain:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmv.v.x v12, a0
+; CHECK-NEXT:    vmadd.vv v8, v10, v12
+; CHECK-NEXT:    ret
+entry:
+  %mul = mul nsw <8 x i32> %m00, %m01
+  %add = add <8 x i32> %mul, splat (i32 32)
+  ret <8 x i32> %add
+}
+
+define <vscale x 1 x i32> @vmadd_vscale(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11) {
+; CHECK-LABEL: vmadd_vscale:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.x v12, a0
+; CHECK-NEXT:    vmadd.vv v8, v9, v12
+; CHECK-NEXT:    vmacc.vv v8, v11, v10
+; CHECK-NEXT:    ret
+entry:
+  %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
+  %mul0 = mul nsw <vscale x 1 x i32> %m00, %m01
+  %mul1 = mul nsw <vscale x 1 x i32> %m10, %m11
+  %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
+  %add1 = add <vscale x 1 x i32> %add0, %mul1
+  ret <vscale x 1 x i32> %add1
+}
+
+define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) {
+; CHECK-LABEL: vmadd_fixed:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmv.v.x v16, a0
+; CHECK-NEXT:    vmadd.vv v8, v10, v16
+; CHECK-NEXT:    vmacc.vv v8, v14, v12
+; CHECK-NEXT:    ret
+entry:
+  %mul0 = mul nsw <8 x i32> %m00, %m01
+  %mul1 = mul nsw <8 x i32> %m10, %m11
+  %add0 = add <8 x i32> %mul0, splat (i32 32)
+  %add1 = add <8 x i32> %add0, %mul1
+  ret <8 x i32> %add1
+}
+
+define <vscale x 1 x i32> @vmadd_vscale_long(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11,
+; CHECK-LABEL: vmadd_vscale_long:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.x v16, a0
+; CHECK-NEXT:    vmadd.vv v8, v9, v16
+; CHECK-NEXT:    vmacc.vv v8, v11, v10
+; CHECK-NEXT:    vmacc.vv v8, v13, v12
+; CHECK-NEXT:    vmacc.vv v8, v15, v14
+; CHECK-NEXT:    ret
+                                             <vscale x 1 x i32> %m20, <vscale x 1 x i32> %m21, <vscale x 1 x i32> %m30, <vscale x 1 x i32> %m31) {
+entry:
+  %vset = tail call i32 @llvm.experimental.get.vector.length.i64(i64 8, i32 1, i1 true)
+  %mul0 = mul nsw <vscale x 1 x i32> %m00, %m01
+  %mul1 = mul nsw <vscale x 1 x i32> %m10, %m11
+  %mul2 = mul nsw <vscale x 1 x i32> %m20, %m21
+  %mul3 = mul nsw <vscale x 1 x i32> %m30, %m31
+  %add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
+  %add1 = add <vscale x 1 x i32> %add0, %mul1
+  %add2 = add <vscale x 1 x i32> %add1, %mul2
+  %add3 = add <vscale x 1 x i32> %add2, %mul3
+  ret <vscale x 1 x i32> %add3
+}
+
+define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11,
+; CHECK-LABEL: vmadd_fixed_long:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a0, 32
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmv.v.x v24, a0
+; CHECK-NEXT:    vmadd.vv v8, v10, v24
+; CHECK-NEXT:    vmacc.vv v8, v14, v12
+; CHECK-NEXT:    vmacc.vv v8, v18, v16
+; CHECK-NEXT:    vmacc.vv v8, v22, v20
+; CHECK-NEXT:    ret
+                                   <8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) {
+entry:
+  %mul0 = mul nsw <8 x i32> %m00, %m01
+  %mul1 = mul nsw <8 x i32> %m10, %m11
+  %mul2 = mul nsw <8 x i32> %m20, %m21
+  %mul3 = mul nsw <8 x i32> %m30, %m31
+  %add0 = add <8 x i32> %mul0, splat (i32 32)
+  %add1 = add <8 x i32> %add0, %mul1
+  %add2 = add <8 x i32> %add1, %mul2
+  %add3 = add <8 x i32> %add2, %mul3
+  ret <8 x i32> %add3
+}

@github-actions
Copy link

github-actions bot commented Nov 19, 2025

🐧 Linux x64 Test Results

  • 186425 tests passed
  • 4868 tests skipped

@bababuck
Copy link
Contributor Author

Updated:

  • Check hasVInstructions() rather than hasStdExtV()
  • Removed dead instructions in test

@bababuck bababuck requested a review from topperc November 20, 2025 16:05
@topperc
Copy link
Collaborator

topperc commented Nov 20, 2025

Can we use something other than "Properly" in the title? Upon first reading, I thought this was fixing a miscompile.

@bababuck bababuck changed the title [RISCV] Properly lower multiply-accumulate chains containing a constant [RISCV] Incorporate scalar addends to extend vector multiply accumulate chains Nov 20, 2025
Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants