-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] Combine getActiveLaneMask with vector_extract #81139
[AArch64] Combine getActiveLaneMask with vector_extract #81139
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Momchil Velikov (momchil-velikov) Changes... into a Full diff: https://github.com/llvm/llvm-project/pull/81139.diff 3 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8573939b04389f..4405e8d3f91df2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1813,8 +1813,8 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
EVT OpVT) const {
- // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
- if (!Subtarget->hasSVE())
+ // Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
+ if (!Subtarget->hasSVEorSME())
return true;
// We can only support legal predicate result types. We can use the SVE
@@ -20004,47 +20004,98 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
return SDValue();
}
-static SDValue performIntrinsicCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI,
- const AArch64Subtarget *Subtarget) {
+static SDValue tryCombineGetActiveLaneMask(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const AArch64Subtarget *Subtarget) {
SelectionDAG &DAG = DCI.DAG;
- unsigned IID = getIntrinsicID(N);
- switch (IID) {
- default:
- break;
- case Intrinsic::get_active_lane_mask: {
- SDValue Res = SDValue();
- EVT VT = N->getValueType(0);
- if (VT.isFixedLengthVector()) {
- // We can use the SVE whilelo instruction to lower this intrinsic by
- // creating the appropriate sequence of scalable vector operations and
- // then extracting a fixed-width subvector from the scalable vector.
+ EVT VT = N->getValueType(0);
+ if (VT.isFixedLengthVector()) {
+ // We can use the SVE whilelo instruction to lower this intrinsic by
+ // creating the appropriate sequence of scalable vector operations and
+ // then extracting a fixed-width subvector from the scalable vector.
+ SDLoc DL(N);
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
- SDLoc DL(N);
- SDValue ID =
- DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
+ EVT WhileVT =
+ EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+ ElementCount::getScalable(VT.getVectorNumElements()));
- EVT WhileVT = EVT::getVectorVT(
- *DAG.getContext(), MVT::i1,
- ElementCount::getScalable(VT.getVectorNumElements()));
+ // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
+ EVT PromVT = getPromotedVTForPredicate(WhileVT);
- // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
- EVT PromVT = getPromotedVTForPredicate(WhileVT);
+ // Get the fixed-width equivalent of PromVT for extraction.
+ EVT ExtVT =
+ EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
+ VT.getVectorElementCount());
- // Get the fixed-width equivalent of PromVT for extraction.
- EVT ExtVT =
- EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
- VT.getVectorElementCount());
+ SDValue Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
+ N->getOperand(1), N->getOperand(2));
+ Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
+ Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
+ DAG.getConstant(0, DL, MVT::i64));
+ Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
- Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
- N->getOperand(1), N->getOperand(2));
- Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
- Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
- DAG.getConstant(0, DL, MVT::i64));
- Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
- }
return Res;
}
+
+ if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+ return SDValue();
+
+ if (!N->hasNUsesOfValue(2, 0))
+ return SDValue();
+
+ auto It = N->use_begin();
+ SDNode *Lo = *It++;
+ SDNode *Hi = *It;
+
+ const uint64_t HalfSize = VT.getVectorMinNumElements() / 2;
+ uint64_t OffLo, OffHi;
+ if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+ !isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
+ (OffLo != 0 && OffLo != HalfSize) ||
+ Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+ !isIntImmediate(Hi->getOperand(1).getNode(), OffHi) ||
+ (OffHi != 0 && OffHi != HalfSize))
+ return SDValue();
+
+ if (OffLo > OffHi) {
+ std::swap(Lo, Hi);
+ std::swap(OffLo, OffHi);
+ }
+
+ if (OffLo != 0 || OffHi != HalfSize)
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+ SDValue Idx = N->getOperand(1);
+ SDValue TC = N->getOperand(2);
+ if (Idx.getValueType() != MVT::i64) {
+ Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+ TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+ }
+ auto R =
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+ {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+ DCI.CombineTo(Lo, R.getValue(0));
+ DCI.CombineTo(Hi, R.getValue(1));
+
+ return SDValue(N, 0);
+}
+
+static SDValue performIntrinsicCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const AArch64Subtarget *Subtarget) {
+ SelectionDAG &DAG = DCI.DAG;
+ unsigned IID = getIntrinsicID(N);
+ switch (IID) {
+ default:
+ break;
+ case Intrinsic::get_active_lane_mask:
+ return tryCombineGetActiveLaneMask(N, DCI, Subtarget);
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/test/CodeGen/AArch64/active_lane_mask.ll b/llvm/test/CodeGen/AArch64/active_lane_mask.ll
index a65c5d66677946..6a509b5f3afcae 100644
--- a/llvm/test/CodeGen/AArch64/active_lane_mask.ll
+++ b/llvm/test/CodeGen/AArch64/active_lane_mask.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme < %s | FileCheck %s
; == Scalable ==
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
new file mode 100644
index 00000000000000..e32dec91d2ff40
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -0,0 +1,90 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mattr=+sve < %s | FileCheck %s -check-prefix CHECK-SVE
+; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1
+; RUN: llc -mattr=+sme2 < %s | FileCheck %s -check-prefix CHECK-SME2
+target triple = "aarch64-linux"
+
+; Test combining of getActiveLaneMask with a pair of extract_vector operations.
+
+define void @f0(i32 %i, i32 %n, ptr %p0, ptr %p1) #0 {
+; CHECK-SVE-LABEL: f0:
+; CHECK-SVE: // %bb.0:
+; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
+; CHECK-SVE-NEXT: ptrue p1.h
+; CHECK-SVE-NEXT: punpklo p2.h, p0.b
+; CHECK-SVE-NEXT: punpkhi p0.h, p0.b
+; CHECK-SVE-NEXT: and p2.b, p2/z, p2.b, p1.b
+; CHECK-SVE-NEXT: and p0.b, p0/z, p0.b, p1.b
+; CHECK-SVE-NEXT: str p2, [x2]
+; CHECK-SVE-NEXT: str p0, [x3]
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SVE2p1-LABEL: f0:
+; CHECK-SVE2p1: // %bb.0:
+; CHECK-SVE2p1-NEXT: mov w8, w1
+; CHECK-SVE2p1-NEXT: mov w9, w0
+; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x9, x8
+; CHECK-SVE2p1-NEXT: str p0, [x2]
+; CHECK-SVE2p1-NEXT: str p1, [x3]
+; CHECK-SVE2p1-NEXT: ret
+;
+; CHECK-SME2-LABEL: f0:
+; CHECK-SME2: // %bb.0:
+; CHECK-SME2-NEXT: mov w8, w1
+; CHECK-SME2-NEXT: mov w9, w0
+; CHECK-SME2-NEXT: whilelo { p0.h, p1.h }, x9, x8
+; CHECK-SME2-NEXT: str p0, [x2]
+; CHECK-SME2-NEXT: str p1, [x3]
+; CHECK-SME2-NEXT: ret
+ %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n)
+ %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+ %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+ %pg0 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v0)
+ %pg1 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v1)
+ store <vscale x 16 x i1> %pg0, ptr %p0
+ store <vscale x 16 x i1> %pg1, ptr %p1
+ ret void
+}
+
+define void @f1(i64 %i, i64 %n, ptr %p0, ptr %p1) #0 {
+; CHECK-SVE-LABEL: f1:
+; CHECK-SVE: // %bb.0:
+; CHECK-SVE-NEXT: whilelo p0.b, x0, x1
+; CHECK-SVE-NEXT: ptrue p1.h
+; CHECK-SVE-NEXT: punpklo p2.h, p0.b
+; CHECK-SVE-NEXT: punpkhi p0.h, p0.b
+; CHECK-SVE-NEXT: and p2.b, p2/z, p2.b, p1.b
+; CHECK-SVE-NEXT: and p0.b, p0/z, p0.b, p1.b
+; CHECK-SVE-NEXT: str p2, [x2]
+; CHECK-SVE-NEXT: str p0, [x3]
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SVE2p1-LABEL: f1:
+; CHECK-SVE2p1: // %bb.0:
+; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SVE2p1-NEXT: str p0, [x2]
+; CHECK-SVE2p1-NEXT: str p1, [x3]
+; CHECK-SVE2p1-NEXT: ret
+;
+; CHECK-SME2-LABEL: f1:
+; CHECK-SME2: // %bb.0:
+; CHECK-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SME2-NEXT: str p0, [x2]
+; CHECK-SME2-NEXT: str p1, [x3]
+; CHECK-SME2-NEXT: ret
+ %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 %i, i64 %n)
+ %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+ %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+ %pg0 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v0)
+ %pg1 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v1)
+ store <vscale x 16 x i1> %pg0, ptr %p0
+ store <vscale x 16 x i1> %pg1, ptr %p1
+ ret void
+}
+
+declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1>)
+declare <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1>, i64)
+declare <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32, i32)
+declare <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64, i64)
+
+attributes #0 = { nounwind }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a nice codegen improvement! I just had a few minor comments ...
af024dd
to
28cdbbe
Compare
Ping? |
Sorry for the delay. To me it looks like you're doing type legalisation as a DAG combine, which while a similar crime is being committed for fixed length types I don't think you should repeat because it will prevent any legitimate DAG combines. Custom type legalisation really belongs in |
28cdbbe
to
3373074
Compare
I didn't think I was doing type legalisation as I wasn't changing any types. Anyway, certainly I meant for this transform to be a plain |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've not reviewed the tests yet, but I have a few comments on the code!
3373074
to
c4dbea5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One last comment but otherwise looks good.
c4dbea5
to
cb97c37
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more suggestion but am happy otherwise.
db12a04
to
56a6f84
Compare
... into a `whilelo` instruction with a pair of predicate registers.
56a6f84
to
ced70f6
Compare
... into a
whilelo
instruction with a pair of predicate registers.