Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SVE] Wide active lane mask #76514

Closed

Conversation

momchil-velikov
Copy link
Collaborator

These patches make the LoopVectorize generate lane masks longer than
the VF to allow the target to better utilise the instruction set.
The vectoriser emits one or more wide llvm.get.active.lane.mask.*
calls plus several llvm.vector.extract.* calls to yield the required
number of VF-wide masks.

The motivating example is a vectorised loop with unroll factor 2 that
can use the SVE2.1 whilelo instruction with predicate pair result, or
a SVE whilelo instruction with smaller element size plus
punpklo/punpkhi.

How wide is the lane mask that the vectoriser emits is controlled
by a TargetTransformInfo hook getMaxPredicateLength.The default
impementation (return the same length as the VF) keeps the
change non-functional for targets that can't or are not prepared
to handle wider lane masks.

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 28, 2023

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-analysis

Author: Momchil Velikov (momchil-velikov)

Changes

These patches make the LoopVectorize generate lane masks longer than
the VF to allow the target to better utilise the instruction set.
The vectoriser emits one or more wide llvm.get.active.lane.mask.*
calls plus several llvm.vector.extract.* calls to yield the required
number of VF-wide masks.

The motivating example is a vectorised loop with unroll factor 2 that
can use the SVE2.1 whilelo instruction with predicate pair result, or
a SVE whilelo instruction with smaller element size plus
punpklo/punpkhi.

How wide is the lane mask that the vectoriser emits is controlled
by a TargetTransformInfo hook getMaxPredicateLength.The default
impementation (return the same length as the VF) keeps the
change non-functional for targets that can't or are not prepared
to handle wider lane masks.


Patch is 426.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76514.diff

