Skip to content

Conversation

kmclaughlin-arm
Copy link
Contributor

The optimisation in canRemovePTestInstr tries to remove ptest instructions when
the predicate is the result of a WHILEcc. This patch extends the support to
WHILEcc (predicate pair) by:

  • Including the WHILEcc_x2 intrinsics in isPredicateCCSettingOp, allowing
    performFirstTrueTestVectorCombine to create the PTEST.
  • Setting the isWhile flag for the predicate pair instructions in tablegen.
  • Looking through copies in canRemovePTestInstr to test isWhileOpcode.

Copy link

github-actions bot commented Sep 2, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented Sep 3, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

Changes

The optimisation in canRemovePTestInstr tries to remove ptest instructions when
the predicate is the result of a WHILEcc. This patch extends the support to
WHILEcc (predicate pair) by:

  • Including the WHILEcc_x2 intrinsics in isPredicateCCSettingOp, allowing
    performFirstTrueTestVectorCombine to create the PTEST.
  • Setting the isWhile flag for the predicate pair instructions in tablegen.
  • Looking through copies in canRemovePTestInstr to test isWhileOpcode.

Patch is 28.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156478.diff

13 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+9-1)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+16-2)
  • (modified) llvm/lib/Target/AArch64/AArch64RegisterInfo.td (+12-15)
  • (modified) llvm/lib/Target/AArch64/SVEInstrFormats.td (+3-1)
  • (modified) llvm/test/CodeGen/AArch64/sve-cmp-folds.ll (+106-1)
  • (modified) llvm/test/CodeGen/AArch64/sve-ptest-removal-whilege.mir (+49)
  • (modified) llvm/test/CodeGen/AArch64/sve-ptest-removal-whilegt.mir (+49)
  • (modified) llvm/test/CodeGen/AArch64/sve-ptest-removal-whilehi.mir (+49)
  • (modified) llvm/test/CodeGen/AArch64/sve-ptest-removal-whilehs.mir (+49)
  • (modified) llvm/test/CodeGen/AArch64/sve-ptest-removal-whilele.mir (+49)
  • (modified) llvm/test/CodeGen/AArch64/sve-ptest-removal-whilelo.mir (+49)
  • (modified) llvm/test/CodeGen/AArch64/sve-ptest-removal-whilels.mir (+49)
  • (modified) llvm/test/CodeGen/AArch64/sve-ptest-removal-whilelt.mir (+49)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b7011e0ea1669..675f78825f612 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20001,13 +20001,21 @@ static bool isPredicateCCSettingOp(SDValue N) {
       (N.getOpcode() == ISD::GET_ACTIVE_LANE_MASK) ||
       (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
        (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
+        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege_x2 ||
         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
+        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt_x2 ||
         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
+        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi_x2 ||
         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
+        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs_x2 ||
         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
+        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele_x2 ||
         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
+        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo_x2 ||
         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt)))
+        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels_x2 ||
+        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
+        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt_x2)))
     return true;
 
   return false;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 3ce7829207cb6..db18183b4950b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1487,6 +1487,21 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
+  uint64_t PredEltSize = 0;
+  if (PredIsWhileLike)
+    PredEltSize = getElementSizeForOpcode(PredOpcode);
+
+  if (Pred->isCopy()) {
+    // Instructions which return a multi-vector (e.g. WHILECC_x2) require copies
+    // before the branch to extract each subregister.
+    auto Op = Pred->getOperand(1);
+    if (Op.isReg() && Op.getReg().isVirtual() && Op.getSubReg() != 0) {
+      MachineInstr *DefMI = MRI->getVRegDef(Op.getReg());
+      PredIsWhileLike = isWhileOpcode(DefMI->getOpcode());
+      PredEltSize = getElementSizeForOpcode(DefMI->getOpcode());
+    }
+  }
+
   if (PredIsWhileLike) {
     // For PTEST(PG, PG), PTEST is redundant when PG is the result of a WHILEcc
     // instruction and the condition is "any" since WHILcc does an implicit
@@ -1498,8 +1513,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
     // redundant since WHILE performs an implicit PTEST with an all active
     // mask.
     if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
-        getElementSizeForOpcode(MaskOpcode) ==
-            getElementSizeForOpcode(PredOpcode))
+        getElementSizeForOpcode(MaskOpcode) == PredEltSize)
       return PredOpcode;
 
     return {};
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index 1a7609bfee8a1..72c303fcbc55b 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -1164,25 +1164,22 @@ class PPRVectorListMul<int ElementWidth, int NumRegs> : PPRVectorList<ElementWid
                                                                 ", AArch64::PPRMul2RegClassID>";
 }
 
+class PPR2MulRegOp<string Suffix, int Size, ElementSizeEnum ES>
+    : RegisterOperand<PPR2Mul2, "printTypedVectorList<0,'"#Suffix#"'>"> {
+  ElementSizeEnum ElementSize;
+  let ElementSize = ES;
+  let ParserMatchClass = PPRVectorListMul<Size, 2>;
+}
+
 let EncoderMethod = "EncodeRegMul_MinMax<2, 0, 14>",
     DecoderMethod = "DecodePPR2Mul2RegisterClass" in {
-  def PP_b_mul_r : RegisterOperand<PPR2Mul2, "printTypedVectorList<0,'b'>"> {
-    let ParserMatchClass = PPRVectorListMul<8, 2>;
-  }
-
-  def PP_h_mul_r : RegisterOperand<PPR2Mul2, "printTypedVectorList<0,'h'>"> {
-    let ParserMatchClass = PPRVectorListMul<16, 2>;
-  }
 
-  def PP_s_mul_r : RegisterOperand<PPR2Mul2, "printTypedVectorList<0,'s'>"> {
-    let ParserMatchClass = PPRVectorListMul<32, 2>;
-  }
-
-  def PP_d_mul_r : RegisterOperand<PPR2Mul2, "printTypedVectorList<0,'d'>"> {
-    let ParserMatchClass = PPRVectorListMul<64, 2>;
-  }
-}  // end let EncoderMethod/DecoderMethod
+  def PP_b_mul_r : PPR2MulRegOp<"b", 8,  ElementSizeB>;
+  def PP_h_mul_r : PPR2MulRegOp<"h", 16, ElementSizeH>;
+  def PP_s_mul_r : PPR2MulRegOp<"s", 32, ElementSizeS>;
+  def PP_d_mul_r : PPR2MulRegOp<"d", 64, ElementSizeD>;
 
+} // end let EncoderMethod/DecoderMethod
 
 //===----------------------------------------------------------------------===//
 // SVE vector register classes
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 74e4a7feb49d0..620c446dde4af 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -10397,7 +10397,7 @@ multiclass sve2p1_int_while_rr_pn<string mnemonic, bits<3> opc> {
 
 // SVE integer compare scalar count and limit (predicate pair)
 class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
-                             RegisterOperand ppr_ty>
+                               PPR2MulRegOp ppr_ty>
     : I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm),
         mnemonic, "\t$Pd, $Rn, $Rm",
         "", []>, Sched<[]> {
@@ -10417,6 +10417,8 @@ class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
 
   let Defs = [NZCV];
   let hasSideEffects = 0;
+  let ElementSize = ppr_ty.ElementSize;
+  let isWhile = 1;
 }
 
 
diff --git a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
index 981cc88298a3e..0d964a488e9ec 100644
--- a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
+++ b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=aarch64-linux-unknown -mattr=+sve2 -o - < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64-linux-unknown -mattr=+sve2p1 -o - < %s | FileCheck %s
 
 define <vscale x 8 x i1> @not_icmp_sle_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
 ; CHECK-LABEL: not_icmp_sle_nxv8i16:
@@ -220,6 +220,102 @@ define i1 @lane_mask_first(i64 %next, i64 %end) {
   ret i1 %bit
 }
 
+define i1 @whilege_x2_first(i64 %next, i64 %end) {
+; CHECK-LABEL: whilege_x2_first:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    whilege { p0.s, p1.s }, x0, x1
+; CHECK-NEXT:    cset w0, mi
+; CHECK-NEXT:    ret
+  %predpair = call { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilege.x2.nxv4i1.i64(i64 %next, i64 %end)
+  %predicate = extractvalue { <vscale x 4 x i1>, <vscale x 4 x i1> } %predpair, 0
+  %bit = extractelement <vscale x 4 x i1> %predicate, i64 0
+  ret i1 %bit
+}
+
+define i1 @whilegt_x2_first(i64 %next, i64 %end) {
+; CHECK-LABEL: whilegt_x2_first:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    whilegt { p0.s, p1.s }, x0, x1
+; CHECK-NEXT:    cset w0, mi
+; CHECK-NEXT:    ret
+  %predpair = call { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilegt.x2.nxv4i1.i64(i64 %next, i64 %end)
+  %predicate = extractvalue { <vscale x 4 x i1>, <vscale x 4 x i1> } %predpair, 0
+  %bit = extractelement <vscale x 4 x i1> %predicate, i64 0
+  ret i1 %bit
+}
+
+define i1 @whilehi_x2_first(i64 %next, i64 %end) {
+; CHECK-LABEL: whilehi_x2_first:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    whilehi { p0.s, p1.s }, x0, x1
+; CHECK-NEXT:    cset w0, mi
+; CHECK-NEXT:    ret
+  %predpair = call { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilehi.x2.nxv4i1.i64(i64 %next, i64 %end)
+  %predicate = extractvalue { <vscale x 4 x i1>, <vscale x 4 x i1> } %predpair, 0
+  %bit = extractelement <vscale x 4 x i1> %predicate, i64 0
+  ret i1 %bit
+}
+
+define i1 @whilehs_x2_first(i64 %next, i64 %end) {
+; CHECK-LABEL: whilehs_x2_first:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    whilehs { p0.s, p1.s }, x0, x1
+; CHECK-NEXT:    cset w0, mi
+; CHECK-NEXT:    ret
+  %predpair = call { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilehs.x2.nxv4i1.i64(i64 %next, i64 %end)
+  %predicate = extractvalue { <vscale x 4 x i1>, <vscale x 4 x i1> } %predpair, 0
+  %bit = extractelement <vscale x 4 x i1> %predicate, i64 0
+  ret i1 %bit
+}
+
+define i1 @whilele_x2_first(i64 %next, i64 %end) {
+; CHECK-LABEL: whilele_x2_first:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    whilele { p0.s, p1.s }, x0, x1
+; CHECK-NEXT:    cset w0, mi
+; CHECK-NEXT:    ret
+  %predpair = call { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilele.x2.nxv4i1.i64(i64 %next, i64 %end)
+  %predicate = extractvalue { <vscale x 4 x i1>, <vscale x 4 x i1> } %predpair, 0
+  %bit = extractelement <vscale x 4 x i1> %predicate, i64 0
+  ret i1 %bit
+}
+
+define i1 @whilelo_x2_first(i64 %next, i64 %end) {
+; CHECK-LABEL: whilelo_x2_first:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    whilelo { p0.s, p1.s }, x0, x1
+; CHECK-NEXT:    cset w0, mi
+; CHECK-NEXT:    ret
+  %predpair = call { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilelo.x2.nxv4i1.i64(i64 %next, i64 %end)
+  %predicate = extractvalue { <vscale x 4 x i1>, <vscale x 4 x i1> } %predpair, 0
+  %bit = extractelement <vscale x 4 x i1> %predicate, i64 0
+  ret i1 %bit
+}
+
+define i1 @whilels_x2_first(i64 %next, i64 %end) {
+; CHECK-LABEL: whilels_x2_first:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    whilels { p0.s, p1.s }, x0, x1
+; CHECK-NEXT:    cset w0, mi
+; CHECK-NEXT:    ret
+  %predpair = call { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilels.x2.nxv4i1.i64(i64 %next, i64 %end)
+  %predicate = extractvalue { <vscale x 4 x i1>, <vscale x 4 x i1> } %predpair, 0
+  %bit = extractelement <vscale x 4 x i1> %predicate, i64 0
+  ret i1 %bit
+}
+
+define i1 @whilelt_x2_first(i64 %next, i64 %end) {
+; CHECK-LABEL: whilelt_x2_first:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    whilelt { p0.s, p1.s }, x0, x1
+; CHECK-NEXT:    cset w0, mi
+; CHECK-NEXT:    ret
+  %predpair = call { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilelt.x2.nxv4i1.i64(i64 %next, i64 %end)
+  %predicate = extractvalue { <vscale x 4 x i1>, <vscale x 4 x i1> } %predpair, 0
+  %bit = extractelement <vscale x 4 x i1> %predicate, i64 0
+  ret i1 %bit
+}
+
 declare i64 @llvm.vscale.i64()
 declare <vscale x 4 x i1> @llvm.aarch64.sve.whilege.nxv4i1.i64(i64, i64)
 declare <vscale x 4 x i1> @llvm.aarch64.sve.whilegt.nxv4i1.i64(i64, i64)
@@ -230,3 +326,12 @@ declare <vscale x 4 x i1> @llvm.aarch64.sve.whilelo.nxv4i1.i64(i64, i64)
 declare <vscale x 4 x i1> @llvm.aarch64.sve.whilels.nxv4i1.i64(i64, i64)
 declare <vscale x 4 x i1> @llvm.aarch64.sve.whilelt.nxv4i1.i64(i64, i64)
 declare <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64, i64)
+
+declare { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilege.x2.nxv4i1(i64, i64)
+declare { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilegt.x2.nxv4i1(i64, i64)
+declare { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilehi.x2.nxv4i1(i64, i64)
+declare { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilehs.x2.nxv4i1(i64, i64)
+declare { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilele.x2.nxv4i1(i64, i64)
+declare { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilelo.x2.nxv4i1(i64, i64)
+declare { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilels.x2.nxv4i1(i64, i64)
+declare { <vscale x 4 x i1>, <vscale x 4 x i1> } @llvm.aarch64.sve.whilelt.x2.nxv4i1(i64, i64)
diff --git a/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilege.mir b/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilege.mir
index 69a2c88d7dbad..d3a1be3de17fb 100644
--- a/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilege.mir
+++ b/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilege.mir
@@ -538,3 +538,52 @@ body:             |
     RET_ReallyLR implicit $w0
 
 ...
+
+# WHILEGE (predicate pair)
+---
+name:            whilege_x2_b64_s64
+alignment:       2
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: gpr64 }
+  - { id: 1, class: gpr64 }
+  - { id: 2, class: ppr }
+  - { id: 3, class: ppr2mul2 }
+  - { id: 4, class: ppr }
+  - { id: 5, class: ppr }
+  - { id: 6, class: gpr32 }
+  - { id: 7, class: gpr32 }
+liveins:
+  - { reg: '$x0', virtual-reg: '%0' }
+  - { reg: '$x1', virtual-reg: '%1' }
+frameInfo:
+  maxCallFrameSize: 0
+body:             |
+  bb.0.entry:
+    liveins: $x0, $x1
+
+    ; CHECK-LABEL: name: whilege_x2_b64_s64
+    ; CHECK: liveins: $x0, $x1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr64 = COPY $x0
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr64 = COPY $x1
+    ; CHECK-NEXT: [[PTRUE_D:%[0-9]+]]:ppr = PTRUE_D 31, implicit $vg
+    ; CHECK-NEXT: [[WHILEGE_2PXX_D:%[0-9]+]]:ppr2mul2 = WHILEGE_2PXX_D [[COPY]], [[COPY1]], implicit-def $nzcv
+    ; CHECK-NEXT: [[COPY2:%[0-9]+]]:ppr = COPY [[WHILEGE_2PXX_D]].psub0
+    ; CHECK-NEXT: [[COPY3:%[0-9]+]]:ppr = COPY [[WHILEGE_2PXX_D]].psub1
+    ; CHECK-NEXT: [[COPY4:%[0-9]+]]:gpr32 = COPY $wzr
+    ; CHECK-NEXT: [[CSINCWr:%[0-9]+]]:gpr32 = CSINCWr [[COPY4]], $wzr, 0, implicit $nzcv
+    ; CHECK-NEXT: $w0 = COPY [[CSINCWr]]
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:gpr64 = COPY $x0
+    %1:gpr64 = COPY $x1
+    %2:ppr = PTRUE_D 31, implicit $vg
+    %3:ppr2mul2 = WHILEGE_2PXX_D %0, %1, implicit-def $nzcv
+    %4:ppr = COPY %3.psub0
+    %5:ppr = COPY %3.psub1
+    PTEST_PP killed %2, killed %4, implicit-def $nzcv
+    %6:gpr32 = COPY $wzr
+    %7:gpr32 = CSINCWr %6, $wzr, 0, implicit $nzcv
+    $w0 = COPY %7
+    RET_ReallyLR implicit $w0
+...
diff --git a/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilegt.mir b/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilegt.mir
index 58db85aba80ad..fb92955f02d52 100644
--- a/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilegt.mir
+++ b/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilegt.mir
@@ -578,3 +578,52 @@ body:             |
     RET_ReallyLR implicit $w0
 
 ...
+
+# WHILEGT (predicate pair)
+---
+name:            whilegt_x2_b64_s64
+alignment:       2
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: gpr64 }
+  - { id: 1, class: gpr64 }
+  - { id: 2, class: ppr }
+  - { id: 3, class: ppr2mul2 }
+  - { id: 4, class: ppr }
+  - { id: 5, class: ppr }
+  - { id: 6, class: gpr32 }
+  - { id: 7, class: gpr32 }
+liveins:
+  - { reg: '$x0', virtual-reg: '%0' }
+  - { reg: '$x1', virtual-reg: '%1' }
+frameInfo:
+  maxCallFrameSize: 0
+body:             |
+  bb.0.entry:
+    liveins: $x0, $x1
+
+    ; CHECK-LABEL: name: whilegt_x2_b64_s64
+    ; CHECK: liveins: $x0, $x1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr64 = COPY $x0
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr64 = COPY $x1
+    ; CHECK-NEXT: [[PTRUE_D:%[0-9]+]]:ppr = PTRUE_D 31, implicit $vg
+    ; CHECK-NEXT: [[WHILEGT_2PXX_D:%[0-9]+]]:ppr2mul2 = WHILEGT_2PXX_D [[COPY]], [[COPY1]], implicit-def $nzcv
+    ; CHECK-NEXT: [[COPY2:%[0-9]+]]:ppr = COPY [[WHILEGT_2PXX_D]].psub0
+    ; CHECK-NEXT: [[COPY3:%[0-9]+]]:ppr = COPY [[WHILEGT_2PXX_D]].psub1
+    ; CHECK-NEXT: [[COPY4:%[0-9]+]]:gpr32 = COPY $wzr
+    ; CHECK-NEXT: [[CSINCWr:%[0-9]+]]:gpr32 = CSINCWr [[COPY4]], $wzr, 0, implicit $nzcv
+    ; CHECK-NEXT: $w0 = COPY [[CSINCWr]]
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:gpr64 = COPY $x0
+    %1:gpr64 = COPY $x1
+    %2:ppr = PTRUE_D 31, implicit $vg
+    %3:ppr2mul2 = WHILEGT_2PXX_D %0, %1, implicit-def $nzcv
+    %4:ppr = COPY %3.psub0
+    %5:ppr = COPY %3.psub1
+    PTEST_PP killed %2, killed %4, implicit-def $nzcv
+    %6:gpr32 = COPY $wzr
+    %7:gpr32 = CSINCWr %6, $wzr, 0, implicit $nzcv
+    $w0 = COPY %7
+    RET_ReallyLR implicit $w0
+...
diff --git a/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilehi.mir b/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilehi.mir
index 03d9768258ebc..97f242b852eb8 100644
--- a/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilehi.mir
+++ b/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilehi.mir
@@ -538,3 +538,52 @@ body:             |
     RET_ReallyLR implicit $w0
 
 ...
+
+# WHILEHI (predicate pair)
+---
+name:            whilehi_x2_b64_s64
+alignment:       2
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: gpr64 }
+  - { id: 1, class: gpr64 }
+  - { id: 2, class: ppr }
+  - { id: 3, class: ppr2mul2 }
+  - { id: 4, class: ppr }
+  - { id: 5, class: ppr }
+  - { id: 6, class: gpr32 }
+  - { id: 7, class: gpr32 }
+liveins:
+  - { reg: '$x0', virtual-reg: '%0' }
+  - { reg: '$x1', virtual-reg: '%1' }
+frameInfo:
+  maxCallFrameSize: 0
+body:             |
+  bb.0.entry:
+    liveins: $x0, $x1
+
+    ; CHECK-LABEL: name: whilehi_x2_b64_s64
+    ; CHECK: liveins: $x0, $x1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr64 = COPY $x0
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr64 = COPY $x1
+    ; CHECK-NEXT: [[PTRUE_D:%[0-9]+]]:ppr = PTRUE_D 31, implicit $vg
+    ; CHECK-NEXT: [[WHILEHI_2PXX_D:%[0-9]+]]:ppr2mul2 = WHILEHI_2PXX_D [[COPY]], [[COPY1]], implicit-def $nzcv
+    ; CHECK-NEXT: [[COPY2:%[0-9]+]]:ppr = COPY [[WHILEHI_2PXX_D]].psub0
+    ; CHECK-NEXT: [[COPY3:%[0-9]+]]:ppr = COPY [[WHILEHI_2PXX_D]].psub1
+    ; CHECK-NEXT: [[COPY4:%[0-9]+]]:gpr32 = COPY $wzr
+    ; CHECK-NEXT: [[CSINCWr:%[0-9]+]]:gpr32 = CSINCWr [[COPY4]], $wzr, 0, implicit $nzcv
+    ; CHECK-NEXT: $w0 = COPY [[CSINCWr]]
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:gpr64 = COPY $x0
+    %1:gpr64 = COPY $x1
+    %2:ppr = PTRUE_D 31, implicit $vg
+    %3:ppr2mul2 = WHILEHI_2PXX_D %0, %1, implicit-def $nzcv
+    %4:ppr = COPY %3.psub0
+    %5:ppr = COPY %3.psub1
+    PTEST_PP killed %2, killed %4, implicit-def $nzcv
+    %6:gpr32 = COPY $wzr
+    %7:gpr32 = CSINCWr %6, $wzr, 0, implicit $nzcv
+    $w0 = COPY %7
+    RET_ReallyLR implicit $w0
+...
diff --git a/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilehs.mir b/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilehs.mir
index 68ecd79c8325b..0ec4788957335 100644
--- a/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilehs.mir
+++ b/llvm/test/CodeGen/AArch64/sve-ptest-removal-whilehs.mir
@@ -538,3 +538,52 @@ body:             |
     RET_ReallyLR implicit $w0
 
 ...
+
+# WHILEHS (predicate pair)
+---
+name:            whilehs_x2_b64_s64
+alignment:       2
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: gpr64 }
+  - { id: 1, class: gpr64 }
+  - { id: 2, class: ppr }
+  - { id: 3, class: ppr2mul2 }
+  - { id: 4, class: ppr }
+  - { id: 5, class: ppr }
+  - { id: 6, class: gpr32 }
+  - { id: 7, class: gpr32 }
+liveins:
+  - { reg: '$x0', virtual-reg: '%0' }
+  - { reg: '$x1', virtual-reg: '%1' }
+frameInfo:
+  maxCallFrameSize: 0
+body:             |
+  bb.0.entry:
+    liveins: $x0, $x1
+
+    ; CHECK-LABEL: name: whilehs_x2_b64_s64
+    ; CHECK: liveins: $x0, $x1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr64 = COPY $x0
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr64 = COPY $x1
+    ; CHECK-NEXT: [[PTRUE_D:%[0-9]+]]:ppr = PTRUE_D 31, implicit $vg
+    ; CHECK-NEXT: [[WHILEHS_2PXX_D:%[0-9]+]]:ppr2mul2 = WHILEHS_2PXX_D [[COPY]], [[COPY1]], implicit-def $nzcv
+    ; CHECK-NEXT: [[COPY2:%[0-9]+]]:ppr = COPY [[WHILEHS_2PXX_D]].psub0
+    ; CHECK-NEXT: [[COPY3:%[0-9]+]]:ppr = COPY [[WHILEHS_2PXX_D]].psub1
+    ; CHECK-NEXT: [[COPY4:%[0-9]+]]:gpr32 = COPY $wzr
+    ; CHECK-NEXT: [[CSINCWr:%[0-9]+]]:gpr32 = CSINCWr [[COPY4]], $wzr, 0, implicit $nzcv
+    ; CHECK-NEXT: $w0 = COPY [[CSINCWr]]
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:gpr64 = COPY $x0
+    %1:gpr64 = COPY $x1
+    %2:ppr = PTRUE_D 31, implicit $vg
+    %3:ppr2mul2 = WHILEHS_2PXX_D %0, %1, implicit-def $nzcv
+    %4:ppr = COPY %3.psub0
+    %5:ppr = COPY %3.psub1
+    PTEST_PP killed %2, killed %4, implicit-def $nzcv
+    %6:gpr32 = COPY $wzr
+    %7:gpr32 = CSINCWr %6, $wzr, 0, implicit $nzcv
+    $w0 = COPY %7
+    RET_ReallyLR implicit $w0
+...
dif...
[truncated]

Comment on lines 1494 to 1496
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at optimizePTestInstr there's code to scan for CC uses between PTest and Pred so I don't think we can just look through a COPY like this. Perhaps this code belongs in optimizePTestInstr so that Pred always points to the flag setting instruction?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "Op.getSubReg() != 0" sufficient? Should it be exactly AArch64::psub0?

@paulwalker-arm
Copy link
Collaborator

As discussed, the while pair instructions end like:

PSTATE.<N,Z,C,V> = PredTest(mask, result, esize);
P[d0, PL] = result<PL-1:0>;
P[d1, PL] = result<PL*2-1:PL>;

So the condition codes are set based on the result of testing the concatenation of the two register results. This means ptest(while_x2(a,b)) is not the same operation as ptest(svget(while_x2(a,b), 0)). There is overlap however, for example when branching on pfirst whereby the two will produce the same result because only the first result of the while pair is relevant.

The problem is when calling canRemovePTestInstr we know there's a ptest we want to remove but we don't know which part of the condition code will be used and thus we cannot identify which of the while pair results is relevant.

We've hit something similar to this before and solved it by creating the PTEST_PP_ANY pseudo node (there's a matching PTEST_ANY ISD node) that essentially encodes this condition code information. I believe your work requires us to create a PTEST_PP_FIRST equivalent which you can match against to prove it's safe to look through the COPY.

I recommend adding the PTEST_PP_FIRST pseudo and AArch64ISD::PTEST_FIRST support as a separate refactoring PR, which this PR can then build upon.

kmclaughlin-arm added a commit that referenced this pull request Sep 9, 2025
The pseudo is created when the condition of a ptest is FIRST_ACTIVE.

This allows optimizePTestInstr to be extended to handle whilecc intrinsics
that return a predicate pair, where it is necessary to identify the
condition code used to remove a ptest on the first result of the pair.
(See #156478)
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Sep 9, 2025
The pseudo is created when the condition of a ptest is FIRST_ACTIVE.

This allows optimizePTestInstr to be extended to handle whilecc intrinsics
that return a predicate pair, where it is necessary to identify the
condition code used to remove a ptest on the first result of the pair.
(See llvm/llvm-project#156478)
The optimisation in canRemovePTestInstr tries to remove ptest instructions when
the predicate is the result of a WHILEcc. This patch extends the support to
WHILEcc (predicate pair) by:
 - Including the WHILEcc_x2 intrinsics in isPredicateCCSettingOp, allowing
   performFirstTrueTestVectorCombine to create the PTEST.
 - Setting the isWhile flag for the predicate pair instructions in tablegen.
 - Looking through copies in canRemovePTestInstr to test isWhileOpcode.
…only

  consider extracts/copies from the first result of whilecc_x2
- Add negative tests
Copy link
Collaborator

@paulwalker-arm paulwalker-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last question but otherwise this looks good to me.

// Instructions which return a multi-vector (e.g. WHILECC_x2) require copies
// before the branch to extract each subregister.
auto Op = Pred->getOperand(1);
if (Op.isReg() && Op.getSubReg() == AArch64::psub0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I'm being paranoid but you might need && Op.getReg().isVirtual() as a requirement for calling getUniqueVRegDef?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is required that the register is virtual. I was checking isVirtual previously when this was part of canRemovePTestInstr, but must have removed it when I moved this to optimizePTestInstr.

@kmclaughlin-arm kmclaughlin-arm merged commit b9fd1e6 into llvm:main Sep 11, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants