Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][GlobalISel] Add custom legalization for v4s8 = G_TRUNC v4s16 #85610

Closed
wants to merge 3 commits into from

Conversation

aemerson
Copy link
Contributor

We see a lot of fallbacks these days due to <4 x s8> types appearing in truncates, and these seem to be commonly being used by the new load/store bitcasting -> s32 rule.

We can keep that load/store rule if we make sure to handle the truncates properly, and we adopt a similar strategy for this custom action as in DAG lowering's LowerTruncateVectorStore(). That is, we first widen the input <4 x s16> to <8 x s16>, so we can generate a legal G_TRUNC to <8 x s8>, and from there extract the final 32 bit sized value.

We see a *lot* of fallbacks these days due to <4 x s8> types appearing in truncates,
and these seem to be commonly being used by the new load/store bitcasting -> s32 rule.

We can keep that load/store rule if we make sure to handle the truncates
properly, and we adopt a similar strategy for this custom action as in
DAG lowering's LowerTruncateVectorStore(). That is, we first widen the
input <4 x s16> to <8 x s16>, so we can generate a legal G_TRUNC to
<8 x s8>, and from there extract the final 32 bit sized value.
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 18, 2024

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-globalisel

Author: Amara Emerson (aemerson)

Changes

We see a lot of fallbacks these days due to <4 x s8> types appearing in truncates, and these seem to be commonly being used by the new load/store bitcasting -> s32 rule.

We can keep that load/store rule if we make sure to handle the truncates properly, and we adopt a similar strategy for this custom action as in DAG lowering's LowerTruncateVectorStore(). That is, we first widen the input <4 x s16> to <8 x s16>, so we can generate a legal G_TRUNC to <8 x s8>, and from there extract the final 32 bit sized value.


Full diff: https://github.com/llvm/llvm-project/pull/85610.diff

4 Files Affected:

  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp (+36-1)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h (+1)
  • (added) llvm/test/CodeGen/AArch64/GlobalISel/legalize-trunc.mir (+24)
  • (modified) llvm/test/CodeGen/AArch64/bitcast.ll (+17-11)
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 36adada2796531..04a228cf522ce7 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -628,7 +628,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
         return DstTy.isVector() && SrcTy.getSizeInBits() > 128 &&
                DstTy.getScalarSizeInBits() * 2 <= SrcTy.getScalarSizeInBits();
       })
-
+      .customIf(all(typeInSet(0, {v4s8}),
+                    typeInSet(1, {v4s16})))
       .alwaysLegal();
 
   getActionDefinitionsBuilder(G_SEXT_INREG)
@@ -1262,11 +1263,45 @@ bool AArch64LegalizerInfo::legalizeCustom(
     return legalizeDynStackAlloc(MI, Helper);
   case TargetOpcode::G_PREFETCH:
     return legalizePrefetch(MI, Helper);
+  case TargetOpcode::G_TRUNC:
+    return legalizeTrunc(MI, Helper);
   }
 
   llvm_unreachable("expected switch to return");
 }
 
