diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index 9eaef3cd37397a..5f705f2964457f 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -2490,135 +2490,24 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF, } } -// Returns true if this is the sext.w pattern, addiw rd, rs1, 0. -bool RISCV::isSEXT_W(const MachineInstr &MI) { - return MI.getOpcode() == RISCV::ADDIW && MI.getOperand(1).isReg() && - MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0; -} - -// Returns true if this is the zext.w pattern, adduw rd, rs1, x0. -bool RISCV::isZEXT_W(const MachineInstr &MI) { - return MI.getOpcode() == RISCV::ADD_UW && MI.getOperand(1).isReg() && - MI.getOperand(2).isReg() && MI.getOperand(2).getReg() == RISCV::X0; -} - -// Returns true if this is the zext.b pattern, andi rd, rs1, 255. -bool RISCV::isZEXT_B(const MachineInstr &MI) { - return MI.getOpcode() == RISCV::ANDI && MI.getOperand(1).isReg() && - MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 255; -} - -static bool isRVVWholeLoadStore(unsigned Opcode) { - switch (Opcode) { - default: - return false; - case RISCV::VS1R_V: - case RISCV::VS2R_V: - case RISCV::VS4R_V: - case RISCV::VS8R_V: - case RISCV::VL1RE8_V: - case RISCV::VL2RE8_V: - case RISCV::VL4RE8_V: - case RISCV::VL8RE8_V: - case RISCV::VL1RE16_V: - case RISCV::VL2RE16_V: - case RISCV::VL4RE16_V: - case RISCV::VL8RE16_V: - case RISCV::VL1RE32_V: - case RISCV::VL2RE32_V: - case RISCV::VL4RE32_V: - case RISCV::VL8RE32_V: - case RISCV::VL1RE64_V: - case RISCV::VL2RE64_V: - case RISCV::VL4RE64_V: - case RISCV::VL8RE64_V: - return true; - } -} - -bool RISCV::isRVVSpill(const MachineInstr &MI) { - // RVV lacks any support for immediate addressing for stack addresses, so be - // conservative. - unsigned Opcode = MI.getOpcode(); - if (!RISCVVPseudosTable::getPseudoInfo(Opcode) && - !isRVVWholeLoadStore(Opcode) && !isRVVSpillForZvlsseg(Opcode)) - return false; - return true; -} - -std::optional> -RISCV::isRVVSpillForZvlsseg(unsigned Opcode) { - switch (Opcode) { - default: - return std::nullopt; - case RISCV::PseudoVSPILL2_M1: - case RISCV::PseudoVRELOAD2_M1: - return std::make_pair(2u, 1u); - case RISCV::PseudoVSPILL2_M2: - case RISCV::PseudoVRELOAD2_M2: - return std::make_pair(2u, 2u); - case RISCV::PseudoVSPILL2_M4: - case RISCV::PseudoVRELOAD2_M4: - return std::make_pair(2u, 4u); - case RISCV::PseudoVSPILL3_M1: - case RISCV::PseudoVRELOAD3_M1: - return std::make_pair(3u, 1u); - case RISCV::PseudoVSPILL3_M2: - case RISCV::PseudoVRELOAD3_M2: - return std::make_pair(3u, 2u); - case RISCV::PseudoVSPILL4_M1: - case RISCV::PseudoVRELOAD4_M1: - return std::make_pair(4u, 1u); - case RISCV::PseudoVSPILL4_M2: - case RISCV::PseudoVRELOAD4_M2: - return std::make_pair(4u, 2u); - case RISCV::PseudoVSPILL5_M1: - case RISCV::PseudoVRELOAD5_M1: - return std::make_pair(5u, 1u); - case RISCV::PseudoVSPILL6_M1: - case RISCV::PseudoVRELOAD6_M1: - return std::make_pair(6u, 1u); - case RISCV::PseudoVSPILL7_M1: - case RISCV::PseudoVRELOAD7_M1: - return std::make_pair(7u, 1u); - case RISCV::PseudoVSPILL8_M1: - case RISCV::PseudoVRELOAD8_M1: - return std::make_pair(8u, 1u); - } -} - -bool RISCV::isFaultFirstLoad(const MachineInstr &MI) { - return MI.getNumExplicitDefs() == 2 && MI.modifiesRegister(RISCV::VL) && - !MI.isInlineAsm(); -} - -bool RISCV::hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2) { - int16_t MI1FrmOpIdx = - RISCV::getNamedOperandIdx(MI1.getOpcode(), RISCV::OpName::frm); - int16_t MI2FrmOpIdx = - RISCV::getNamedOperandIdx(MI2.getOpcode(), RISCV::OpName::frm); - if (MI1FrmOpIdx < 0 || MI2FrmOpIdx < 0) - return false; - MachineOperand FrmOp1 = MI1.getOperand(MI1FrmOpIdx); - MachineOperand FrmOp2 = MI2.getOperand(MI2FrmOpIdx); - return FrmOp1.getImm() == FrmOp2.getImm(); -} - -// Checks if all users only demand the lower word of the original instruction's -// result. +// Checks if all users only demand the lower \p OrigBits of the original +// instruction's result. // TODO: handle multiple interdependent transformations -bool RISCV::hasAllWUsers(const MachineInstr &OrigMI, - const MachineRegisterInfo &MRI) { +bool RISCVInstrInfo::hasAllNBitUsers(const MachineInstr &OrigMI, + const MachineRegisterInfo &MRI, + unsigned OrigBits) const { - SmallPtrSet Visited; - SmallVector Worklist; + SmallSet, 4> Visited; + SmallVector, 4> Worklist; - Worklist.push_back(&OrigMI); + Worklist.push_back(std::make_pair(&OrigMI, OrigBits)); while (!Worklist.empty()) { - const MachineInstr *MI = Worklist.pop_back_val(); + auto P = Worklist.pop_back_val(); + const MachineInstr *MI = P.first; + unsigned Bits = P.second; - if (!Visited.insert(MI).second) + if (!Visited.insert(P).second) continue; // Only handle instructions with one def. @@ -2654,7 +2543,6 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI, case RISCV::CTZW: case RISCV::CPOPW: case RISCV::SLLI_UW: - case RISCV::FMV_H_X: case RISCV::FMV_W_X: case RISCV::FCVT_H_W: case RISCV::FCVT_H_WU: @@ -2662,40 +2550,59 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI, case RISCV::FCVT_S_WU: case RISCV::FCVT_D_W: case RISCV::FCVT_D_WU: + if (Bits >= 32) + break; + return false; case RISCV::SEXT_B: + case RISCV::PACKH: + if (Bits >= 8) + break; + return false; case RISCV::SEXT_H: + case RISCV::FMV_H_X: + case RISCV::ZEXT_H_RV32: case RISCV::ZEXT_H_RV64: - case RISCV::PACK: - case RISCV::PACKH: case RISCV::PACKW: - break; + if (Bits >= 16) + break; + return false; + + case RISCV::PACK: + if (Bits >= (STI.getXLen() / 2)) + break; + return false; // these overwrite higher input bits, otherwise the lower word of output // depends only on the lower word of input. So check their uses read W. case RISCV::SLLI: - if (UserMI->getOperand(2).getImm() >= 32) + if (Bits >= (STI.getXLen() - UserMI->getOperand(2).getImm())) break; - Worklist.push_back(UserMI); + Worklist.push_back(std::make_pair(UserMI, Bits)); break; case RISCV::ANDI: - if (isUInt<11>(UserMI->getOperand(2).getImm())) + if (Bits >= + (64 - countLeadingZeros((uint64_t)UserMI->getOperand(2).getImm()))) break; - Worklist.push_back(UserMI); + Worklist.push_back(std::make_pair(UserMI, Bits)); break; case RISCV::ORI: - if (!isUInt<11>(UserMI->getOperand(2).getImm())) + if (Bits >= + (64 - countLeadingOnes((uint64_t)UserMI->getOperand(2).getImm()))) break; - Worklist.push_back(UserMI); + Worklist.push_back(std::make_pair(UserMI, Bits)); break; case RISCV::SLL: case RISCV::BSET: case RISCV::BCLR: case RISCV::BINV: - // Operand 2 is the shift amount which uses 6 bits. - if (OpIdx == 2) - break; - Worklist.push_back(UserMI); + // Operand 2 is the shift amount which uses log2(xlen) bits. + if (OpIdx == 2) { + if (Bits >= Log2_32(STI.getXLen())) + break; + return false; + } + Worklist.push_back(std::make_pair(UserMI, Bits)); break; case RISCV::SRA: @@ -2703,7 +2610,7 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI, case RISCV::ROL: case RISCV::ROR: // Operand 2 is the shift amount which uses 6 bits. - if (OpIdx == 2) + if (OpIdx == 2 && Bits >= Log2_32(STI.getXLen())) break; return false; @@ -2712,23 +2619,31 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI, case RISCV::SH2ADD_UW: case RISCV::SH3ADD_UW: // Operand 1 is implicitly zero extended. - if (OpIdx == 1) + if (OpIdx == 1 && Bits >= 32) break; - Worklist.push_back(UserMI); + Worklist.push_back(std::make_pair(UserMI, Bits)); break; case RISCV::BEXTI: - if (UserMI->getOperand(2).getImm() >= 32) + if (UserMI->getOperand(2).getImm() >= Bits) return false; break; case RISCV::SB: + // The first argument is the value to store. + if (OpIdx == 0 && Bits >= 8) + break; + return false; case RISCV::SH: + // The first argument is the value to store. + if (OpIdx == 0 && Bits >= 16) + break; + return false; case RISCV::SW: // The first argument is the value to store. - if (OpIdx != 0) - return false; - break; + if (OpIdx == 0 && Bits >= 32) + break; + return false; // For these, lower word of output in these operations, depends only on // the lower word of input. So, we check all uses only read lower word. @@ -2756,7 +2671,7 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI, case RISCV::BSETI: case RISCV::BCLRI: case RISCV::BINVI: - Worklist.push_back(UserMI); + Worklist.push_back(std::make_pair(UserMI, Bits)); break; case RISCV::PseudoCCMOVGPR: @@ -2765,14 +2680,14 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI, // of operand 4 and 5 is used. if (OpIdx != 4 && OpIdx != 5) return false; - Worklist.push_back(UserMI); + Worklist.push_back(std::make_pair(UserMI, Bits)); break; case RISCV::VT_MASKC: case RISCV::VT_MASKCN: if (OpIdx != 1) return false; - Worklist.push_back(UserMI); + Worklist.push_back(std::make_pair(UserMI, Bits)); break; } } @@ -2780,3 +2695,117 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI, return true; } + +// Returns true if this is the sext.w pattern, addiw rd, rs1, 0. +bool RISCV::isSEXT_W(const MachineInstr &MI) { + return MI.getOpcode() == RISCV::ADDIW && MI.getOperand(1).isReg() && + MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0; +} + +// Returns true if this is the zext.w pattern, adduw rd, rs1, x0. +bool RISCV::isZEXT_W(const MachineInstr &MI) { + return MI.getOpcode() == RISCV::ADD_UW && MI.getOperand(1).isReg() && + MI.getOperand(2).isReg() && MI.getOperand(2).getReg() == RISCV::X0; +} + +// Returns true if this is the zext.b pattern, andi rd, rs1, 255. +bool RISCV::isZEXT_B(const MachineInstr &MI) { + return MI.getOpcode() == RISCV::ANDI && MI.getOperand(1).isReg() && + MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 255; +} + +static bool isRVVWholeLoadStore(unsigned Opcode) { + switch (Opcode) { + default: + return false; + case RISCV::VS1R_V: + case RISCV::VS2R_V: + case RISCV::VS4R_V: + case RISCV::VS8R_V: + case RISCV::VL1RE8_V: + case RISCV::VL2RE8_V: + case RISCV::VL4RE8_V: + case RISCV::VL8RE8_V: + case RISCV::VL1RE16_V: + case RISCV::VL2RE16_V: + case RISCV::VL4RE16_V: + case RISCV::VL8RE16_V: + case RISCV::VL1RE32_V: + case RISCV::VL2RE32_V: + case RISCV::VL4RE32_V: + case RISCV::VL8RE32_V: + case RISCV::VL1RE64_V: + case RISCV::VL2RE64_V: + case RISCV::VL4RE64_V: + case RISCV::VL8RE64_V: + return true; + } +} + +bool RISCV::isRVVSpill(const MachineInstr &MI) { + // RVV lacks any support for immediate addressing for stack addresses, so be + // conservative. + unsigned Opcode = MI.getOpcode(); + if (!RISCVVPseudosTable::getPseudoInfo(Opcode) && + !isRVVWholeLoadStore(Opcode) && !isRVVSpillForZvlsseg(Opcode)) + return false; + return true; +} + +std::optional> +RISCV::isRVVSpillForZvlsseg(unsigned Opcode) { + switch (Opcode) { + default: + return std::nullopt; + case RISCV::PseudoVSPILL2_M1: + case RISCV::PseudoVRELOAD2_M1: + return std::make_pair(2u, 1u); + case RISCV::PseudoVSPILL2_M2: + case RISCV::PseudoVRELOAD2_M2: + return std::make_pair(2u, 2u); + case RISCV::PseudoVSPILL2_M4: + case RISCV::PseudoVRELOAD2_M4: + return std::make_pair(2u, 4u); + case RISCV::PseudoVSPILL3_M1: + case RISCV::PseudoVRELOAD3_M1: + return std::make_pair(3u, 1u); + case RISCV::PseudoVSPILL3_M2: + case RISCV::PseudoVRELOAD3_M2: + return std::make_pair(3u, 2u); + case RISCV::PseudoVSPILL4_M1: + case RISCV::PseudoVRELOAD4_M1: + return std::make_pair(4u, 1u); + case RISCV::PseudoVSPILL4_M2: + case RISCV::PseudoVRELOAD4_M2: + return std::make_pair(4u, 2u); + case RISCV::PseudoVSPILL5_M1: + case RISCV::PseudoVRELOAD5_M1: + return std::make_pair(5u, 1u); + case RISCV::PseudoVSPILL6_M1: + case RISCV::PseudoVRELOAD6_M1: + return std::make_pair(6u, 1u); + case RISCV::PseudoVSPILL7_M1: + case RISCV::PseudoVRELOAD7_M1: + return std::make_pair(7u, 1u); + case RISCV::PseudoVSPILL8_M1: + case RISCV::PseudoVRELOAD8_M1: + return std::make_pair(8u, 1u); + } +} + +bool RISCV::isFaultFirstLoad(const MachineInstr &MI) { + return MI.getNumExplicitDefs() == 2 && MI.modifiesRegister(RISCV::VL) && + !MI.isInlineAsm(); +} + +bool RISCV::hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2) { + int16_t MI1FrmOpIdx = + RISCV::getNamedOperandIdx(MI1.getOpcode(), RISCV::OpName::frm); + int16_t MI2FrmOpIdx = + RISCV::getNamedOperandIdx(MI2.getOpcode(), RISCV::OpName::frm); + if (MI1FrmOpIdx < 0 || MI2FrmOpIdx < 0) + return false; + MachineOperand FrmOp1 = MI1.getOperand(MI1FrmOpIdx); + MachineOperand FrmOp2 = MI2.getOperand(MI2FrmOpIdx); + return FrmOp1.getImm() == FrmOp2.getImm(); +} diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index e03582efc65216..c663af75a5579a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -220,6 +220,17 @@ class RISCVInstrInfo : public RISCVGenInstrInfo { std::optional getInverseOpcode(unsigned Opcode) const override; + // Returns true if all uses of OrigMI only depend on the lower \p NBits bits + // of its output. + bool hasAllNBitUsers(const MachineInstr &MI, const MachineRegisterInfo &MRI, + unsigned NBits) const; + // Returns true if all uses of OrigMI only depend on the lower word of its + // output, so we can transform OrigMI to the corresponding W-version. + bool hasAllWUsers(const MachineInstr &MI, + const MachineRegisterInfo &MRI) const { + return hasAllNBitUsers(MI, MRI, 32); + } + protected: const RISCVSubtarget &STI; }; @@ -250,9 +261,6 @@ bool hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2); // Special immediate for AVL operand of V pseudo instructions to indicate VLMax. static constexpr int64_t VLMaxSentinel = -1LL; -// Returns true if all uses of OrigMI only depend on the lower word of its -// output, so we can transform OrigMI to the corresponding W-version. -bool hasAllWUsers(const MachineInstr &MI, const MachineRegisterInfo &MRI); } // namespace RISCV namespace RISCVVPseudosTable { diff --git a/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp b/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp index 061b4defef1cfd..2ee228d728258c 100644 --- a/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp +++ b/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp @@ -95,6 +95,7 @@ static bool isSignExtendingOpW(const MachineInstr &MI, } static bool isSignExtendedW(Register SrcReg, const MachineRegisterInfo &MRI, + const RISCVInstrInfo &TII, SmallPtrSetImpl &FixableDef) { SmallPtrSet Visited; @@ -282,7 +283,7 @@ static bool isSignExtendedW(Register SrcReg, const MachineRegisterInfo &MRI, case RISCV::LWU: case RISCV::MUL: case RISCV::SUB: - if (RISCV::hasAllWUsers(*MI, MRI)) { + if (TII.hasAllWUsers(*MI, MRI)) { FixableDef.insert(MI); break; } @@ -343,8 +344,8 @@ bool RISCVSExtWRemoval::runOnMachineFunction(MachineFunction &MF) { // If all users only use the lower bits, this sext.w is redundant. // Or if all definitions reaching MI sign-extend their output, // then sext.w is redundant. - if (!RISCV::hasAllWUsers(*MI, MRI) && - !isSignExtendedW(SrcReg, MRI, FixableDefs)) + if (!TII.hasAllWUsers(*MI, MRI) && + !isSignExtendedW(SrcReg, MRI, TII, FixableDefs)) continue; Register DstReg = MI->getOperand(0).getReg(); diff --git a/llvm/lib/Target/RISCV/RISCVStripWSuffix.cpp b/llvm/lib/Target/RISCV/RISCVStripWSuffix.cpp index e818cc5459d191..14ab9c2dd6557f 100644 --- a/llvm/lib/Target/RISCV/RISCVStripWSuffix.cpp +++ b/llvm/lib/Target/RISCV/RISCVStripWSuffix.cpp @@ -72,7 +72,7 @@ bool RISCVStripWSuffix::runOnMachineFunction(MachineFunction &MF) { switch (MI.getOpcode()) { case RISCV::ADDW: case RISCV::SLLIW: - if (RISCV::hasAllWUsers(MI, MRI)) { + if (TII.hasAllWUsers(MI, MRI)) { unsigned Opc = MI.getOpcode() == RISCV::ADDW ? RISCV::ADD : RISCV::SLLI; MI.setDesc(TII.get(Opc));