diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index dc1e5a3b86046..d39d069d9eb2a 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -402,7 +402,11 @@ class TargetTransformInfo { /// Branch divergence has a significantly negative impact on GPU performance /// when threads in the same wavefront take different paths due to conditional /// branches. - bool hasBranchDivergence() const; + /// + /// If \p F is passed, provides a context function. If \p F is known to only + /// execute in a single threaded environment, the target may choose to skip + /// uniformity analysis and assume all values are uniform. + bool hasBranchDivergence(const Function *F = nullptr) const; /// Returns whether V is a source of divergence. /// @@ -1689,7 +1693,7 @@ class TargetTransformInfo::Concept { ArrayRef Operands, TargetCostKind CostKind) = 0; virtual BranchProbability getPredictableBranchThreshold() = 0; - virtual bool hasBranchDivergence() = 0; + virtual bool hasBranchDivergence(const Function *F = nullptr) = 0; virtual bool isSourceOfDivergence(const Value *V) = 0; virtual bool isAlwaysUniform(const Value *V) = 0; virtual bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const = 0; @@ -2066,7 +2070,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { BranchProbability getPredictableBranchThreshold() override { return Impl.getPredictableBranchThreshold(); } - bool hasBranchDivergence() override { return Impl.hasBranchDivergence(); } + bool hasBranchDivergence(const Function *F = nullptr) override { + return Impl.hasBranchDivergence(F); + } bool isSourceOfDivergence(const Value *V) override { return Impl.isSourceOfDivergence(V); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index fe133b315754f..64ac5701c03e4 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -91,7 +91,7 @@ class TargetTransformInfoImplBase { return BranchProbability(99, 100); } - bool hasBranchDivergence() const { return false; } + bool hasBranchDivergence(const Function *F = nullptr) const { return false; } bool isSourceOfDivergence(const Value *V) const { return false; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index a824fd45556f7..f57697da9aaf9 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -276,7 +276,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { E, AddressSpace, Alignment, MachineMemOperand::MONone, Fast); } - bool hasBranchDivergence() { return false; } + bool hasBranchDivergence(const Function *F = nullptr) { return false; } bool isSourceOfDivergence(const Value *V) { return false; } diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 06d97b5b2c0a4..4500e0615219c 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -259,8 +259,8 @@ BranchProbability TargetTransformInfo::getPredictableBranchThreshold() const { : TTIImpl->getPredictableBranchThreshold(); } -bool TargetTransformInfo::hasBranchDivergence() const { - return TTIImpl->hasBranchDivergence(); +bool TargetTransformInfo::hasBranchDivergence(const Function *F) const { + return TTIImpl->hasBranchDivergence(F); } bool TargetTransformInfo::isSourceOfDivergence(const Value *V) const { diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp index 5a9e87deecc14..c1428e111e20e 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp @@ -301,6 +301,10 @@ GCNTTIImpl::GCNTTIImpl(const AMDGPUTargetMachine *TM, const Function &F) HasFP64FP16Denormals = Mode.allFP64FP16Denormals(); } +bool GCNTTIImpl::hasBranchDivergence(const Function *F) const { + return true; +} + unsigned GCNTTIImpl::getNumberOfRegisters(unsigned RCID) const { // NB: RCID is not an RCID. In fact it is 0 or 1 for scalar or vector // registers. See getRegisterClassForType for the implementation. diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h index 7c61429767eca..db223e1272a23 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h @@ -102,7 +102,7 @@ class GCNTTIImpl final : public BasicTTIImplBase { public: explicit GCNTTIImpl(const AMDGPUTargetMachine *TM, const Function &F); - bool hasBranchDivergence() { return true; } + bool hasBranchDivergence(const Function *F = nullptr) const; void getUnrollingPreferences(Loop *L, ScalarEvolution &SE, TTI::UnrollingPreferences &UP, diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h index b1dc49c6f9a09..ec0fd454c8084 100644 --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -92,9 +92,7 @@ class HexagonTTIImpl : public BasicTTIImplBase { return true; } bool supportsEfficientVectorElementLoadStore() { return false; } - bool hasBranchDivergence() { - return false; - } + bool hasBranchDivergence(const Function *F = nullptr) { return false; } bool enableAggressiveInterleaving(bool LoopHasReductions) { return false; } diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h index 0cee130e1e114..3ce2675560c4d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h +++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h @@ -41,7 +41,7 @@ class NVPTXTTIImpl : public BasicTTIImplBase { : BaseT(TM, F.getParent()->getDataLayout()), ST(TM->getSubtargetImpl()), TLI(ST->getTargetLowering()) {} - bool hasBranchDivergence() { return true; } + bool hasBranchDivergence(const Function *F = nullptr) { return true; } bool isSourceOfDivergence(const Value *V);