Skip to content

Conversation

sdesmalen-arm
Copy link
Collaborator

Rather than expanding, we can handle this case natively by
widening the accumulator.

@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Sander de Smalen (sdesmalen-arm)

Changes

Rather than expanding, we can handle this case natively by
widening the accumulator.


Full diff: https://github.com/llvm/llvm-project/pull/161833.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+13)
  • (modified) llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll (+49)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 70d5ad7d660f1..056d367a11949 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1458,6 +1458,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
       setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
       setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
+      setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v16i8, Custom);
       setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
 
       if (Subtarget->hasMatMulInt8()) {
@@ -30768,6 +30769,18 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
       ResultVT.isFixedLengthVector() &&
       useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
 
+  // We can handle this case natively by accumulating into a wider
+  // zero-padded vector.
+  if (!ConvertToScalable && ResultVT == MVT::v2i32 && OpVT == MVT::v16i8) {
+    SDValue ZeroVec = DAG.getConstant(0, DL, MVT::v4i32);
+    SDValue WideAcc = DAG.getInsertSubvector(DL, ZeroVec, Acc, 0);
+    SDValue Wide = DAG.getNode(Op.getOpcode(), DL, MVT::v4i32,
+                               WideAcc, LHS, RHS);
+    SDValue Lo = DAG.getExtractSubvector(DL, MVT::v2i32, Wide, 0);
+    SDValue Hi = DAG.getExtractSubvector(DL, MVT::v2i32, Wide, 2);
+    return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
+  }
+
   if (ConvertToScalable) {
     ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
     OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 428750740fc56..fc9e3c8a52850 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1451,3 +1451,52 @@ define <4 x i32> @partial_reduce_shl_zext_non_const_rhs(<16 x i8> %l, <4 x i32>
   %red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
   ret <4 x i32> %red
 }
+
+define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) {
+; CHECK-NODOT-LABEL: udot_v16i8tov2i32:
+; CHECK-NODOT:       // %bb.0: // %entry
+; CHECK-NODOT-NEXT:    ushll v2.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT:    ushll v3.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT:    ushll2 v4.4s, v2.8h, #0
+; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    ext v3.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT:    ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT:    ext v1.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    add v0.2s, v2.2s, v0.2s
+; CHECK-NODOT-NEXT:    ext v2.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT:    add v0.2s, v2.2s, v0.2s
+; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-DOT-LABEL: udot_v16i8tov2i32:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v2.16b, #1
+; CHECK-DOT-NEXT:    fmov d0, d0
+; CHECK-DOT-NEXT:    udot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
+; CHECK-DOT-NEXT:    add v0.2s, v0.2s, v1.2s
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-DOT-I8MM-LABEL: udot_v16i8tov2i32:
+; CHECK-DOT-I8MM:       // %bb.0: // %entry
+; CHECK-DOT-I8MM-NEXT:    movi v2.16b, #1
+; CHECK-DOT-I8MM-NEXT:    fmov d0, d0
+; CHECK-DOT-I8MM-NEXT:    udot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
+; CHECK-DOT-I8MM-NEXT:    add v0.2s, v0.2s, v1.2s
+; CHECK-DOT-I8MM-NEXT:    ret
+entry:
+    %input.wide = zext <16 x i8> %input to <16 x i32>
+    %partial.reduce = tail call <2 x i32> @llvm.vector.partial.reduce.add(<2 x i32> %acc, <16 x i32> %input.wide)
+    ret <2 x i32> %partial.reduce
+}

Copy link

github-actions bot commented Oct 3, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with the clang-format error fixed, thanks!

@sdesmalen-arm sdesmalen-arm force-pushed the improve-partial-reduce-v2i32 branch from aebcbc9 to 1dc580e Compare October 9, 2025 14:18
@sdesmalen-arm sdesmalen-arm merged commit e160b2a into llvm:main Oct 9, 2025
9 checks passed
svkeerthy pushed a commit that referenced this pull request Oct 9, 2025
)

Rather than expanding, we can handle this case natively by
widening the accumulator.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants