Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] Combine getActiveLaneMask with vector_extract #81139

Merged
merged 4 commits into from
May 10, 2024

Conversation

momchil-velikov
Copy link
Collaborator

... into a whilelo instruction with a pair of predicate registers.

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 8, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Momchil Velikov (momchil-velikov)

Changes

... into a whilelo instruction with a pair of predicate registers.


Full diff: https://github.com/llvm/llvm-project/pull/81139.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+86-35)
  • (modified) llvm/test/CodeGen/AArch64/active_lane_mask.ll (+1)
  • (added) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+90)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8573939b04389f..4405e8d3f91df2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1813,8 +1813,8 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
 
 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
                                                           EVT OpVT) const {
-  // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
-  if (!Subtarget->hasSVE())
+  // Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
+  if (!Subtarget->hasSVEorSME())
     return true;
 
   // We can only support legal predicate result types. We can use the SVE
@@ -20004,47 +20004,98 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
   return SDValue();
 }
 
-static SDValue performIntrinsicCombine(SDNode *N,
-                                       TargetLowering::DAGCombinerInfo &DCI,
-                                       const AArch64Subtarget *Subtarget) {
+static SDValue tryCombineGetActiveLaneMask(SDNode *N,
+                                           TargetLowering::DAGCombinerInfo &DCI,
+                                           const AArch64Subtarget *Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
-  unsigned IID = getIntrinsicID(N);
-  switch (IID) {
-  default:
-    break;
-  case Intrinsic::get_active_lane_mask: {
-    SDValue Res = SDValue();
-    EVT VT = N->getValueType(0);
-    if (VT.isFixedLengthVector()) {
-      // We can use the SVE whilelo instruction to lower this intrinsic by
-      // creating the appropriate sequence of scalable vector operations and
-      // then extracting a fixed-width subvector from the scalable vector.
+  EVT VT = N->getValueType(0);
+  if (VT.isFixedLengthVector()) {
+    // We can use the SVE whilelo instruction to lower this intrinsic by
+    // creating the appropriate sequence of scalable vector operations and
+    // then extracting a fixed-width subvector from the scalable vector.
+    SDLoc DL(N);
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
 
-      SDLoc DL(N);
-      SDValue ID =
-          DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
+    EVT WhileVT =
+        EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                         ElementCount::getScalable(VT.getVectorNumElements()));
 
-      EVT WhileVT = EVT::getVectorVT(
-          *DAG.getContext(), MVT::i1,
-          ElementCount::getScalable(VT.getVectorNumElements()));
+    // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
+    EVT PromVT = getPromotedVTForPredicate(WhileVT);
 
-      // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
-      EVT PromVT = getPromotedVTForPredicate(WhileVT);
+    // Get the fixed-width equivalent of PromVT for extraction.
+    EVT ExtVT =
+        EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
+                         VT.getVectorElementCount());
 
-      // Get the fixed-width equivalent of PromVT for extraction.
-      EVT ExtVT =
-          EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
-                           VT.getVectorElementCount());
+    SDValue Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
+                              N->getOperand(1), N->getOperand(2));
+    Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
+    Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
+                      DAG.getConstant(0, DL, MVT::i64));
+    Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
 
-      Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
-                        N->getOperand(1), N->getOperand(2));
-      Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
-      Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
-                        DAG.getConstant(0, DL, MVT::i64));
-      Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
-    }
     return Res;
   }
