Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] Optimise test of the LSB of a paired whileCC instruction #81141

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

momchil-velikov
Copy link
Collaborator

Try to directly use the flags set by a whileCC instruction.

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 8, 2024

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-analysis

Author: Momchil Velikov (momchil-velikov)

Changes

Try to directly use the flags set by a whileCC instruction.


Patch is 503.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81141.diff

25 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+10)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+2)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+135-61)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+13-2)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+9)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Target/AArch64/SVEInstrFormats.td (+7-5)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+8)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+6-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+42-4)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+70-15)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+7-11)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1)
  • (modified) llvm/test/CodeGen/AArch64/active_lane_mask.ll (+1)
  • (added) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+90)
  • (added) llvm/test/CodeGen/AArch64/sve-wide-lane-mask.ll (+1053)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll (+171-186)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll (+890-878)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding-unroll.ll (+197-189)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/sve-wide-lane-mask.ll (+664)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/uniform-args-call-variants.ll (+77-89)
  • (modified) llvm/test/Transforms/LoopVectorize/ARM/tail-folding-prefer-flag.ll (+3-3)
  • (modified) llvm/test/Transforms/LoopVectorize/strict-fadd-interleave-only.ll (+10-10)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 58577a6b6eb5c..67e1b45cce29c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1228,6 +1228,8 @@ class TargetTransformInfo {
   /// and the number of execution units in the CPU.
   unsigned getMaxInterleaveFactor(ElementCount VF) const;
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   /// Collect properties of V used in cost analysis, e.g. OP_PowerOf2.
   static OperandValueInfo getOperandInfo(const Value *V);
 
@@ -1981,6 +1983,9 @@ class TargetTransformInfo::Concept {
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
+
+  virtual ElementCount getMaxPredicateLength(ElementCount VF) const = 0;
+
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       OperandValueInfo Opd1Info, OperandValueInfo Opd2Info,
@@ -2601,6 +2606,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
+
+  ElementCount getMaxPredicateLength(ElementCount VF) const override {
+    return Impl.getMaxPredicateLength(VF);
+  }
+
   unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
                                             unsigned &JTSize,
                                             ProfileSummaryInfo *PSI,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 3d5db96e86b80..b6d01e0764ab1 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -528,6 +528,8 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info, TTI::OperandValueInfo Opd2Info,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index bb17298daba03..2b0d0f3ed6f70 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -881,6 +881,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1f11f0d7dd620..daea8e48981ec 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -808,6 +808,10 @@ unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
 
+ElementCount TargetTransformInfo::getMaxPredicateLength(ElementCount VF) const {
+  return TTIImpl->getMaxPredicateLength(VF);
+}
+
 TargetTransformInfo::OperandValueInfo
 TargetTransformInfo::getOperandInfo(const Value *V) {
   OperandValueKind OpInfo = OK_AnyValue;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8573939b04389..7d40721b24fcc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1813,8 +1813,8 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
 
 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
                                                           EVT OpVT) const {
-  // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
-  if (!Subtarget->hasSVE())
+  // Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
+  if (!Subtarget->hasSVEorSME())
     return true;
 
   // We can only support legal predicate result types. We can use the SVE
@@ -18032,22 +18032,49 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
                         AArch64CC::CondCode Cond);
 
-static bool isPredicateCCSettingOp(SDValue N) {
-  if ((N.getOpcode() == ISD::SETCC) ||
-      (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-       (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
-        // get_active_lane_mask is lowered to a whilelo instruction.
-        N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
-    return true;
+static SDValue getPredicateCCSettingOp(SDValue N) {
+  if (N.getOpcode() == ISD::SETCC) {
+    EVT VT = N.getValueType();
+    return VT.isScalableVector() && VT.getVectorElementType() == MVT::i1
+               ? N
+               : SDValue();
+  }
 
-  return false;
+  if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+      isNullConstant(N.getOperand(1)))
+    N = N.getOperand(0);
+
+  if (N.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+    return SDValue();
+
+  switch (N.getConstantOperandVal(0)) {
+  default:
+    return SDValue();
+  case Intrinsic::aarch64_sve_whilege_x2:
+  case Intrinsic::aarch64_sve_whilegt_x2:
+  case Intrinsic::aarch64_sve_whilehi_x2:
+  case Intrinsic::aarch64_sve_whilehs_x2:
+  case Intrinsic::aarch64_sve_whilele_x2:
+  case Intrinsic::aarch64_sve_whilelo_x2:
+  case Intrinsic::aarch64_sve_whilels_x2:
+  case Intrinsic::aarch64_sve_whilelt_x2:
+    if (N.getResNo() != 0)
+      return SDValue();
+    [[fallthrough]];
+  case Intrinsic::aarch64_sve_whilege:
+  case Intrinsic::aarch64_sve_whilegt:
+  case Intrinsic::aarch64_sve_whilehi:
+  case Intrinsic::aarch64_sve_whilehs:
+  case Intrinsic::aarch64_sve_whilele:
+  case Intrinsic::aarch64_sve_whilelo:
+  case Intrinsic::aarch64_sve_whilels:
+  case Intrinsic::aarch64_sve_whilelt:
+  case Intrinsic::get_active_lane_mask:
+    assert(N.getValueType().isScalableVector() &&
+           N.getValueType().getVectorElementType() == MVT::i1 &&
+           "Intrinsic expected to yield scalable i1 vector");
+    return N;
+  }
 }
 
 // Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
@@ -18061,21 +18088,17 @@ performFirstTrueTestVectorCombine(SDNode *N,
   if (!Subtarget->hasSVE() || DCI.isBeforeLegalize())
     return SDValue();
 
-  SDValue N0 = N->getOperand(0);
-  EVT VT = N0.getValueType();
-
-  if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1 ||
-      !isNullConstant(N->getOperand(1)))
-    return SDValue();
-
-  // Restricted the DAG combine to only cases where we're extracting from a
-  // flag-setting operation.
-  if (!isPredicateCCSettingOp(N0))
+  // Restrict the DAG combine to only cases where we're extracting the zero-th
+  // element from the result of a flag-setting operation.
+  SDValue N0;
+  if (!isNullConstant(N->getOperand(1)) ||
+      !(N0 = getPredicateCCSettingOp(N->getOperand(0))))
     return SDValue();
 
   // Extracts of lane 0 for SVE can be expressed as PTEST(Op, FIRST) ? 1 : 0
   SelectionDAG &DAG = DCI.DAG;
-  SDValue Pg = getPTrue(DAG, SDLoc(N), VT, AArch64SVEPredPattern::all);
+  SDValue Pg =
+      getPTrue(DAG, SDLoc(N), N0.getValueType(), AArch64SVEPredPattern::all);
   return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::FIRST_ACTIVE);
 }
 
@@ -20004,47 +20027,98 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
   return SDValue();
 }
 
-static SDValue performIntrinsicCombine(SDNode *N,
-                                       TargetLowering::DAGCombinerInfo &DCI,
-                                       const AArch64Subtarget *Subtarget) {
+static SDValue tryCombineGetActiveLaneMask(SDNode *N,
+                                           TargetLowering::DAGCombinerInfo &DCI,
+                                           const AArch64Subtarget *Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
-  unsigned IID = getIntrinsicID(N);
-  switch (IID) {
-  default:
-    break;
-  case Intrinsic::get_active_lane_mask: {
-    SDValue Res = SDValue();
-    EVT VT = N->getValueType(0);
-    if (VT.isFixedLengthVector()) {
-      // We can use the SVE whilelo instruction to lower this intrinsic by
-      // creating the appropriate sequence of scalable vector operations and
-      // then extracting a fixed-width subvector from the scalable vector.
+  EVT VT = N->getValueType(0);
+  if (VT.isFixedLengthVector()) {
+    // We can use the SVE whilelo instruction to lower this intrinsic by
+    // creating the appropriate sequence of scalable vector operations and
+    // then extracting a fixed-width subvector from the scalable vector.
+    SDLoc DL(N);
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
 
-      SDLoc DL(N);
-      SDValue ID =
-          DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
+    EVT WhileVT =
+        EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                         ElementCount::getScalable(VT.getVectorNumElements()));
 
-      EVT WhileVT = EVT::getVectorVT(
-          *DAG.getContext(), MVT::i1,
-          ElementCount::getScalable(VT.getVectorNumElements()));
+    // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
+    EVT PromVT = getPromotedVTForPredicate(WhileVT);
 
-      // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
-      EVT PromVT = getPromotedVTForPredicate(WhileVT);
+    // Get the fixed-width equivalent of PromVT for extraction.
+    EVT ExtVT =
+        EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
+                         VT.getVectorElementCount());
 
-      // Get the fixed-width equivalent of PromVT for extraction.
-      EVT ExtVT =
-          EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
-                           VT.getVectorElementCount());
+    SDValue Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
+                              N->getOperand(1), N->getOperand(2));
+    Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
+    Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
+                      DAG.getConstant(0, DL, MVT::i64));
+    Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
 
-      Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
-                        N->getOperand(1), N->getOperand(2));
-      Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
-      Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
-                        DAG.getConstant(0, DL, MVT::i64));
-      Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
-    }
     return Res;
   }
+
+  if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+    return SDValue();
+
+  if (!N->hasNUsesOfValue(2, 0))
+    return SDValue();
+
+  auto It = N->use_begin();
+  SDNode *Lo = *It++;
+  SDNode *Hi = *It;
+
+  const uint64_t HalfSize = VT.getVectorMinNumElements() / 2;
+  uint64_t OffLo, OffHi;
+  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
+      (OffLo != 0 && OffLo != HalfSize) ||
+      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Hi->getOperand(1).getNode(), OffHi) ||
+      (OffHi != 0 && OffHi != HalfSize))
+    return SDValue();
+
+  if (OffLo > OffHi) {
+    std::swap(Lo, Hi);
+    std::swap(OffLo, OffHi);
+  }
+
+  if (OffLo != 0 || OffHi != HalfSize)
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+  SDValue Idx = N->getOperand(1);
+  SDValue TC = N->getOperand(2);
+  if (Idx.getValueType() != MVT::i64) {
+    Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+    TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+  }
+  auto R =
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+                  {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+  DCI.CombineTo(Lo, R.getValue(0));
+  DCI.CombineTo(Hi, R.getValue(1));
+
+  return SDValue(N, 0);
+}
+
+static SDValue performIntrinsicCombine(SDNode *N,
+                                       TargetLowering::DAGCombinerInfo &DCI,
+                                       const AArch64Subtarget *Subtarget) {
+  SelectionDAG &DAG = DCI.DAG;
+  unsigned IID = getIntrinsicID(N);
+  switch (IID) {
+  default:
+    break;
+  case Intrinsic::get_active_lane_mask:
+    return tryCombineGetActiveLaneMask(N, DCI, Subtarget);
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 9add7d87017a7..e2068b2d88ec9 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1358,11 +1358,22 @@ bool AArch64InstrInfo::optimizePTestInstr(
     const MachineRegisterInfo *MRI) const {
   auto *Mask = MRI->getUniqueVRegDef(MaskReg);
   auto *Pred = MRI->getUniqueVRegDef(PredReg);
-  auto NewOp = Pred->getOpcode();
+  unsigned NewOp;
   bool OpChanged = false;
 
   unsigned MaskOpcode = Mask->getOpcode();
   unsigned PredOpcode = Pred->getOpcode();
+
+  // Handle a COPY from the LSB of a paired WHILEcc instruction.
+  if ((PredOpcode == TargetOpcode::COPY &&
+       Pred->getOperand(1).getSubReg() == AArch64::psub0)) {
+    MachineInstr *MI = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+    if (MI && isWhileOpcode(MI->getOpcode())) {
+      Pred = MI;
+      PredOpcode = MI->getOpcode();
+    }
+  }
+
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
@@ -1478,9 +1489,9 @@ bool AArch64InstrInfo::optimizePTestInstr(
   // as they are prior to PTEST. Sometimes this requires the tested PTEST
   // operand to be replaced with an equivalent instruction that also sets the
   // flags.
-  Pred->setDesc(get(NewOp));
   PTest->eraseFromParent();
   if (OpChanged) {
+    Pred->setDesc(get(NewOp));
     bool succeeded = UpdateOperandRegClass(*Pred);
     (void)succeeded;
     assert(succeeded && "Operands have incompatible register classes!");
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index cdd2750521d2c..73aca77305df1 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3285,6 +3285,15 @@ unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
 
+ElementCount AArch64TTIImpl::getMaxPredicateLength(ElementCount VF) const {
+  // Do not create masks bigger than `<vscale x 16 x i1>`.
+  unsigned N = ST->hasSVE() ? 16 : 0;
+  // Do not create masks that are more than twice the VF.
+  N = std::min(N, 2 * VF.getKnownMinValue());
+  return VF.isScalable() ? ElementCount::getScalable(N)
+                         : ElementCount::getFixed(N);
+}
+
 // For Falkor, we want to avoid having too many strided loads in a loop since
 // that can exhaust the HW prefetcher resources.  We adjust the unroller
 // MaxCount preference below to attempt to ensure unrolling doesn't create too
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index de39dea2be43e..6501cc4a85e8d 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -157,6 +157,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   bool prefersVectorizedAddressing() const;
 
   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 789ec817d3d8b..718b245c6d829 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -9754,7 +9754,7 @@ multiclass sve2p1_int_while_rr_pn<string mnemonic, bits<3> opc> {
 
 // SVE integer compare scalar count and limit (predicate pair)
 class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
-                             RegisterOperand ppr_ty>
+                             RegisterOperand ppr_ty, ElementSizeEnum EltSz>
     : I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm),
         mnemonic, "\t$Pd, $Rn, $Rm",
         "", []>, Sched<[]> {
@@ -9772,16 +9772,18 @@ class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
   let Inst{3-1}   = Pd;
   let Inst{0}     = opc{0};
 
+  let ElementSize = EltSz;
   let Defs = [NZCV];
   let hasSideEffects = 0;
+  let isWhile = 1;
 }
 
 
 multiclass sve2p1_int_while_rr_pair<string mnemonic, bits<3> opc> {
- def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r>;
- def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r>;
- def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r>;
- def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r>;
+ def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r, ElementSizeB>;
+ def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r, ElementSizeH>;
+ def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r, ElementSizeS>;
+ def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r, ElementSizeD>;
 }
 
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index a7ebf78e54ceb..0e681c8080bfd 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -184,6 +184,14 @@ class VPBuilder {
   VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
                       DebugLoc DL = {}, const Twine &Name = "");
 
+  VPValue *createGetActiveLaneMask(VPValue *IV, VPValue *TC, DebugLoc DL,
+                                   const Twine &Name = "") {
+    auto *ALM = new VPActiveLaneMaskRecipe(IV, TC, DL, Name);
+    if (BB)
+      BB->insert(ALM, InsertPt);
+    return ALM;
+  }
+
   //===--------------------------------------------------------------------===//
   // RAII helpers.
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1a7b301c35f2b..bac66e633a6f3 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -599,6 +599,10 @@ class InnerLoopVectorizer {
   /// count of the original loop for both main loop and epilogue vectorization.
   void setTripCount(Value *TC) { TripCount = TC; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const {
+    return TTI->getMaxPredicateLength(VF);
+  }
+
 protected:
   friend class LoopVectorizationPlanner;
 
@@ -7550,7 +7554,8 @@ LoopVectorizationPlanner::executePlan(
     VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE);
 
   // Perform the actual loop transformation.
-  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan,
+  VPTransformState State(BestVF, BestUF, TTI.getMaxPredicateLength(BestVF), LI,
+                         DT, ILV.Builder, &ILV, &BestVPlan,
                          OrigLoop->getHeader()->getContext());
 
   // 0. Generate SCEV-dependent code into the preheader...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 8, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Momchil Velikov (momchil-velikov)

Changes

Try to directly use the flags set by a whileCC instruction.


Patch is 503.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81141.diff

25 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+10)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+2)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+135-61)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+13-2)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+9)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Target/AArch64/SVEInstrFormats.td (+7-5)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+8)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+6-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+42-4)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+70-15)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+7-11)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1)
  • (modified) llvm/test/CodeGen/AArch64/active_lane_mask.ll (+1)
  • (added) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+90)
  • (added) llvm/test/CodeGen/AArch64/sve-wide-lane-mask.ll (+1053)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll (+171-186)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll (+890-878)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding-unroll.ll (+197-189)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/sve-wide-lane-mask.ll (+664)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/uniform-args-call-variants.ll (+77-89)
  • (modified) llvm/test/Transforms/LoopVectorize/ARM/tail-folding-prefer-flag.ll (+3-3)
  • (modified) llvm/test/Transforms/LoopVectorize/strict-fadd-interleave-only.ll (+10-10)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 58577a6b6eb5c0..67e1b45cce29c5 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1228,6 +1228,8 @@ class TargetTransformInfo {
   /// and the number of execution units in the CPU.
   unsigned getMaxInterleaveFactor(ElementCount VF) const;
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   /// Collect properties of V used in cost analysis, e.g. OP_PowerOf2.
   static OperandValueInfo getOperandInfo(const Value *V);
 
@@ -1981,6 +1983,9 @@ class TargetTransformInfo::Concept {
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
+
+  virtual ElementCount getMaxPredicateLength(ElementCount VF) const = 0;
+
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       OperandValueInfo Opd1Info, OperandValueInfo Opd2Info,
@@ -2601,6 +2606,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
+
+  ElementCount getMaxPredicateLength(ElementCount VF) const override {
+    return Impl.getMaxPredicateLength(VF);
+  }
+
   unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
                                             unsigned &JTSize,
                                             ProfileSummaryInfo *PSI,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 3d5db96e86b804..b6d01e0764ab14 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -528,6 +528,8 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info, TTI::OperandValueInfo Opd2Info,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index bb17298daba03a..2b0d0f3ed6f706 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -881,6 +881,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1f11f0d7dd620e..daea8e48981ecb 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -808,6 +808,10 @@ unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
 
+ElementCount TargetTransformInfo::getMaxPredicateLength(ElementCount VF) const {
+  return TTIImpl->getMaxPredicateLength(VF);
+}
+
 TargetTransformInfo::OperandValueInfo
 TargetTransformInfo::getOperandInfo(const Value *V) {
   OperandValueKind OpInfo = OK_AnyValue;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8573939b04389f..7d40721b24fcc1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1813,8 +1813,8 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
 
 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
                                                           EVT OpVT) const {
-  // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
-  if (!Subtarget->hasSVE())
+  // Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
+  if (!Subtarget->hasSVEorSME())
     return true;
 
   // We can only support legal predicate result types. We can use the SVE
@@ -18032,22 +18032,49 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
                         AArch64CC::CondCode Cond);
 
-static bool isPredicateCCSettingOp(SDValue N) {
-  if ((N.getOpcode() == ISD::SETCC) ||
-      (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-       (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
-        // get_active_lane_mask is lowered to a whilelo instruction.
-        N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
-    return true;
+static SDValue getPredicateCCSettingOp(SDValue N) {
+  if (N.getOpcode() == ISD::SETCC) {
+    EVT VT = N.getValueType();
+    return VT.isScalableVector() && VT.getVectorElementType() == MVT::i1
+               ? N
+               : SDValue();
+  }
 
-  return false;
+  if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+      isNullConstant(N.getOperand(1)))
+    N = N.getOperand(0);
+
+  if (N.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+    return SDValue();
+
+  switch (N.getConstantOperandVal(0)) {
+  default:
+    return SDValue();
+  case Intrinsic::aarch64_sve_whilege_x2:
+  case Intrinsic::aarch64_sve_whilegt_x2:
+  case Intrinsic::aarch64_sve_whilehi_x2:
+  case Intrinsic::aarch64_sve_whilehs_x2:
+  case Intrinsic::aarch64_sve_whilele_x2:
+  case Intrinsic::aarch64_sve_whilelo_x2:
+  case Intrinsic::aarch64_sve_whilels_x2:
+  case Intrinsic::aarch64_sve_whilelt_x2:
+    if (N.getResNo() != 0)
+      return SDValue();
+    [[fallthrough]];
+  case Intrinsic::aarch64_sve_whilege:
+  case Intrinsic::aarch64_sve_whilegt:
+  case Intrinsic::aarch64_sve_whilehi:
+  case Intrinsic::aarch64_sve_whilehs:
+  case Intrinsic::aarch64_sve_whilele:
+  case Intrinsic::aarch64_sve_whilelo:
+  case Intrinsic::aarch64_sve_whilels:
+  case Intrinsic::aarch64_sve_whilelt:
+  case Intrinsic::get_active_lane_mask:
+    assert(N.getValueType().isScalableVector() &&
+           N.getValueType().getVectorElementType() == MVT::i1 &&
+           "Intrinsic expected to yield scalable i1 vector");
+    return N;
+  }
 }
 
 // Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
@@ -18061,21 +18088,17 @@ performFirstTrueTestVectorCombine(SDNode *N,
   if (!Subtarget->hasSVE() || DCI.isBeforeLegalize())
     return SDValue();
 
-  SDValue N0 = N->getOperand(0);
-  EVT VT = N0.getValueType();
-
-  if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1 ||
-      !isNullConstant(N->getOperand(1)))
-    return SDValue();
-
-  // Restricted the DAG combine to only cases where we're extracting from a
-  // flag-setting operation.
-  if (!isPredicateCCSettingOp(N0))
+  // Restrict the DAG combine to only cases where we're extracting the zero-th
+  // element from the result of a flag-setting operation.
+  SDValue N0;
+  if (!isNullConstant(N->getOperand(1)) ||
+      !(N0 = getPredicateCCSettingOp(N->getOperand(0))))
     return SDValue();
 
   // Extracts of lane 0 for SVE can be expressed as PTEST(Op, FIRST) ? 1 : 0
   SelectionDAG &DAG = DCI.DAG;
-  SDValue Pg = getPTrue(DAG, SDLoc(N), VT, AArch64SVEPredPattern::all);
+  SDValue Pg =
+      getPTrue(DAG, SDLoc(N), N0.getValueType(), AArch64SVEPredPattern::all);
   return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::FIRST_ACTIVE);
 }
 
@@ -20004,47 +20027,98 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
   return SDValue();
 }
 
-static SDValue performIntrinsicCombine(SDNode *N,
-                                       TargetLowering::DAGCombinerInfo &DCI,
-                                       const AArch64Subtarget *Subtarget) {
+static SDValue tryCombineGetActiveLaneMask(SDNode *N,
+                                           TargetLowering::DAGCombinerInfo &DCI,
+                                           const AArch64Subtarget *Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
-  unsigned IID = getIntrinsicID(N);
-  switch (IID) {
-  default:
-    break;
-  case Intrinsic::get_active_lane_mask: {
-    SDValue Res = SDValue();
-    EVT VT = N->getValueType(0);
-    if (VT.isFixedLengthVector()) {
-      // We can use the SVE whilelo instruction to lower this intrinsic by
-      // creating the appropriate sequence of scalable vector operations and
-      // then extracting a fixed-width subvector from the scalable vector.
+  EVT VT = N->getValueType(0);
+  if (VT.isFixedLengthVector()) {
+    // We can use the SVE whilelo instruction to lower this intrinsic by
+    // creating the appropriate sequence of scalable vector operations and
+    // then extracting a fixed-width subvector from the scalable vector.
+    SDLoc DL(N);
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
 
-      SDLoc DL(N);
-      SDValue ID =
-          DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
+    EVT WhileVT =
+        EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                         ElementCount::getScalable(VT.getVectorNumElements()));
 
-      EVT WhileVT = EVT::getVectorVT(
-          *DAG.getContext(), MVT::i1,
-          ElementCount::getScalable(VT.getVectorNumElements()));
+    // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
+    EVT PromVT = getPromotedVTForPredicate(WhileVT);
 
-      // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
-      EVT PromVT = getPromotedVTForPredicate(WhileVT);
+    // Get the fixed-width equivalent of PromVT for extraction.
+    EVT ExtVT =
+        EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
+                         VT.getVectorElementCount());
 
-      // Get the fixed-width equivalent of PromVT for extraction.
-      EVT ExtVT =
-          EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
-                           VT.getVectorElementCount());
+    SDValue Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
+                              N->getOperand(1), N->getOperand(2));
+    Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
+    Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
+                      DAG.getConstant(0, DL, MVT::i64));
+    Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
 
-      Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
-                        N->getOperand(1), N->getOperand(2));
-      Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
-      Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
-                        DAG.getConstant(0, DL, MVT::i64));
-      Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
-    }
     return Res;
   }
+
+  if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+    return SDValue();
+
+  if (!N->hasNUsesOfValue(2, 0))
+    return SDValue();
+
+  auto It = N->use_begin();
+  SDNode *Lo = *It++;
+  SDNode *Hi = *It;
+
+  const uint64_t HalfSize = VT.getVectorMinNumElements() / 2;
+  uint64_t OffLo, OffHi;
+  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
+      (OffLo != 0 && OffLo != HalfSize) ||
+      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Hi->getOperand(1).getNode(), OffHi) ||
+      (OffHi != 0 && OffHi != HalfSize))
+    return SDValue();
+
+  if (OffLo > OffHi) {
+    std::swap(Lo, Hi);
+    std::swap(OffLo, OffHi);
+  }
+
+  if (OffLo != 0 || OffHi != HalfSize)
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+  SDValue Idx = N->getOperand(1);
+  SDValue TC = N->getOperand(2);
+  if (Idx.getValueType() != MVT::i64) {
+    Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+    TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+  }
+  auto R =
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+                  {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+  DCI.CombineTo(Lo, R.getValue(0));
+  DCI.CombineTo(Hi, R.getValue(1));
+
+  return SDValue(N, 0);
+}
+
+static SDValue performIntrinsicCombine(SDNode *N,
+                                       TargetLowering::DAGCombinerInfo &DCI,
+                                       const AArch64Subtarget *Subtarget) {
+  SelectionDAG &DAG = DCI.DAG;
+  unsigned IID = getIntrinsicID(N);
+  switch (IID) {
+  default:
+    break;
+  case Intrinsic::get_active_lane_mask:
+    return tryCombineGetActiveLaneMask(N, DCI, Subtarget);
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 9add7d87017a73..e2068b2d88ec9f 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1358,11 +1358,22 @@ bool AArch64InstrInfo::optimizePTestInstr(
     const MachineRegisterInfo *MRI) const {
   auto *Mask = MRI->getUniqueVRegDef(MaskReg);
   auto *Pred = MRI->getUniqueVRegDef(PredReg);
-  auto NewOp = Pred->getOpcode();
+  unsigned NewOp;
   bool OpChanged = false;
 
   unsigned MaskOpcode = Mask->getOpcode();
   unsigned PredOpcode = Pred->getOpcode();
+
+  // Handle a COPY from the LSB of a paired WHILEcc instruction.
+  if ((PredOpcode == TargetOpcode::COPY &&
+       Pred->getOperand(1).getSubReg() == AArch64::psub0)) {
+    MachineInstr *MI = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+    if (MI && isWhileOpcode(MI->getOpcode())) {
+      Pred = MI;
+      PredOpcode = MI->getOpcode();
+    }
+  }
+
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
@@ -1478,9 +1489,9 @@ bool AArch64InstrInfo::optimizePTestInstr(
   // as they are prior to PTEST. Sometimes this requires the tested PTEST
   // operand to be replaced with an equivalent instruction that also sets the
   // flags.
-  Pred->setDesc(get(NewOp));
   PTest->eraseFromParent();
   if (OpChanged) {
+    Pred->setDesc(get(NewOp));
     bool succeeded = UpdateOperandRegClass(*Pred);
     (void)succeeded;
     assert(succeeded && "Operands have incompatible register classes!");
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index cdd2750521d2c9..73aca77305df1f 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3285,6 +3285,15 @@ unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
 
+ElementCount AArch64TTIImpl::getMaxPredicateLength(ElementCount VF) const {
+  // Do not create masks bigger than `<vscale x 16 x i1>`.
+  unsigned N = ST->hasSVE() ? 16 : 0;
+  // Do not create masks that are more than twice the VF.
+  N = std::min(N, 2 * VF.getKnownMinValue());
+  return VF.isScalable() ? ElementCount::getScalable(N)
+                         : ElementCount::getFixed(N);
+}
+
 // For Falkor, we want to avoid having too many strided loads in a loop since
 // that can exhaust the HW prefetcher resources.  We adjust the unroller
 // MaxCount preference below to attempt to ensure unrolling doesn't create too
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index de39dea2be43e1..6501cc4a85e8d3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -157,6 +157,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   bool prefersVectorizedAddressing() const;
 
   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 789ec817d3d8b8..718b245c6d8290 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -9754,7 +9754,7 @@ multiclass sve2p1_int_while_rr_pn<string mnemonic, bits<3> opc> {
 
 // SVE integer compare scalar count and limit (predicate pair)
 class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
-                             RegisterOperand ppr_ty>
+                             RegisterOperand ppr_ty, ElementSizeEnum EltSz>
     : I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm),
         mnemonic, "\t$Pd, $Rn, $Rm",
         "", []>, Sched<[]> {
@@ -9772,16 +9772,18 @@ class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
   let Inst{3-1}   = Pd;
   let Inst{0}     = opc{0};
 
+  let ElementSize = EltSz;
   let Defs = [NZCV];
   let hasSideEffects = 0;
+  let isWhile = 1;
 }
 
 
 multiclass sve2p1_int_while_rr_pair<string mnemonic, bits<3> opc> {
- def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r>;
- def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r>;
- def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r>;
- def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r>;
+ def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r, ElementSizeB>;
+ def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r, ElementSizeH>;
+ def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r, ElementSizeS>;
+ def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r, ElementSizeD>;
 }
 
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index a7ebf78e54ceb6..0e681c8080bfd1 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -184,6 +184,14 @@ class VPBuilder {
   VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
                       DebugLoc DL = {}, const Twine &Name = "");
 
+  VPValue *createGetActiveLaneMask(VPValue *IV, VPValue *TC, DebugLoc DL,
+                                   const Twine &Name = "") {
+    auto *ALM = new VPActiveLaneMaskRecipe(IV, TC, DL, Name);
+    if (BB)
+      BB->insert(ALM, InsertPt);
+    return ALM;
+  }
+
   //===--------------------------------------------------------------------===//
   // RAII helpers.
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1a7b301c35f2b8..bac66e633a6f3f 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -599,6 +599,10 @@ class InnerLoopVectorizer {
   /// count of the original loop for both main loop and epilogue vectorization.
   void setTripCount(Value *TC) { TripCount = TC; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const {
+    return TTI->getMaxPredicateLength(VF);
+  }
+
 protected:
   friend class LoopVectorizationPlanner;
 
@@ -7550,7 +7554,8 @@ LoopVectorizationPlanner::executePlan(
     VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE);
 
   // Perform the actual loop transformation.
-  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan,
+  VPTransformState State(BestVF, BestUF, TTI.getMaxPredicateLength(BestVF), LI,
+                         DT, ILV.Builder, &ILV, &BestVPlan,
                          OrigLoop->getHeader()->getContext());
 
   // 0. Generate SCEV-dependent c...
[truncated]

Copy link

github-actions bot commented Apr 5, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff ae8627809076390dbab04e01f3bf9d384c9e124e ccaf625347f47674d839a7d5b2bf1517d193c4a4 -- llvm/include/llvm/Analysis/TargetTransformInfo.h llvm/include/llvm/Analysis/TargetTransformInfoImpl.h llvm/include/llvm/CodeGen/BasicTTIImpl.h llvm/lib/Analysis/TargetTransformInfo.cpp llvm/lib/Target/AArch64/AArch64ISelLowering.cpp llvm/lib/Target/AArch64/AArch64ISelLowering.h llvm/lib/Target/AArch64/AArch64InstrInfo.cpp llvm/lib/Target/AArch64/AArch64InstrInfo.h llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h llvm/lib/Transforms/Vectorize/LoopVectorize.cpp llvm/lib/Transforms/Vectorize/VPlan.cpp llvm/lib/Transforms/Vectorize/VPlan.h llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp llvm/lib/Transforms/Vectorize/VPlanValue.h
View the diff from clang-format here.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 5cd18ea7ba..0ccd94f0cb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -221,7 +221,6 @@ m_BranchOnCond(const Op0_t &Op0) {
   return m_VPInstruction<VPInstruction::BranchOnCond>(Op0);
 }
 
-
 template <typename Op0_t, typename Op1_t>
 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount>
 m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) {
@@ -292,8 +291,7 @@ m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
   return m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1);
 }
 
-template <typename Op0_t, typename Op1_t>
-struct VPActiveLaneMask_match {
+template <typename Op0_t, typename Op1_t> struct VPActiveLaneMask_match {
   Op0_t Op0;
   Op1_t Op1;
 
@@ -315,8 +313,8 @@ struct VPActiveLaneMask_match {
 };
 
 template <typename Op0_t, typename Op1_t>
-inline VPActiveLaneMask_match<Op0_t, Op1_t>
-m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) {
+inline VPActiveLaneMask_match<Op0_t, Op1_t> m_ActiveLaneMask(const Op0_t &Op0,
+                                                             const Op1_t &Op1) {
   return {Op0, Op1};
 }
 

This patch makes the LoopVectorize generate lane masks longer than
the VF to allow the target to better utilise the instruction set.
The vectorizer emit one or more wide `llvm.get.active.lane.mask.*`
calls plus several `llvm.vector.extract.*` calls to yield the required
number of VF-wide masks.

The motivating exammple is a vectorised loop with unroll factor 2 that
can use the SVE2.1 `whilelo` instruction with predicate pair result, or
a SVE `whilelo` instruction with smaller element size plus
`punpklo`/`punpkhi`.

How wide is the lane mask that the vectoriser emits is controlled
by a TargetTransformInfo hook `getMaxPredicateLength`.The default
impementation (return the same length as the VF) keeps the
change non-functional for targets that can't or are not prepared
to handle wider lane masks.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants