-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[AArch64] Combine PTEST_FIRST(PTRUE, CONCAT(A, B)) -> PTEST_FIRST(PTRUE, A) #161384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] Combine PTEST_FIRST(PTRUE, CONCAT(A, B)) -> PTEST_FIRST(PTRUE, A) #161384
Conversation
…UE, A) When input to a ptest_first is a vector concat and the mask is all active, performPTestFirstCombine returns a ptest_first using the first operand of the concat, looking through any reinterpret casts added by getPTest. This allows optimizePTestInstr to later remove the ptest when the first operand is a flag setting instruction such as whilelo.
@llvm/pr-subscribers-backend-aarch64 Author: Kerry McLaughlin (kmclaughlin-arm) ChangesWhen the input to ptest_first is a vector concat and the mask is all active, This allows optimizePTestInstr to later remove the ptest when the first Full diff: https://github.com/llvm/llvm-project/pull/161384.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 45f52352d45fd..1f641cd8c14fc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20370,7 +20370,7 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
}
static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
- AArch64CC::CondCode Cond);
+ AArch64CC::CondCode Cond, bool EmitCSel = true);
static bool isPredicateCCSettingOp(SDValue N) {
if ((N.getOpcode() == ISD::SETCC) ||
@@ -20495,6 +20495,7 @@ static SDValue
performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
+
if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget))
return Res;
if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
@@ -22535,7 +22536,7 @@ static SDValue tryConvertSVEWideCompare(SDNode *N, ISD::CondCode CC,
}
static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
- AArch64CC::CondCode Cond) {
+ AArch64CC::CondCode Cond, bool EmitCSel) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDLoc DL(Op);
@@ -22568,6 +22569,8 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
// Set condition code (CC) flags.
SDValue Test = DAG.getNode(PTest, DL, MVT::i32, Pg, Op);
+ if (!EmitCSel)
+ return Test;
// Convert CC to integer based on requested condition.
// NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
@@ -27519,6 +27522,37 @@ static SDValue performMULLCombine(SDNode *N,
return SDValue();
}
+static SDValue performPTestFirstCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ SelectionDAG &DAG) {
+ if (DCI.isBeforeLegalize())
+ return SDValue();
+
+ SDLoc DL(N);
+ auto Mask = N->getOperand(0);
+ auto Pred = N->getOperand(1);
+
+ if (Mask->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+ Mask = Mask->getOperand(0);
+
+ if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+ Pred = Pred->getOperand(0);
+
+ if (Pred->getValueType(0).getVectorElementType() != MVT::i1 ||
+ !isAllActivePredicate(DAG, Mask))
+ return SDValue();
+
+ if (Pred->getOpcode() == ISD::CONCAT_VECTORS) {
+ Pred = Pred->getOperand(0);
+ SDValue Mask = DAG.getSplatVector(Pred->getValueType(0), DL,
+ DAG.getAllOnesConstant(DL, MVT::i64));
+ return getPTest(DAG, N->getValueType(0), Mask, Pred,
+ AArch64CC::FIRST_ACTIVE, /* EmitCSel */ false);
+ }
+
+ return SDValue();
+}
+
static SDValue
performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
@@ -27875,6 +27909,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case AArch64ISD::UMULL:
case AArch64ISD::PMULL:
return performMULLCombine(N, DCI, DAG);
+ case AArch64ISD::PTEST_FIRST:
+ return performPTestFirstCombine(N, DCI, DAG);
case ISD::INTRINSIC_VOID:
case ISD::INTRINSIC_W_CHAIN:
switch (N->getConstantOperandVal(1)) {
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index b89f55188b0f2..e2c861b40e706 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -327,9 +327,6 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT: ptrue p2.b
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.b, p0.b, p1.b
-; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
@@ -368,9 +365,6 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT: ptrue p2.h
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.h, p0.h, p1.h
-; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
@@ -413,14 +407,9 @@ define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p0.h, p1.h
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p2.h, p3.h
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p4.b, p5.b
-; CHECK-SVE2p1-SME2-NEXT: ptrue p5.b
-; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
@@ -463,14 +452,9 @@ define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p0.s, p1.s
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p2.s, p3.s
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p4.h, p5.h
-; CHECK-SVE2p1-SME2-NEXT: ptrue p5.h
-; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
|
return getPTest(DAG, N->getValueType(0), Mask, Pred, | ||
AArch64CC::FIRST_ACTIVE, /* EmitCSel */ false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doubt you're using much in getPTest
to warrant the EmitCSel
change. Having already proven the DAG is testing the first lane of the predicate[1] there's no complexity related to "hidden lanes" and so this can be simplified to:
Pred = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, nxv16i1, Pred)
Mask = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, nxv16i1, Mask)
return DAG.getNode(AArch64ISD::PTEST_FIRST, DL, N->getValueType(0), Pred, Mask);
[1] Is worth adding commentary detailing the "we know we are testing the first lane" requirement because that's key to the combine.
if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST) | ||
Pred = Pred->getOperand(0); | ||
|
||
if (Pred->getValueType(0).getVectorElementType() != MVT::i1 || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need the element type check because AArch64ISD::REINTERPRET_CAST
is only allowed to change the element count.
SDValue Mask = DAG.getSplatVector(Pred->getValueType(0), DL, | ||
DAG.getAllOnesConstant(DL, MVT::i64)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You've already proven the existing Mask
does what you need so should be reusable. Perhaps it's worth adding an isLane0KnownOne
helper function, then you don't need to strip reinterpret from Mask
and can then reuse it directly.
- Extend canRemovePTestInstr for PTEST_PP_FIRST - Remove getVectorElementType() != MVT::i1 check for Mask
if (!isLane0KnownActive(Mask)) | ||
return SDValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps move this before looking through Pred's reinterpret_cast because this is what makes that code logically safe.
getElementSizeForOpcode(PredOpcode)) | ||
return PredOpcode; | ||
|
||
if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps add a comment to match the other cases?
- Add comment to changes in canRemovePTestInstr
…UE, A) (llvm#161384) When the input to ptest_first is a vector concat and the mask is all active, performPTestFirstCombine returns a ptest_first using the first operand of the concat, looking through any reinterpret casts. This allows optimizePTestInstr to later remove the ptest when the first operand is a flag setting instruction such as whilelo.
When the input to ptest_first is a vector concat and the mask is all active,
performPTestFirstCombine returns a ptest_first using the first operand
of the concat, looking through any reinterpret casts.
This allows optimizePTestInstr to later remove the ptest when the first
operand is a flag setting instruction such as whilelo.