From 6694ccd3c2e78ff3b0cac20c10e73dab134d1e8f Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Fri, 3 Oct 2025 11:35:07 +0100 Subject: [PATCH 1/4] Precommit test --- .../neon-partial-reduce-dot-product.ll | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll index 428750740fc56..824a3708451ba 100644 --- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll @@ -1451,3 +1451,34 @@ define <4 x i32> @partial_reduce_shl_zext_non_const_rhs(<16 x i8> %l, <4 x i32> %red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift) ret <4 x i32> %red } + +define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) { +; CHECK-COMMON-LABEL: udot_v16i8tov2i32: +; CHECK-COMMON: // %bb.0: // %entry +; CHECK-COMMON-NEXT: ushll v2.8h, v1.8b, #0 +; CHECK-COMMON-NEXT: // kill: def $d0 killed $d0 def $q0 +; CHECK-COMMON-NEXT: ushll2 v1.8h, v1.16b, #0 +; CHECK-COMMON-NEXT: ushll v3.4s, v2.4h, #0 +; CHECK-COMMON-NEXT: uaddw v0.4s, v0.4s, v2.4h +; CHECK-COMMON-NEXT: ushll2 v4.4s, v2.8h, #0 +; CHECK-COMMON-NEXT: ext v2.16b, v2.16b, v2.16b, #8 +; CHECK-COMMON-NEXT: ext v3.16b, v3.16b, v3.16b, #8 +; CHECK-COMMON-NEXT: add v0.2s, v3.2s, v0.2s +; CHECK-COMMON-NEXT: ext v3.16b, v4.16b, v4.16b, #8 +; CHECK-COMMON-NEXT: uaddw v0.4s, v0.4s, v2.4h +; CHECK-COMMON-NEXT: ushll v2.4s, v1.4h, #0 +; CHECK-COMMON-NEXT: add v0.2s, v3.2s, v0.2s +; CHECK-COMMON-NEXT: ext v2.16b, v2.16b, v2.16b, #8 +; CHECK-COMMON-NEXT: ushll2 v3.4s, v1.8h, #0 +; CHECK-COMMON-NEXT: uaddw v0.4s, v0.4s, v1.4h +; CHECK-COMMON-NEXT: ext v1.16b, v1.16b, v1.16b, #8 +; CHECK-COMMON-NEXT: add v0.2s, v2.2s, v0.2s +; CHECK-COMMON-NEXT: ext v2.16b, v3.16b, v3.16b, #8 +; CHECK-COMMON-NEXT: uaddw v0.4s, v0.4s, v1.4h +; CHECK-COMMON-NEXT: add v0.2s, v2.2s, v0.2s +; CHECK-COMMON-NEXT: ret +entry: + %input.wide = zext <16 x i8> %input to <16 x i32> + %partial.reduce = tail call <2 x i32> @llvm.vector.partial.reduce.add(<2 x i32> %acc, <16 x i32> %input.wide) + ret <2 x i32> %partial.reduce +} From 3455b405cbcd1eb1737784a17261456d7fb1f7aa Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Fri, 3 Oct 2025 11:36:13 +0100 Subject: [PATCH 2/4] [AArch64] Improve codegen for partial.reduce.add v16i8 -> v2i32 Rather than expanding, we can handle this case natively by widening the accumulator. --- .../Target/AArch64/AArch64ISelLowering.cpp | 13 ++++ .../neon-partial-reduce-dot-product.ll | 66 ++++++++++++------- 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index dc8e7c84f5e2c..a4b71dae68bb8 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1458,6 +1458,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal); setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal); + setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v16i8, Custom); setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom); if (Subtarget->hasMatMulInt8()) { @@ -30769,6 +30770,18 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, ResultVT.isFixedLengthVector() && useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true); + // We can handle this case natively by accumulating into a wider + // zero-padded vector. + if (!ConvertToScalable && ResultVT == MVT::v2i32 && OpVT == MVT::v16i8) { + SDValue ZeroVec = DAG.getConstant(0, DL, MVT::v4i32); + SDValue WideAcc = DAG.getInsertSubvector(DL, ZeroVec, Acc, 0); + SDValue Wide = DAG.getNode(Op.getOpcode(), DL, MVT::v4i32, + WideAcc, LHS, RHS); + SDValue Lo = DAG.getExtractSubvector(DL, MVT::v2i32, Wide, 0); + SDValue Hi = DAG.getExtractSubvector(DL, MVT::v2i32, Wide, 2); + return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi); + } + if (ConvertToScalable) { ResultVT = getContainerForFixedLengthVector(DAG, ResultVT); OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType()); diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll index 824a3708451ba..fc9e3c8a52850 100644 --- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll @@ -1453,30 +1453,48 @@ define <4 x i32> @partial_reduce_shl_zext_non_const_rhs(<16 x i8> %l, <4 x i32> } define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) { -; CHECK-COMMON-LABEL: udot_v16i8tov2i32: -; CHECK-COMMON: // %bb.0: // %entry -; CHECK-COMMON-NEXT: ushll v2.8h, v1.8b, #0 -; CHECK-COMMON-NEXT: // kill: def $d0 killed $d0 def $q0 -; CHECK-COMMON-NEXT: ushll2 v1.8h, v1.16b, #0 -; CHECK-COMMON-NEXT: ushll v3.4s, v2.4h, #0 -; CHECK-COMMON-NEXT: uaddw v0.4s, v0.4s, v2.4h -; CHECK-COMMON-NEXT: ushll2 v4.4s, v2.8h, #0 -; CHECK-COMMON-NEXT: ext v2.16b, v2.16b, v2.16b, #8 -; CHECK-COMMON-NEXT: ext v3.16b, v3.16b, v3.16b, #8 -; CHECK-COMMON-NEXT: add v0.2s, v3.2s, v0.2s -; CHECK-COMMON-NEXT: ext v3.16b, v4.16b, v4.16b, #8 -; CHECK-COMMON-NEXT: uaddw v0.4s, v0.4s, v2.4h -; CHECK-COMMON-NEXT: ushll v2.4s, v1.4h, #0 -; CHECK-COMMON-NEXT: add v0.2s, v3.2s, v0.2s -; CHECK-COMMON-NEXT: ext v2.16b, v2.16b, v2.16b, #8 -; CHECK-COMMON-NEXT: ushll2 v3.4s, v1.8h, #0 -; CHECK-COMMON-NEXT: uaddw v0.4s, v0.4s, v1.4h -; CHECK-COMMON-NEXT: ext v1.16b, v1.16b, v1.16b, #8 -; CHECK-COMMON-NEXT: add v0.2s, v2.2s, v0.2s -; CHECK-COMMON-NEXT: ext v2.16b, v3.16b, v3.16b, #8 -; CHECK-COMMON-NEXT: uaddw v0.4s, v0.4s, v1.4h -; CHECK-COMMON-NEXT: add v0.2s, v2.2s, v0.2s -; CHECK-COMMON-NEXT: ret +; CHECK-NODOT-LABEL: udot_v16i8tov2i32: +; CHECK-NODOT: // %bb.0: // %entry +; CHECK-NODOT-NEXT: ushll v2.8h, v1.8b, #0 +; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0 +; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0 +; CHECK-NODOT-NEXT: ushll v3.4s, v2.4h, #0 +; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h +; CHECK-NODOT-NEXT: ushll2 v4.4s, v2.8h, #0 +; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8 +; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8 +; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s +; CHECK-NODOT-NEXT: ext v3.16b, v4.16b, v4.16b, #8 +; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h +; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0 +; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s +; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8 +; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0 +; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h +; CHECK-NODOT-NEXT: ext v1.16b, v1.16b, v1.16b, #8 +; CHECK-NODOT-NEXT: add v0.2s, v2.2s, v0.2s +; CHECK-NODOT-NEXT: ext v2.16b, v3.16b, v3.16b, #8 +; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h +; CHECK-NODOT-NEXT: add v0.2s, v2.2s, v0.2s +; CHECK-NODOT-NEXT: ret +; +; CHECK-DOT-LABEL: udot_v16i8tov2i32: +; CHECK-DOT: // %bb.0: // %entry +; CHECK-DOT-NEXT: movi v2.16b, #1 +; CHECK-DOT-NEXT: fmov d0, d0 +; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b +; CHECK-DOT-NEXT: ext v1.16b, v0.16b, v0.16b, #8 +; CHECK-DOT-NEXT: add v0.2s, v0.2s, v1.2s +; CHECK-DOT-NEXT: ret +; +; CHECK-DOT-I8MM-LABEL: udot_v16i8tov2i32: +; CHECK-DOT-I8MM: // %bb.0: // %entry +; CHECK-DOT-I8MM-NEXT: movi v2.16b, #1 +; CHECK-DOT-I8MM-NEXT: fmov d0, d0 +; CHECK-DOT-I8MM-NEXT: udot v0.4s, v1.16b, v2.16b +; CHECK-DOT-I8MM-NEXT: ext v1.16b, v0.16b, v0.16b, #8 +; CHECK-DOT-I8MM-NEXT: add v0.2s, v0.2s, v1.2s +; CHECK-DOT-I8MM-NEXT: ret entry: %input.wide = zext <16 x i8> %input to <16 x i32> %partial.reduce = tail call <2 x i32> @llvm.vector.partial.reduce.add(<2 x i32> %acc, <16 x i32> %input.wide) From 2a324f33e855b408acf582b18f45d560e80b538d Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Fri, 3 Oct 2025 14:18:58 +0100 Subject: [PATCH 3/4] Use addp instead of ext + add --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 5 ++--- .../CodeGen/AArch64/neon-partial-reduce-dot-product.ll | 8 ++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a4b71dae68bb8..cdd074dcf839f 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -30777,9 +30777,8 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, SDValue WideAcc = DAG.getInsertSubvector(DL, ZeroVec, Acc, 0); SDValue Wide = DAG.getNode(Op.getOpcode(), DL, MVT::v4i32, WideAcc, LHS, RHS); - SDValue Lo = DAG.getExtractSubvector(DL, MVT::v2i32, Wide, 0); - SDValue Hi = DAG.getExtractSubvector(DL, MVT::v2i32, Wide, 2); - return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi); + SDValue Reduced = DAG.getNode(AArch64ISD::ADDP, DL, MVT::v4i32, Wide, Wide); + return DAG.getExtractSubvector(DL, MVT::v2i32, Reduced, 0); } if (ConvertToScalable) { diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll index fc9e3c8a52850..dfff35d9eb1b2 100644 --- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll @@ -1483,8 +1483,8 @@ define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) { ; CHECK-DOT-NEXT: movi v2.16b, #1 ; CHECK-DOT-NEXT: fmov d0, d0 ; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b -; CHECK-DOT-NEXT: ext v1.16b, v0.16b, v0.16b, #8 -; CHECK-DOT-NEXT: add v0.2s, v0.2s, v1.2s +; CHECK-DOT-NEXT: addp v0.4s, v0.4s, v0.4s +; CHECK-DOT-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-DOT-NEXT: ret ; ; CHECK-DOT-I8MM-LABEL: udot_v16i8tov2i32: @@ -1492,8 +1492,8 @@ define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) { ; CHECK-DOT-I8MM-NEXT: movi v2.16b, #1 ; CHECK-DOT-I8MM-NEXT: fmov d0, d0 ; CHECK-DOT-I8MM-NEXT: udot v0.4s, v1.16b, v2.16b -; CHECK-DOT-I8MM-NEXT: ext v1.16b, v0.16b, v0.16b, #8 -; CHECK-DOT-I8MM-NEXT: add v0.2s, v0.2s, v1.2s +; CHECK-DOT-I8MM-NEXT: addp v0.4s, v0.4s, v0.4s +; CHECK-DOT-I8MM-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-DOT-I8MM-NEXT: ret entry: %input.wide = zext <16 x i8> %input to <16 x i32> From 1dc580e3091dea79a384b714718595f17f003822 Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Thu, 9 Oct 2025 12:43:52 +0000 Subject: [PATCH 4/4] clang-format --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index cdd074dcf839f..31b3d1807933b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -30775,8 +30775,8 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, if (!ConvertToScalable && ResultVT == MVT::v2i32 && OpVT == MVT::v16i8) { SDValue ZeroVec = DAG.getConstant(0, DL, MVT::v4i32); SDValue WideAcc = DAG.getInsertSubvector(DL, ZeroVec, Acc, 0); - SDValue Wide = DAG.getNode(Op.getOpcode(), DL, MVT::v4i32, - WideAcc, LHS, RHS); + SDValue Wide = + DAG.getNode(Op.getOpcode(), DL, MVT::v4i32, WideAcc, LHS, RHS); SDValue Reduced = DAG.getNode(AArch64ISD::ADDP, DL, MVT::v4i32, Wide, Wide); return DAG.getExtractSubvector(DL, MVT::v2i32, Reduced, 0); }