diff --git a/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp b/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp index 00948278401cb..1873057b3694c 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp @@ -533,13 +533,10 @@ std::pair AMDGPUSubtarget::getFlatWorkGroupSizes( } std::pair AMDGPUSubtarget::getWavesPerEU( - const Function &F) const { + const Function &F, std::pair FlatWorkGroupSizes) const { // Default minimum/maximum number of waves per execution unit. std::pair Default(1, getMaxWavesPerEU()); - // Default/requested minimum/maximum flat work group sizes. - std::pair FlatWorkGroupSizes = getFlatWorkGroupSizes(F); - // If minimum/maximum flat work group sizes were explicitly requested using // "amdgpu-flat-work-group-size" attribute, then set default minimum/maximum // number of waves per execution unit to values implied by requested diff --git a/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h b/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h index b160cdf3a97aa..1d8a9e61a0857 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h @@ -91,7 +91,18 @@ class AMDGPUSubtarget { /// be converted to integer, violate subtarget's specifications, or are not /// compatible with minimum/maximum number of waves limited by flat work group /// size, register usage, and/or lds usage. - std::pair getWavesPerEU(const Function &F) const; + std::pair getWavesPerEU(const Function &F) const { + // Default/requested minimum/maximum flat work group sizes. + std::pair FlatWorkGroupSizes = getFlatWorkGroupSizes(F); + return getWavesPerEU(F, FlatWorkGroupSizes); + } + + /// Overload which uses the specified values for the flat work group sizes, + /// rather than querying the function itself. \p FlatWorkGroupSizes Should + /// correspond to the function's value for getFlatWorkGroupSizes. + std::pair + getWavesPerEU(const Function &F, + std::pair FlatWorkGroupSizes) const; /// Return the amount of LDS that can be used that will not restrict the /// occupancy lower than WaveCount.