22 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+10)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+2)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+105-33)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+13-2)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+8)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Target/AArch64/SVEInstrFormats.td (+7-5)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+8)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+6-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+34-5)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+70-14)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+7-11)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1)
  • (added) llvm/test/CodeGen/AArch64/get-active-lane-mask-32x1.ll (+31)
  • (added) llvm/test/CodeGen/AArch64/sve-wide-lane-mask.ll (+1003)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll (+886-874)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding-unroll.ll (+196-188)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/sve-wide-lane-mask.ll (+655)
  • (modified) llvm/test/Transforms/LoopVectorize/ARM/tail-folding-prefer-flag.ll (+3-3)
  • (modified) llvm/test/Transforms/LoopVectorize/strict-fadd-interleave-only.ll (+10-10)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 048912beaba5a1..8416f6138c6c46 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1216,6 +1216,8 @@ class TargetTransformInfo {
   /// and the number of execution units in the CPU.
   unsigned getMaxInterleaveFactor(ElementCount VF) const;
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   /// Collect properties of V used in cost analysis, e.g. OP_PowerOf2.
   static OperandValueInfo getOperandInfo(const Value *V);
 
@@ -1952,6 +1954,9 @@ class TargetTransformInfo::Concept {
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
+
+  virtual ElementCount getMaxPredicateLength(ElementCount VF) const = 0;
+
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       OperandValueInfo Opd1Info, OperandValueInfo Opd2Info,
@@ -2557,6 +2562,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
+
+  ElementCount getMaxPredicateLength(ElementCount VF) const override {
+    return Impl.getMaxPredicateLength(VF);
+  }
+
   unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
                                             unsigned &JTSize,
                                             ProfileSummaryInfo *PSI,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 7ad3ce512a3552..e341220fa48b09 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -513,6 +513,8 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info, TTI::OperandValueInfo Opd2Info,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 5e7bdcdf72a49f..fcad6a86538256 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -881,6 +881,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 67246afa23147a..aaf39be8ff0634 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -789,6 +789,10 @@ unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
 
+ElementCount TargetTransformInfo::getMaxPredicateLength(ElementCount VF) const {
+  return TTIImpl->getMaxPredicateLength(VF);
+}
+
 TargetTransformInfo::OperandValueInfo
 TargetTransformInfo::getOperandInfo(const Value *V) {
   OperandValueKind OpInfo = OK_AnyValue;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index dffe69bdb900db..bf5d48c903307f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1772,15 +1772,17 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
 
 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
                                                           EVT OpVT) const {
-  // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
-  if (!Subtarget->hasSVE())
+  // Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
+  if (!Subtarget->hasSVEorSME())
     return true;
 
   // We can only support legal predicate result types. We can use the SVE
   // whilelo instruction for generating fixed-width predicates too.
   if (ResVT != MVT::nxv2i1 && ResVT != MVT::nxv4i1 && ResVT != MVT::nxv8i1 &&
       ResVT != MVT::nxv16i1 && ResVT != MVT::v2i1 && ResVT != MVT::v4i1 &&
-      ResVT != MVT::v8i1 && ResVT != MVT::v16i1)
+      ResVT != MVT::v8i1 && ResVT != MVT::v16i1 &&
+      (!(Subtarget->hasSVE2p1() || Subtarget->hasSME2()) ||
+       ResVT != MVT::nxv32i1)) // TODO: handle MVT::v32i1
     return true;
 
   // The whilelo instruction only works with i32 or i64 scalar inputs.
@@ -17760,22 +17762,49 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
                         AArch64CC::CondCode Cond);
 
-static bool isPredicateCCSettingOp(SDValue N) {
-  if ((N.getOpcode() == ISD::SETCC) ||
-      (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-       (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
-        // get_active_lane_mask is lowered to a whilelo instruction.
-        N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
-    return true;
+static SDValue getPredicateCCSettingOp(SDValue N) {
+  if (N.getOpcode() == ISD::SETCC) {
+    EVT VT = N.getValueType();
+    return VT.isScalableVector() && VT.getVectorElementType() == MVT::i1
+               ? N
+               : SDValue();
+  }
 
-  return false;
+  if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+      isNullConstant(N.getOperand(1)))
+    N = N.getOperand(0);
+
+  if (N.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+    return SDValue();
+
+  switch (N.getConstantOperandVal(0)) {
+  default:
+    return SDValue();
+  case Intrinsic::aarch64_sve_whilege_x2:
+  case Intrinsic::aarch64_sve_whilegt_x2:
+  case Intrinsic::aarch64_sve_whilehi_x2:
+  case Intrinsic::aarch64_sve_whilehs_x2:
+  case Intrinsic::aarch64_sve_whilele_x2:
+  case Intrinsic::aarch64_sve_whilelo_x2:
+  case Intrinsic::aarch64_sve_whilels_x2:
+  case Intrinsic::aarch64_sve_whilelt_x2:
+    if (N.getResNo() != 0)
+      return SDValue();
+    [[fallthrough]];
+  case Intrinsic::aarch64_sve_whilege:
+  case Intrinsic::aarch64_sve_whilegt:
+  case Intrinsic::aarch64_sve_whilehi:
+  case Intrinsic::aarch64_sve_whilehs:
+  case Intrinsic::aarch64_sve_whilele:
+  case Intrinsic::aarch64_sve_whilelo:
+  case Intrinsic::aarch64_sve_whilels:
+  case Intrinsic::aarch64_sve_whilelt:
+  case Intrinsic::get_active_lane_mask:
+    assert(N.getValueType().isScalableVector() &&
+           N.getValueType().getVectorElementType() == MVT::i1 &&
+           "Intrinsic expected to yield scalable i1 vector");
+    return N;
+  }
 }
 
 // Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
@@ -17789,21 +17818,17 @@ performFirstTrueTestVectorCombine(SDNode *N,
   if (!Subtarget->hasSVE() || DCI.isBeforeLegalize())
     return SDValue();
 
-  SDValue N0 = N->getOperand(0);
-  EVT VT = N0.getValueType();
-
-  if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1 ||
-      !isNullConstant(N->getOperand(1)))
-    return SDValue();
-
-  // Restricted the DAG combine to only cases where we're extracting from a
-  // flag-setting operation.
-  if (!isPredicateCCSettingOp(N0))
+  // Restrict the DAG combine to only cases where we're extracting the zero-th
+  // element from the result of a flag-setting operation.
+  SDValue N0;
+  if (!isNullConstant(N->getOperand(1)) ||
+      !(N0 = getPredicateCCSettingOp(N->getOperand(0))))
     return SDValue();
 
   // Extracts of lane 0 for SVE can be expressed as PTEST(Op, FIRST) ? 1 : 0
   SelectionDAG &DAG = DCI.DAG;
-  SDValue Pg = getPTrue(DAG, SDLoc(N), VT, AArch64SVEPredPattern::all);
+  SDValue Pg =
+      getPTrue(DAG, SDLoc(N), N0.getValueType(), AArch64SVEPredPattern::all);
   return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::FIRST_ACTIVE);
 }
 
@@ -19768,7 +19793,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
   default:
     break;
   case Intrinsic::get_active_lane_mask: {
-    SDValue Res = SDValue();
     EVT VT = N->getValueType(0);
     if (VT.isFixedLengthVector()) {
       // We can use the SVE whilelo instruction to lower this intrinsic by
@@ -19791,15 +19815,63 @@ static SDValue performIntrinsicCombine(SDNode *N,
           EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
                            VT.getVectorElementCount());
 
-      Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
-                        N->getOperand(1), N->getOperand(2));
+      SDValue Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
+                                N->getOperand(1), N->getOperand(2));
       Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
       Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
                         DAG.getConstant(0, DL, MVT::i64));
       Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
+
+      return Res;
     }
-    return Res;
+
+    if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+      return SDValue();
+
+    if (!N->hasNUsesOfValue(2, 0))
+      return SDValue();
+
+    auto It = N->use_begin();
+    SDNode *Lo = *It++;
+    SDNode *Hi = *It;
+
+    const uint64_t HalfSize = VT.getVectorMinNumElements() / 2;
+    uint64_t OffLo, OffHi;
+    if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+        Lo->getOperand(1)->getOpcode() != ISD::Constant ||
+        ((OffLo = Lo->getConstantOperandVal(1)) != 0 && OffLo != HalfSize) ||
+        Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+        Hi->getOperand(1)->getOpcode() != ISD::Constant ||
+        ((OffHi = Hi->getConstantOperandVal(1)) != 0 && OffHi != HalfSize))
+      return SDValue();
+
+    if (OffLo > OffHi) {
+      std::swap(Lo, Hi);
+      std::swap(OffLo, OffHi);
+    }
+
+    if (OffLo != 0 || OffHi != HalfSize)
+      return SDValue();
+
+    SDLoc DL(N);
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+    SDValue Idx = N->getOperand(1);
+    SDValue TC = N->getOperand(2);
+    if (Idx.getValueType() != MVT::i64) {
+      Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+      TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+    }
+    auto R =
+        DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+                    {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+    DCI.CombineTo(Lo, R.getValue(0));
+    DCI.CombineTo(Hi, R.getValue(1));
+
+    return SDValue(N, 0);
   }
+
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 1cfbf4737a6f72..91a9af727ffb0b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1356,11 +1356,22 @@ bool AArch64InstrInfo::optimizePTestInstr(
     const MachineRegisterInfo *MRI) const {
   auto *Mask = MRI->getUniqueVRegDef(MaskReg);
   auto *Pred = MRI->getUniqueVRegDef(PredReg);
-  auto NewOp = Pred->getOpcode();
+  unsigned NewOp;
   bool OpChanged = false;
 
   unsigned MaskOpcode = Mask->getOpcode();
   unsigned PredOpcode = Pred->getOpcode();
+
+  // Handle a COPY from the LSB of a paired WHILEcc instruction.
+  if ((PredOpcode == TargetOpcode::COPY &&
+       Pred->getOperand(1).getSubReg() == AArch64::psub0)) {
+    MachineInstr *MI = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+    if (MI && isWhileOpcode(MI->getOpcode())) {
+      Pred = MI;
+      PredOpcode = MI->getOpcode();
+    }
+  }
+
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
@@ -1476,9 +1487,9 @@ bool AArch64InstrInfo::optimizePTestInstr(
   // as they are prior to PTEST. Sometimes this requires the tested PTEST
   // operand to be replaced with an equivalent instruction that also sets the
   // flags.
-  Pred->setDesc(get(NewOp));
   PTest->eraseFromParent();
   if (OpChanged) {
+    Pred->setDesc(get(NewOp));
     bool succeeded = UpdateOperandRegClass(*Pred);
     (void)succeeded;
     assert(succeeded && "Operands have incompatible register classes!");
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index b5b8b68291786d..58d7be9979fba8 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3223,6 +3223,14 @@ unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
 
+ElementCount AArch64TTIImpl::getMaxPredicateLength(ElementCount VF) const {
+  // Do not create masks that are more than twice the VF.
+  unsigned N = ST->hasSVE2p1() ? 32 : ST->hasSVE() ? 16 : 0;
+  N = std::min(N, 2 * VF.getKnownMinValue());
+  return VF.isScalable() ? ElementCount::getScalable(N)
+                         : ElementCount::getFixed(N);
+}
+
 // For Falkor, we want to avoid having too many strided loads in a loop since
 // that can exhaust the HW prefetcher resources.  We adjust the unroller
 // MaxCount preference below to attempt to ensure unrolling doesn't create too
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 0b220069a388b6..077320a7a23715 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -157,6 +157,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   bool prefersVectorizedAddressing() const;
 
   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index b7552541e950d9..963b62988f7001 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -9754,7 +9754,7 @@ multiclass sve2p1_int_while_rr_pn<string mnemonic, bits<3> opc> {
 
 // SVE integer compare scalar count and limit (predicate pair)
 class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
-                             RegisterOperand ppr_ty>
+                             RegisterOperand ppr_ty, ElementSizeEnum EltSz>
     : I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm),
         mnemonic, "\t$Pd, $Rn, $Rm",
         "", []>, Sched<[]> {
@@ -9772,16 +9772,18 @@ class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
   let Inst{3-1}   = Pd;
   let Inst{0}     = opc{0};
 
+  let ElementSize = EltSz;
   let Defs = [NZCV];
   let hasSideEffects = 0;
+  let isWhile = 1;
 }
 
 
 multiclass sve2p1_int_while_rr_pair<string mnemonic, bits<3> opc> {
- def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r>;
- def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r>;
- def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r>;
- def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r>;
+ def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r, ElementSizeB>;
+ def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r, ElementSizeH>;
+ def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r, ElementSizeS>;
+ def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r, ElementSizeD>;
 }
 
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 577ce8000de27b..ff19245b429eff 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -178,6 +178,14 @@ class VPBuilder {
   VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
                       DebugLoc DL = {}, const Twine &Name = "");
 
+  VPValue *createGetActiveLaneMask(VPValue *IV, VPValue *TC, DebugLoc DL,
+                                   const Twine &Name = "") {
+    auto *ALM = new VPActiveLaneMaskRecipe(IV, TC, DL, Name);
+    if (BB)
+      BB->insert(ALM, InsertPt);
+    return ALM;
+  }
+
   //===--------------------------------------------------------------------===//
   // RAII helpers.
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f82e161fb846d1..48c87b9bb8c554 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -608,6 +608,10 @@ class InnerLoopVectorizer {
   /// count of the original loop for both main loop and epilogue vectorization.
   void setTripCount(Value *TC) { TripCount = TC; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const {
+    return TTI->getMaxPredicateLength(VF);
+  }
+
 protected:
   friend class LoopVectorizationPlanner;
 
@@ -7604,7 +7608,8 @@ SCEV2ValueTy LoopVectorizationPlanner::executePlan(
     VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE);
 
   // Perform the actual loop transformation.
-  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan,
+  VPTransformState State(BestVF, BestUF, TTI.getMaxPredicateLength(BestVF), LI,
+                         DT, ILV.Builder, &ILV, &BestVPlan,
                          OrigLoop->getHeader()->getContext());
 
   // 0. Generate SCEV-dependent code into the preheader, including TripCount,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 94cb7688981361..dd761aacd11fda 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -232,15 +232,16 @@ struct VPIteration {
 /// VPTransformState holds information passed down when "executing" a VPlan,
 /// needed for generating the output IR.
 struct VPTransformState {
-  VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
-                   DominatorTree *DT, IRBuilderBase &Builder,
+  VPTransformState(ElementCount VF, unsigned UF, ElementCount MaxPred,
+                   LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder,
                    InnerLoopVectorizer *ILV, VPlan *Plan, LLVMContext &Ctx)
-      : VF(VF), UF(UF), LI(LI), DT(DT), Builder(Builder), ILV(ILV), Plan(Plan),
-        LVer(nullptr), TypeAnalysis(Ctx) {}
+      : VF(VF), UF(UF), MaxPred(MaxPred), LI(LI), DT(DT), Builder(Builder),
+        ILV(ILV), Plan(Plan), LVer(nullptr), TypeAnalysis(Ctx) {}
 
   /// The chosen Vectorization and Unroll Factors of the loop being vectorized.
   ElementCount VF;
   unsigned UF;
+  ElementCount MaxPred;
 
   /// Hold the indices to generate specific scalar instructions. Null indicates
   /// that all instances are to be generated, using either scalar or vector
@@ -1164,7 +1165,6 @@ class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
     switch (getOpcode()) {
     default:
       return false;
-    case VPInstruction::ActiveLaneMask:
     case VPInstruction::CalculateTripCountMinusVF:
     case VPInstruction::CanonicalIVIncrementForPart:
     case VPInstruction::BranchOnCount:
@@ -1189,6 +1189,35 @@ class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
   }
 };
 
+class VPActiveLaneMaskRecipe : public VPRecipeWithIRFlags, public VPValue {
+  const std::string Name;
+  ElementCount MaxLength;
+
+public:
+  VPActiveLaneMaskRecipe(VPValu...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 28, 2023

@llvm/pr-subscribers-llvm-transforms

Author: Momchil Velikov (momchil-velikov)

Changes

These patches make the LoopVectorize generate lane masks longer than
the VF to allow the target to better utilise the instruction set.
The vectoriser emits one or more wide llvm.get.active.lane.mask.*
calls plus several llvm.vector.extract.* calls to yield the required
number of VF-wide masks.

The motivating example is a vectorised loop with unroll factor 2 that
can use the SVE2.1 whilelo instruction with predicate pair result, or
a SVE whilelo instruction with smaller element size plus
punpklo/punpkhi.

How wide is the lane mask that the vectoriser emits is controlled
by a TargetTransformInfo hook getMaxPredicateLength.The default
impementation (return the same length as the VF) keeps the
change non-functional for targets that can't or are not prepared
to handle wider lane masks.


Patch is 426.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76514.diff

22 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+10)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+2)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+105-33)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+13-2)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+8)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Target/AArch64/SVEInstrFormats.td (+7-5)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+8)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+6-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+34-5)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+70-14)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+7-11)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1)
  • (added) llvm/test/CodeGen/AArch64/get-active-lane-mask-32x1.ll (+31)
  • (added) llvm/test/CodeGen/AArch64/sve-wide-lane-mask.ll (+1003)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll (+886-874)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding-unroll.ll (+196-188)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/sve-wide-lane-mask.ll (+655)
  • (modified) llvm/test/Transforms/LoopVectorize/ARM/tail-folding-prefer-flag.ll (+3-3)
  • (modified) llvm/test/Transforms/LoopVectorize/strict-fadd-interleave-only.ll (+10-10)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 048912beaba5a1..8416f6138c6c46 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1216,6 +1216,8 @@ class TargetTransformInfo {
   /// and the number of execution units in the CPU.
   unsigned getMaxInterleaveFactor(ElementCount VF) const;
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   /// Collect properties of V used in cost analysis, e.g. OP_PowerOf2.
   static OperandValueInfo getOperandInfo(const Value *V);
 
@@ -1952,6 +1954,9 @@ class TargetTransformInfo::Concept {
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
+
+  virtual ElementCount getMaxPredicateLength(ElementCount VF) const = 0;
+
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       OperandValueInfo Opd1Info, OperandValueInfo Opd2Info,
@@ -2557,6 +2562,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
+
+  ElementCount getMaxPredicateLength(ElementCount VF) const override {
+    return Impl.getMaxPredicateLength(VF);
+  }
+
   unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
                                             unsigned &JTSize,
                                             ProfileSummaryInfo *PSI,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 7ad3ce512a3552..e341220fa48b09 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -513,6 +513,8 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info, TTI::OperandValueInfo Opd2Info,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 5e7bdcdf72a49f..fcad6a86538256 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -881,6 +881,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 67246afa23147a..aaf39be8ff0634 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -789,6 +789,10 @@ unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
 
+ElementCount TargetTransformInfo::getMaxPredicateLength(ElementCount VF) const {
+  return TTIImpl->getMaxPredicateLength(VF);
+}
+
 TargetTransformInfo::OperandValueInfo
 TargetTransformInfo::getOperandInfo(const Value *V) {
   OperandValueKind OpInfo = OK_AnyValue;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index dffe69bdb900db..bf5d48c903307f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1772,15 +1772,17 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
 
 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
                                                           EVT OpVT) const {
-  // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
-  if (!Subtarget->hasSVE())
+  // Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
+  if (!Subtarget->hasSVEorSME())
     return true;
 
   // We can only support legal predicate result types. We can use the SVE
   // whilelo instruction for generating fixed-width predicates too.
   if (ResVT != MVT::nxv2i1 && ResVT != MVT::nxv4i1 && ResVT != MVT::nxv8i1 &&
       ResVT != MVT::nxv16i1 && ResVT != MVT::v2i1 && ResVT != MVT::v4i1 &&
-      ResVT != MVT::v8i1 && ResVT != MVT::v16i1)
+      ResVT != MVT::v8i1 && ResVT != MVT::v16i1 &&
+      (!(Subtarget->hasSVE2p1() || Subtarget->hasSME2()) ||
+       ResVT != MVT::nxv32i1)) // TODO: handle MVT::v32i1
     return true;
 
   // The whilelo instruction only works with i32 or i64 scalar inputs.
@@ -17760,22 +17762,49 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
                         AArch64CC::CondCode Cond);
 
-static bool isPredicateCCSettingOp(SDValue N) {
-  if ((N.getOpcode() == ISD::SETCC) ||
-      (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-       (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
-        // get_active_lane_mask is lowered to a whilelo instruction.
-        N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
-    return true;
+static SDValue getPredicateCCSettingOp(SDValue N) {
+  if (N.getOpcode() == ISD::SETCC) {
+    EVT VT = N.getValueType();
+    return VT.isScalableVector() && VT.getVectorElementType() == MVT::i1
+               ? N
+               : SDValue();
+  }
 
-  return false;
+  if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+      isNullConstant(N.getOperand(1)))
+    N = N.getOperand(0);
+
+  if (N.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+    return SDValue();
+
+  switch (N.getConstantOperandVal(0)) {
+  default:
+    return SDValue();
+  case Intrinsic::aarch64_sve_whilege_x2:
+  case Intrinsic::aarch64_sve_whilegt_x2:
+  case Intrinsic::aarch64_sve_whilehi_x2:
+  case Intrinsic::aarch64_sve_whilehs_x2:
+  case Intrinsic::aarch64_sve_whilele_x2:
+  case Intrinsic::aarch64_sve_whilelo_x2:
+  case Intrinsic::aarch64_sve_whilels_x2:
+  case Intrinsic::aarch64_sve_whilelt_x2:
+    if (N.getResNo() != 0)
+      return SDValue();
+    [[fallthrough]];
+  case Intrinsic::aarch64_sve_whilege:
+  case Intrinsic::aarch64_sve_whilegt:
+  case Intrinsic::aarch64_sve_whilehi:
+  case Intrinsic::aarch64_sve_whilehs:
+  case Intrinsic::aarch64_sve_whilele:
+  case Intrinsic::aarch64_sve_whilelo:
+  case Intrinsic::aarch64_sve_whilels:
+  case Intrinsic::aarch64_sve_whilelt:
+  case Intrinsic::get_active_lane_mask:
+    assert(N.getValueType().isScalableVector() &&
+           N.getValueType().getVectorElementType() == MVT::i1 &&
+           "Intrinsic expected to yield scalable i1 vector");
+    return N;
+  }
 }
 
 // Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
@@ -17789,21 +17818,17 @@ performFirstTrueTestVectorCombine(SDNode *N,
   if (!Subtarget->hasSVE() || DCI.isBeforeLegalize())
     return SDValue();
 
-  SDValue N0 = N->getOperand(0);
-  EVT VT = N0.getValueType();
-
-  if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1 ||
-      !isNullConstant(N->getOperand(1)))
-    return SDValue();
-
-  // Restricted the DAG combine to only cases where we're extracting from a
-  // flag-setting operation.
-  if (!isPredicateCCSettingOp(N0))
+  // Restrict the DAG combine to only cases where we're extracting the zero-th
+  // element from the result of a flag-setting operation.
+  SDValue N0;
+  if (!isNullConstant(N->getOperand(1)) ||
+      !(N0 = getPredicateCCSettingOp(N->getOperand(0))))
     return SDValue();
 
   // Extracts of lane 0 for SVE can be expressed as PTEST(Op, FIRST) ? 1 : 0
   SelectionDAG &DAG = DCI.DAG;
-  SDValue Pg = getPTrue(DAG, SDLoc(N), VT, AArch64SVEPredPattern::all);
+  SDValue Pg =
+      getPTrue(DAG, SDLoc(N), N0.getValueType(), AArch64SVEPredPattern::all);
   return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::FIRST_ACTIVE);
 }
 
@@ -19768,7 +19793,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
   default:
     break;
   case Intrinsic::get_active_lane_mask: {
-    SDValue Res = SDValue();
     EVT VT = N->getValueType(0);
     if (VT.isFixedLengthVector()) {
       // We can use the SVE whilelo instruction to lower this intrinsic by
@@ -19791,15 +19815,63 @@ static SDValue performIntrinsicCombine(SDNode *N,
           EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
                            VT.getVectorElementCount());
 
-      Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
-                        N->getOperand(1), N->getOperand(2));
+      SDValue Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
+                                N->getOperand(1), N->getOperand(2));
       Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
       Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
                         DAG.getConstant(0, DL, MVT::i64));
       Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
+
+      return Res;
     }
-    return Res;
+
+    if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+      return SDValue();
+
+    if (!N->hasNUsesOfValue(2, 0))
+      return SDValue();
+
+    auto It = N->use_begin();
+    SDNode *Lo = *It++;
+    SDNode *Hi = *It;
+
+    const uint64_t HalfSize = VT.getVectorMinNumElements() / 2;
+    uint64_t OffLo, OffHi;
+    if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+        Lo->getOperand(1)->getOpcode() != ISD::Constant ||
+        ((OffLo = Lo->getConstantOperandVal(1)) != 0 && OffLo != HalfSize) ||
+        Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+        Hi->getOperand(1)->getOpcode() != ISD::Constant ||
+        ((OffHi = Hi->getConstantOperandVal(1)) != 0 && OffHi != HalfSize))
+      return SDValue();
+
+    if (OffLo > OffHi) {
+      std::swap(Lo, Hi);
+      std::swap(OffLo, OffHi);
+    }
+
+    if (OffLo != 0 || OffHi != HalfSize)
+      return SDValue();
+
+    SDLoc DL(N);
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+    SDValue Idx = N->getOperand(1);
+    SDValue TC = N->getOperand(2);
+    if (Idx.getValueType() != MVT::i64) {
+      Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+      TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+    }
+    auto R =
+        DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+                    {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+    DCI.CombineTo(Lo, R.getValue(0));
+    DCI.CombineTo(Hi, R.getValue(1));
+
+    return SDValue(N, 0);
   }
