-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64][GlobalISel] Add custom legalization for v4s8 = G_TRUNC v4s16 #85610
Conversation
We see a *lot* of fallbacks these days due to <4 x s8> types appearing in truncates, and these seem to be commonly being used by the new load/store bitcasting -> s32 rule. We can keep that load/store rule if we make sure to handle the truncates properly, and we adopt a similar strategy for this custom action as in DAG lowering's LowerTruncateVectorStore(). That is, we first widen the input <4 x s16> to <8 x s16>, so we can generate a legal G_TRUNC to <8 x s8>, and from there extract the final 32 bit sized value.
@llvm/pr-subscribers-backend-aarch64 @llvm/pr-subscribers-llvm-globalisel Author: Amara Emerson (aemerson) ChangesWe see a lot of fallbacks these days due to <4 x s8> types appearing in truncates, and these seem to be commonly being used by the new load/store bitcasting -> s32 rule. We can keep that load/store rule if we make sure to handle the truncates properly, and we adopt a similar strategy for this custom action as in DAG lowering's LowerTruncateVectorStore(). That is, we first widen the input <4 x s16> to <8 x s16>, so we can generate a legal G_TRUNC to <8 x s8>, and from there extract the final 32 bit sized value. Full diff: https://github.com/llvm/llvm-project/pull/85610.diff 4 Files Affected:
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 36adada2796531..04a228cf522ce7 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -628,7 +628,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
return DstTy.isVector() && SrcTy.getSizeInBits() > 128 &&
DstTy.getScalarSizeInBits() * 2 <= SrcTy.getScalarSizeInBits();
})
-
+ .customIf(all(typeInSet(0, {v4s8}),
+ typeInSet(1, {v4s16})))
.alwaysLegal();
getActionDefinitionsBuilder(G_SEXT_INREG)
@@ -1262,11 +1263,45 @@ bool AArch64LegalizerInfo::legalizeCustom(
return legalizeDynStackAlloc(MI, Helper);
case TargetOpcode::G_PREFETCH:
return legalizePrefetch(MI, Helper);
+ case TargetOpcode::G_TRUNC:
+ return legalizeTrunc(MI, Helper);
}
llvm_unreachable("expected switch to return");
}
+bool AArch64LegalizerInfo::legalizeTrunc(MachineInstr &MI,
+ LegalizerHelper &Helper) const {
+ assert(MI.getOpcode() == TargetOpcode::G_TRUNC);
+
+ // Handle <4 x s8> = G_TRUNC <4 x s16> by widening to <8 x s16> first.
+ // So the sequence is:
+ // %orig_val(<4 x s16>) = ...
+ // %wide = G_MERGE_VALUES %orig_val, %undef:_(<4 x s16>)
+ // %wide_trunc:_(<8 x s8>) = G_TRUNC %wide
+ // %bc:_(<2 x s32>) = G_BITCAST %wide_trunc
+ // %eve:_(s32) = G_EXTRACT_VECTOR_ELT %bc, 0
+ // %final:_(<4 x s8>) = G_BITCAST %eve
+
+ MachineIRBuilder &MIB = Helper.MIRBuilder;
+
+ auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
+ assert(DstTy == LLT::fixed_vector(4, LLT::scalar(8)) &&
+ SrcTy == LLT::fixed_vector(4, LLT::scalar(16)));
+
+ auto WideTy = LLT::fixed_vector(8, LLT::scalar(16));
+ auto Undef = MIB.buildUndef(SrcTy);
+ auto Merge = MIB.buildMergeLikeInstr(WideTy, {SrcReg, Undef});
+ auto Trunc = MIB.buildTrunc(LLT::fixed_vector(8, LLT::scalar(8)), Merge);
+ auto BC = MIB.buildBitcast(LLT::fixed_vector(2, LLT::scalar(32)), Trunc);
+ auto Extract = MIB.buildExtractVectorElement(
+ LLT::scalar(32), BC, MIB.buildConstant(LLT::scalar(32), 0));
+ MIB.buildBitcast(DstReg, Extract);
+
+ MI.eraseFromParent();
+ return true;
+}
+
bool AArch64LegalizerInfo::legalizeFunnelShift(MachineInstr &MI,
MachineRegisterInfo &MRI,
MachineIRBuilder &MIRBuilder,
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
index b69d9b015bd2b3..e9d8b54de9ef70 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
@@ -64,6 +64,7 @@ class AArch64LegalizerInfo : public LegalizerInfo {
LegalizerHelper &Helper) const;
bool legalizeDynStackAlloc(MachineInstr &MI, LegalizerHelper &Helper) const;
bool legalizePrefetch(MachineInstr &MI, LegalizerHelper &Helper) const;
+ bool legalizeTrunc(MachineInstr &MI, LegalizerHelper &Helper) const;
const AArch64Subtarget *ST;
};
} // End llvm namespace.
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalize-trunc.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalize-trunc.mir
new file mode 100644
index 00000000000000..e3e12558c2242a
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalize-trunc.mir
@@ -0,0 +1,24 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
+# RUN: llc -mtriple=aarch64 -run-pass=legalizer %s -o - | FileCheck %s
+---
+name: trunc_v4s8_v4s16
+body: |
+ bb.1:
+ liveins: $x0
+ ; CHECK-LABEL: name: trunc_v4s8_v4s16
+ ; CHECK: liveins: $x0
+ ; CHECK-NEXT: {{ $}}
+ ; CHECK-NEXT: %in:_(<4 x s16>) = COPY $x0
+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<4 x s16>) = G_IMPLICIT_DEF
+ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<8 x s16>) = G_CONCAT_VECTORS %in(<4 x s16>), [[DEF]](<4 x s16>)
+ ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(<8 x s8>) = G_TRUNC [[CONCAT_VECTORS]](<8 x s16>)
+ ; CHECK-NEXT: [[BITCAST:%[0-9]+]]:_(<2 x s32>) = G_BITCAST [[TRUNC]](<8 x s8>)
+ ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 0
+ ; CHECK-NEXT: [[EVEC:%[0-9]+]]:_(s32) = G_EXTRACT_VECTOR_ELT [[BITCAST]](<2 x s32>), [[C]](s64)
+ ; CHECK-NEXT: %trunc:_(<4 x s8>) = G_BITCAST [[EVEC]](s32)
+ ; CHECK-NEXT: $s0 = COPY %trunc(<4 x s8>)
+ %in:_(<4 x s16>) = COPY $x0
+ %trunc:_(<4 x s8>) = G_TRUNC %in
+ $s0 = COPY %trunc
+
+...
diff --git a/llvm/test/CodeGen/AArch64/bitcast.ll b/llvm/test/CodeGen/AArch64/bitcast.ll
index 9ebd570e687a01..2b7065fe450617 100644
--- a/llvm/test/CodeGen/AArch64/bitcast.ll
+++ b/llvm/test/CodeGen/AArch64/bitcast.ll
@@ -4,8 +4,7 @@
; PR23065: SCALAR_TO_VECTOR implies the top elements 1 to N-1 of the N-element vector are undefined.
-; CHECK-GI: warning: Instruction selection used fallback path for bitcast_v4i8_i32
-; CHECK-GI-NEXT: warning: Instruction selection used fallback path for bitcast_i32_v4i8
+; CHECK-GI: warning: Instruction selection used fallback path for bitcast_i32_v4i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for bitcast_v2i16_i32
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for bitcast_i32_v2i16
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for bitcast_v2i16_v4i8
@@ -54,15 +53,22 @@ define <4 x i16> @foo2(<2 x i32> %a) {
; ===== To and From Scalar Types =====
define i32 @bitcast_v4i8_i32(<4 x i8> %a, <4 x i8> %b){
-; CHECK-LABEL: bitcast_v4i8_i32:
-; CHECK: // %bb.0:
-; CHECK-NEXT: sub sp, sp, #16
-; CHECK-NEXT: .cfi_def_cfa_offset 16
-; CHECK-NEXT: add v0.4h, v0.4h, v1.4h
-; CHECK-NEXT: uzp1 v0.8b, v0.8b, v0.8b
-; CHECK-NEXT: fmov w0, s0
-; CHECK-NEXT: add sp, sp, #16
-; CHECK-NEXT: ret
+; CHECK-SD-LABEL: bitcast_v4i8_i32:
+; CHECK-SD: // %bb.0:
+; CHECK-SD-NEXT: sub sp, sp, #16
+; CHECK-SD-NEXT: .cfi_def_cfa_offset 16
+; CHECK-SD-NEXT: add v0.4h, v0.4h, v1.4h
+; CHECK-SD-NEXT: uzp1 v0.8b, v0.8b, v0.8b
+; CHECK-SD-NEXT: fmov w0, s0
+; CHECK-SD-NEXT: add sp, sp, #16
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: bitcast_v4i8_i32:
+; CHECK-GI: // %bb.0:
+; CHECK-GI-NEXT: add v0.4h, v0.4h, v1.4h
+; CHECK-GI-NEXT: uzp1 v0.8b, v0.8b, v0.8b
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
%c = add <4 x i8> %a, %b
%d = bitcast <4 x i8> %c to i32
ret i32 %d
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
auto Trunc = MIB.buildTrunc(LLT::fixed_vector(8, LLT::scalar(8)), Merge); | ||
auto BC = MIB.buildBitcast(LLT::fixed_vector(2, LLT::scalar(32)), Trunc); | ||
auto Extract = MIB.buildExtractVectorElement( | ||
LLT::scalar(32), BC, MIB.buildConstant(LLT::scalar(32), 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The index should be s64, see mir test.
It looks complicated. InstCombine combines bitcasts on the vector register of extractVectorElement,
|
Hi. I had been looking at v4i8 truncate again recently, and had assumed that we would moreElements them. It had some inefficiencies that was stopping me from putting up the patch though, although my attempts to fix them had only led to more problems so far.
My understanding was that until we fixed v4i8 load/store recently, these would have fallen back due to the load/store? Sounds like we are moving in the right direction. Whatever we do it should ideally handle other small types too - v2i8 and v2i16. I can put up the moreElements version if it is useful, I was hoping that the extra merge/unmerges introduced could all be removed. My worry with extra bitcast is that they get in the way of optimizations (especially under BE), but with enough combines either can probably be made to work cleanly I would hope. |
This is an alternative to llvm#85610, that moreElement's small G_TRUNC vectors to widen the vectors. It needs to disable one of the existing Unmerge(Trunc(..)) combines, and some of the code is not as optimal as it could be. I believe with some extra optimizations it could look better (I was thinking combining trunc(buildvector) -> buildvector and possibly improving buildvector lowering by generating insert_vector_element earlier).
I've put up #85625 that shows what it does. I'm not sure which is better in the long run between that method and this one, but I think it would be better to add support for the other types if we do use bitcasts. |
I think your approach is fine as long as restricting the combine didn't pessimism too much, which it didn't look like it did. |
This is an alternative to #85610, that moreElement's small G_TRUNC vectors to widen the vectors. It needs to disable one of the existing Unmerge(Trunc(..)) combines, and some of the code is not as optimal as it could be. I believe with some extra optimizations it could look better (I was thinking combining trunc(buildvector) -> buildvector and possibly improving buildvector lowering by generating insert_vector_element earlier).
This is an alternative to llvm#85610, that moreElement's small G_TRUNC vectors to widen the vectors. It needs to disable one of the existing Unmerge(Trunc(..)) combines, and some of the code is not as optimal as it could be. I believe with some extra optimizations it could look better (I was thinking combining trunc(buildvector) -> buildvector and possibly improving buildvector lowering by generating insert_vector_element earlier).
We see a lot of fallbacks these days due to <4 x s8> types appearing in truncates, and these seem to be commonly being used by the new load/store bitcasting -> s32 rule.
We can keep that load/store rule if we make sure to handle the truncates properly, and we adopt a similar strategy for this custom action as in DAG lowering's LowerTruncateVectorStore(). That is, we first widen the input <4 x s16> to <8 x s16>, so we can generate a legal G_TRUNC to <8 x s8>, and from there extract the final 32 bit sized value.