Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 93 additions & 6 deletions llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ class SIInsertWaitcnts {
mutable bool RelaxationApplied = false;
// Pointer to the last barrier in the loop (found during eligibility check)
const MachineInstr *LastBarrier = nullptr;
// The wait count "floor" established by same-iteration uses/overwrites.
// When a DS load result is used in the same iteration, the baseline inserts
// a wait. This floor indicates the expected counter state after that wait.
// WMMAs that only use flushed loads can rely on this floor.
unsigned FloorWaitCount = 0;
};

// Cache of loop DS wait optimization info, keyed by loop header MBB.
Expand Down Expand Up @@ -2775,9 +2780,21 @@ void SIInsertWaitcnts::analyzeSingleBBLoopDSLoads(MachineLoop *ML) {
// if one exists. LastBarrier was already found during eligibility check.
// These are likely to be prefetch loads whose results are used in the next
// iteration.
//
// If a load result is used or overwritten within the same iteration, the
// baseline will insert a wait before that instruction. Since DS loads
// complete in FIFO order, that wait also completes all earlier loads. So we
// can drop those "flushed" loads from our tracking and only consider
// subsequent loads as true prefetch loads. Overwrites also require the load
// to complete first to avoid write-after-write races.
const MachineInstr *LastBarrier = Info.LastBarrier;

// Single pass: track DS load destinations, handle uses (which flush prior
// loads) and detect overwrites (which invalidate our analysis).
// TrackedLoads: (Register, Position) pairs for checking uses/overwrites
SmallVector<std::pair<Register, unsigned>, 64> TrackedLoads;
unsigned LoadPosition = 0;
unsigned LastFlushedPosition = 0; // Loads up to this position will be flushed
bool AfterLastBarrier = (LastBarrier == nullptr); // If no barrier, track all

for (const MachineInstr &MI : *MBB) {
Expand All @@ -2789,6 +2806,42 @@ void SIInsertWaitcnts::analyzeSingleBBLoopDSLoads(MachineLoop *ML) {
if (!AfterLastBarrier)
continue;

// Check for instructions that write to LDS through DMA (global_load_lds,
// etc). These write to LDS but aren't DS instructions.
// Bail out if any appear after the barrier.
if (SIInstrInfo::mayWriteLDSThroughDMA(MI)) {
LLVM_DEBUG(
dbgs() << "Loop DS Wait Opt: LDS DMA write after last barrier, "
<< "skipping\n");
Info.Valid = false;
return;
}

// Check for tensor_load_to_lds instructions (MIMG, not caught by above)
if (MI.getOpcode() == AMDGPU::TENSOR_LOAD_TO_LDS ||
MI.getOpcode() == AMDGPU::TENSOR_LOAD_TO_LDS_D2) {
LLVM_DEBUG(dbgs() << "Loop DS Wait Opt: tensor_load_to_lds after last "
<< "barrier, skipping\n");
Info.Valid = false;
return;
}

// Check if this instruction uses or overwrites any tracked DS load
// destination. If so, baseline will have inserted a wait that flushes
// all loads up to that position (since DS loads complete in order).
// Overwrites also require the load to complete first to avoid races.
for (auto &[Reg, Position] : TrackedLoads) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto &[Reg, Position] : TrackedLoads) {
for (auto [Reg, Position] : TrackedLoads) {

if (Position <= LastFlushedPosition)
continue; // Already flushed

if (MI.readsRegister(Reg, TRI) || MI.modifiesRegister(Reg, TRI)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These queries can be fused to any register reference?

LLVM_DEBUG(dbgs() << "Loop DS Wait Opt: DS load at position "
<< Position << " used/overwritten in same iteration, "
<< "flushing positions 1-" << Position << "\n");
LastFlushedPosition = std::max(LastFlushedPosition, Position);
}
}

// Check DS instructions
if (SIInstrInfo::isDS(MI)) {
// DS stores after barrier not allowed - same counter, may complete
Expand All @@ -2806,6 +2859,7 @@ void SIInsertWaitcnts::analyzeSingleBBLoopDSLoads(MachineLoop *ML) {
for (const MachineOperand &Op : MI.defs()) {
if (Op.isReg() && Op.getReg().isPhysical() &&
TRI->isVGPR(*MRI, Op.getReg())) {
TrackedLoads.emplace_back(Op.getReg(), LoadPosition);
for (MCRegUnit Unit : TRI->regunits(Op.getReg())) {
Info.VGPRToLoadPosition[static_cast<unsigned>(Unit)] =
LoadPosition;
Expand All @@ -2816,12 +2870,32 @@ void SIInsertWaitcnts::analyzeSingleBBLoopDSLoads(MachineLoop *ML) {
}
}

Info.TotalDSLoads = LoadPosition;
// Filter out flushed loads and renumber remaining ones
// Also compute the floor wait count - the wait established by same-iteration
// use
if (LastFlushedPosition > 0) {
DenseMap<unsigned, unsigned> NewMap;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Number of register units is a fixed size

for (auto &[RegUnit, Position] : Info.VGPRToLoadPosition) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto &[RegUnit, Position] : Info.VGPRToLoadPosition) {
for (auto [RegUnit, Position] : Info.VGPRToLoadPosition) {

if (Position > LastFlushedPosition) {
NewMap[RegUnit] = Position - LastFlushedPosition;
}
}
Info.VGPRToLoadPosition = std::move(NewMap);
// FloorWaitCount: when same-iteration use waits for load N, it leaves
// (TotalLoads - N) loads in flight. For the next iteration's WMMAs,
// any that only use flushed loads are already covered by this wait.
Info.FloorWaitCount = LoadPosition - LastFlushedPosition;
} else {
Info.FloorWaitCount = 0;
}

Info.TotalDSLoads = LoadPosition - LastFlushedPosition;
Info.Valid = Info.TotalDSLoads > 0;

LLVM_DEBUG(dbgs() << "Loop DS Wait Opt: Analyzed loop at ";
MBB->printName(dbgs());
dbgs() << " - " << Info.TotalDSLoads << " DS loads"
dbgs() << " - " << Info.TotalDSLoads << " DS loads, "
<< "FloorWaitCount=" << Info.FloorWaitCount
<< ", HasBarrier=" << (LastBarrier != nullptr)
<< ", Valid=" << Info.Valid << "\n");
}
Expand Down Expand Up @@ -2851,13 +2925,26 @@ SIInsertWaitcnts::getOptimalDSWaitCount(MachineBasicBlock *LoopHeader,
}
}

if (MaxLoadPosition == 0)
return std::nullopt;

// Optimal wait = TotalDSLoads - MaxLoadPosition
// This means we wait until all loads up to and including MaxLoadPosition
// have completed, but loads after it can still be in flight.
return Info.TotalDSLoads - MaxLoadPosition;
unsigned OptimalWait = Info.TotalDSLoads - MaxLoadPosition;

// If MaxLoadPosition == 0, this instruction only uses flushed loads
// (whose results are used in the same iteration). The same-iteration use
// will insert a wait that leaves FloorWaitCount loads in flight.
// So this instruction's needs are covered if OptimalWait >= FloorWaitCount.
// We return FloorWaitCount to indicate "can relax to this level".
if (MaxLoadPosition == 0 && Info.FloorWaitCount > 0) {
// All operands are from flushed loads - covered by same-iteration use's
// wait
return Info.FloorWaitCount;
}

if (MaxLoadPosition == 0)
return std::nullopt;

return OptimalWait;
}

// Try to apply DS loop wait optimization to relax conservative wait counts.
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/AMDGPU/waitcnt-loop-ds-opt-eligible.mir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# With opt: S_WAIT_DSCNT 12 (wait for only 4 loads, 12 remain in flight)
#
# DBG: Loop DS Wait Opt: Loop at bb.1 - 16 DS loads, 8 WMMA/MFMA, {{[0-9]+}} total insts, eligible
# DBG: Loop DS Wait Opt: Analyzed loop at bb.1 - 16 DS loads, HasBarrier=1, Valid=1
# DBG: Loop DS Wait Opt: Analyzed loop at bb.1 - 16 DS loads, FloorWaitCount=0, HasBarrier=1, Valid=1
# DBG: DS Loop Opt: Relaxing DsCnt from 0 to 12 for:
# DBG: DS Loop Opt: Inserted DS_CNT flush in preheader bb.0 for loop at bb.1

Expand Down
109 changes: 109 additions & 0 deletions llvm/test/CodeGen/AMDGPU/waitcnt-loop-ds-opt-no-improvement.mir
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx1250 -run-pass=si-insert-waitcnts -amdgpu-waitcnt-loop-ds-opt=true -verify-machineinstrs -o - %s | FileCheck -check-prefix=OPT %s
# RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx1250 -run-pass=si-insert-waitcnts -amdgpu-waitcnt-loop-ds-opt=false -verify-machineinstrs -o - %s | FileCheck -check-prefix=NOOPT %s

# Test: preheader loads with slight reordering
# Baseline is close to optimal, improvement below threshold of 4
# Both OPT and NOOPT should produce the same wait count.

--- |
define amdgpu_kernel void @ds_loop_no_improvement() { ret void }
...

---
# OPT-LABEL: name: ds_loop_no_improvement
# NOOPT-LABEL: name: ds_loop_no_improvement
name: ds_loop_no_improvement
tracksRegLiveness: true
machineFunctionInfo:
isEntryFunction: true
waveLimiter: false
body: |
; Check preheader: neither OPT nor NOOPT adds flush (optimization doesn't apply)
; OPT: bb.0:
; OPT-NOT: S_WAIT_DSCNT
; OPT: S_BRANCH %bb.1

; NOOPT: bb.0:
; NOOPT-NOT: S_WAIT_DSCNT
; NOOPT: S_BRANCH %bb.1

bb.0:
successors: %bb.1
liveins: $sgpr0, $vgpr0

; Preheader: DS loads with slight reordering (Pulled up Loads #5 and #6)
; This creates a small mismatch with loop body order, so baseline is
; close to optimal but not perfect. Improvement should be below threshold.
$vgpr26_vgpr27_vgpr28_vgpr29 = DS_READ_B128 $vgpr0, 64, 0, implicit $m0, implicit $exec
$vgpr30_vgpr31_vgpr32_vgpr33 = DS_READ_B128 $vgpr0, 80, 0, implicit $m0, implicit $exec
$vgpr10_vgpr11_vgpr12_vgpr13 = DS_READ_B128 $vgpr0, 0, 0, implicit $m0, implicit $exec
$vgpr14_vgpr15_vgpr16_vgpr17 = DS_READ_B128 $vgpr0, 16, 0, implicit $m0, implicit $exec
$vgpr18_vgpr19_vgpr20_vgpr21 = DS_READ_B128 $vgpr0, 32, 0, implicit $m0, implicit $exec
$vgpr22_vgpr23_vgpr24_vgpr25 = DS_READ_B128 $vgpr0, 48, 0, implicit $m0, implicit $exec
$vgpr34_vgpr35_vgpr36_vgpr37 = DS_READ_B128 $vgpr0, 96, 0, implicit $m0, implicit $exec
$vgpr38_vgpr39_vgpr40_vgpr41 = DS_READ_B128 $vgpr0, 112, 0, implicit $m0, implicit $exec
$vgpr42_vgpr43_vgpr44_vgpr45 = DS_READ_B128 $vgpr0, 128, 0, implicit $m0, implicit $exec
$vgpr46_vgpr47_vgpr48_vgpr49 = DS_READ_B128 $vgpr0, 144, 0, implicit $m0, implicit $exec
$vgpr50_vgpr51_vgpr52_vgpr53 = DS_READ_B128 $vgpr0, 160, 0, implicit $m0, implicit $exec
$vgpr54_vgpr55_vgpr56_vgpr57 = DS_READ_B128 $vgpr0, 176, 0, implicit $m0, implicit $exec
$vgpr58_vgpr59_vgpr60_vgpr61 = DS_READ_B128 $vgpr0, 192, 0, implicit $m0, implicit $exec
$vgpr62_vgpr63_vgpr64_vgpr65 = DS_READ_B128 $vgpr0, 208, 0, implicit $m0, implicit $exec
$vgpr66_vgpr67_vgpr68_vgpr69 = DS_READ_B128 $vgpr0, 224, 0, implicit $m0, implicit $exec
$vgpr70_vgpr71_vgpr72_vgpr73 = DS_READ_B128 $vgpr0, 240, 0, implicit $m0, implicit $exec
S_BRANCH %bb.1

; Preheader has loads #5,#6 first, then #1-4, then rest.
; First WMMA uses loads #1-4 (vgpr10-25), which are at positions 3-6 in preheader.
; Baseline produces S_WAIT_DSCNT 10 (wait for first 6 loads to complete).
; Optimal would be 12 (only need first 4 loads in loop body order).
; Improvement = 12 - 10 = 2, which is below threshold of 4, so no relaxation.
; OPT: bb.1:
; OPT: S_WAIT_DSCNT 10
; OPT-NEXT: early-clobber $vgpr80{{.*}} = V_WMMA

; NOOPT: bb.1:
; NOOPT: S_WAIT_DSCNT 10
; NOOPT-NEXT: early-clobber $vgpr80{{.*}} = V_WMMA

bb.1:
successors: %bb.1, %bb.2
liveins: $sgpr0, $vgpr0, $vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17, $vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23_vgpr24_vgpr25, $vgpr26_vgpr27_vgpr28_vgpr29_vgpr30_vgpr31_vgpr32_vgpr33, $vgpr34_vgpr35_vgpr36_vgpr37_vgpr38_vgpr39_vgpr40_vgpr41, $vgpr42_vgpr43_vgpr44_vgpr45_vgpr46_vgpr47_vgpr48_vgpr49, $vgpr50_vgpr51_vgpr52_vgpr53_vgpr54_vgpr55_vgpr56_vgpr57, $vgpr58_vgpr59_vgpr60_vgpr61_vgpr62_vgpr63_vgpr64_vgpr65, $vgpr66_vgpr67_vgpr68_vgpr69_vgpr70_vgpr71_vgpr72_vgpr73, $vgpr80_vgpr81_vgpr82_vgpr83_vgpr84_vgpr85_vgpr86_vgpr87, $vgpr88_vgpr89_vgpr90_vgpr91_vgpr92_vgpr93_vgpr94_vgpr95

; First WMMA uses vgpr10-25 (loads #1-4)
early-clobber $vgpr80_vgpr81_vgpr82_vgpr83_vgpr84_vgpr85_vgpr86_vgpr87 = V_WMMA_F32_16X16X32_F16_w32_twoaddr 8, $vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17, 8, $vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23_vgpr24_vgpr25, 8, killed $vgpr80_vgpr81_vgpr82_vgpr83_vgpr84_vgpr85_vgpr86_vgpr87, 0, 0, 0, 0, implicit $exec
early-clobber $vgpr88_vgpr89_vgpr90_vgpr91_vgpr92_vgpr93_vgpr94_vgpr95 = V_WMMA_F32_16X16X32_F16_w32_twoaddr 8, $vgpr26_vgpr27_vgpr28_vgpr29_vgpr30_vgpr31_vgpr32_vgpr33, 8, $vgpr34_vgpr35_vgpr36_vgpr37_vgpr38_vgpr39_vgpr40_vgpr41, 8, killed $vgpr88_vgpr89_vgpr90_vgpr91_vgpr92_vgpr93_vgpr94_vgpr95, 0, 0, 0, 0, implicit $exec
early-clobber $vgpr80_vgpr81_vgpr82_vgpr83_vgpr84_vgpr85_vgpr86_vgpr87 = V_WMMA_F32_16X16X32_F16_w32_twoaddr 8, $vgpr42_vgpr43_vgpr44_vgpr45_vgpr46_vgpr47_vgpr48_vgpr49, 8, $vgpr50_vgpr51_vgpr52_vgpr53_vgpr54_vgpr55_vgpr56_vgpr57, 8, killed $vgpr80_vgpr81_vgpr82_vgpr83_vgpr84_vgpr85_vgpr86_vgpr87, 0, 0, 0, 0, implicit $exec
early-clobber $vgpr88_vgpr89_vgpr90_vgpr91_vgpr92_vgpr93_vgpr94_vgpr95 = V_WMMA_F32_16X16X32_F16_w32_twoaddr 8, $vgpr58_vgpr59_vgpr60_vgpr61_vgpr62_vgpr63_vgpr64_vgpr65, 8, $vgpr66_vgpr67_vgpr68_vgpr69_vgpr70_vgpr71_vgpr72_vgpr73, 8, killed $vgpr88_vgpr89_vgpr90_vgpr91_vgpr92_vgpr93_vgpr94_vgpr95, 0, 0, 0, 0, implicit $exec
early-clobber $vgpr80_vgpr81_vgpr82_vgpr83_vgpr84_vgpr85_vgpr86_vgpr87 = V_WMMA_F32_16X16X32_F16_w32_twoaddr 8, $vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17, 8, $vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23_vgpr24_vgpr25, 8, killed $vgpr80_vgpr81_vgpr82_vgpr83_vgpr84_vgpr85_vgpr86_vgpr87, 0, 0, 0, 0, implicit $exec
early-clobber $vgpr88_vgpr89_vgpr90_vgpr91_vgpr92_vgpr93_vgpr94_vgpr95 = V_WMMA_F32_16X16X32_F16_w32_twoaddr 8, $vgpr26_vgpr27_vgpr28_vgpr29_vgpr30_vgpr31_vgpr32_vgpr33, 8, $vgpr34_vgpr35_vgpr36_vgpr37_vgpr38_vgpr39_vgpr40_vgpr41, 8, killed $vgpr88_vgpr89_vgpr90_vgpr91_vgpr92_vgpr93_vgpr94_vgpr95, 0, 0, 0, 0, implicit $exec
early-clobber $vgpr80_vgpr81_vgpr82_vgpr83_vgpr84_vgpr85_vgpr86_vgpr87 = V_WMMA_F32_16X16X32_F16_w32_twoaddr 8, $vgpr42_vgpr43_vgpr44_vgpr45_vgpr46_vgpr47_vgpr48_vgpr49, 8, $vgpr50_vgpr51_vgpr52_vgpr53_vgpr54_vgpr55_vgpr56_vgpr57, 8, killed $vgpr80_vgpr81_vgpr82_vgpr83_vgpr84_vgpr85_vgpr86_vgpr87, 0, 0, 0, 0, implicit $exec
early-clobber $vgpr88_vgpr89_vgpr90_vgpr91_vgpr92_vgpr93_vgpr94_vgpr95 = V_WMMA_F32_16X16X32_F16_w32_twoaddr 8, $vgpr58_vgpr59_vgpr60_vgpr61_vgpr62_vgpr63_vgpr64_vgpr65, 8, $vgpr66_vgpr67_vgpr68_vgpr69_vgpr70_vgpr71_vgpr72_vgpr73, 8, killed $vgpr88_vgpr89_vgpr90_vgpr91_vgpr92_vgpr93_vgpr94_vgpr95, 0, 0, 0, 0, implicit $exec

S_BARRIER

; Prefetch DS loads for next iteration (FORWARD order)
$vgpr10_vgpr11_vgpr12_vgpr13 = DS_READ_B128 $vgpr0, 0, 0, implicit $m0, implicit $exec
$vgpr14_vgpr15_vgpr16_vgpr17 = DS_READ_B128 $vgpr0, 16, 0, implicit $m0, implicit $exec
$vgpr18_vgpr19_vgpr20_vgpr21 = DS_READ_B128 $vgpr0, 32, 0, implicit $m0, implicit $exec
$vgpr22_vgpr23_vgpr24_vgpr25 = DS_READ_B128 $vgpr0, 48, 0, implicit $m0, implicit $exec
$vgpr26_vgpr27_vgpr28_vgpr29 = DS_READ_B128 $vgpr0, 64, 0, implicit $m0, implicit $exec
$vgpr30_vgpr31_vgpr32_vgpr33 = DS_READ_B128 $vgpr0, 80, 0, implicit $m0, implicit $exec
$vgpr34_vgpr35_vgpr36_vgpr37 = DS_READ_B128 $vgpr0, 96, 0, implicit $m0, implicit $exec
$vgpr38_vgpr39_vgpr40_vgpr41 = DS_READ_B128 $vgpr0, 112, 0, implicit $m0, implicit $exec
$vgpr42_vgpr43_vgpr44_vgpr45 = DS_READ_B128 $vgpr0, 128, 0, implicit $m0, implicit $exec
$vgpr46_vgpr47_vgpr48_vgpr49 = DS_READ_B128 $vgpr0, 144, 0, implicit $m0, implicit $exec
$vgpr50_vgpr51_vgpr52_vgpr53 = DS_READ_B128 $vgpr0, 160, 0, implicit $m0, implicit $exec
$vgpr54_vgpr55_vgpr56_vgpr57 = DS_READ_B128 $vgpr0, 176, 0, implicit $m0, implicit $exec
$vgpr58_vgpr59_vgpr60_vgpr61 = DS_READ_B128 $vgpr0, 192, 0, implicit $m0, implicit $exec
$vgpr62_vgpr63_vgpr64_vgpr65 = DS_READ_B128 $vgpr0, 208, 0, implicit $m0, implicit $exec
$vgpr66_vgpr67_vgpr68_vgpr69 = DS_READ_B128 $vgpr0, 224, 0, implicit $m0, implicit $exec
$vgpr70_vgpr71_vgpr72_vgpr73 = DS_READ_B128 $vgpr0, 240, 0, implicit $m0, implicit $exec

$sgpr0 = S_ADD_I32 $sgpr0, -1, implicit-def $scc
S_CBRANCH_SCC1 %bb.1, implicit $scc
S_BRANCH %bb.2

bb.2:
S_ENDPGM 0
...

Loading