+
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 1cfbf4737a6f72..91a9af727ffb0b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1356,11 +1356,22 @@ bool AArch64InstrInfo::optimizePTestInstr(
     const MachineRegisterInfo *MRI) const {
   auto *Mask = MRI->getUniqueVRegDef(MaskReg);
   auto *Pred = MRI->getUniqueVRegDef(PredReg);
-  auto NewOp = Pred->getOpcode();
+  unsigned NewOp;
   bool OpChanged = false;
 
   unsigned MaskOpcode = Mask->getOpcode();
   unsigned PredOpcode = Pred->getOpcode();
+
+  // Handle a COPY from the LSB of a paired WHILEcc instruction.
+  if ((PredOpcode == TargetOpcode::COPY &&
+       Pred->getOperand(1).getSubReg() == AArch64::psub0)) {
+    MachineInstr *MI = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+    if (MI && isWhileOpcode(MI->getOpcode())) {
+      Pred = MI;
+      PredOpcode = MI->getOpcode();
+    }
+  }
+
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
@@ -1476,9 +1487,9 @@ bool AArch64InstrInfo::optimizePTestInstr(
   // as they are prior to PTEST. Sometimes this requires the tested PTEST
   // operand to be replaced with an equivalent instruction that also sets the
   // flags.
-  Pred->setDesc(get(NewOp));
   PTest->eraseFromParent();
   if (OpChanged) {
+    Pred->setDesc(get(NewOp));
     bool succeeded = UpdateOperandRegClass(*Pred);
     (void)succeeded;
     assert(succeeded && "Operands have incompatible register classes!");
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index b5b8b68291786d..58d7be9979fba8 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3223,6 +3223,14 @@ unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
 
+ElementCount AArch64TTIImpl::getMaxPredicateLength(ElementCount VF) const {
+  // Do not create masks that are more than twice the VF.
+  unsigned N = ST->hasSVE2p1() ? 32 : ST->hasSVE() ? 16 : 0;
+  N = std::min(N, 2 * VF.getKnownMinValue());
+  return VF.isScalable() ? ElementCount::getScalable(N)
+                         : ElementCount::getFixed(N);
+}
+
 // For Falkor, we want to avoid having too many strided loads in a loop since
 // that can exhaust the HW prefetcher resources.  We adjust the unroller
 // MaxCount preference below to attempt to ensure unrolling doesn't create too
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 0b220069a388b6..077320a7a23715 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -157,6 +157,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   bool prefersVectorizedAddressing() const;
 
   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index b7552541e950d9..963b62988f7001 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -9754,7 +9754,7 @@ multiclass sve2p1_int_while_rr_pn<string mnemonic, bits<3> opc> {
 
 // SVE integer compare scalar count and limit (predicate pair)
 class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
-                             RegisterOperand ppr_ty>
+                             RegisterOperand ppr_ty, ElementSizeEnum EltSz>
     : I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm),
         mnemonic, "\t$Pd, $Rn, $Rm",
         "", []>, Sched<[]> {
@@ -9772,16 +9772,18 @@ class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
   let Inst{3-1}   = Pd;
   let Inst{0}     = opc{0};
 
+  let ElementSize = EltSz;
   let Defs = [NZCV];
   let hasSideEffects = 0;
+  let isWhile = 1;
 }
 
 
 multiclass sve2p1_int_while_rr_pair<string mnemonic, bits<3> opc> {
- def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r>;
- def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r>;
- def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r>;
- def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r>;
+ def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r, ElementSizeB>;
+ def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r, ElementSizeH>;
+ def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r, ElementSizeS>;
+ def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r, ElementSizeD>;
 }
 
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 577ce8000de27b..ff19245b429eff 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -178,6 +178,14 @@ class VPBuilder {
   VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
                       DebugLoc DL = {}, const Twine &Name = "");
 
+  VPValue *createGetActiveLaneMask(VPValue *IV, VPValue *TC, DebugLoc DL,
+                                   const Twine &Name = "") {
+    auto *ALM = new VPActiveLaneMaskRecipe(IV, TC, DL, Name);
+    if (BB)
+      BB->insert(ALM, InsertPt);
+    return ALM;
+  }
+
   //===--------------------------------------------------------------------===//
   // RAII helpers.
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f82e161fb846d1..48c87b9bb8c554 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -608,6 +608,10 @@ class InnerLoopVectorizer {
   /// count of the original loop for both main loop and epilogue vectorization.
   void setTripCount(Value *TC) { TripCount = TC; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const {
+    return TTI->getMaxPredicateLength(VF);
+  }
+
 protected:
   friend class LoopVectorizationPlanner;
 
@@ -7604,7 +7608,8 @@ SCEV2ValueTy LoopVectorizationPlanner::executePlan(
     VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE);
 
   // Perform the actual loop transformation.
-  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan,
+  VPTransformState State(BestVF, BestUF, TTI.getMaxPredicateLength(BestVF), LI,
+                         DT, ILV.Builder, &ILV, &BestVPlan,
                          OrigLoop->getHeader()->getContext());
 
   // 0. Generate SCEV-dependent code into the preheader, including TripCount,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 94cb7688981361..dd761aacd11fda 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -232,15 +232,16 @@ struct VPIteration {
 /// VPTransformState holds information passed down when "executing" a VPlan,
 /// needed for generating the output IR.
 struct VPTransformState {
-  VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
-                   DominatorTree *DT, IRBuilderBase &Builder,
+  VPTransformState(ElementCount VF, unsigned UF, ElementCount MaxPred,
+                   LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder,
                    InnerLoopVectorizer *ILV, VPlan *Plan, LLVMContext &Ctx)
-      : VF(VF), UF(UF), LI(LI), DT(DT), Builder(Builder), ILV(ILV), Plan(Plan),
-        LVer(nullptr), TypeAnalysis(Ctx) {}
+      : VF(VF), UF(UF), MaxPred(MaxPred), LI(LI), DT(DT), Builder(Builder),
+        ILV(ILV), Plan(Plan), LVer(nullptr), TypeAnalysis(Ctx) {}
 
   /// The chosen Vectorization and Unroll Factors of the loop being vectorized.
   ElementCount VF;
   unsigned UF;
+  ElementCount MaxPred;
 
   /// Hold the indices to generate specific scalar instructions. Null indicates
   /// that all instances are to be generated, using either scalar or vector
@@ -1164,7 +1165,6 @@ class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
     switch (getOpcode()) {
     default:
       return false;
-    case VPInstruction::ActiveLaneMask:
     case VPInstruction::CalculateTripCountMinusVF:
     case VPInstruction::CanonicalIVIncrementForPart:
     case VPInstruction::BranchOnCount:
@@ -1189,6 +1189,35 @@ class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
   }
 };
 
+class VPActiveLaneMaskRecipe : public VPRecipeWithIRFlags, public VPValue {
+  const std::string Name;
+  ElementCount MaxLength;
+
+public:
+  VPActiveLaneMaskRecipe(VPValu...
[truncated]

@fhahn
Copy link
Contributor

fhahn commented Jan 3, 2024

This should probably be split up in 3 separate PRs to keep the reviews and comments focused and separate

@david-arm
Copy link
Contributor

This should probably be split up in 3 separate PRs to keep the reviews and comments focused and separate

I was just about to say the same thing, but then you beat me to it. :)

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
Hi->getOperand(1)->getOpcode() != ISD::Constant ||
((OffHi = Hi->getConstantOperandVal(1)) != 0 && OffHi != HalfSize))
return SDValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all places where we return SDValue() here do we have negative tests to show what happens? I just want to make sure we don't crash with non-matching IR. Even if the code is terrible it needs to work correctly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we will not be able to process arbitrary LLVM IR that uses 32-bit lane masks. That's pretty much the current state - we cannot process correctly in the AArch64 backend arbitrary valid LLVM IR.

The current solution is to not generate LLVM IR that the backend cannot process.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we can perhaps move the logic from here

if (!TLI.shouldExpandGetActiveLaneMask(CCVT, ElementVT)) {

to CodeGenPrepare and do the expansion in LLVM IR instead of DAG (and it'll work for GlobalISel too, maybe now it doesn't).

I'd rather not emit LLVM IR that results in terrible code in the first place, though.

Support lowering of `llvm.get.active.lane.mask.*` yielding a vector with
32 elements when the result is immediately split in half.
This patch makes the LoopVectorize generate lane masks longer than
the VF to allow the target to better utilise the instruction set.
The vectorizer emit one or more wide `llvm.get.active.lane.mask.*`
calls plus several `llvm.vector.extract.*` calls to yield the required
number of VF-wide masks.

The motivating exammple is a vectorised loop with unroll factor 2 that
can use the SVE2.1 `whilelo` instruction with predicate pair result, or
a SVE `whilelo` instruction with smaller element size plus
`punpklo`/`punpkhi`.

How wide is the lane mask that the vectoriser emits is controlled
by a TargetTransformInfo hook `getMaxPredicateLength`.The default
impementation (return the same length as the VF) keeps the
change non-functional for targets that can't or are not prepared
to handle wider lane masks.
@momchil-velikov
Copy link
Collaborator Author

Thanks for the reviews!

I've uploaded an update with some changes, next I'm going to create a new set of PRs.

@CarolineConcatto
Copy link
Contributor

Just in case: s/insntruction/instruction in the title bellow
[AArch64] Optimise test of the LSB of a paired whileCC insntruction
About this:
"..., next I'm going to create a new set of PRs."
Should we wait for you to split to review the patches again?

@momchil-velikov momchil-velikov deleted the wide-active-lane-mask branch February 8, 2024 14:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants