Skip to content

Conversation

kmclaughlin-arm
Copy link
Contributor

When the input to ptest_first is a vector concat and the mask is all active,
performPTestFirstCombine returns a ptest_first using the first operand
of the concat, looking through any reinterpret casts.

This allows optimizePTestInstr to later remove the ptest when the first
operand is a flag setting instruction such as whilelo.

…UE, A)

When input to a ptest_first is a vector concat and the mask is all
active, performPTestFirstCombine returns a ptest_first using the
first operand of the concat, looking through any reinterpret casts
added by getPTest.

This allows optimizePTestInstr to later remove the ptest when the
first operand is a flag setting instruction such as whilelo.
@llvmbot
Copy link
Member

llvmbot commented Sep 30, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

Changes

When the input to ptest_first is a vector concat and the mask is all active,
performPTestFirstCombine returns a ptest_first using the first operand
of the concat, looking through any reinterpret casts.

This allows optimizePTestInstr to later remove the ptest when the first
operand is a flag setting instruction such as whilelo.


Full diff: https://github.com/llvm/llvm-project/pull/161384.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+38-2)
  • (modified) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+2-18)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 45f52352d45fd..1f641cd8c14fc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20370,7 +20370,7 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
 }
 
 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
-                        AArch64CC::CondCode Cond);
+                        AArch64CC::CondCode Cond, bool EmitCSel = true);
 
 static bool isPredicateCCSettingOp(SDValue N) {
   if ((N.getOpcode() == ISD::SETCC) ||
@@ -20495,6 +20495,7 @@ static SDValue
 performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                                const AArch64Subtarget *Subtarget) {
   assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
+
   if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget))
     return Res;
   if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
@@ -22535,7 +22536,7 @@ static SDValue tryConvertSVEWideCompare(SDNode *N, ISD::CondCode CC,
 }
 
 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
-                        AArch64CC::CondCode Cond) {
+                        AArch64CC::CondCode Cond, bool EmitCSel) {
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
   SDLoc DL(Op);
@@ -22568,6 +22569,8 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
 
   // Set condition code (CC) flags.
   SDValue Test = DAG.getNode(PTest, DL, MVT::i32, Pg, Op);
+  if (!EmitCSel)
+    return Test;
 
   // Convert CC to integer based on requested condition.
   // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
@@ -27519,6 +27522,37 @@ static SDValue performMULLCombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue performPTestFirstCombine(SDNode *N,
+                                        TargetLowering::DAGCombinerInfo &DCI,
+                                        SelectionDAG &DAG) {
+  if (DCI.isBeforeLegalize())
+    return SDValue();
+
+  SDLoc DL(N);
+  auto Mask = N->getOperand(0);
+  auto Pred = N->getOperand(1);
+
+  if (Mask->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+    Mask = Mask->getOperand(0);
+
+  if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+    Pred = Pred->getOperand(0);
+
+  if (Pred->getValueType(0).getVectorElementType() != MVT::i1 ||
+      !isAllActivePredicate(DAG, Mask))
+    return SDValue();
+
+  if (Pred->getOpcode() == ISD::CONCAT_VECTORS) {
+    Pred = Pred->getOperand(0);
+    SDValue Mask = DAG.getSplatVector(Pred->getValueType(0), DL,
+                                      DAG.getAllOnesConstant(DL, MVT::i64));
+    return getPTest(DAG, N->getValueType(0), Mask, Pred,
+                    AArch64CC::FIRST_ACTIVE, /* EmitCSel */ false);
+  }
+
+  return SDValue();
+}
+
 static SDValue
 performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                              SelectionDAG &DAG) {
@@ -27875,6 +27909,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case AArch64ISD::UMULL:
   case AArch64ISD::PMULL:
     return performMULLCombine(N, DCI, DAG);
+  case AArch64ISD::PTEST_FIRST:
+    return performPTestFirstCombine(N, DCI, DAG);
   case ISD::INTRINSIC_VOID:
   case ISD::INTRINSIC_W_CHAIN:
     switch (N->getConstantOperandVal(1)) {
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index b89f55188b0f2..e2c861b40e706 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -327,9 +327,6 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
 ; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
 ; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
 ; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.h, p1.h }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT:    ptrue p2.b
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p3.b, p0.b, p1.b
-; CHECK-SVE2p1-SME2-NEXT:    ptest p2, p3.b
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB11_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
 ; CHECK-SVE2p1-SME2-NEXT:    b use
@@ -368,9 +365,6 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
 ; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
 ; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
 ; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.s, p1.s }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT:    ptrue p2.h
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p3.h, p0.h, p1.h
-; CHECK-SVE2p1-SME2-NEXT:    ptest p2, p3.b
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB12_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
 ; CHECK-SVE2p1-SME2-NEXT:    b use
@@ -413,14 +407,9 @@ define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
 ; CHECK-SVE2p1-SME2-NEXT:    adds x8, x0, x8
 ; CHECK-SVE2p1-SME2-NEXT:    csinv x8, x8, xzr, lo
 ; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.s, p1.s }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.s, p3.s }, x8, x1
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.h, p0.h, p1.h
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p5.h, p2.h, p3.h
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.b, p4.b, p5.b
-; CHECK-SVE2p1-SME2-NEXT:    ptrue p5.b
-; CHECK-SVE2p1-SME2-NEXT:    ptest p5, p4.b
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB13_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.s, p3.s }, x8, x1
 ; CHECK-SVE2p1-SME2-NEXT:    b use
 ; CHECK-SVE2p1-SME2-NEXT:  .LBB13_2: // %if.end
 ; CHECK-SVE2p1-SME2-NEXT:    ret
@@ -463,14 +452,9 @@ define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
 ; CHECK-SVE2p1-SME2-NEXT:    adds x8, x0, x8
 ; CHECK-SVE2p1-SME2-NEXT:    csinv x8, x8, xzr, lo
 ; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.d, p1.d }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.d, p3.d }, x8, x1
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.s, p0.s, p1.s
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p5.s, p2.s, p3.s
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.h, p4.h, p5.h
-; CHECK-SVE2p1-SME2-NEXT:    ptrue p5.h
-; CHECK-SVE2p1-SME2-NEXT:    ptest p5, p4.b
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB14_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.d, p3.d }, x8, x1
 ; CHECK-SVE2p1-SME2-NEXT:    b use
 ; CHECK-SVE2p1-SME2-NEXT:  .LBB14_2: // %if.end
 ; CHECK-SVE2p1-SME2-NEXT:    ret

Comment on lines 27549 to 27550
return getPTest(DAG, N->getValueType(0), Mask, Pred,
AArch64CC::FIRST_ACTIVE, /* EmitCSel */ false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt you're using much in getPTest to warrant the EmitCSel change. Having already proven the DAG is testing the first lane of the predicate[1] there's no complexity related to "hidden lanes" and so this can be simplified to:

Pred = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, nxv16i1, Pred)
Mask = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, nxv16i1, Mask)
return DAG.getNode(AArch64ISD::PTEST_FIRST, DL, N->getValueType(0), Pred, Mask);

[1] Is worth adding commentary detailing the "we know we are testing the first lane" requirement because that's key to the combine.

if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST)
Pred = Pred->getOperand(0);

if (Pred->getValueType(0).getVectorElementType() != MVT::i1 ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need the element type check because AArch64ISD::REINTERPRET_CAST is only allowed to change the element count.

Comment on lines 27547 to 27548
SDValue Mask = DAG.getSplatVector(Pred->getValueType(0), DL,
DAG.getAllOnesConstant(DL, MVT::i64));
Copy link
Collaborator

@paulwalker-arm paulwalker-arm Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've already proven the existing Mask does what you need so should be reusable. Perhaps it's worth adding an isLane0KnownOne helper function, then you don't need to strip reinterpret from Mask and can then reuse it directly.

- Extend canRemovePTestInstr for PTEST_PP_FIRST
- Remove getVectorElementType() != MVT::i1 check for Mask
Comment on lines 27550 to 27551
if (!isLane0KnownActive(Mask))
return SDValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps move this before looking through Pred's reinterpret_cast because this is what makes that code logically safe.

getElementSizeForOpcode(PredOpcode))
return PredOpcode;

if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add a comment to match the other cases?

 - Add comment to changes in canRemovePTestInstr
@kmclaughlin-arm kmclaughlin-arm merged commit 20e0e80 into llvm:main Oct 2, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
…UE, A) (llvm#161384)

When the input to ptest_first is a vector concat and the mask is all active,
performPTestFirstCombine returns a ptest_first using the first operand
of the concat, looking through any reinterpret casts.

This allows optimizePTestInstr to later remove the ptest when the first
operand is a flag setting instruction such as whilelo.
@kmclaughlin-arm kmclaughlin-arm deleted the ptest-concat-combine branch October 7, 2025 08:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants