Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTI]Add support for strided loads/stores. #80329

Conversation

alexey-bataev
Copy link
Member

@alexey-bataev alexey-bataev commented Feb 1, 2024

Added basic legality check and cost estimation functions for strided loads and stores.

These interfaces will be built upon in #80310.

Created using spr 1.3.5
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 1, 2024

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-backend-risc-v

Author: Alexey Bataev (alexey-bataev)

Changes

Added basic legality check and cost estimation functions.


Full diff: https://github.com/llvm/llvm-project/pull/80329.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+34)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+13)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+14)
  • (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp (+24)
  • (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (+11)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 3b615bc700bbb..58577a6b6eb5c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -781,6 +781,9 @@ class TargetTransformInfo {
   /// Return true if the target supports masked expand load.
   bool isLegalMaskedExpandLoad(Type *DataType) const;
 
+  /// Return true if the target supports strided load.
+  bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const;
+
   /// Return true if this is an alternating opcode pattern that can be lowered
   /// to a single instruction on the target. In X86 this is for the addsub
   /// instruction which corrsponds to a Shuffle + Fadd + FSub pattern in IR.
@@ -1412,6 +1415,20 @@ class TargetTransformInfo {
       Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
       const Instruction *I = nullptr) const;
 
+  /// \return The cost of strided memory operations.
+  /// \p Opcode - is a type of memory access Load or Store
+  /// \p DataTy - a vector type of the data to be loaded or stored
+  /// \p Ptr - pointer [or vector of pointers] - address[es] in memory
+  /// \p VariableMask - true when the memory access is predicated with a mask
+  ///                   that is not a compile-time constant
+  /// \p Alignment - alignment of single element
+  /// \p I - the optional original context instruction, if one exists, e.g. the
+  ///        load/store to transform or the call to the gather/scatter intrinsic
+  InstructionCost getStridedMemoryOpCost(
+      unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+      Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
+      const Instruction *I = nullptr) const;
+
   /// \return The cost of the interleaved memory operation.
   /// \p Opcode is the memory operation code
   /// \p VecTy is the vector type of the interleaved access.
@@ -1848,6 +1865,7 @@ class TargetTransformInfo::Concept {
                                            Align Alignment) = 0;
   virtual bool isLegalMaskedCompressStore(Type *DataType) = 0;
   virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0;
+  virtual bool isLegalStridedLoadStore(Type *DataType, Align Alignment) = 0;
   virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
                                unsigned Opcode1,
                                const SmallBitVector &OpcodeMask) const = 0;
@@ -2023,6 +2041,11 @@ class TargetTransformInfo::Concept {
                          bool VariableMask, Align Alignment,
                          TTI::TargetCostKind CostKind,
                          const Instruction *I = nullptr) = 0;
+  virtual InstructionCost
+  getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
+                         bool VariableMask, Align Alignment,
+                         TTI::TargetCostKind CostKind,
+                         const Instruction *I = nullptr) = 0;
 
   virtual InstructionCost getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
@@ -2341,6 +2364,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   bool isLegalMaskedExpandLoad(Type *DataType) override {
     return Impl.isLegalMaskedExpandLoad(DataType);
   }
+  bool isLegalStridedLoadStore(Type *DataType, Align Alignment) override {
+    return Impl.isLegalStridedLoadStore(DataType, Alignment);
+  }
   bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
                        const SmallBitVector &OpcodeMask) const override {
     return Impl.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask);
@@ -2671,6 +2697,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
                                        Alignment, CostKind, I);
   }
+  InstructionCost
+  getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
+                         bool VariableMask, Align Alignment,
+                         TTI::TargetCostKind CostKind,
+                         const Instruction *I = nullptr) override {
+    return Impl.getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+                                       Alignment, CostKind, I);
+  }
   InstructionCost getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 9958b4daa6ed8..1fe126379ac75 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -304,6 +304,10 @@ class TargetTransformInfoImplBase {
 
   bool isLegalMaskedExpandLoad(Type *DataType) const { return false; }
 
+  bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const {
+    return false;
+  }
+
   bool enableOrderedReductions() const { return false; }
 
   bool hasDivRemOp(Type *DataType, bool IsSigned) const { return false; }
@@ -687,6 +691,15 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
+                                         const Value *Ptr, bool VariableMask,
+                                         Align Alignment,
+                                         TTI::TargetCostKind CostKind,
+                                         const Instruction *I = nullptr) const {
+    return CostKind == TTI::TCK_RecipThroughput ? TTI::TCC_Expensive
+                                                : TTI::TCC_Basic;
+  }
+
   unsigned getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 8902dde37cbca..8158c519562ee 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -500,6 +500,11 @@ bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
   return TTIImpl->isLegalMaskedExpandLoad(DataType);
 }
 
+bool TargetTransformInfo::isLegalStridedLoadStore(Type *DataType,
+                                                  Align Alignment) const {
+  return TTIImpl->isLegalStridedLoadStore(DataType, Alignment);
+}
+
 bool TargetTransformInfo::enableOrderedReductions() const {
   return TTIImpl->enableOrderedReductions();
 }
@@ -1041,6 +1046,15 @@ InstructionCost TargetTransformInfo::getGatherScatterOpCost(
   return Cost;
 }
 
+InstructionCost TargetTransformInfo::getStridedMemoryOpCost(
+    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
+  InstructionCost Cost = TTIImpl->getStridedMemoryOpCost(
+      Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I);
+  assert(Cost >= 0 && "TTI should not produce negative costs!");
+  return Cost;
+}
+
 InstructionCost TargetTransformInfo::getInterleavedMemoryOpCost(
     unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
     Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index fe1cdb2dfa423..3e349458ec0db 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -658,6 +658,30 @@ InstructionCost RISCVTTIImpl::getGatherScatterOpCost(
   return NumLoads * MemOpCost;
 }
 
+InstructionCost RISCVTTIImpl::getStridedMemoryOpCost(
+    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+                                         Alignment, CostKind, I);
+
+  if (((Opcode == Instruction::Load || Opcode == Instruction::Store) &&
+       !isLegalStridedLoadStore(DataTy, Alignment)) ||
+      (Opcode != Instruction::Load && Opcode != Instruction::Store))
+    return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+                                         Alignment, CostKind, I);
+
+  // Cost is proportional to the number of memory operations implied.  For
+  // scalable vectors, we use an estimate on that number since we don't
+  // know exactly what VL will be.
+  auto &VTy = *cast<VectorType>(DataTy);
+  InstructionCost MemOpCost =
+      getMemoryOpCost(Opcode, VTy.getElementType(), Alignment, 0, CostKind,
+                      {TTI::OK_AnyValue, TTI::OP_None}, I);
+  unsigned NumLoads = getEstimatedVLFor(&VTy);
+  return NumLoads * MemOpCost;
+}
+
 // Currently, these represent both throughput and codesize costs
 // for the respective intrinsics.  The costs in this table are simply
 // instruction counts with the following adjustments made:
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 0747a778fe9a2..af36e9d5d5e88 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -143,6 +143,12 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
                                          TTI::TargetCostKind CostKind,
                                          const Instruction *I);
 
+  InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
+                                         const Value *Ptr, bool VariableMask,
+                                         Align Alignment,
+                                         TTI::TargetCostKind CostKind,
+                                         const Instruction *I);
+
   InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
                                    TTI::CastContextHint CCH,
                                    TTI::TargetCostKind CostKind,
@@ -250,6 +256,11 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
     return ST->is64Bit() && !ST->hasVInstructionsI64();
   }
 
+  bool isLegalStridedLoadStore(Type *DataType, Align Alignment) {
+    EVT DataTypeVT = TLI->getValueType(DL, DataType);
+    return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
+  }
+
   bool isVScaleKnownToBeAPowerOfTwo() const {
     return TLI->isVScaleKnownToBeAPowerOfTwo();
   }

Align Alignment,
TTI::TargetCostKind CostKind,
const Instruction *I = nullptr) const {
return CostKind == TTI::TCK_RecipThroughput ? TTI::TCC_Expensive
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said in the review before you split, I would strongly prefer if this returned InstructionCost::getInvalid by default. TCC_Basic is fine for codesize of a legal strided access, but all others should return invalid.

Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
InstructionCost Cost = TTIImpl->getStridedMemoryOpCost(
Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I);
assert(Cost >= 0 && "TTI should not produce negative costs!");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assert will need adjusted.

@preames
Copy link
Collaborator

preames commented Feb 1, 2024

For context of anyone reading along...

I asked this to be split off and intend to approve this despite the last of testing. It is tested by the dependent SLP change, and part of my motivation for getting this in is unblocking a change I'm going to write to adjust the costs returned via the printer so that we can test this more easily. I need the API which this change has already added for my change to make sense.

Created using spr 1.3.5
Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@alexey-bataev alexey-bataev merged commit 8ad14b6 into main Feb 1, 2024
3 of 4 checks passed
@alexey-bataev alexey-bataev deleted the users/alexey-bataev/spr/ttiadd-support-for-strided-loadsstores branch February 1, 2024 21:07
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
Added basic legality check and cost estimation functions for strided loads and stores.

These interfaces will be built upon in llvm#80310.

Reviewers: preames

Reviewed By: preames

Pull Request: llvm#80329
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants