-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[AMDGPU] Add DS loop wait analysis and relaxation (2/4) #171944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/hidekisaito/ds-loop-wait-opt-1
Are you sure you want to change the base?
[AMDGPU] Add DS loop wait analysis and relaxation (2/4) #171944
Conversation
Add the DS load position analysis and wait count relaxation for single-block loops with WMMA instructions (GFX12+). Assisted-by: Cursor / claude-4.5-opus-high
|
@llvm/pr-subscribers-backend-amdgpu Author: None (hidekisaito) ChangesAdd the DS load position analysis and wait count relaxation for single-block loops with WMMA instructions (GFX12+). Assisted-by: Cursor / claude-4.5-opus-high Depends on #171942. Full diff: https://github.com/llvm/llvm-project/pull/171944.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
index 140b79136227c..777491fb58b80 100644
--- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
@@ -466,6 +466,8 @@ class SIInsertWaitcnts {
// Set to true when relaxation is actually applied in the loop body.
// Used to determine if preheader needs DS_CNT flush.
mutable bool RelaxationApplied = false;
+ // Pointer to the last barrier in the loop (found during eligibility check)
+ const MachineInstr *LastBarrier = nullptr;
};
// Cache of loop DS wait optimization info, keyed by loop header MBB.
@@ -600,6 +602,9 @@ class SIInsertWaitcnts {
// DS loop wait optimization functions
bool isEligibleForDSLoopOpt(MachineLoop *ML, LoopDSWaitOptInfo &Info) const;
void analyzeSingleBBLoopDSLoads(MachineLoop *ML);
+ std::optional<unsigned> getOptimalDSWaitCount(MachineBasicBlock *LoopHeader,
+ const MachineInstr &MI) const;
+ bool applyDSLoopWaitOpt(MachineInstr &MI, AMDGPU::Waitcnt &Wait);
};
// This objects maintains the current score brackets of each wait counter, and
@@ -2138,6 +2143,10 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
// Verify that the wait is actually needed.
ScoreBrackets.simplifyWaitcnt(Wait);
+ // DS Loop Wait Optimization (GFX12+):
+ // Try to relax conservative DS wait counts in single-block loops with WMMA.
+ applyDSLoopWaitOpt(MI, Wait);
+
// Since the translation for VMEM addresses occur in-order, we can apply the
// XCnt if the current instruction is of VMEM type and has a memory
// dependency with another VMEM instruction in flight.
@@ -2701,9 +2710,11 @@ bool SIInsertWaitcnts::isEligibleForDSLoopOpt(MachineLoop *ML,
MachineBasicBlock *MBB = ML->getHeader();
// Count DS loads, WMMA/MFMA instructions, and total non-meta instructions
+ // Also find the last barrier during this traversal to avoid re-traversing
unsigned NumDSLoads = 0;
unsigned NumWMMA = 0;
unsigned NumInsts = 0;
+ Info.LastBarrier = nullptr;
for (const MachineInstr &MI : *MBB) {
if (!MI.isMetaInstruction())
@@ -2715,6 +2726,13 @@ bool SIInsertWaitcnts::isEligibleForDSLoopOpt(MachineLoop *ML,
} else if (SIInstrInfo::isWMMA(MI) || SIInstrInfo::isMFMA(MI)) {
++NumWMMA;
}
+
+ // Track the last barrier instruction
+ if (MI.getOpcode() == AMDGPU::S_BARRIER ||
+ MI.getOpcode() == AMDGPU::S_BARRIER_SIGNAL_IMM ||
+ MI.getOpcode() == AMDGPU::S_BARRIER_SIGNAL_ISFIRST_IMM) {
+ Info.LastBarrier = &MI;
+ }
}
// Heuristics: need significant number of DS loads and WMMA/MFMA
@@ -2745,8 +2763,145 @@ void SIInsertWaitcnts::analyzeSingleBBLoopDSLoads(MachineLoop *ML) {
return;
}
- // For now, just mark as invalid - full analysis comes in a later PR.
- Info.Valid = false;
+ // Looking for something similar to software-pipelined GEMM loops,
+ // where the last part of the loop body is prefetching data for the next
+ // iteration. Such code also has loads in the preheader block whose ordering
+ // may be significantly different from the load ordering at the end of the
+ // loop body since their orderings are not co-optimized. That can end up in
+ // rather conservative LDS wait counts.
+
+ // We only care about the LDS loads after the last barrier in the loop body,
+ // if one exists. LastBarrier was already found during eligibility check.
+ // These are likely to be prefetch loads whose results are used in the next
+ // iteration.
+ const MachineInstr *LastBarrier = Info.LastBarrier;
+
+ unsigned LoadPosition = 0;
+ bool AfterLastBarrier = (LastBarrier == nullptr); // If no barrier, track all
+
+ for (const MachineInstr &MI : *MBB) {
+ if (&MI == LastBarrier) {
+ AfterLastBarrier = true;
+ continue;
+ }
+
+ if (!AfterLastBarrier)
+ continue;
+
+ // Check DS instructions
+ if (SIInstrInfo::isDS(MI)) {
+ // DS stores after barrier not allowed - same counter, may complete
+ // out of order with loads
+ if (MI.mayStore()) {
+ LLVM_DEBUG(dbgs() << "Loop DS Wait Opt: DS store after last barrier, "
+ << "skipping\n");
+ Info.Valid = false;
+ return;
+ }
+
+ // Track DS loads - record position
+ if (MI.mayLoad()) {
+ ++LoadPosition;
+ for (const MachineOperand &Op : MI.defs()) {
+ if (Op.isReg() && Op.getReg().isPhysical() &&
+ TRI->isVGPR(*MRI, Op.getReg())) {
+ for (MCRegUnit Unit : TRI->regunits(Op.getReg())) {
+ Info.VGPRToLoadPosition[static_cast<unsigned>(Unit)] =
+ LoadPosition;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ Info.TotalDSLoads = LoadPosition;
+ Info.Valid = Info.TotalDSLoads > 0;
+
+ LLVM_DEBUG(dbgs() << "Loop DS Wait Opt: Analyzed loop at ";
+ MBB->printName(dbgs());
+ dbgs() << " - " << Info.TotalDSLoads << " DS loads"
+ << ", HasBarrier=" << (LastBarrier != nullptr)
+ << ", Valid=" << Info.Valid << "\n");
+}
+
+std::optional<unsigned>
+SIInsertWaitcnts::getOptimalDSWaitCount(MachineBasicBlock *LoopHeader,
+ const MachineInstr &MI) const {
+ auto It = LoopDSWaitOptCache.find(LoopHeader);
+ if (It == LoopDSWaitOptCache.end() || !It->second.Valid)
+ return std::nullopt;
+
+ const LoopDSWaitOptInfo &Info = It->second;
+
+ // Find the maximum load position among all VGPR operands used by MI
+ unsigned MaxLoadPosition = 0;
+ for (const MachineOperand &Op : MI.operands()) {
+ if (!Op.isReg() || !Op.isUse() || !Op.getReg().isPhysical())
+ continue;
+ if (!TRI->isVGPR(*MRI, Op.getReg()))
+ continue;
+
+ for (MCRegUnit Unit : TRI->regunits(Op.getReg())) {
+ auto PosIt = Info.VGPRToLoadPosition.find(static_cast<unsigned>(Unit));
+ if (PosIt != Info.VGPRToLoadPosition.end()) {
+ MaxLoadPosition = std::max(MaxLoadPosition, PosIt->second);
+ }
+ }
+ }
+
+ if (MaxLoadPosition == 0)
+ return std::nullopt;
+
+ // Optimal wait = TotalDSLoads - MaxLoadPosition
+ // This means we wait until all loads up to and including MaxLoadPosition
+ // have completed, but loads after it can still be in flight.
+ return Info.TotalDSLoads - MaxLoadPosition;
+}
+
+// Try to apply DS loop wait optimization to relax conservative wait counts.
+// Returns true if the wait count was modified.
+bool SIInsertWaitcnts::applyDSLoopWaitOpt(MachineInstr &MI,
+ AMDGPU::Waitcnt &Wait) {
+ // Only applies to GFX12+ with separate DS counter
+ if (!ST->hasExtendedWaitCounts())
+ return false;
+
+ // Only optimize if baseline wants a DS wait
+ if (Wait.DsCnt == ~0u)
+ return false;
+
+ MachineBasicBlock *MBB = MI.getParent();
+ MachineLoop *ML = MLI->getLoopFor(MBB);
+
+ // Only apply in single-block loop headers
+ if (!ML || ML->getNumBlocks() != 1 || ML->getHeader() != MBB)
+ return false;
+
+ auto CacheIt = LoopDSWaitOptCache.find(MBB);
+ if (CacheIt == LoopDSWaitOptCache.end() || !CacheIt->second.Valid)
+ return false;
+
+ // Only optimize if wait is conservative (less than half of loads in flight)
+ unsigned HalfLoads = CacheIt->second.TotalDSLoads / 2;
+ if (Wait.DsCnt >= HalfLoads)
+ return false;
+
+ auto OptWait = getOptimalDSWaitCount(MBB, MI);
+ if (!OptWait)
+ return false;
+
+ // Only relax the wait (increase the count), never tighten it
+ // and only when the relaxation is significant (at least 4 more)
+ if (*OptWait <= Wait.DsCnt || (*OptWait - Wait.DsCnt) < 4)
+ return false;
+
+ LLVM_DEBUG(dbgs() << "DS Loop Opt: Relaxing DsCnt from " << Wait.DsCnt
+ << " to " << *OptWait << " for: " << MI);
+ Wait.DsCnt = *OptWait;
+ // Mark that relaxation was applied so preheader flush is inserted
+ CacheIt->second.RelaxationApplied = true;
+ return true;
}
// Return true if it is better to flush the vmcnt counter in the preheader of
diff --git a/llvm/test/CodeGen/AMDGPU/waitcnt-loop-ds-opt-eligible.mir b/llvm/test/CodeGen/AMDGPU/waitcnt-loop-ds-opt-eligible.mir
index d7d2cf96ceac5..48fdabf255e6f 100644
--- a/llvm/test/CodeGen/AMDGPU/waitcnt-loop-ds-opt-eligible.mir
+++ b/llvm/test/CodeGen/AMDGPU/waitcnt-loop-ds-opt-eligible.mir
@@ -1,24 +1,44 @@
+# RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx1250 -run-pass=si-insert-waitcnts -amdgpu-waitcnt-loop-ds-opt=true -verify-machineinstrs -o - %s | FileCheck -check-prefix=OPT %s
+# RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx1250 -run-pass=si-insert-waitcnts -amdgpu-waitcnt-loop-ds-opt=false -verify-machineinstrs -o - %s | FileCheck -check-prefix=NOOPT %s
+
+# Debug output test (requires asserts build)
+# RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx1250 -run-pass=si-insert-waitcnts -amdgpu-waitcnt-loop-ds-opt=true -debug-only=si-insert-waitcnts -o /dev/null %s 2>&1 | FileCheck -check-prefix=DBG %s
# REQUIRES: asserts
-# RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx1250 -run-pass=si-insert-waitcnts -amdgpu-waitcnt-loop-ds-opt=true -debug-only=si-insert-waitcnts -o /dev/null %s 2>&1 | FileCheck %s
-# Test for DS loop wait optimization eligibility check.
-# Verifies that the pass correctly identifies single-block loops with
-# sufficient DS loads (>=16) and WMMA instructions (>=8) as eligible
-# for optimization.
+# Test for DS loop wait optimization in single-block loops with WMMA.
+# The preheader DS loads are in reverse order compared to loop body loads,
+# which causes the baseline to produce conservative waits that the optimization
+# can relax.
+#
+# Key improvement demonstrated:
+# Without opt: S_WAIT_DSCNT 0 (wait for ALL 16 loads) before first WMMA
+# With opt: S_WAIT_DSCNT 12 (wait for only 4 loads, 12 remain in flight)
#
-# CHECK: Loop DS Wait Opt: Loop at bb.1 - 16 DS loads, 8 WMMA/MFMA, {{[0-9]+}} total insts, eligible
+# DBG: Loop DS Wait Opt: Loop at bb.1 - 16 DS loads, 8 WMMA/MFMA, {{[0-9]+}} total insts, eligible
+# DBG: Loop DS Wait Opt: Analyzed loop at bb.1 - 16 DS loads, HasBarrier=1, Valid=1
+# DBG: DS Loop Opt: Relaxing DsCnt from 0 to 12 for:
--- |
define amdgpu_kernel void @ds_loop_eligible() { ret void }
...
---
+# OPT-LABEL: name: ds_loop_eligible
+# NOOPT-LABEL: name: ds_loop_eligible
name: ds_loop_eligible
tracksRegLiveness: true
machineFunctionInfo:
isEntryFunction: true
waveLimiter: false
body: |
+ ; OPT: bb.0:
+ ; OPT-NOT: S_WAIT_DSCNT
+ ; OPT: S_BRANCH %bb.1
+
+ ; NOOPT: bb.0:
+ ; NOOPT-NOT: S_WAIT_DSCNT
+ ; NOOPT: S_BRANCH %bb.1
+
bb.0:
successors: %bb.1
liveins: $sgpr0, $vgpr0
@@ -44,6 +64,17 @@ body: |
$vgpr10_vgpr11_vgpr12_vgpr13 = DS_READ_B128 $vgpr0, 0, 0, implicit $m0, implicit $exec
S_BRANCH %bb.1
+ ; OPT: bb.1:
+ ; OPT: S_WAIT_DSCNT 12
+ ; OPT-NEXT: early-clobber $vgpr80{{.*}} = V_WMMA
+ ; OPT: S_WAIT_DSCNT 8
+ ; OPT-NEXT: early-clobber $vgpr88{{.*}} = V_WMMA
+
+ ; NOOPT: bb.1:
+ ; NOOPT: S_WAIT_DSCNT 0
+ ; NOOPT-NEXT: early-clobber $vgpr80{{.*}} = V_WMMA
+ ; NOOPT-NOT: S_WAIT_DSCNT 8
+
bb.1:
; Single-block loop with WMMA and DS loads after barrier
successors: %bb.1, %bb.2
|
| LLVM_DEBUG(dbgs() << "Loop DS Wait Opt: DS store after last barrier, " | ||
| << "skipping\n"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| LLVM_DEBUG(dbgs() << "Loop DS Wait Opt: DS store after last barrier, " | |
| << "skipping\n"); | |
| LLVM_DEBUG(dbgs() << "Loop DS Wait Opt: DS store after last barrier, " | |
| "skipping\n"); |
| if (MI.mayLoad()) { | ||
| ++LoadPosition; | ||
| for (const MachineOperand &Op : MI.defs()) { | ||
| if (Op.isReg() && Op.getReg().isPhysical() && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All registers must be physical at this point
| LLVM_DEBUG(dbgs() << "Loop DS Wait Opt: Analyzed loop at "; | ||
| MBB->printName(dbgs()); | ||
| dbgs() << " - " << Info.TotalDSLoads << " DS loads" | ||
| << ", HasBarrier=" << (LastBarrier != nullptr) | ||
| << ", Valid=" << Info.Valid << "\n"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| LLVM_DEBUG(dbgs() << "Loop DS Wait Opt: Analyzed loop at "; | |
| MBB->printName(dbgs()); | |
| dbgs() << " - " << Info.TotalDSLoads << " DS loads" | |
| << ", HasBarrier=" << (LastBarrier != nullptr) | |
| << ", Valid=" << Info.Valid << "\n"); | |
| LLVM_DEBUG(dbgs() << "Loop DS Wait Opt: Analyzed loop at " | |
| << printMBBReference(MBB) << " - " << Info.TotalDSLoads << " DS loads" | |
| << ", HasBarrier=" << (LastBarrier != nullptr) | |
| << ", Valid=" << Info.Valid << '\n'); |
|
|
||
| // Find the maximum load position among all VGPR operands used by MI | ||
| unsigned MaxLoadPosition = 0; | ||
| for (const MachineOperand &Op : MI.operands()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for (const MachineOperand &Op : MI.operands()) { | |
| for (const MachineOperand &Op : MI.all_uses()) { |
Pierre-vh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I will have time to review this properly before I'm OOO.
For now I just have one question about this patch: Do we really need to do everything inside SIInsertWaitcnt ?
Is there no way to split this up a bit better, either by moving some things to utility functions in another file, or doing the heavy lifting in a prior pass that inserts pseudos, so that this pass has less work to do ?
I'd simply like to make sure the separation of concerns is respected and this is here because it makes sense, not because it's convenient
Add the DS load position analysis and wait count relaxation for single-block loops with WMMA instructions (GFX12+).
Assisted-by: Cursor / claude-4.5-opus-high
Depends on #171942.