Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoopVectorize] Vectorize the compact pattern #68980

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

huhu233
Copy link
Contributor

@huhu233 huhu233 commented Oct 13, 2023

This patch tries to vectorize the compact pattern, as shown,

for(i=0; i<N; i++){
x = comp[i];
if(x<a) Out_ref[n++]=B[i];
}

It introduces some changes:

  1. Add a pattern matching in LoopVectorizationLegality to cache specific cases.
  2. Introduce two new recipes to hande the compact chain:
    VPCompactPHIRecipe: Handle the entry PHI of compact chain.
    VPWidenCompactInstructionRecipe: Handle other instructions in compact chain.
  3. Slightly adapt the cost model for compact pattern.

This patch tries to vectorize the compact pattern, as shown,

  for (i = 0; i < N; i++) {
    x = comp[i];
    if(x<a) Out_ref[n++]=B[i];
  }

It introduces some changes:
1.Add a pattern matching in LoopVectorizationLegality to cache
specific cases.
2.Introduce two new recipes to hande the compact chain:
VPCompactPHIRecipe: Handle the entry PHI of compact chain.
VPWidenCompactInstructionRecipe: Handle other instructions in compact chain.
3.Slightly adapt the cost model for compact pattern.
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 13, 2023

@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: None (huhu233)

Changes

This patch tries to vectorize the compact pattern, as shown,

for(i=0; i<N; i++){
x = comp[i];
if(x<a) Out_ref[n++]=B[i];
}

It introduces some changes:

  1. Add a pattern matching in LoopVectorizationLegality to cache specific cases.
  2. Introduce two new recipes to hande the compact chain:
    VPCompactPHIRecipe: Handle the entry PHI of compact chain.
    VPWidenCompactInstructionRecipe: Handle other instructions in compact chain.
  3. Slightly adapt the cost model for compact pattern.

Patch is 56.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68980.diff

20 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+25)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+7)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+4)
  • (modified) llvm/include/llvm/Transforms/Utils/LoopUtils.h (+8)
  • (modified) llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h (+35)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+16)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+12)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+1)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+2)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+10)
  • (modified) llvm/lib/Transforms/Utils/LoopUtils.cpp (+29)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp (+157)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+111-5)
  • (modified) llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h (+5)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.cpp (+17)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+62)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+135)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+2)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/compact-vplan.ll (+78)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/compact.ll (+153)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 5234ef8788d9e96..c2851c10e6ff3ef 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1672,6 +1672,11 @@ class TargetTransformInfo {
   /// \return The maximum number of function arguments the target supports.
   unsigned getMaxNumArgs() const;
 
+  InstructionCost getCompactCost() const;
+  bool isTargetSupportedCompactStore() const;
+  unsigned getTargetSupportedCompact() const;
+  unsigned getTargetSupportedCNTP() const;
+
   /// @}
 
 private:
@@ -2041,6 +2046,10 @@ class TargetTransformInfo::Concept {
   getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual bool hasArmWideBranch(bool Thumb) const = 0;
   virtual unsigned getMaxNumArgs() const = 0;
+  virtual bool isTargetSupportedCompactStore() const = 0;
+  virtual unsigned getTargetSupportedCompact() const = 0;
+  virtual unsigned getTargetSupportedCNTP() const = 0;
+  virtual InstructionCost getCompactCost() const = 0;
 };
 
 template <typename T>
@@ -2757,6 +2766,22 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxNumArgs() const override {
     return Impl.getMaxNumArgs();
   }
+
+  bool isTargetSupportedCompactStore() const override {
+    return Impl.isTargetSupportedCompactStore();
+  }
+
+  unsigned getTargetSupportedCompact() const override {
+    return Impl.getTargetSupportedCompact();
+  }
+
+  unsigned getTargetSupportedCNTP() const override {
+    return Impl.getTargetSupportedCNTP();
+  }
+
+  InstructionCost getCompactCost() const override {
+    return Impl.getCompactCost();
+  }
 };
 
 template <typename T>
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index c1ff314ae51c98b..e063f383980a724 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -895,6 +895,13 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxNumArgs() const { return UINT_MAX; }
 
+  bool isTargetSupportedCompactStore() const { return false; }
+  unsigned getTargetSupportedCompact() const { return 0; }
+  unsigned getTargetSupportedCNTP() const { return 0; }
+  InstructionCost getCompactCost() const {
+    return InstructionCost::getInvalid();
+  }
+
 protected:
   // Obtain the minimum required size to hold the value (without the sign)
   // In case of a vector it returns the min required size for one element.
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 3dd16dafe3c42a7..737757761eca4ab 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -700,6 +700,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return getST()->getMaxPrefetchIterationsAhead();
   }
 
+  virtual InstructionCost getCompactCost() const {
+    return InstructionCost::getInvalid();
+  }
+
   virtual bool enableWritePrefetching() const {
     return getST()->enableWritePrefetching();
   }
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 0d99249be413762..348b8ad03de4179 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -409,6 +409,14 @@ Value *createAnyOfTargetReduction(IRBuilderBase &B, Value *Src,
 Value *createTargetReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc,
                              Value *Src, PHINode *OrigPhi = nullptr);
 
+Value *createTargetCompact(IRBuilderBase &B, Module *M,
+                           const TargetTransformInfo *TTI, Value *Mask,
+                           Value *Val);
+
+Value *createTargetCNTP(IRBuilderBase &B, Module *M,
+                        const TargetTransformInfo *TTI, Value *Mask,
+                        Value *Val);
+
 /// Create an ordered reduction intrinsic using the given recurrence
 /// descriptor \p Desc.
 Value *createOrderedReduction(IRBuilderBase &B,
diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
index 20cfc680e8f90b3..7f82154699e5174 100644
--- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
@@ -224,6 +224,26 @@ class LoopVectorizationRequirements {
   Instruction *ExactFPMathInst = nullptr;
 };
 
+class CompactDescriptor {
+  PHINode *LiveOutPhi;
+  bool IsCompactSign;
+  SmallPtrSet<Value *, 8> Chain;
+
+public:
+  CompactDescriptor() = default;
+  CompactDescriptor(SmallPtrSetImpl<Value *> &CompactChain, PHINode *LiveOut,
+                    bool IsSign)
+      : LiveOutPhi(LiveOut), IsCompactSign(IsSign) {
+    Chain.insert(CompactChain.begin(), CompactChain.end());
+  }
+
+  bool isInCompactChain(Value *V) const { return Chain.find(V) != Chain.end(); }
+
+  PHINode *getLiveOutPhi() const { return LiveOutPhi; }
+
+  bool isSign() const { return IsCompactSign; }
+};
+
 /// LoopVectorizationLegality checks if it is legal to vectorize a loop, and
 /// to what vectorization factor.
 /// This class does not look at the profitability of vectorization, only the
@@ -261,6 +281,8 @@ class LoopVectorizationLegality {
   /// inductions and reductions.
   using RecurrenceSet = SmallPtrSet<const PHINode *, 8>;
 
+  using CompactList = MapVector<PHINode *, CompactDescriptor>;
+
   /// Returns true if it is legal to vectorize this loop.
   /// This does not mean that it is profitable to vectorize this
   /// loop, only that it is legal to do so.
@@ -397,6 +419,14 @@ class LoopVectorizationLegality {
 
   DominatorTree *getDominatorTree() const { return DT; }
 
+  const CompactList &getCompactList() const { return CpList; }
+
+  bool hasCompactChain() const { return CpList.size() > 0; }
+
+  PHINode *getCompactChainStart(Instruction *I) const;
+
+  bool isSign(PHINode *Phi) { return CpList[Phi].isSign(); };
+
 private:
   /// Return true if the pre-header, exiting and latch blocks of \p Lp and all
   /// its nested loops are considered legal for vectorization. These legal
@@ -425,6 +455,8 @@ class LoopVectorizationLegality {
   /// and we only need to check individual instructions.
   bool canVectorizeInstrs();
 
+  bool isMatchCompact(PHINode *Phi, Loop *TheLoop, CompactDescriptor &CpDesc);
+
   /// When we vectorize loops we may change the order in which
   /// we read and write from memory. This method checks if it is
   /// legal to vectorize the code, considering only memory constrains.
@@ -538,6 +570,9 @@ class LoopVectorizationLegality {
   /// BFI and PSI are used to check for profile guided size optimizations.
   BlockFrequencyInfo *BFI;
   ProfileSummaryInfo *PSI;
+
+  // Record compact chain in the loop.
+  CompactList CpList;
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index aad14f21d114619..b7596bb2e0dfc92 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1248,6 +1248,22 @@ bool TargetTransformInfo::hasActiveVectorLength(unsigned Opcode, Type *DataType,
   return TTIImpl->hasActiveVectorLength(Opcode, DataType, Alignment);
 }
 
+bool TargetTransformInfo::isTargetSupportedCompactStore() const {
+  return TTIImpl->isTargetSupportedCompactStore();
+}
+
+unsigned TargetTransformInfo::getTargetSupportedCompact() const {
+  return TTIImpl->getTargetSupportedCompact();
+}
+
+unsigned TargetTransformInfo::getTargetSupportedCNTP() const {
+  return TTIImpl->getTargetSupportedCNTP();
+}
+
+InstructionCost TargetTransformInfo::getCompactCost() const {
+  return TTIImpl->getCompactCost();
+}
+
 TargetTransformInfo::Concept::~Concept() = default;
 
 TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index fc9e3ff3734989d..d30d0b57b5d47b0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -22,6 +22,7 @@
 #include "llvm/CodeGen/StackMaps.h"
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/KnownBits.h"
 #include "llvm/Support/raw_ostream.h"
@@ -301,6 +302,11 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FFREXP:
     Res = PromoteIntRes_FFREXP(N);
     break;
+  case ISD::INTRINSIC_WO_CHAIN:
+    if (N->getConstantOperandVal(0) == Intrinsic::aarch64_sve_compact) {
+      Res = PromoteIntRes_COMPACT(N);
+      break;
+    }
   }
 
   // If the result is null then the sub-method took care of registering it.
@@ -5942,6 +5948,12 @@ SDValue DAGTypeLegalizer::PromoteIntOp_CONCAT_VECTORS(SDNode *N) {
   return DAG.getBuildVector(N->getValueType(0), dl, NewOps);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_COMPACT(SDNode *N) {
+  SDValue OpExt = SExtOrZExtPromotedInteger(N->getOperand(2));
+  return DAG.getNode(N->getOpcode(), SDLoc(N), OpExt.getValueType(),
+                     N->getOperand(0), N->getOperand(1), OpExt);
+}
+
 SDValue DAGTypeLegalizer::ExpandIntOp_STACKMAP(SDNode *N, unsigned OpNo) {
   assert(OpNo > 1);
   SDValue Op = N->getOperand(OpNo);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index c802604a3470e13..d204169ed2327f7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -364,6 +364,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_FunnelShift(SDNode *N);
   SDValue PromoteIntRes_VPFunnelShift(SDNode *N);
   SDValue PromoteIntRes_IS_FPCLASS(SDNode *N);
+  SDValue PromoteIntRes_COMPACT(SDNode *N);
 
   // Integer Operand Promotion.
   bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index d8a0e68d7123759..5ca5f22525d3dd3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3889,3 +3889,5 @@ AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
     return AM.Scale != 0 && AM.Scale != 1;
   return -1;
 }
+
+InstructionCost AArch64TTIImpl::getCompactCost() const { return 6; }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a6baade412c77d2..28bd48e8ed76c4a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -24,6 +24,7 @@
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
 #include <cstdint>
 #include <optional>
 
@@ -412,6 +413,15 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
     return BaseT::getStoreMinimumVF(VF, ScalarMemTy, ScalarValTy);
   }
+
+  bool isTargetSupportedCompactStore() const { return ST->hasSVE(); }
+  unsigned getTargetSupportedCompact() const {
+    return Intrinsic::aarch64_sve_compact;
+  }
+  unsigned getTargetSupportedCNTP() const {
+    return Intrinsic::aarch64_sve_cntp;
+  }
+  InstructionCost getCompactCost() const override;
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 21affe7bdce406e..1373fb7931f0a7e 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -34,6 +34,7 @@
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
@@ -1119,6 +1120,34 @@ Value *llvm::createTargetReduction(IRBuilderBase &B,
   return createSimpleTargetReduction(B, Src, RK);
 }
 
+Value *llvm::createTargetCompact(IRBuilderBase &B, Module *M,
+                                 const TargetTransformInfo *TTI, Value *Mask,
+                                 Value *Val) {
+  Intrinsic::ID IID = TTI->getTargetSupportedCompact();
+  switch (IID) {
+  default:
+    return nullptr;
+  case Intrinsic::aarch64_sve_compact:
+    Function *CompactMaskDecl = Intrinsic::getDeclaration(
+        M, Intrinsic::aarch64_sve_compact, Val->getType());
+    return B.CreateCall(CompactMaskDecl, {Mask, Val});
+  }
+}
+
+Value *llvm::createTargetCNTP(IRBuilderBase &B, Module *M,
+                              const TargetTransformInfo *TTI, Value *Mask,
+                              Value *Val) {
+  Intrinsic::ID IID = TTI->getTargetSupportedCNTP();
+  switch (IID) {
+  default:
+    return nullptr;
+  case Intrinsic::aarch64_sve_cntp:
+    Function *CNTPDecl = Intrinsic::getDeclaration(
+        M, Intrinsic::aarch64_sve_cntp, Val->getType());
+    return B.CreateCall(CNTPDecl, {Mask, Val});
+  }
+}
+
 Value *llvm::createOrderedReduction(IRBuilderBase &B,
                                     const RecurrenceDescriptor &Desc,
                                     Value *Src, Value *Start) {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index 35d69df56dc7220..dbab8af159a1621 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -24,6 +24,7 @@
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/KnownBits.h"
 #include "llvm/Transforms/Utils/SizeOpts.h"
 #include "llvm/Transforms/Vectorize/LoopVectorize.h"
 
@@ -78,6 +79,11 @@ static cl::opt<LoopVectorizeHints::ScalableForceKind>
                 "Scalable vectorization is available and favored when the "
                 "cost is inconclusive.")));
 
+static cl::opt<bool>
+    EnableCompactVectorization("enable-compact-vectorization", cl::init(true),
+                               cl::Hidden,
+                               cl::desc("Enable vectorizing compact pattern."));
+
 /// Maximum vectorization interleave count.
 static const unsigned MaxInterleaveFactor = 16;
 
@@ -785,6 +791,143 @@ static bool isTLIScalarize(const TargetLibraryInfo &TLI, const CallInst &CI) {
   return Scalarize;
 }
 
+static bool isUserOfCompactPHI(BasicBlock *BB, PHINode *Phi, Instruction *I) {
+  if (I->getParent() != BB)
+    return false;
+
+  // Operations on PHI should be affine.
+  if (I->getOpcode() != Instruction::Add &&
+      I->getOpcode() != Instruction::Sub &&
+      I->getOpcode() != Instruction::SExt &&
+      I->getOpcode() != Instruction::ZExt)
+    return false;
+
+  if (I == Phi)
+    return true;
+
+  for (unsigned i = 0; i < I->getNumOperands(); i++) {
+    if (auto *Instr = dyn_cast<Instruction>(I->getOperand(i)))
+      if (isUserOfCompactPHI(BB, Phi, Instr))
+        return true;
+  }
+  return false;
+}
+
+// Match the basic compact pattern:
+// for.body:
+//    %src.phi = phi i64 [ 0, %preheader ], [ %target.phi, %for.inc ]
+//    ...
+// if.then:
+//    ...
+//    %data = load i32, ptr %In
+//    (there may be additional sext/zext if %src.phi types i32)
+//    %addr = getelementptr i32, ptr %Out, i64 %src.phi
+//    store i32 %data, ptr %addr
+//    %inc = add i64 %src.phi, 1
+// for.inc
+//    %target.phi = phi i64 [ %inc, if.then ], [ %src.phi, %for.body ]
+bool LoopVectorizationLegality::isMatchCompact(PHINode *Phi, Loop *TheLoop,
+                                               CompactDescriptor &CpDesc) {
+  if (Phi->getNumIncomingValues() > 2)
+    return false;
+
+  // Don't support phis who is used as mask.
+  for (User *U : Phi->users()) {
+    if (isa<CmpInst>(U))
+      return false;
+  }
+
+  SmallPtrSet<Value *, 8> CompactChain;
+  CompactChain.insert(Phi);
+
+  BasicBlock *LoopPreHeader = TheLoop->getLoopPreheader();
+  int ExitIndex = Phi->getIncomingBlock(0) == LoopPreHeader ? 1 : 0;
+  BasicBlock *ExitBlock = Phi->getIncomingBlock(ExitIndex);
+  PHINode *CompactLiveOut = nullptr;
+  Value *IncValue = nullptr;
+  BasicBlock *IncBlock = nullptr;
+  bool IsCycle = false;
+  for (auto &CandPhi : ExitBlock->phis()) {
+    if (llvm::is_contained(CandPhi.incoming_values(), Phi) &&
+        CandPhi.getNumIncomingValues() == 2) {
+      IsCycle = true;
+      CompactLiveOut = &CandPhi;
+      int IncIndex = CandPhi.getIncomingBlock(0) == Phi->getParent() ? 1 : 0;
+      IncBlock = CandPhi.getIncomingBlock(IncIndex);
+      IncValue = CandPhi.getIncomingValueForBlock(IncBlock);
+      break;
+    }
+  }
+  // Similar with reduction PHI.
+  if (!IsCycle)
+    return false;
+  CompactChain.insert(CompactLiveOut);
+
+  // Match the pattern %inc = add i32 %src.phi, 1.
+  Value *Index = nullptr, *Step = nullptr;
+  if (!match(IncValue, m_Add(m_Value(Index), m_Value(Step))))
+    return false;
+  if (Index != Phi) {
+    std::swap(Index, Step);
+  }
+  if (Step != ConstantInt::get(Step->getType(), 1))
+    return false;
+  CompactChain.insert(IncValue);
+
+  const DataLayout &DL = Phi->getModule()->getDataLayout();
+  int CntCandStores = 0;
+  GetElementPtrInst *GEP = nullptr;
+  for (auto &Inst : *IncBlock) {
+    if (auto *SI = dyn_cast<StoreInst>(&Inst)) {
+      // TODO: Support llvm.sve.compact.nxv8i16, llvm.sve.compact.nxv16i18 in
+      // the future.
+      unsigned TySize = DL.getTypeSizeInBits(SI->getValueOperand()->getType());
+      if (TySize < 32)
+        return false;
+
+      GEP = dyn_cast<GetElementPtrInst>(SI->getPointerOperand());
+      if (GEP == nullptr)
+        continue;
+
+      // Only handle single pointer.
+      if (GEP->getNumOperands() != 2)
+        continue;
+
+      // Get the index of GEP, index could be phi or sext/zext (if phi types
+      // i32).
+      Value *Op1 = GEP->getOperand(1);
+      Value *X = nullptr;
+      SmallSet<Value *, 16> CandiInstrs;
+      if (match(Op1, m_SExt(m_Value(X))) || match(Op1, m_ZExt(m_Value(X)))) {
+        Op1 = X;
+      }
+      Instruction *Op1Instr = dyn_cast<Instruction>(Op1);
+      if (!Op1Instr || isUserOfCompactPHI(IncBlock, Phi, Op1Instr))
+        continue;
+      CompactChain.insert(GEP);
+      CompactChain.insert(SI);
+      CntCandStores++;
+    }
+  }
+  if (!CntCandStores)
+    return false;
+
+  KnownBits Bits = computeKnownBits(Phi, DL);
+  bool IsSign = !Bits.isNonNegative();
+  CompactDescriptor CompactDesc(CompactChain, CompactLiveOut, IsSign);
+  CpDesc = CompactDesc;
+  LLVM_DEBUG(dbgs() << "LV: Found a compact chain.\n");
+  return true;
+}
+
+PHINode *LoopVectorizationLegality::getCompactChainStart(Instruction *I) const {
+  for (auto &CpDesc : CpList) {
+    if (CpDesc.second.isInCompactChain(I))
+      return CpDesc.first;
+  }
+  return nullptr;
+}
+
 bool LoopVectorizationLegality::canVectorizeInstrs() {
   BasicBlock *Header = TheLoop->getHeader();
 
@@ -881,6 +1024,14 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
           continue;
         }
 
+        CompactDescriptor CpDesc;
+        if (EnableCompactVectorization &&
+            TTI->isTargetSupportedCompactStore() &&
+            isMatchCompact(Phi, TheLoop, CpDesc)) {
+          CpList[Phi] = CpDesc;
+          continue;
+        }
+
         reportVectorizationFailure("Found an unidentified PHI",
             "value that could not be identified as "
             "reduction is used outside the loop",
@@ -1525,16 +1676,22 @@ bool LoopVectorizationLegality::prepareToFoldTailByMasking() {
   LLVM_DEBUG(dbgs() << "LV: checking if tail can be folded by masking.\n");
 
   SmallPtrSet<const Value *, 8> ReductionLiveOuts;
+  SmallPtrSet<const Value *, 8> CompactLiveOuts;
 
   for (const auto &Reduction : getReductionVars())
     ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
 
+  for (const auto &Compact : getCompactList())
+    CompactLiveOuts.insert(Compact.second.getLiveOutPhi());
+
   // TODO: handle non-reduction outside users when tail is folded by masking.
   for (auto *AE : AllowedExit) {
    ...
[truncated]

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs resolve to main

@huntergr-arm huntergr-arm self-requested a review February 22, 2024 11:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants