diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp index e43181a76fc5c..cfa6ec691002a 100644 --- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp +++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp @@ -783,6 +783,9 @@ class WaitcntBrackets { : CntT(CntT), Limits(&Limits) {} /// \returns the count of outstanding instrs tracked by this counter. unsigned getCount() const { return UB - LB; } + /// \returns how much we should wait for the instruction corresponding to + /// \p Score to complete, assuming in-order completion. + unsigned getWait(unsigned Score) const { return UB - Score; } // TODO: Make private: we should not provide raw access to the internals. void setLB(unsigned NewLB) { LB = NewLB; } // TODO: Make private: we should not provide raw access to the internals. @@ -910,7 +913,7 @@ class WaitcntBrackets { } unsigned getPendingGDSWait() const { - return std::min(getScoreUB(AMDGPU::DS_CNT) - LastGDS, + return std::min(Counters[AMDGPU::DS_CNT].getWait(LastGDS), getWaitCountMax(Context->getLimits(), AMDGPU::DS_CNT) - 1); } @@ -1612,11 +1615,8 @@ void WaitcntBrackets::purgeEmptyTrackingData() { void WaitcntBrackets::determineWaitForScore(AMDGPU::InstCounterType T, unsigned ScoreToWait, AMDGPU::Waitcnt &Wait) const { - const unsigned LB = getScoreLB(T); - const unsigned UB = getScoreUB(T); - // If the score falls within the bracket, we need a waitcnt. - if ((UB >= ScoreToWait) && (ScoreToWait > LB)) { + if (Counters[T].contains(ScoreToWait)) { if ((T == AMDGPU::LOAD_CNT || T == AMDGPU::DS_CNT) && hasPendingFlat() && !Context->ST.hasFlatLgkmVMemCountInOrder()) { // If there is a pending FLAT operation, and this is a VMem or LGKM @@ -1631,8 +1631,9 @@ void WaitcntBrackets::determineWaitForScore(AMDGPU::InstCounterType T, } else { // If a counter has been maxed out avoid overflow by waiting for // MAX(CounterType) - 1 instead. - unsigned NeededWait = std::min( - UB - ScoreToWait, getWaitCountMax(Context->getLimits(), T) - 1); + unsigned NeededWait = + std::min(Counters[T].getWait(ScoreToWait), + getWaitCountMax(Context->getLimits(), T) - 1); addWait(Wait, T, NeededWait); } }