diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp index 3f9a1f492ace5..76bfce8c0f6f9 100644 --- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp +++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp @@ -418,15 +418,14 @@ class WaitcntGeneratorGFX12Plus : public WaitcntGenerator { class SIInsertWaitcnts { public: const GCNSubtarget *ST; + const SIInstrInfo *TII = nullptr; + const SIRegisterInfo *TRI = nullptr; + const MachineRegisterInfo *MRI = nullptr; InstCounterType SmemAccessCounter; InstCounterType MaxCounter; const unsigned *WaitEventMaskForInst; private: - const SIInstrInfo *TII = nullptr; - const SIRegisterInfo *TRI = nullptr; - const MachineRegisterInfo *MRI = nullptr; - DenseMap SLoadAddresses; DenseMap PreheadersToFlush; MachineLoopInfo *MLI; @@ -631,8 +630,6 @@ class WaitcntBrackets { bool merge(const WaitcntBrackets &Other); RegInterval getRegInterval(const MachineInstr *MI, - const MachineRegisterInfo *MRI, - const SIRegisterInfo *TRI, const MachineOperand &Op) const; bool counterOutOfOrder(InstCounterType T) const; @@ -650,9 +647,7 @@ class WaitcntBrackets { void applyWaitcnt(const AMDGPU::Waitcnt &Wait); void applyWaitcnt(InstCounterType T, unsigned Count); void applyXcnt(const AMDGPU::Waitcnt &Wait); - void updateByEvent(const SIInstrInfo *TII, const SIRegisterInfo *TRI, - const MachineRegisterInfo *MRI, WaitEventType E, - MachineInstr &MI); + void updateByEvent(WaitEventType E, MachineInstr &MI); unsigned hasPendingEvent() const { return PendingEvents; } unsigned hasPendingEvent(WaitEventType E) const { @@ -761,10 +756,8 @@ class WaitcntBrackets { void setScoreByInterval(RegInterval Interval, InstCounterType CntTy, unsigned Score); - void setScoreByOperand(const MachineInstr *MI, const SIRegisterInfo *TRI, - const MachineRegisterInfo *MRI, - const MachineOperand &Op, InstCounterType CntTy, - unsigned Val); + void setScoreByOperand(const MachineInstr *MI, const MachineOperand &Op, + InstCounterType CntTy, unsigned Val); const SIInsertWaitcnts *Context; @@ -821,12 +814,13 @@ class SIInsertWaitcntsLegacy : public MachineFunctionPass { } // end anonymous namespace RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI, - const MachineRegisterInfo *MRI, - const SIRegisterInfo *TRI, const MachineOperand &Op) const { if (Op.getReg() == AMDGPU::SCC) return {SCC, SCC + 1}; + const SIRegisterInfo *TRI = Context->TRI; + const MachineRegisterInfo *MRI = Context->MRI; + if (!TRI->isInAllocatableClass(Op.getReg())) return {-1, -1}; @@ -891,11 +885,9 @@ void WaitcntBrackets::setScoreByInterval(RegInterval Interval, } void WaitcntBrackets::setScoreByOperand(const MachineInstr *MI, - const SIRegisterInfo *TRI, - const MachineRegisterInfo *MRI, const MachineOperand &Op, InstCounterType CntTy, unsigned Score) { - RegInterval Interval = getRegInterval(MI, MRI, TRI, Op); + RegInterval Interval = getRegInterval(MI, Op); setScoreByInterval(Interval, CntTy, Score); } @@ -927,10 +919,7 @@ bool WaitcntBrackets::hasPointSamplePendingVmemTypes( return hasOtherPendingVmemTypes(Interval, VMEM_NOSAMPLER); } -void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII, - const SIRegisterInfo *TRI, - const MachineRegisterInfo *MRI, - WaitEventType E, MachineInstr &Inst) { +void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) { InstCounterType T = eventCounter(Context->WaitEventMaskForInst, E); unsigned UB = getScoreUB(T); @@ -943,6 +932,10 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII, PendingEvents |= 1 << E; setScoreUB(T, CurrScore); + const SIRegisterInfo *TRI = Context->TRI; + const MachineRegisterInfo *MRI = Context->MRI; + const SIInstrInfo *TII = Context->TII; + if (T == EXP_CNT) { // Put score on the source vgprs. If this is a store, just use those // specific register(s). @@ -950,59 +943,56 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII, // All GDS operations must protect their address register (same as // export.) if (const auto *AddrOp = TII->getNamedOperand(Inst, AMDGPU::OpName::addr)) - setScoreByOperand(&Inst, TRI, MRI, *AddrOp, EXP_CNT, CurrScore); + setScoreByOperand(&Inst, *AddrOp, EXP_CNT, CurrScore); if (Inst.mayStore()) { if (const auto *Data0 = TII->getNamedOperand(Inst, AMDGPU::OpName::data0)) - setScoreByOperand(&Inst, TRI, MRI, *Data0, EXP_CNT, CurrScore); + setScoreByOperand(&Inst, *Data0, EXP_CNT, CurrScore); if (const auto *Data1 = TII->getNamedOperand(Inst, AMDGPU::OpName::data1)) - setScoreByOperand(&Inst, TRI, MRI, *Data1, EXP_CNT, CurrScore); + setScoreByOperand(&Inst, *Data1, EXP_CNT, CurrScore); } else if (SIInstrInfo::isAtomicRet(Inst) && !SIInstrInfo::isGWS(Inst) && Inst.getOpcode() != AMDGPU::DS_APPEND && Inst.getOpcode() != AMDGPU::DS_CONSUME && Inst.getOpcode() != AMDGPU::DS_ORDERED_COUNT) { for (const MachineOperand &Op : Inst.all_uses()) { if (TRI->isVectorRegister(*MRI, Op.getReg())) - setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore); + setScoreByOperand(&Inst, Op, EXP_CNT, CurrScore); } } } else if (TII->isFLAT(Inst)) { if (Inst.mayStore()) { - setScoreByOperand(&Inst, TRI, MRI, + setScoreByOperand(&Inst, *TII->getNamedOperand(Inst, AMDGPU::OpName::data), EXP_CNT, CurrScore); } else if (SIInstrInfo::isAtomicRet(Inst)) { - setScoreByOperand(&Inst, TRI, MRI, + setScoreByOperand(&Inst, *TII->getNamedOperand(Inst, AMDGPU::OpName::data), EXP_CNT, CurrScore); } } else if (TII->isMIMG(Inst)) { if (Inst.mayStore()) { - setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT, - CurrScore); + setScoreByOperand(&Inst, Inst.getOperand(0), EXP_CNT, CurrScore); } else if (SIInstrInfo::isAtomicRet(Inst)) { - setScoreByOperand(&Inst, TRI, MRI, + setScoreByOperand(&Inst, *TII->getNamedOperand(Inst, AMDGPU::OpName::data), EXP_CNT, CurrScore); } } else if (TII->isMTBUF(Inst)) { if (Inst.mayStore()) - setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT, - CurrScore); + setScoreByOperand(&Inst, Inst.getOperand(0), EXP_CNT, CurrScore); } else if (TII->isMUBUF(Inst)) { if (Inst.mayStore()) { - setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT, - CurrScore); + setScoreByOperand(&Inst, Inst.getOperand(0), EXP_CNT, CurrScore); } else if (SIInstrInfo::isAtomicRet(Inst)) { - setScoreByOperand(&Inst, TRI, MRI, + setScoreByOperand(&Inst, *TII->getNamedOperand(Inst, AMDGPU::OpName::data), EXP_CNT, CurrScore); } } else if (TII->isLDSDIR(Inst)) { // LDSDIR instructions attach the score to the destination. - setScoreByOperand(&Inst, TRI, MRI, + setScoreByOperand(&Inst, *TII->getNamedOperand(Inst, AMDGPU::OpName::vdst), EXP_CNT, CurrScore); } else { @@ -1013,18 +1003,18 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII, // score. for (MachineOperand &DefMO : Inst.all_defs()) { if (TRI->isVGPR(*MRI, DefMO.getReg())) { - setScoreByOperand(&Inst, TRI, MRI, DefMO, EXP_CNT, CurrScore); + setScoreByOperand(&Inst, DefMO, EXP_CNT, CurrScore); } } } for (const MachineOperand &Op : Inst.all_uses()) { if (TRI->isVectorRegister(*MRI, Op.getReg())) - setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore); + setScoreByOperand(&Inst, Op, EXP_CNT, CurrScore); } } } else if (T == X_CNT) { for (const MachineOperand &Op : Inst.all_uses()) - setScoreByOperand(&Inst, TRI, MRI, Op, T, CurrScore); + setScoreByOperand(&Inst, Op, T, CurrScore); } else /* LGKM_CNT || EXP_CNT || VS_CNT || NUM_INST_CNTS */ { // Match the score to the destination registers. // @@ -1036,7 +1026,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII, // Special cases where implicit register defs exists, such as M0 or VCC, // but none with memory instructions. for (const MachineOperand &Op : Inst.defs()) { - RegInterval Interval = getRegInterval(&Inst, MRI, TRI, Op); + RegInterval Interval = getRegInterval(&Inst, Op); if (T == LOAD_CNT || T == SAMPLE_CNT || T == BVH_CNT) { if (Interval.first >= NUM_ALL_VGPRS) continue; @@ -1928,7 +1918,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI, const auto &CallAddrOp = *TII->getNamedOperand(MI, AMDGPU::OpName::src0); if (CallAddrOp.isReg()) { RegInterval CallAddrOpInterval = - ScoreBrackets.getRegInterval(&MI, MRI, TRI, CallAddrOp); + ScoreBrackets.getRegInterval(&MI, CallAddrOp); ScoreBrackets.determineWait(SmemAccessCounter, CallAddrOpInterval, Wait); @@ -1936,7 +1926,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI, if (const auto *RtnAddrOp = TII->getNamedOperand(MI, AMDGPU::OpName::dst)) { RegInterval RtnAddrOpInterval = - ScoreBrackets.getRegInterval(&MI, MRI, TRI, *RtnAddrOp); + ScoreBrackets.getRegInterval(&MI, *RtnAddrOp); ScoreBrackets.determineWait(SmemAccessCounter, RtnAddrOpInterval, Wait); @@ -2000,7 +1990,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI, if (Op.isTied() && Op.isUse() && TII->doesNotReadTiedSource(MI)) continue; - RegInterval Interval = ScoreBrackets.getRegInterval(&MI, MRI, TRI, Op); + RegInterval Interval = ScoreBrackets.getRegInterval(&MI, Op); const bool IsVGPR = TRI->isVectorRegister(*MRI, Op.getReg()); if (IsVGPR) { @@ -2237,16 +2227,15 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst, if (TII->isDS(Inst) && TII->usesLGKM_CNT(Inst)) { if (TII->isAlwaysGDS(Inst.getOpcode()) || TII->hasModifiersSet(Inst, AMDGPU::OpName::gds)) { - ScoreBrackets->updateByEvent(TII, TRI, MRI, GDS_ACCESS, Inst); - ScoreBrackets->updateByEvent(TII, TRI, MRI, GDS_GPR_LOCK, Inst); + ScoreBrackets->updateByEvent(GDS_ACCESS, Inst); + ScoreBrackets->updateByEvent(GDS_GPR_LOCK, Inst); ScoreBrackets->setPendingGDS(); } else { - ScoreBrackets->updateByEvent(TII, TRI, MRI, LDS_ACCESS, Inst); + ScoreBrackets->updateByEvent(LDS_ACCESS, Inst); } } else if (TII->isFLAT(Inst)) { if (SIInstrInfo::isGFX12CacheInvOrWBInst(Inst.getOpcode())) { - ScoreBrackets->updateByEvent(TII, TRI, MRI, getVmemWaitEventType(Inst), - Inst); + ScoreBrackets->updateByEvent(getVmemWaitEventType(Inst), Inst); return; } @@ -2257,13 +2246,12 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst, if (TII->mayAccessVMEMThroughFlat(Inst)) { ++FlatASCount; IsVMEMAccess = true; - ScoreBrackets->updateByEvent(TII, TRI, MRI, getVmemWaitEventType(Inst), - Inst); + ScoreBrackets->updateByEvent(getVmemWaitEventType(Inst), Inst); } if (TII->mayAccessLDSThroughFlat(Inst)) { ++FlatASCount; - ScoreBrackets->updateByEvent(TII, TRI, MRI, LDS_ACCESS, Inst); + ScoreBrackets->updateByEvent(LDS_ACCESS, Inst); } // This is a flat memory operation that access both VMEM and LDS, so note it @@ -2274,16 +2262,15 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst, } else if (SIInstrInfo::isVMEM(Inst) && !llvm::AMDGPU::getMUBUFIsBufferInv(Inst.getOpcode())) { IsVMEMAccess = true; - ScoreBrackets->updateByEvent(TII, TRI, MRI, getVmemWaitEventType(Inst), - Inst); + ScoreBrackets->updateByEvent(getVmemWaitEventType(Inst), Inst); if (ST->vmemWriteNeedsExpWaitcnt() && (Inst.mayStore() || SIInstrInfo::isAtomicRet(Inst))) { - ScoreBrackets->updateByEvent(TII, TRI, MRI, VMW_GPR_LOCK, Inst); + ScoreBrackets->updateByEvent(VMW_GPR_LOCK, Inst); } } else if (TII->isSMRD(Inst)) { IsSMEMAccess = true; - ScoreBrackets->updateByEvent(TII, TRI, MRI, SMEM_ACCESS, Inst); + ScoreBrackets->updateByEvent(SMEM_ACCESS, Inst); } else if (Inst.isCall()) { if (callWaitsOnFunctionReturn(Inst)) { // Act as a wait on everything @@ -2295,33 +2282,33 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst, ScoreBrackets->applyWaitcnt(AMDGPU::Waitcnt()); } } else if (SIInstrInfo::isLDSDIR(Inst)) { - ScoreBrackets->updateByEvent(TII, TRI, MRI, EXP_LDS_ACCESS, Inst); + ScoreBrackets->updateByEvent(EXP_LDS_ACCESS, Inst); } else if (TII->isVINTERP(Inst)) { int64_t Imm = TII->getNamedOperand(Inst, AMDGPU::OpName::waitexp)->getImm(); ScoreBrackets->applyWaitcnt(EXP_CNT, Imm); } else if (SIInstrInfo::isEXP(Inst)) { unsigned Imm = TII->getNamedOperand(Inst, AMDGPU::OpName::tgt)->getImm(); if (Imm >= AMDGPU::Exp::ET_PARAM0 && Imm <= AMDGPU::Exp::ET_PARAM31) - ScoreBrackets->updateByEvent(TII, TRI, MRI, EXP_PARAM_ACCESS, Inst); + ScoreBrackets->updateByEvent(EXP_PARAM_ACCESS, Inst); else if (Imm >= AMDGPU::Exp::ET_POS0 && Imm <= AMDGPU::Exp::ET_POS_LAST) - ScoreBrackets->updateByEvent(TII, TRI, MRI, EXP_POS_ACCESS, Inst); + ScoreBrackets->updateByEvent(EXP_POS_ACCESS, Inst); else - ScoreBrackets->updateByEvent(TII, TRI, MRI, EXP_GPR_LOCK, Inst); + ScoreBrackets->updateByEvent(EXP_GPR_LOCK, Inst); } else if (SIInstrInfo::isSBarrierSCCWrite(Inst.getOpcode())) { - ScoreBrackets->updateByEvent(TII, TRI, MRI, SCC_WRITE, Inst); + ScoreBrackets->updateByEvent(SCC_WRITE, Inst); } else { switch (Inst.getOpcode()) { case AMDGPU::S_SENDMSG: case AMDGPU::S_SENDMSG_RTN_B32: case AMDGPU::S_SENDMSG_RTN_B64: case AMDGPU::S_SENDMSGHALT: - ScoreBrackets->updateByEvent(TII, TRI, MRI, SQ_MESSAGE, Inst); + ScoreBrackets->updateByEvent(SQ_MESSAGE, Inst); break; case AMDGPU::S_MEMTIME: case AMDGPU::S_MEMREALTIME: case AMDGPU::S_GET_BARRIER_STATE_M0: case AMDGPU::S_GET_BARRIER_STATE_IMM: - ScoreBrackets->updateByEvent(TII, TRI, MRI, SMEM_ACCESS, Inst); + ScoreBrackets->updateByEvent(SMEM_ACCESS, Inst); break; } } @@ -2330,10 +2317,10 @@ void SIInsertWaitcnts::updateEventWaitcntAfter(MachineInstr &Inst, return; if (IsVMEMAccess) - ScoreBrackets->updateByEvent(TII, TRI, MRI, VMEM_GROUP, Inst); + ScoreBrackets->updateByEvent(VMEM_GROUP, Inst); if (IsSMEMAccess) - ScoreBrackets->updateByEvent(TII, TRI, MRI, SMEM_GROUP, Inst); + ScoreBrackets->updateByEvent(SMEM_GROUP, Inst); } bool WaitcntBrackets::mergeScore(const MergeInfo &M, unsigned &Score, @@ -2637,7 +2624,7 @@ bool SIInsertWaitcnts::shouldFlushVmCnt(MachineLoop *ML, for (const MachineOperand &Op : MI.all_uses()) { if (Op.isDebug() || !TRI->isVectorRegister(*MRI, Op.getReg())) continue; - RegInterval Interval = Brackets.getRegInterval(&MI, MRI, TRI, Op); + RegInterval Interval = Brackets.getRegInterval(&MI, Op); // Vgpr use for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { // If we find a register that is loaded inside the loop, 1. and 2. @@ -2662,7 +2649,7 @@ bool SIInsertWaitcnts::shouldFlushVmCnt(MachineLoop *ML, // VMem load vgpr def if (isVMEMOrFlatVMEM(MI) && MI.mayLoad()) { for (const MachineOperand &Op : MI.all_defs()) { - RegInterval Interval = Brackets.getRegInterval(&MI, MRI, TRI, Op); + RegInterval Interval = Brackets.getRegInterval(&MI, Op); for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { // If we find a register that is loaded inside the loop, 1. and 2. // are invalidated and we can exit.