diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 0f17312b03827..9d5aaab82905d 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -24,6 +24,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/BitmaskEnum.h" +#include "llvm/ADT/Uniformity.h" #include "llvm/Analysis/IVDescriptors.h" #include "llvm/Analysis/InterestingMemoryOperand.h" #include "llvm/IR/FMF.h" @@ -468,6 +469,15 @@ class TargetTransformInfo { // even taking non-uniform arguments LLVM_ABI bool isAlwaysUniform(const Value *V) const; + /// Get target-specific uniformity information for an instruction. + /// This allows targets to provide more fine-grained control over + /// uniformity analysis by specifying whether specific instructions + /// should always or never be considered uniform, or require custom + /// operand-based analysis. + /// \param V The value to query for uniformity information. + /// \return InstructionUniformity. + LLVM_ABI InstructionUniformity getInstructionUniformity(const Value *V) const; + /// Query the target whether the specified address space cast from FromAS to /// ToAS is valid. LLVM_ABI bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const; diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index aacb88d2f9684..60b3c6f397e4f 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -135,6 +135,10 @@ class TargetTransformInfoImplBase { virtual bool isAlwaysUniform(const Value *V) const { return false; } + virtual InstructionUniformity getInstructionUniformity(const Value *V) const { + return InstructionUniformity::Default; + } + virtual bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const { return false; } diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 0426ac7e62fab..301dcba5b1cb7 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -306,6 +306,11 @@ bool llvm::TargetTransformInfo::isAlwaysUniform(const Value *V) const { return TTIImpl->isAlwaysUniform(V); } +InstructionUniformity +llvm::TargetTransformInfo::getInstructionUniformity(const Value *V) const { + return TTIImpl->getInstructionUniformity(V); +} + bool llvm::TargetTransformInfo::isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const { return TTIImpl->isValidAddrSpaceCast(FromAS, ToAS); diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp index 2e4063f5db14e..b56534935d7c2 100644 --- a/llvm/lib/Analysis/UniformityAnalysis.cpp +++ b/llvm/lib/Analysis/UniformityAnalysis.cpp @@ -31,15 +31,22 @@ bool llvm::GenericUniformityAnalysisImpl::markDefsDivergent( template <> void llvm::GenericUniformityAnalysisImpl::initialize() { for (auto &I : instructions(F)) { - if (TTI->isSourceOfDivergence(&I)) - markDivergent(I); - else if (TTI->isAlwaysUniform(&I)) + InstructionUniformity IU = TTI->getInstructionUniformity(&I); + switch (IU) { + case InstructionUniformity::AlwaysUniform: addUniformOverride(I); + continue; + case InstructionUniformity::NeverUniform: + markDivergent(I); + continue; + case InstructionUniformity::Default: + break; + } } for (auto &Arg : F.args()) { - if (TTI->isSourceOfDivergence(&Arg)) { + if (TTI->getInstructionUniformity(&Arg) == + InstructionUniformity::NeverUniform) markDivergent(&Arg); - } } } diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp index e4b82ce83fda6..238d29d386574 100644 --- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp +++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp @@ -53,13 +53,16 @@ void llvm::GenericUniformityAnalysisImpl::initialize() { for (const MachineBasicBlock &block : F) { for (const MachineInstr &instr : block) { auto uniformity = InstrInfo.getInstructionUniformity(instr); - if (uniformity == InstructionUniformity::AlwaysUniform) { - addUniformOverride(instr); - continue; - } - if (uniformity == InstructionUniformity::NeverUniform) { + switch (uniformity) { + case InstructionUniformity::AlwaysUniform: + addUniformOverride(instr); + break; + case InstructionUniformity::NeverUniform: markDivergent(instr); + break; + case InstructionUniformity::Default: + break; } } } diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp index 03d16fdd54c42..48b8376d23a66 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp @@ -1574,3 +1574,14 @@ unsigned GCNTTIImpl::getNumberOfParts(Type *Tp) const { } return BaseT::getNumberOfParts(Tp); } + +InstructionUniformity +GCNTTIImpl::getInstructionUniformity(const Value *V) const { + if (isAlwaysUniform(V)) + return InstructionUniformity::AlwaysUniform; + + if (isSourceOfDivergence(V)) + return InstructionUniformity::NeverUniform; + + return InstructionUniformity::Default; +} diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h index 20da8344c9d37..c2e102c9bab8e 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h @@ -302,6 +302,8 @@ class GCNTTIImpl final : public BasicTTIImplBase { /// together under a single i32 value. Otherwise fall back to base /// implementation. unsigned getNumberOfParts(Type *Tp) const override; + + InstructionUniformity getInstructionUniformity(const Value *V) const override; }; } // end namespace llvm diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp index 64593e6439184..c9920ae320d86 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp @@ -635,3 +635,11 @@ void NVPTXTTIImpl::collectKernelLaunchBounds( if (MaxNTID.size() > 2) LB.push_back({"maxntidz", MaxNTID[2]}); } + +InstructionUniformity +NVPTXTTIImpl::getInstructionUniformity(const Value *V) const { + if (isSourceOfDivergence(V)) + return InstructionUniformity::NeverUniform; + + return InstructionUniformity::Default; +} diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h index 78eb751cf3c2e..bb4d9a05ad805 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h +++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h @@ -195,6 +195,8 @@ class NVPTXTTIImpl final : public BasicTTIImplBase { // Self-referential globals are not supported. return false; } + + InstructionUniformity getInstructionUniformity(const Value *V) const override; }; } // end namespace llvm