diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 6c156057ccd7d..94d1994895325 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -2841,10 +2841,9 @@ bool RISCVDAGToDAGISel::selectSHXADD_UWOp(SDValue N, unsigned ShAmt, static bool vectorPseudoHasAllNBitUsers(SDNode *User, unsigned UserOpNo, unsigned Bits, const TargetInstrInfo *TII) { - const RISCVVPseudosTable::PseudoInfo *PseudoInfo = - RISCVVPseudosTable::getPseudoInfo(User->getMachineOpcode()); + unsigned MCOpcode = RISCV::getRVVMCOpcode(User->getMachineOpcode()); - if (!PseudoInfo) + if (!MCOpcode) return false; const MCInstrDesc &MCID = TII->get(User->getMachineOpcode()); @@ -2865,7 +2864,7 @@ static bool vectorPseudoHasAllNBitUsers(SDNode *User, unsigned UserOpNo, return false; auto NumDemandedBits = - RISCV::getVectorLowDemandedScalarBits(PseudoInfo->BaseInstr, Log2SEW); + RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW); return NumDemandedBits && Bits >= *NumDemandedBits; } @@ -3404,21 +3403,11 @@ bool RISCVDAGToDAGISel::doPeepholeMaskedRVV(MachineSDNode *N) { } static bool IsVMerge(SDNode *N) { - unsigned Opc = N->getMachineOpcode(); - return Opc == RISCV::PseudoVMERGE_VVM_MF8 || - Opc == RISCV::PseudoVMERGE_VVM_MF4 || - Opc == RISCV::PseudoVMERGE_VVM_MF2 || - Opc == RISCV::PseudoVMERGE_VVM_M1 || - Opc == RISCV::PseudoVMERGE_VVM_M2 || - Opc == RISCV::PseudoVMERGE_VVM_M4 || Opc == RISCV::PseudoVMERGE_VVM_M8; + return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMERGE_VVM; } static bool IsVMv(SDNode *N) { - unsigned Opc = N->getMachineOpcode(); - return Opc == RISCV::PseudoVMV_V_V_MF8 || Opc == RISCV::PseudoVMV_V_V_MF4 || - Opc == RISCV::PseudoVMV_V_V_MF2 || Opc == RISCV::PseudoVMV_V_V_M1 || - Opc == RISCV::PseudoVMV_V_V_M2 || Opc == RISCV::PseudoVMV_V_V_M4 || - Opc == RISCV::PseudoVMV_V_V_M8; + return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMV_V_V; } static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) { diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp index 27f281fff6fc5..f6d8b1f0a70e1 100644 --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -67,16 +67,8 @@ static bool isVLPreservingConfig(const MachineInstr &MI) { return RISCV::X0 == MI.getOperand(0).getReg(); } -static uint16_t getRVVMCOpcode(uint16_t RVVPseudoOpcode) { - const RISCVVPseudosTable::PseudoInfo *RVV = - RISCVVPseudosTable::getPseudoInfo(RVVPseudoOpcode); - if (!RVV) - return 0; - return RVV->BaseInstr; -} - static bool isFloatScalarMoveOrScalarSplatInstr(const MachineInstr &MI) { - switch (getRVVMCOpcode(MI.getOpcode())) { + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { default: return false; case RISCV::VFMV_S_F: @@ -86,7 +78,7 @@ static bool isFloatScalarMoveOrScalarSplatInstr(const MachineInstr &MI) { } static bool isScalarExtractInstr(const MachineInstr &MI) { - switch (getRVVMCOpcode(MI.getOpcode())) { + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { default: return false; case RISCV::VMV_X_S: @@ -96,7 +88,7 @@ static bool isScalarExtractInstr(const MachineInstr &MI) { } static bool isScalarInsertInstr(const MachineInstr &MI) { - switch (getRVVMCOpcode(MI.getOpcode())) { + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { default: return false; case RISCV::VMV_S_X: @@ -106,7 +98,7 @@ static bool isScalarInsertInstr(const MachineInstr &MI) { } static bool isScalarSplatInstr(const MachineInstr &MI) { - switch (getRVVMCOpcode(MI.getOpcode())) { + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { default: return false; case RISCV::VMV_V_I: @@ -117,7 +109,7 @@ static bool isScalarSplatInstr(const MachineInstr &MI) { } static bool isVSlideInstr(const MachineInstr &MI) { - switch (getRVVMCOpcode(MI.getOpcode())) { + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { default: return false; case RISCV::VSLIDEDOWN_VX: @@ -131,7 +123,7 @@ static bool isVSlideInstr(const MachineInstr &MI) { /// Get the EEW for a load or store instruction. Return std::nullopt if MI is /// not a load or store which ignores SEW. static std::optional getEEWForLoadStore(const MachineInstr &MI) { - switch (getRVVMCOpcode(MI.getOpcode())) { + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { default: return std::nullopt; case RISCV::VLE8_V: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index bfe43bae7cb12..996ef1c6f574a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -3103,3 +3103,11 @@ RISCV::getVectorLowDemandedScalarBits(uint16_t Opcode, unsigned Log2SEW) { return 1U << Log2SEW; } } + +unsigned RISCV::getRVVMCOpcode(unsigned RVVPseudoOpcode) { + const RISCVVPseudosTable::PseudoInfo *RVV = + RISCVVPseudosTable::getPseudoInfo(RVVPseudoOpcode); + if (!RVV) + return 0; + return RVV->BaseInstr; +} diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index a51dd0c316ef6..d0112a464677a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -283,6 +283,9 @@ bool hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2); std::optional getVectorLowDemandedScalarBits(uint16_t Opcode, unsigned Log2SEW); +// Returns the MC opcode of RVV pseudo instruction. +unsigned getRVVMCOpcode(unsigned RVVPseudoOpcode); + // Special immediate for AVL operand of V pseudo instructions to indicate VLMax. static constexpr int64_t VLMaxSentinel = -1LL; diff --git a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp index 63be3f88d48cc..a62c7b4bbae06 100644 --- a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp +++ b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp @@ -84,10 +84,9 @@ FunctionPass *llvm::createRISCVOptWInstrsPass() { static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp, unsigned Bits) { const MachineInstr &MI = *UserOp.getParent(); - const RISCVVPseudosTable::PseudoInfo *PseudoInfo = - RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); + unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode()); - if (!PseudoInfo) + if (!MCOpcode) return false; const MCInstrDesc &MCID = MI.getDesc(); @@ -101,7 +100,7 @@ static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp, return false; auto NumDemandedBits = - RISCV::getVectorLowDemandedScalarBits(PseudoInfo->BaseInstr, Log2SEW); + RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW); return NumDemandedBits && Bits >= *NumDemandedBits; }