+bool AArch64LegalizerInfo::legalizeTrunc(MachineInstr &MI,
+                                         LegalizerHelper &Helper) const {
+  assert(MI.getOpcode() == TargetOpcode::G_TRUNC);
+
+  // Handle <4 x s8> = G_TRUNC <4 x s16> by widening to <8 x s16> first.
+  // So the sequence is:
+  // %orig_val(<4 x s16>) = ...
+  // %wide = G_MERGE_VALUES %orig_val, %undef:_(<4 x s16>)
+  // %wide_trunc:_(<8 x s8>) = G_TRUNC %wide
+  // %bc:_(<2 x s32>) = G_BITCAST %wide_trunc
+  // %eve:_(s32) = G_EXTRACT_VECTOR_ELT %bc, 0
+  // %final:_(<4 x s8>) = G_BITCAST %eve
+
+  MachineIRBuilder &MIB = Helper.MIRBuilder;
+
+  auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
+  assert(DstTy == LLT::fixed_vector(4, LLT::scalar(8)) &&
+         SrcTy == LLT::fixed_vector(4, LLT::scalar(16)));
+
+  auto WideTy = LLT::fixed_vector(8, LLT::scalar(16));
+  auto Undef = MIB.buildUndef(SrcTy);
+  auto Merge = MIB.buildMergeLikeInstr(WideTy, {SrcReg, Undef});
+  auto Trunc = MIB.buildTrunc(LLT::fixed_vector(8, LLT::scalar(8)), Merge);
+  auto BC = MIB.buildBitcast(LLT::fixed_vector(2, LLT::scalar(32)), Trunc);
+  auto Extract = MIB.buildExtractVectorElement(
+      LLT::scalar(32), BC, MIB.buildConstant(LLT::scalar(32), 0));
+  MIB.buildBitcast(DstReg, Extract);
+
+  MI.eraseFromParent();
+  return true;
+}
+
 bool AArch64LegalizerInfo::legalizeFunnelShift(MachineInstr &MI,
                                                MachineRegisterInfo &MRI,
                                                MachineIRBuilder &MIRBuilder,
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
index b69d9b015bd2b3..e9d8b54de9ef70 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
@@ -64,6 +64,7 @@ class AArch64LegalizerInfo : public LegalizerInfo {
                                 LegalizerHelper &Helper) const;
   bool legalizeDynStackAlloc(MachineInstr &MI, LegalizerHelper &Helper) const;
   bool legalizePrefetch(MachineInstr &MI, LegalizerHelper &Helper) const;
+  bool legalizeTrunc(MachineInstr &MI, LegalizerHelper &Helper) const;
   const AArch64Subtarget *ST;
 };
 } // End llvm namespace.
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalize-trunc.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalize-trunc.mir
new file mode 100644
index 00000000000000..e3e12558c2242a
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalize-trunc.mir
@@ -0,0 +1,24 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
+# RUN: llc -mtriple=aarch64  -run-pass=legalizer %s -o - | FileCheck %s
+---
+name:            trunc_v4s8_v4s16
+body:             |
+  bb.1:
+    liveins: $x0
+    ; CHECK-LABEL: name: trunc_v4s8_v4s16
+    ; CHECK: liveins: $x0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: %in:_(<4 x s16>) = COPY $x0
+    ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<4 x s16>) = G_IMPLICIT_DEF
+    ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<8 x s16>) = G_CONCAT_VECTORS %in(<4 x s16>), [[DEF]](<4 x s16>)
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(<8 x s8>) = G_TRUNC [[CONCAT_VECTORS]](<8 x s16>)
+    ; CHECK-NEXT: [[BITCAST:%[0-9]+]]:_(<2 x s32>) = G_BITCAST [[TRUNC]](<8 x s8>)
+    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 0
+    ; CHECK-NEXT: [[EVEC:%[0-9]+]]:_(s32) = G_EXTRACT_VECTOR_ELT [[BITCAST]](<2 x s32>), [[C]](s64)
+    ; CHECK-NEXT: %trunc:_(<4 x s8>) = G_BITCAST [[EVEC]](s32)
+    ; CHECK-NEXT: $s0 = COPY %trunc(<4 x s8>)
+    %in:_(<4 x s16>) = COPY $x0
+    %trunc:_(<4 x s8>) = G_TRUNC %in
+    $s0 = COPY %trunc
+
+...
diff --git a/llvm/test/CodeGen/AArch64/bitcast.ll b/llvm/test/CodeGen/AArch64/bitcast.ll
index 9ebd570e687a01..2b7065fe450617 100644
--- a/llvm/test/CodeGen/AArch64/bitcast.ll
+++ b/llvm/test/CodeGen/AArch64/bitcast.ll
@@ -4,8 +4,7 @@
 
 ; PR23065: SCALAR_TO_VECTOR implies the top elements 1 to N-1 of the N-element vector are undefined.
 
-; CHECK-GI:         warning: Instruction selection used fallback path for bitcast_v4i8_i32
-; CHECK-GI-NEXT:    warning: Instruction selection used fallback path for bitcast_i32_v4i8
+; CHECK-GI:         warning: Instruction selection used fallback path for bitcast_i32_v4i8
 ; CHECK-GI-NEXT:    warning: Instruction selection used fallback path for bitcast_v2i16_i32
 ; CHECK-GI-NEXT:    warning: Instruction selection used fallback path for bitcast_i32_v2i16
 ; CHECK-GI-NEXT:    warning: Instruction selection used fallback path for bitcast_v2i16_v4i8
@@ -54,15 +53,22 @@ define <4 x i16> @foo2(<2 x i32> %a) {
 ; ===== To and From Scalar Types =====
 
 define i32 @bitcast_v4i8_i32(<4 x i8> %a, <4 x i8> %b){
-; CHECK-LABEL: bitcast_v4i8_i32:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    sub sp, sp, #16
-; CHECK-NEXT:    .cfi_def_cfa_offset 16
-; CHECK-NEXT:    add v0.4h, v0.4h, v1.4h
-; CHECK-NEXT:    uzp1 v0.8b, v0.8b, v0.8b
-; CHECK-NEXT:    fmov w0, s0
-; CHECK-NEXT:    add sp, sp, #16
-; CHECK-NEXT:    ret
+; CHECK-SD-LABEL: bitcast_v4i8_i32:
+; CHECK-SD:       // %bb.0:
+; CHECK-SD-NEXT:    sub sp, sp, #16
+; CHECK-SD-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-SD-NEXT:    add v0.4h, v0.4h, v1.4h
+; CHECK-SD-NEXT:    uzp1 v0.8b, v0.8b, v0.8b
+; CHECK-SD-NEXT:    fmov w0, s0
+; CHECK-SD-NEXT:    add sp, sp, #16
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: bitcast_v4i8_i32:
+; CHECK-GI:       // %bb.0:
+; CHECK-GI-NEXT:    add v0.4h, v0.4h, v1.4h
+; CHECK-GI-NEXT:    uzp1 v0.8b, v0.8b, v0.8b
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
   %c = add <4 x i8> %a, %b
   %d = bitcast <4 x i8> %c to i32
   ret i32 %d

Copy link

github-actions bot commented Mar 18, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

auto Trunc = MIB.buildTrunc(LLT::fixed_vector(8, LLT::scalar(8)), Merge);
auto BC = MIB.buildBitcast(LLT::fixed_vector(2, LLT::scalar(32)), Trunc);
auto Extract = MIB.buildExtractVectorElement(
LLT::scalar(32), BC, MIB.buildConstant(LLT::scalar(32), 0));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The index should be s64, see mir test.

@tschuett
Copy link
Member

It looks complicated. InstCombine combines bitcasts on the vector register of extractVectorElement,

Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) {

@davemgreen
Copy link
Collaborator

Hi. I had been looking at v4i8 truncate again recently, and had assumed that we would moreElements them. It had some inefficiencies that was stopping me from putting up the patch though, although my attempts to fix them had only led to more problems so far.

We see a lot of fallbacks these days due to <4 x s8> types appearing in truncates, and these seem to be commonly being used by the new load/store bitcasting -> s32 rule.

We can keep that load/store rule if we make sure to handle the truncates properly, and we adopt a similar strategy for this custom action as in DAG lowering's LowerTruncateVectorStore(). That is, we first widen the input <4 x s16> to <8 x s16>, so we can generate a legal G_TRUNC to <8 x s8>, and from there extract the final 32 bit sized value.

My understanding was that until we fixed v4i8 load/store recently, these would have fallen back due to the load/store? Sounds like we are moving in the right direction.

Whatever we do it should ideally handle other small types too - v2i8 and v2i16. I can put up the moreElements version if it is useful, I was hoping that the extra merge/unmerges introduced could all be removed. My worry with extra bitcast is that they get in the way of optimizations (especially under BE), but with enough combines either can probably be made to work cleanly I would hope.

davemgreen added a commit to davemgreen/llvm-project that referenced this pull request Mar 18, 2024
This is an alternative to llvm#85610, that moreElement's small G_TRUNC vectors to
widen the vectors. It needs to disable one of the existing Unmerge(Trunc(..))
combines, and some of the code is not as optimal as it could be. I believe with
some extra optimizations it could look better (I was thinking combining
trunc(buildvector) -> buildvector and possibly improving buildvector lowering
by generating insert_vector_element earlier).
@davemgreen
Copy link
Collaborator

I've put up #85625 that shows what it does. I'm not sure which is better in the long run between that method and this one, but I think it would be better to add support for the other types if we do use bitcasts.

@aemerson
Copy link
Contributor Author

I think your approach is fine as long as restricting the combine didn't pessimism too much, which it didn't look like it did.

@aemerson aemerson closed this Mar 18, 2024
@aemerson aemerson deleted the legalize-v4s8-trunc branch March 18, 2024 16:06
aemerson pushed a commit that referenced this pull request Mar 18, 2024
This is an alternative to #85610, that moreElement's small G_TRUNC
vectors to widen the vectors. It needs to disable one of the existing
Unmerge(Trunc(..)) combines, and some of the code is not as optimal as
it could be. I believe with some extra optimizations it could look
better (I was thinking combining trunc(buildvector) -> buildvector and
possibly improving buildvector lowering by generating
insert_vector_element earlier).
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
This is an alternative to llvm#85610, that moreElement's small G_TRUNC
vectors to widen the vectors. It needs to disable one of the existing
Unmerge(Trunc(..)) combines, and some of the code is not as optimal as
it could be. I believe with some extra optimizations it could look
better (I was thinking combining trunc(buildvector) -> buildvector and
possibly improving buildvector lowering by generating
insert_vector_element earlier).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants