From a3167b5d438413715690df382fc0c3a0a3ecbef8 Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin Date: Wed, 24 Sep 2025 09:02:09 +0000 Subject: [PATCH 1/4] [AArch64] Combine PTEST_FIRST(PTRUE, CONCAT(A, B)) -> PTEST_FIRST(PTRUE, A) When input to a ptest_first is a vector concat and the mask is all active, performPTestFirstCombine returns a ptest_first using the first operand of the concat, looking through any reinterpret casts added by getPTest. This allows optimizePTestInstr to later remove the ptest when the first operand is a flag setting instruction such as whilelo. --- .../Target/AArch64/AArch64ISelLowering.cpp | 40 ++++++++++++++++++- .../AArch64/get-active-lane-mask-extract.ll | 20 +--------- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 45f52352d45fd..1f641cd8c14fc 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -20370,7 +20370,7 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) { } static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, - AArch64CC::CondCode Cond); + AArch64CC::CondCode Cond, bool EmitCSel = true); static bool isPredicateCCSettingOp(SDValue N) { if ((N.getOpcode() == ISD::SETCC) || @@ -20495,6 +20495,7 @@ static SDValue performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT); + if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget)) return Res; if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget)) @@ -22535,7 +22536,7 @@ static SDValue tryConvertSVEWideCompare(SDNode *N, ISD::CondCode CC, } static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, - AArch64CC::CondCode Cond) { + AArch64CC::CondCode Cond, bool EmitCSel) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); SDLoc DL(Op); @@ -22568,6 +22569,8 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, // Set condition code (CC) flags. SDValue Test = DAG.getNode(PTest, DL, MVT::i32, Pg, Op); + if (!EmitCSel) + return Test; // Convert CC to integer based on requested condition. // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare. @@ -27519,6 +27522,37 @@ static SDValue performMULLCombine(SDNode *N, return SDValue(); } +static SDValue performPTestFirstCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + if (DCI.isBeforeLegalize()) + return SDValue(); + + SDLoc DL(N); + auto Mask = N->getOperand(0); + auto Pred = N->getOperand(1); + + if (Mask->getOpcode() == AArch64ISD::REINTERPRET_CAST) + Mask = Mask->getOperand(0); + + if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST) + Pred = Pred->getOperand(0); + + if (Pred->getValueType(0).getVectorElementType() != MVT::i1 || + !isAllActivePredicate(DAG, Mask)) + return SDValue(); + + if (Pred->getOpcode() == ISD::CONCAT_VECTORS) { + Pred = Pred->getOperand(0); + SDValue Mask = DAG.getSplatVector(Pred->getValueType(0), DL, + DAG.getAllOnesConstant(DL, MVT::i64)); + return getPTest(DAG, N->getValueType(0), Mask, Pred, + AArch64CC::FIRST_ACTIVE, /* EmitCSel */ false); + } + + return SDValue(); +} + static SDValue performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { @@ -27875,6 +27909,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, case AArch64ISD::UMULL: case AArch64ISD::PMULL: return performMULLCombine(N, DCI, DAG); + case AArch64ISD::PTEST_FIRST: + return performPTestFirstCombine(N, DCI, DAG); case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: switch (N->getConstantOperandVal(1)) { diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll index b89f55188b0f2..e2c861b40e706 100644 --- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll +++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll @@ -327,9 +327,6 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) { ; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest: ; CHECK-SVE2p1-SME2: // %bb.0: // %entry ; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1 -; CHECK-SVE2p1-SME2-NEXT: ptrue p2.b -; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.b, p0.b, p1.b -; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b ; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2 ; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then ; CHECK-SVE2p1-SME2-NEXT: b use @@ -368,9 +365,6 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n ; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts: ; CHECK-SVE2p1-SME2: // %bb.0: // %entry ; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1 -; CHECK-SVE2p1-SME2-NEXT: ptrue p2.h -; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.h, p0.h, p1.h -; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b ; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2 ; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then ; CHECK-SVE2p1-SME2-NEXT: b use @@ -413,14 +407,9 @@ define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) { ; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8 ; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo ; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1 -; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1 -; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p0.h, p1.h -; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p2.h, p3.h -; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p4.b, p5.b -; CHECK-SVE2p1-SME2-NEXT: ptrue p5.b -; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b ; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2 ; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then +; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1 ; CHECK-SVE2p1-SME2-NEXT: b use ; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end ; CHECK-SVE2p1-SME2-NEXT: ret @@ -463,14 +452,9 @@ define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n ; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8 ; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo ; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1 -; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1 -; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p0.s, p1.s -; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p2.s, p3.s -; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p4.h, p5.h -; CHECK-SVE2p1-SME2-NEXT: ptrue p5.h -; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b ; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2 ; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then +; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1 ; CHECK-SVE2p1-SME2-NEXT: b use ; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end ; CHECK-SVE2p1-SME2-NEXT: ret From b99f48c44bbf8aeb126ebbafaadc0e09eef4769c Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin Date: Wed, 1 Oct 2025 10:22:23 +0000 Subject: [PATCH 2/4] - Add isLane1KnownActive helper - Extend canRemovePTestInstr for PTEST_PP_FIRST - Remove getVectorElementType() != MVT::i1 check for Mask --- .../Target/AArch64/AArch64ISelLowering.cpp | 35 +++++++++++-------- llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 4 +++ 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 1f641cd8c14fc..20b8e75040512 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -20370,7 +20370,7 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) { } static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, - AArch64CC::CondCode Cond, bool EmitCSel = true); + AArch64CC::CondCode Cond); static bool isPredicateCCSettingOp(SDValue N) { if ((N.getOpcode() == ISD::SETCC) || @@ -20495,7 +20495,6 @@ static SDValue performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT); - if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget)) return Res; if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget)) @@ -22536,7 +22535,7 @@ static SDValue tryConvertSVEWideCompare(SDNode *N, ISD::CondCode CC, } static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, - AArch64CC::CondCode Cond, bool EmitCSel) { + AArch64CC::CondCode Cond) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); SDLoc DL(Op); @@ -22569,8 +22568,6 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, // Set condition code (CC) flags. SDValue Test = DAG.getNode(PTest, DL, MVT::i32, Pg, Op); - if (!EmitCSel) - return Test; // Convert CC to integer based on requested condition. // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare. @@ -27237,6 +27234,21 @@ static bool isLanes1toNKnownZero(SDValue Op) { } } +// Return true if the vector operation can guarantee that the first lane of its +// result is active. +static bool isLane1KnownActive(SDValue Op) { + switch (Op.getOpcode()) { + default: + return false; + case AArch64ISD::REINTERPRET_CAST: + return isLane1KnownActive(Op->getOperand(0)); + case ISD::SPLAT_VECTOR: + return isOneConstant(Op.getOperand(0)); + case AArch64ISD::PTRUE: + return Op.getConstantOperandVal(0) == AArch64SVEPredPattern::all; + }; +} + static SDValue removeRedundantInsertVectorElt(SDNode *N) { assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT && "Unexpected node!"); SDValue InsertVec = N->getOperand(0); @@ -27532,22 +27544,17 @@ static SDValue performPTestFirstCombine(SDNode *N, auto Mask = N->getOperand(0); auto Pred = N->getOperand(1); - if (Mask->getOpcode() == AArch64ISD::REINTERPRET_CAST) - Mask = Mask->getOperand(0); - if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST) Pred = Pred->getOperand(0); - if (Pred->getValueType(0).getVectorElementType() != MVT::i1 || - !isAllActivePredicate(DAG, Mask)) + if (!isLane1KnownActive(Mask)) return SDValue(); if (Pred->getOpcode() == ISD::CONCAT_VECTORS) { Pred = Pred->getOperand(0); - SDValue Mask = DAG.getSplatVector(Pred->getValueType(0), DL, - DAG.getAllOnesConstant(DL, MVT::i64)); - return getPTest(DAG, N->getValueType(0), Mask, Pred, - AArch64CC::FIRST_ACTIVE, /* EmitCSel */ false); + Pred = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Pred); + return DAG.getNode(AArch64ISD::PTEST_FIRST, DL, N->getValueType(0), Mask, + Pred); } return SDValue(); diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index 5a51c812732e6..ef056002f085c 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -1503,6 +1503,10 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask, getElementSizeForOpcode(PredOpcode)) return PredOpcode; + if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && + isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) + return PredOpcode; + return {}; } From c5402afee5b5578a44ec7714e234efae05410ed2 Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin Date: Wed, 1 Oct 2025 14:18:04 +0000 Subject: [PATCH 3/4] - Rename isLane1KnownActive -> isLane0KnownActive --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 20b8e75040512..cc4baa474f99b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -27236,12 +27236,12 @@ static bool isLanes1toNKnownZero(SDValue Op) { // Return true if the vector operation can guarantee that the first lane of its // result is active. -static bool isLane1KnownActive(SDValue Op) { +static bool isLane0KnownActive(SDValue Op) { switch (Op.getOpcode()) { default: return false; case AArch64ISD::REINTERPRET_CAST: - return isLane1KnownActive(Op->getOperand(0)); + return isLane0KnownActive(Op->getOperand(0)); case ISD::SPLAT_VECTOR: return isOneConstant(Op.getOperand(0)); case AArch64ISD::PTRUE: @@ -27547,7 +27547,7 @@ static SDValue performPTestFirstCombine(SDNode *N, if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST) Pred = Pred->getOperand(0); - if (!isLane1KnownActive(Mask)) + if (!isLane0KnownActive(Mask)) return SDValue(); if (Pred->getOpcode() == ISD::CONCAT_VECTORS) { From 095b1c37c9ab5663bb41ba3025a903515dd30f26 Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin Date: Wed, 1 Oct 2025 15:49:32 +0000 Subject: [PATCH 4/4] - Move isLane0KnownActive before looking through reinterpret_cast - Add comment to changes in canRemovePTestInstr --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 6 +++--- llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index cc4baa474f99b..a1f4734f83562 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -27544,12 +27544,12 @@ static SDValue performPTestFirstCombine(SDNode *N, auto Mask = N->getOperand(0); auto Pred = N->getOperand(1); - if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST) - Pred = Pred->getOperand(0); - if (!isLane0KnownActive(Mask)) return SDValue(); + if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST) + Pred = Pred->getOperand(0); + if (Pred->getOpcode() == ISD::CONCAT_VECTORS) { Pred = Pred->getOperand(0); Pred = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Pred); diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index ef056002f085c..35b27ea2ec9dd 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -1503,6 +1503,9 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask, getElementSizeForOpcode(PredOpcode)) return PredOpcode; + // For PTEST_FIRST(PTRUE_ALL, WHILE), the PTEST_FIRST is redundant since + // WHILEcc performs an implicit PTEST with an all active mask, setting + // the N flag as the PTEST_FIRST would. if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) return PredOpcode;