-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[AArch64] Improve codegen for partial.reduce.add v16i8 -> v2i32 #161833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] Improve codegen for partial.reduce.add v16i8 -> v2i32 #161833
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Sander de Smalen (sdesmalen-arm) ChangesRather than expanding, we can handle this case natively by Full diff: https://github.com/llvm/llvm-project/pull/161833.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 70d5ad7d660f1..056d367a11949 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1458,6 +1458,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v16i8, Custom);
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
if (Subtarget->hasMatMulInt8()) {
@@ -30768,6 +30769,18 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
ResultVT.isFixedLengthVector() &&
useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
+ // We can handle this case natively by accumulating into a wider
+ // zero-padded vector.
+ if (!ConvertToScalable && ResultVT == MVT::v2i32 && OpVT == MVT::v16i8) {
+ SDValue ZeroVec = DAG.getConstant(0, DL, MVT::v4i32);
+ SDValue WideAcc = DAG.getInsertSubvector(DL, ZeroVec, Acc, 0);
+ SDValue Wide = DAG.getNode(Op.getOpcode(), DL, MVT::v4i32,
+ WideAcc, LHS, RHS);
+ SDValue Lo = DAG.getExtractSubvector(DL, MVT::v2i32, Wide, 0);
+ SDValue Hi = DAG.getExtractSubvector(DL, MVT::v2i32, Wide, 2);
+ return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
+ }
+
if (ConvertToScalable) {
ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 428750740fc56..fc9e3c8a52850 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1451,3 +1451,52 @@ define <4 x i32> @partial_reduce_shl_zext_non_const_rhs(<16 x i8> %l, <4 x i32>
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
ret <4 x i32> %red
}
+
+define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) {
+; CHECK-NODOT-LABEL: udot_v16i8tov2i32:
+; CHECK-NODOT: // %bb.0: // %entry
+; CHECK-NODOT-NEXT: ushll v2.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT: ushll v3.4s, v2.4h, #0
+; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT: ushll2 v4.4s, v2.8h, #0
+; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT: ext v3.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT: ext v1.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT: add v0.2s, v2.2s, v0.2s
+; CHECK-NODOT-NEXT: ext v2.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT: add v0.2s, v2.2s, v0.2s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-DOT-LABEL: udot_v16i8tov2i32:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v2.16b, #1
+; CHECK-DOT-NEXT: fmov d0, d0
+; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT: ext v1.16b, v0.16b, v0.16b, #8
+; CHECK-DOT-NEXT: add v0.2s, v0.2s, v1.2s
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-DOT-I8MM-LABEL: udot_v16i8tov2i32:
+; CHECK-DOT-I8MM: // %bb.0: // %entry
+; CHECK-DOT-I8MM-NEXT: movi v2.16b, #1
+; CHECK-DOT-I8MM-NEXT: fmov d0, d0
+; CHECK-DOT-I8MM-NEXT: udot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT: ext v1.16b, v0.16b, v0.16b, #8
+; CHECK-DOT-I8MM-NEXT: add v0.2s, v0.2s, v1.2s
+; CHECK-DOT-I8MM-NEXT: ret
+entry:
+ %input.wide = zext <16 x i8> %input to <16 x i32>
+ %partial.reduce = tail call <2 x i32> @llvm.vector.partial.reduce.add(<2 x i32> %acc, <16 x i32> %input.wide)
+ ret <2 x i32> %partial.reduce
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with the clang-format error fixed, thanks!
Rather than expanding, we can handle this case natively by widening the accumulator.
aebcbc9
to
1dc580e
Compare
Rather than expanding, we can handle this case natively by
widening the accumulator.