+
+  if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+    return SDValue();
+
+  if (!N->hasNUsesOfValue(2, 0))
+    return SDValue();
+
+  auto It = N->use_begin();
+  SDNode *Lo = *It++;
+  SDNode *Hi = *It;
+
+  const uint64_t HalfSize = VT.getVectorMinNumElements() / 2;
+  uint64_t OffLo, OffHi;
+  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
+      (OffLo != 0 && OffLo != HalfSize) ||
+      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Hi->getOperand(1).getNode(), OffHi) ||
+      (OffHi != 0 && OffHi != HalfSize))
+    return SDValue();
+
+  if (OffLo > OffHi) {
+    std::swap(Lo, Hi);
+    std::swap(OffLo, OffHi);
+  }
+
+  if (OffLo != 0 || OffHi != HalfSize)
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+  SDValue Idx = N->getOperand(1);
+  SDValue TC = N->getOperand(2);
+  if (Idx.getValueType() != MVT::i64) {
+    Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+    TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+  }
+  auto R =
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+                  {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+  DCI.CombineTo(Lo, R.getValue(0));
+  DCI.CombineTo(Hi, R.getValue(1));
+
+  return SDValue(N, 0);
+}
+
+static SDValue performIntrinsicCombine(SDNode *N,
+                                       TargetLowering::DAGCombinerInfo &DCI,
+                                       const AArch64Subtarget *Subtarget) {
+  SelectionDAG &DAG = DCI.DAG;
+  unsigned IID = getIntrinsicID(N);
+  switch (IID) {
+  default:
+    break;
+  case Intrinsic::get_active_lane_mask:
+    return tryCombineGetActiveLaneMask(N, DCI, Subtarget);
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/test/CodeGen/AArch64/active_lane_mask.ll b/llvm/test/CodeGen/AArch64/active_lane_mask.ll
index a65c5d66677946..6a509b5f3afcae 100644
--- a/llvm/test/CodeGen/AArch64/active_lane_mask.ll
+++ b/llvm/test/CodeGen/AArch64/active_lane_mask.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme < %s | FileCheck %s
 
 ; == Scalable ==
 
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
new file mode 100644
index 00000000000000..e32dec91d2ff40
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -0,0 +1,90 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mattr=+sve    < %s | FileCheck %s -check-prefix CHECK-SVE
+; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1
+; RUN: llc -mattr=+sme2   < %s | FileCheck %s -check-prefix CHECK-SME2
+target triple = "aarch64-linux"
+
+; Test combining of getActiveLaneMask with a pair of extract_vector operations.
+
+define void @f0(i32 %i, i32 %n, ptr %p0, ptr %p1) #0 {
+; CHECK-SVE-LABEL: f0:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p0.b, w0, w1
+; CHECK-SVE-NEXT:    ptrue p1.h
+; CHECK-SVE-NEXT:    punpklo p2.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p0.h, p0.b
+; CHECK-SVE-NEXT:    and p2.b, p2/z, p2.b, p1.b
+; CHECK-SVE-NEXT:    and p0.b, p0/z, p0.b, p1.b
+; CHECK-SVE-NEXT:    str p2, [x2]
+; CHECK-SVE-NEXT:    str p0, [x3]
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-LABEL: f0:
+; CHECK-SVE2p1:       // %bb.0:
+; CHECK-SVE2p1-NEXT:    mov w8, w1
+; CHECK-SVE2p1-NEXT:    mov w9, w0
+; CHECK-SVE2p1-NEXT:    whilelo { p0.h, p1.h }, x9, x8
+; CHECK-SVE2p1-NEXT:    str p0, [x2]
+; CHECK-SVE2p1-NEXT:    str p1, [x3]
+; CHECK-SVE2p1-NEXT:    ret
+;
+; CHECK-SME2-LABEL: f0:
+; CHECK-SME2:       // %bb.0:
+; CHECK-SME2-NEXT:    mov w8, w1
+; CHECK-SME2-NEXT:    mov w9, w0
+; CHECK-SME2-NEXT:    whilelo { p0.h, p1.h }, x9, x8
+; CHECK-SME2-NEXT:    str p0, [x2]
+; CHECK-SME2-NEXT:    str p1, [x3]
+; CHECK-SME2-NEXT:    ret
+    %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n)
+    %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+    %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+    %pg0 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v0)
+    %pg1 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v1)
+    store <vscale x 16 x i1> %pg0, ptr %p0
+    store <vscale x 16 x i1> %pg1, ptr %p1
+    ret void
+}
+
+define void @f1(i64 %i, i64 %n, ptr %p0, ptr %p1) #0 {
+; CHECK-SVE-LABEL: f1:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p0.b, x0, x1
+; CHECK-SVE-NEXT:    ptrue p1.h
+; CHECK-SVE-NEXT:    punpklo p2.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p0.h, p0.b
+; CHECK-SVE-NEXT:    and p2.b, p2/z, p2.b, p1.b
+; CHECK-SVE-NEXT:    and p0.b, p0/z, p0.b, p1.b
+; CHECK-SVE-NEXT:    str p2, [x2]
+; CHECK-SVE-NEXT:    str p0, [x3]
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-LABEL: f1:
+; CHECK-SVE2p1:       // %bb.0:
+; CHECK-SVE2p1-NEXT:    whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SVE2p1-NEXT:    str p0, [x2]
+; CHECK-SVE2p1-NEXT:    str p1, [x3]
+; CHECK-SVE2p1-NEXT:    ret
+;
+; CHECK-SME2-LABEL: f1:
+; CHECK-SME2:       // %bb.0:
+; CHECK-SME2-NEXT:    whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SME2-NEXT:    str p0, [x2]
+; CHECK-SME2-NEXT:    str p1, [x3]
+; CHECK-SME2-NEXT:    ret
+    %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 %i, i64 %n)
+    %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+    %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+    %pg0 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v0)
+    %pg1 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v1)
+    store <vscale x 16 x i1> %pg0, ptr %p0
+    store <vscale x 16 x i1> %pg1, ptr %p1
+    ret void
+}
+
+declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1>)
+declare <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1>, i64)
+declare <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32, i32)
+declare <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64, i64)
+
+attributes #0 = { nounwind }

Copy link
Contributor

@david-arm david-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a nice codegen improvement! I just had a few minor comments ...

llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll Outdated Show resolved Hide resolved
@momchil-velikov
Copy link
Collaborator Author

Ping?

@paulwalker-arm
Copy link
Collaborator

paulwalker-arm commented Mar 11, 2024

Sorry for the delay. To me it looks like you're doing type legalisation as a DAG combine, which while a similar crime is being committed for fixed length types I don't think you should repeat because it will prevent any legitimate DAG combines. Custom type legalisation really belongs in ReplaceNodeResults.

@momchil-velikov
Copy link
Collaborator Author

Sorry for the delay. To me it looks like you're doing type legalisation as a DAG combine, which while a similar crime is being committed for fixed length types I don't think you should repeat because it will prevent any legitimate DAG combines. Custom type legalisation really belongs in ReplaceNodeResults.

I didn't think I was doing type legalisation as I wasn't changing any types. Anyway, certainly I meant for this transform to be a plain
simple DAG combine on a type-legalised DAG, so I moved it to a later phase (by pivoting on aarch64_sve_whilelo instead of get_active_lane_mask).

Copy link
Contributor

@david-arm david-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not reviewed the tests yet, but I have a few comments on the code!

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@paulwalker-arm paulwalker-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last comment but otherwise looks good.

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@paulwalker-arm paulwalker-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more suggestion but am happy otherwise.

@momchil-velikov momchil-velikov merged commit 88da875 into llvm:main May 10, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants