diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp index 2ae028477ac7e..62c977fc96a89 100644 --- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp +++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/Sequence.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/CodeGen/MachineLoopInfo.h" #include "llvm/CodeGen/MachinePostDominators.h" #include "llvm/InitializePasses.h" @@ -121,8 +122,13 @@ enum RegisterMapping { SQ_MAX_PGM_VGPRS = 512, // Maximum programmable VGPRs across all targets. AGPR_OFFSET = 256, // Maximum programmable ArchVGPRs across all targets. SQ_MAX_PGM_SGPRS = 256, // Maximum programmable SGPRs across all targets. - NUM_EXTRA_VGPRS = 1, // A reserved slot for DS. - EXTRA_VGPR_LDS = 0, // An artificial register to track LDS writes. + NUM_EXTRA_VGPRS = 9, // Reserved slots for DS. + // Artificial register slots to track LDS writes into specific LDS locations + // if a location is known. When slots are exhausted or location is + // unknown use the first slot. The first slot is also always updated in + // addition to known location's slot to properly generate waits if dependent + // instruction's location is unknown. + EXTRA_VGPR_LDS = 0, NUM_ALL_VGPRS = SQ_MAX_PGM_VGPRS + NUM_EXTRA_VGPRS, // Where SGPR starts. }; @@ -297,6 +303,10 @@ class WaitcntBrackets { PendingEvents |= WaitEventMaskForInst[VS_CNT]; } + ArrayRef getLDSDMAStores() const { + return LDSDMAStores; + } + void print(raw_ostream &); void dump() { print(dbgs()); } @@ -359,6 +369,9 @@ class WaitcntBrackets { // Bitmask of the VmemTypes of VMEM instructions that might have a pending // write to each vgpr. unsigned char VgprVmemTypes[NUM_ALL_VGPRS] = {0}; + // Store representative LDS DMA operations. The only useful info here is + // alias info. One store is kept per unique AAInfo. + SmallVector LDSDMAStores; }; class SIInsertWaitcnts : public MachineFunctionPass { @@ -373,6 +386,7 @@ class SIInsertWaitcnts : public MachineFunctionPass { DenseMap PreheadersToFlush; MachineLoopInfo *MLI; MachinePostDominatorTree *PDT; + AliasAnalysis *AA = nullptr; struct BlockInfo { std::unique_ptr Incoming; @@ -415,6 +429,8 @@ class SIInsertWaitcnts : public MachineFunctionPass { AU.setPreservesCFG(); AU.addRequired(); AU.addRequired(); + AU.addUsedIfAvailable(); + AU.addPreserved(); MachineFunctionPass::getAnalysisUsage(AU); } @@ -707,7 +723,40 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII, (TII->isDS(Inst) || TII->mayWriteLDSThroughDMA(Inst))) { // MUBUF and FLAT LDS DMA operations need a wait on vmcnt before LDS // written can be accessed. A load from LDS to VMEM does not need a wait. - setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS, T, CurrScore); + unsigned Slot = 0; + for (const auto *MemOp : Inst.memoperands()) { + if (!MemOp->isStore() || + MemOp->getAddrSpace() != AMDGPUAS::LOCAL_ADDRESS) + continue; + // Comparing just AA info does not guarantee memoperands are equal + // in general, but this is so for LDS DMA in practice. + auto AAI = MemOp->getAAInfo(); + // Alias scope information gives a way to definitely identify an + // original memory object and practically produced in the module LDS + // lowering pass. If there is no scope available we will not be able + // to disambiguate LDS aliasing as after the module lowering all LDS + // is squashed into a single big object. Do not attempt to use one of + // the limited LDSDMAStores for something we will not be able to use + // anyway. + if (!AAI || !AAI.Scope) + break; + for (unsigned I = 0, E = LDSDMAStores.size(); I != E && !Slot; ++I) { + for (const auto *MemOp : LDSDMAStores[I]->memoperands()) { + if (MemOp->isStore() && AAI == MemOp->getAAInfo()) { + Slot = I + 1; + break; + } + } + } + if (Slot || LDSDMAStores.size() == NUM_EXTRA_VGPRS - 1) + break; + LDSDMAStores.push_back(&Inst); + Slot = LDSDMAStores.size(); + break; + } + setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS + Slot, T, CurrScore); + if (Slot) + setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS, T, CurrScore); } } } @@ -1183,9 +1232,27 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI, // No need to wait before load from VMEM to LDS. if (TII->mayWriteLDSThroughDMA(MI)) continue; - unsigned RegNo = SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS; + // VM_CNT is only relevant to vgpr or LDS. - ScoreBrackets.determineWait(VM_CNT, RegNo, Wait); + unsigned RegNo = SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS; + bool FoundAliasingStore = false; + // Only objects with alias scope info were added to LDSDMAScopes array. + // In the absense of the scope info we will not be able to disambiguate + // aliasing here. There is no need to try searching for a corresponding + // store slot. This is conservatively correct because in that case we + // will produce a wait using the first (general) LDS DMA wait slot which + // will wait on all of them anyway. + if (Ptr && Memop->getAAInfo() && Memop->getAAInfo().Scope) { + const auto &LDSDMAStores = ScoreBrackets.getLDSDMAStores(); + for (unsigned I = 0, E = LDSDMAStores.size(); I != E; ++I) { + if (MI.mayAlias(AA, *LDSDMAStores[I], true)) { + FoundAliasingStore = true; + ScoreBrackets.determineWait(VM_CNT, RegNo + I + 1, Wait); + } + } + } + if (!FoundAliasingStore) + ScoreBrackets.determineWait(VM_CNT, RegNo, Wait); if (Memop->isStore()) { ScoreBrackets.determineWait(EXP_CNT, RegNo, Wait); } @@ -1834,6 +1901,8 @@ bool SIInsertWaitcnts::runOnMachineFunction(MachineFunction &MF) { const SIMachineFunctionInfo *MFI = MF.getInfo(); MLI = &getAnalysis(); PDT = &getAnalysis(); + if (auto AAR = getAnalysisIfAvailable()) + AA = &AAR->getAAResults(); ForceEmitZeroWaitcnts = ForceEmitZeroFlag; for (auto T : inst_counter_types()) diff --git a/llvm/test/CodeGen/AMDGPU/lds-dma-waits.ll b/llvm/test/CodeGen/AMDGPU/lds-dma-waits.ll index 1d6968a86a2e2..5cb3ca0b80b66 100644 --- a/llvm/test/CodeGen/AMDGPU/lds-dma-waits.ll +++ b/llvm/test/CodeGen/AMDGPU/lds-dma-waits.ll @@ -3,21 +3,23 @@ @lds.0 = internal addrspace(3) global [64 x float] poison, align 16 @lds.1 = internal addrspace(3) global [64 x float] poison, align 16 +@lds.2 = internal addrspace(3) global [64 x float] poison, align 16 +@lds.3 = internal addrspace(3) global [64 x float] poison, align 16 +@lds.4 = internal addrspace(3) global [64 x float] poison, align 16 +@lds.5 = internal addrspace(3) global [64 x float] poison, align 16 +@lds.6 = internal addrspace(3) global [64 x float] poison, align 16 +@lds.7 = internal addrspace(3) global [64 x float] poison, align 16 +@lds.8 = internal addrspace(3) global [64 x float] poison, align 16 +@lds.9 = internal addrspace(3) global [64 x float] poison, align 16 declare void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) nocapture, i32 %size, i32 %voffset, i32 %soffset, i32 %offset, i32 %aux) declare void @llvm.amdgcn.global.load.lds(ptr addrspace(1) nocapture %gptr, ptr addrspace(3) nocapture %lptr, i32 %size, i32 %offset, i32 %aux) -; FIXME: vmcnt(0) is too strong, it shall use vmcnt(2) before the first -; ds_read_b32 and vmcnt(0) before the second. - ; GCN-LABEL: {{^}}buffer_load_lds_dword_2_arrays: ; GCN-COUNT-4: buffer_load_dword -; GCN: s_waitcnt vmcnt(0) +; GCN: s_waitcnt vmcnt(2) ; GCN: ds_read_b32 - -; FIXME: -; GCN-NOT: s_waitcnt - +; GCN: s_waitcnt vmcnt(0) ; GCN: ds_read_b32 define amdgpu_kernel void @buffer_load_lds_dword_2_arrays(<4 x i32> %rsrc, i32 %i1, i32 %i2, ptr addrspace(1) %out) { main_body: @@ -43,15 +45,9 @@ main_body: ; GCN-COUNT-4: global_load_dword ; GFX9: s_waitcnt vmcnt(0) ; GFX9-COUNT-2: ds_read_b32 - -; FIXME: can be vmcnt(2) - -; GFX10: s_waitcnt vmcnt(0) +; GFX10: s_waitcnt vmcnt(2) ; GFX10: ds_read_b32 - -; FIXME: -; GFX10-NOT: s_waitcnt - +; GFX10: s_waitcnt vmcnt(0) ; GFX10: ds_read_b32 define amdgpu_kernel void @global_load_lds_dword_2_arrays(ptr addrspace(1) nocapture %gptr, i32 %i1, i32 %i2, ptr addrspace(1) %out) { main_body: @@ -70,4 +66,89 @@ main_body: ret void } +; There are 8 pseudo registers defined to track LDS DMA dependencies. +; When exhausted we default to vmcnt(0). + +; GCN-LABEL: {{^}}buffer_load_lds_dword_10_arrays: +; GCN-COUNT-10: buffer_load_dword +; GCN: s_waitcnt vmcnt(8) +; GCN: ds_read_b32 +; GCN: s_waitcnt vmcnt(7) +; GCN: ds_read_b32 +; GCN: s_waitcnt vmcnt(6) +; GCN: ds_read_b32 +; GCN: s_waitcnt vmcnt(5) +; GCN: ds_read_b32 +; GCN: s_waitcnt vmcnt(4) +; GCN: ds_read_b32 +; GCN: s_waitcnt vmcnt(3) +; GCN: ds_read_b32 +; GCN: s_waitcnt vmcnt(2) +; GCN-NOT: s_waitcnt vmcnt +; GCN: ds_read_b32 +; GCN: s_waitcnt vmcnt(0) +; GCN: ds_read_b32 +define amdgpu_kernel void @buffer_load_lds_dword_10_arrays(<4 x i32> %rsrc, i32 %i1, i32 %i2, i32 %i3, i32 %i4, i32 %i5, i32 %i6, i32 %i7, i32 %i8, i32 %i9, ptr addrspace(1) %out) { +main_body: + call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.0, i32 4, i32 0, i32 0, i32 0, i32 0) + call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.1, i32 4, i32 0, i32 0, i32 0, i32 0) + call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.2, i32 4, i32 0, i32 0, i32 0, i32 0) + call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.3, i32 4, i32 0, i32 0, i32 0, i32 0) + call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.4, i32 4, i32 0, i32 0, i32 0, i32 0) + call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.5, i32 4, i32 0, i32 0, i32 0, i32 0) + call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.6, i32 4, i32 0, i32 0, i32 0, i32 0) + call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.7, i32 4, i32 0, i32 0, i32 0, i32 0) + call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.8, i32 4, i32 0, i32 0, i32 0, i32 0) + call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.9, i32 4, i32 0, i32 0, i32 0, i32 0) + %gep.0 = getelementptr float, ptr addrspace(3) @lds.0, i32 %i1 + %gep.1 = getelementptr float, ptr addrspace(3) @lds.1, i32 %i2 + %gep.2 = getelementptr float, ptr addrspace(3) @lds.2, i32 %i2 + %gep.3 = getelementptr float, ptr addrspace(3) @lds.3, i32 %i2 + %gep.4 = getelementptr float, ptr addrspace(3) @lds.4, i32 %i2 + %gep.5 = getelementptr float, ptr addrspace(3) @lds.5, i32 %i2 + %gep.6 = getelementptr float, ptr addrspace(3) @lds.6, i32 %i2 + %gep.7 = getelementptr float, ptr addrspace(3) @lds.7, i32 %i2 + %gep.8 = getelementptr float, ptr addrspace(3) @lds.8, i32 %i2 + %gep.9 = getelementptr float, ptr addrspace(3) @lds.9, i32 %i2 + %val.0 = load float, ptr addrspace(3) %gep.0, align 4 + call void @llvm.amdgcn.wave.barrier() + %val.1 = load float, ptr addrspace(3) %gep.1, align 4 + call void @llvm.amdgcn.wave.barrier() + %val.2 = load float, ptr addrspace(3) %gep.2, align 4 + call void @llvm.amdgcn.wave.barrier() + %val.3 = load float, ptr addrspace(3) %gep.3, align 4 + call void @llvm.amdgcn.wave.barrier() + %val.4 = load float, ptr addrspace(3) %gep.4, align 4 + call void @llvm.amdgcn.wave.barrier() + %val.5 = load float, ptr addrspace(3) %gep.5, align 4 + call void @llvm.amdgcn.wave.barrier() + %val.6 = load float, ptr addrspace(3) %gep.6, align 4 + call void @llvm.amdgcn.wave.barrier() + %val.7 = load float, ptr addrspace(3) %gep.7, align 4 + call void @llvm.amdgcn.wave.barrier() + %val.8 = load float, ptr addrspace(3) %gep.8, align 4 + call void @llvm.amdgcn.wave.barrier() + %val.9 = load float, ptr addrspace(3) %gep.9, align 4 + %out.gep.1 = getelementptr float, ptr addrspace(1) %out, i32 1 + %out.gep.2 = getelementptr float, ptr addrspace(1) %out, i32 2 + %out.gep.3 = getelementptr float, ptr addrspace(1) %out, i32 3 + %out.gep.4 = getelementptr float, ptr addrspace(1) %out, i32 4 + %out.gep.5 = getelementptr float, ptr addrspace(1) %out, i32 5 + %out.gep.6 = getelementptr float, ptr addrspace(1) %out, i32 6 + %out.gep.7 = getelementptr float, ptr addrspace(1) %out, i32 7 + %out.gep.8 = getelementptr float, ptr addrspace(1) %out, i32 8 + %out.gep.9 = getelementptr float, ptr addrspace(1) %out, i32 9 + store float %val.0, ptr addrspace(1) %out + store float %val.1, ptr addrspace(1) %out.gep.1 + store float %val.2, ptr addrspace(1) %out.gep.2 + store float %val.3, ptr addrspace(1) %out.gep.3 + store float %val.4, ptr addrspace(1) %out.gep.4 + store float %val.5, ptr addrspace(1) %out.gep.5 + store float %val.6, ptr addrspace(1) %out.gep.6 + store float %val.7, ptr addrspace(1) %out.gep.7 + store float %val.8, ptr addrspace(1) %out.gep.8 + store float %val.9, ptr addrspace(1) %out.gep.9 + ret void +} + declare void @llvm.amdgcn.wave.barrier()