Skip to content

Conversation

@sdesmalen-arm
Copy link
Collaborator

This generates more optimal codegen when using partial reductions with predication.

partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
-> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)

partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
-> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2025

@llvm/pr-subscribers-backend-risc-v

@llvm/pr-subscribers-llvm-selectiondag

Author: Sander de Smalen (sdesmalen-arm)

Changes

This generates more optimal codegen when using partial reductions with predication.

partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
-> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)

partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
-> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))

Full diff: https://github.com/llvm/llvm-project/pull/167857.diff

2 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+59-13)
  • (added) llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll (+159)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index df353c4d91b1a..c21b79225323e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13018,22 +13018,37 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   return SDValue();
 }
 
-// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
+// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1))
 // -> partial_reduce_*mla(acc, a, b)
 //
-// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-// -> partial_reduce_*mla(acc, x, C)
+// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, splat(C))
 //
-// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
-// -> partial_reduce_fmla(acc, a, b)
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
+//
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C))
 SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDLoc DL(N);
   auto *Context = DAG.getContext();
   SDValue Acc = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
   SDValue Op2 = N->getOperand(2);
-
   unsigned Opc = Op1->getOpcode();
+
+  // Handle predication by moving the SELECT into the operand of the MUL.
+  SDValue Pred;
+  if (Opc == ISD::VSELECT) {
+    APInt C;
+    if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+        !C.isZero())
+      return SDValue();
+    Pred = Op1->getOperand(0);
+    Op1 = Op1->getOperand(1);
+    Opc = Op1->getOpcode();
+  }
+
   if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
     return SDValue();
 
@@ -13068,6 +13083,18 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue LHSExtOp = LHS->getOperand(0);
   EVT LHSExtOpVT = LHSExtOp.getValueType();
 
+  // Return 'select(P, Op, splat(0))' if P is nonzero,
+  // or 'P' otherwise.
+  auto tryPredicate = [&](SDValue P, SDValue Op) {
+    if (!P)
+      return Op;
+    EVT OpVT = Op.getValueType();
+    SDValue Zero = OpVT.isFloatingPoint()
+                       ? DAG.getConstantFP(0.0, DL, OpVT)
+                       : DAG.getConstant(0, DL, OpVT);
+    return DAG.getSelect(DL, OpVT, P, Op, Zero);
+  };
+
   // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
   // -> partial_reduce_*mla(acc, x, C)
   APInt C;
@@ -13090,8 +13117,10 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
             TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
       return SDValue();
 
+    SDValue Constant =
+        tryPredicate(Pred, DAG.getConstant(CTrunc, DL, LHSExtOpVT));
     return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
-                       DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+                       Constant);
   }
 
   unsigned RHSOpcode = RHS->getOpcode();
@@ -13132,17 +13161,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
+  RHSExtOp = tryPredicate(Pred, RHSExtOp);
   return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
 }
 
-// partial.reduce.umla(acc, zext(op), splat(1))
-// -> partial.reduce.umla(acc, op, splat(trunc(1)))
-// partial.reduce.smla(acc, sext(op), splat(1))
-// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.*mla(acc, *ext(op), splat(1))
+// -> partial.reduce.*mla(acc, op, splat(trunc(1)))
 // partial.reduce.sumla(acc, sext(op), splat(1))
 // -> partial.reduce.smla(acc, op, splat(trunc(1)))
-// partial.reduce.fmla(acc, fpext(op), splat(1.0))
-// -> partial.reduce.fmla(acc, op, splat(1.0))
+//
+// partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
+// -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
 SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
   SDLoc DL(N);
   SDValue Acc = N->getOperand(0);
@@ -13152,7 +13181,18 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
   if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
     return SDValue();
 
+  SDValue Pred;
   unsigned Op1Opcode = Op1.getOpcode();
+  if (Op1Opcode == ISD::VSELECT) {
+    APInt C;
+    if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+        !C.isZero())
+      return SDValue();
+    Pred = Op1->getOperand(0);
+    Op1 = Op1->getOperand(1);
+    Op1Opcode = Op1->getOpcode();
+  }
+
   if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
     return SDValue();
 
@@ -13181,6 +13221,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
                          ? DAG.getConstantFP(1, DL, UnextOp1VT)
                          : DAG.getConstant(1, DL, UnextOp1VT);
 
+  if (Pred) {
+    SDValue Zero = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+                         ? DAG.getConstantFP(0, DL, UnextOp1VT)
+                         : DAG.getConstant(0, DL, UnextOp1VT);
+    Constant = DAG.getSelect(DL, UnextOp1VT, Pred, Constant, Zero);
+  }
   return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
                      Constant);
 }
diff --git a/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
new file mode 100644
index 0000000000000..24cdd0a852222
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
@@ -0,0 +1,159 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s | FileCheck %s
+
+target triple = "aarch64"
+
+define <4 x i32> @predicate_dot_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a, <16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v1.16b, v1.16b, #7
+; CHECK-NEXT:    cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %ext.2 = sext <16 x i8> %b to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, %ext.2
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <4 x i32> @predicate_dot_by_C_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v1.16b, v1.16b, #7
+; CHECK-NEXT:    movi v3.16b, #127
+; CHECK-NEXT:    cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, splat(i32 127)
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEXT:    sel z2.b, p0, z2.b, z3.b
+; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %ext.2 = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, %ext.2
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_by_C_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z2.b, p0/z, #127 // =0x7f
+; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, splat(i32 127)
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x i32> @predicate_ext_mul_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v3.16b, #1
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+ %ext = sext <16 x i8> %a to <16 x i32>
+ %sel = select <16 x i1> %p, <16 x i32> %ext, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_ext_mul_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z2.b, p0/z, #1 // =0x1
+; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+ %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %ext, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x float> @predicated_fdot_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a, <8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    shl v1.8h, v1.8h, #15
+; CHECK-NEXT:    cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    fdot z0.s, z2.h, z1.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+ %ext.1 = fpext <8 x half> %a to <8 x float>
+ %ext.2 = fpext <8 x half> %b to <8 x float>
+ %mul = fmul <8 x float> %ext.1, %ext.2
+ %sel = select <8 x i1> %p, <8 x float> %mul, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fdot_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEXT:    sel z2.h, p0, z2.h, z3.h
+; CHECK-NEXT:    fdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+ %ext.1 = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %ext.2 = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+ %mul = fmul <vscale x 8 x float> %ext.1, %ext.2
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %mul, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+define <4 x float> @predicated_fpext_fmul_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    movi v3.8h, #60, lsl #8
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    shl v1.8h, v1.8h, #15
+; CHECK-NEXT:    cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    fdot z0.s, z2.h, z1.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+ %ext = fpext <8 x half> %a to <8 x float>
+ %sel = select <8 x i1> %p, <8 x float> %ext, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fpext_fmul_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-NEXT:    fmov z2.h, p0/m, #1.00000000
+; CHECK-NEXT:    fdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+ %ext = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %ext, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+attributes #0 = { nounwind "target-features"="+sve,+dotprod" }
+attributes #1 = { nounwind "target-features"="+sve2p1,+dotprod" }

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Sander de Smalen (sdesmalen-arm)

Changes

This generates more optimal codegen when using partial reductions with predication.

partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
-&gt; partial_reduce_*mla(acc, sel(p, a, splat(0)), b)

partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
-&gt; partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))

Full diff: https://github.com/llvm/llvm-project/pull/167857.diff

2 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+59-13)
  • (added) llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll (+159)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index df353c4d91b1a..c21b79225323e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13018,22 +13018,37 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   return SDValue();
 }
 
-// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
+// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1))
 // -> partial_reduce_*mla(acc, a, b)
 //
-// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-// -> partial_reduce_*mla(acc, x, C)
+// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, splat(C))
 //
-// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
-// -> partial_reduce_fmla(acc, a, b)
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
+//
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C))
 SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDLoc DL(N);
   auto *Context = DAG.getContext();
   SDValue Acc = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
   SDValue Op2 = N->getOperand(2);
-
   unsigned Opc = Op1->getOpcode();
+
+  // Handle predication by moving the SELECT into the operand of the MUL.
+  SDValue Pred;
+  if (Opc == ISD::VSELECT) {
+    APInt C;
+    if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+        !C.isZero())
+      return SDValue();
+    Pred = Op1->getOperand(0);
+    Op1 = Op1->getOperand(1);
+    Opc = Op1->getOpcode();
+  }
+
   if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
     return SDValue();
 
@@ -13068,6 +13083,18 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue LHSExtOp = LHS->getOperand(0);
   EVT LHSExtOpVT = LHSExtOp.getValueType();
 
+  // Return 'select(P, Op, splat(0))' if P is nonzero,
+  // or 'P' otherwise.
+  auto tryPredicate = [&](SDValue P, SDValue Op) {
+    if (!P)
+      return Op;
+    EVT OpVT = Op.getValueType();
+    SDValue Zero = OpVT.isFloatingPoint()
+                       ? DAG.getConstantFP(0.0, DL, OpVT)
+                       : DAG.getConstant(0, DL, OpVT);
+    return DAG.getSelect(DL, OpVT, P, Op, Zero);
+  };
+
   // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
   // -> partial_reduce_*mla(acc, x, C)
   APInt C;
@@ -13090,8 +13117,10 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
             TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
       return SDValue();
 
+    SDValue Constant =
+        tryPredicate(Pred, DAG.getConstant(CTrunc, DL, LHSExtOpVT));
     return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
-                       DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+                       Constant);
   }
 
   unsigned RHSOpcode = RHS->getOpcode();
@@ -13132,17 +13161,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
+  RHSExtOp = tryPredicate(Pred, RHSExtOp);
   return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
 }
 
-// partial.reduce.umla(acc, zext(op), splat(1))
-// -> partial.reduce.umla(acc, op, splat(trunc(1)))
-// partial.reduce.smla(acc, sext(op), splat(1))
-// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.*mla(acc, *ext(op), splat(1))
+// -> partial.reduce.*mla(acc, op, splat(trunc(1)))
 // partial.reduce.sumla(acc, sext(op), splat(1))
 // -> partial.reduce.smla(acc, op, splat(trunc(1)))
-// partial.reduce.fmla(acc, fpext(op), splat(1.0))
-// -> partial.reduce.fmla(acc, op, splat(1.0))
+//
+// partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
+// -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
 SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
   SDLoc DL(N);
   SDValue Acc = N->getOperand(0);
@@ -13152,7 +13181,18 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
   if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
     return SDValue();
 
+  SDValue Pred;
   unsigned Op1Opcode = Op1.getOpcode();
+  if (Op1Opcode == ISD::VSELECT) {
+    APInt C;
+    if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+        !C.isZero())
+      return SDValue();
+    Pred = Op1->getOperand(0);
+    Op1 = Op1->getOperand(1);
+    Op1Opcode = Op1->getOpcode();
+  }
+
   if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
     return SDValue();
 
@@ -13181,6 +13221,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
                          ? DAG.getConstantFP(1, DL, UnextOp1VT)
                          : DAG.getConstant(1, DL, UnextOp1VT);
 
+  if (Pred) {
+    SDValue Zero = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+                         ? DAG.getConstantFP(0, DL, UnextOp1VT)
+                         : DAG.getConstant(0, DL, UnextOp1VT);
+    Constant = DAG.getSelect(DL, UnextOp1VT, Pred, Constant, Zero);
+  }
   return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
                      Constant);
 }
diff --git a/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
new file mode 100644
index 0000000000000..24cdd0a852222
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
@@ -0,0 +1,159 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s | FileCheck %s
+
+target triple = "aarch64"
+
+define <4 x i32> @predicate_dot_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a, <16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v1.16b, v1.16b, #7
+; CHECK-NEXT:    cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %ext.2 = sext <16 x i8> %b to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, %ext.2
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <4 x i32> @predicate_dot_by_C_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v1.16b, v1.16b, #7
+; CHECK-NEXT:    movi v3.16b, #127
+; CHECK-NEXT:    cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, splat(i32 127)
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEXT:    sel z2.b, p0, z2.b, z3.b
+; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %ext.2 = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, %ext.2
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_by_C_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z2.b, p0/z, #127 // =0x7f
+; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, splat(i32 127)
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x i32> @predicate_ext_mul_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v3.16b, #1
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+ %ext = sext <16 x i8> %a to <16 x i32>
+ %sel = select <16 x i1> %p, <16 x i32> %ext, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_ext_mul_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z2.b, p0/z, #1 // =0x1
+; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+ %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %ext, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x float> @predicated_fdot_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a, <8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    shl v1.8h, v1.8h, #15
+; CHECK-NEXT:    cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    fdot z0.s, z2.h, z1.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+ %ext.1 = fpext <8 x half> %a to <8 x float>
+ %ext.2 = fpext <8 x half> %b to <8 x float>
+ %mul = fmul <8 x float> %ext.1, %ext.2
+ %sel = select <8 x i1> %p, <8 x float> %mul, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fdot_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEXT:    sel z2.h, p0, z2.h, z3.h
+; CHECK-NEXT:    fdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+ %ext.1 = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %ext.2 = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+ %mul = fmul <vscale x 8 x float> %ext.1, %ext.2
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %mul, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+define <4 x float> @predicated_fpext_fmul_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    movi v3.8h, #60, lsl #8
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    shl v1.8h, v1.8h, #15
+; CHECK-NEXT:    cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    fdot z0.s, z2.h, z1.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+ %ext = fpext <8 x half> %a to <8 x float>
+ %sel = select <8 x i1> %p, <8 x float> %ext, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fpext_fmul_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-NEXT:    fmov z2.h, p0/m, #1.00000000
+; CHECK-NEXT:    fdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+ %ext = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %ext, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+attributes #0 = { nounwind "target-features"="+sve,+dotprod" }
+attributes #1 = { nounwind "target-features"="+sve2p1,+dotprod" }

@github-actions
Copy link

github-actions bot commented Nov 13, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@sdesmalen-arm sdesmalen-arm force-pushed the users/sdesmalen-arm/predicated-partial-reduce-add branch from 14e935c to a126fe7 Compare November 13, 2025 10:10
Base automatically changed from users/sdesmalen-arm/partial-reduce-fdot to main November 13, 2025 10:50
lukel97 added a commit that referenced this pull request Nov 13, 2025
@lukel97
Copy link
Contributor

lukel97 commented Nov 13, 2025

I've precommitted a RISC-V test case here 2a53949, can you rebase this PR on top of it?

This generates more optimal codegen when using partial reductions
with predication.

partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
-> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)

partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
-> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
@sdesmalen-arm sdesmalen-arm force-pushed the users/sdesmalen-arm/predicated-partial-reduce-add branch from a126fe7 to 7506ebc Compare November 13, 2025 15:02
Comment on lines 13042 to 13046
if (Opc == ISD::VSELECT) {
APInt C;
if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
!C.isZero())
return SDValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (Opc == ISD::VSELECT) {
APInt C;
if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
!C.isZero())
return SDValue();
if (Opc == ISD::VSELECT && isNullOrNullSplat(Op1->getOperand(2))) {

Comment on lines 13187 to 13191
if (Op1Opcode == ISD::VSELECT) {
APInt C;
if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
!C.isZero())
return SDValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (Op1Opcode == ISD::VSELECT) {
APInt C;
if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
!C.isZero())
return SDValue();
if (Op1Opcode == ISD::VSELECT && isNullOrNullSplat(Op1->getOperand(2))) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems isNullOrNullSplat does not recognise FP constants, so I've added a similar FP variant (similar to what was done for isOneOrOneSplatFP)

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@sdesmalen-arm sdesmalen-arm merged commit f369a53 into main Nov 18, 2025
10 checks passed
@sdesmalen-arm sdesmalen-arm deleted the users/sdesmalen-arm/predicated-partial-reduce-add branch November 18, 2025 09:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants