Skip to content

Commit

Permalink
[RISCV] Move allWUsers from RISCVInstrInfo to RISCVOptWInstrs.
Browse files Browse the repository at this point in the history
It was only in RISCVInstrInfo because it was used by 2 passes, but those
passes have been merged in D147173.

Reviewed By: asb

Differential Revision: https://reviews.llvm.org/D147174
  • Loading branch information
topperc committed Mar 29, 2023
1 parent 4c10a61 commit 3d7fa6d
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 241 deletions.
220 changes: 0 additions & 220 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Expand Up @@ -2614,226 +2614,6 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
}
}

// Checks if all users only demand the lower \p OrigBits of the original
// instruction's result.
// TODO: handle multiple interdependent transformations
bool RISCVInstrInfo::hasAllNBitUsers(const MachineInstr &OrigMI,
const MachineRegisterInfo &MRI,
unsigned OrigBits) const {

SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;

Worklist.push_back(std::make_pair(&OrigMI, OrigBits));

while (!Worklist.empty()) {
auto P = Worklist.pop_back_val();
const MachineInstr *MI = P.first;
unsigned Bits = P.second;

if (!Visited.insert(P).second)
continue;

// Only handle instructions with one def.
if (MI->getNumExplicitDefs() != 1)
return false;

for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) {
const MachineInstr *UserMI = UserOp.getParent();
unsigned OpIdx = UserOp.getOperandNo();

switch (UserMI->getOpcode()) {
default:
return false;

case RISCV::ADDIW:
case RISCV::ADDW:
case RISCV::DIVUW:
case RISCV::DIVW:
case RISCV::MULW:
case RISCV::REMUW:
case RISCV::REMW:
case RISCV::SLLIW:
case RISCV::SLLW:
case RISCV::SRAIW:
case RISCV::SRAW:
case RISCV::SRLIW:
case RISCV::SRLW:
case RISCV::SUBW:
case RISCV::ROLW:
case RISCV::RORW:
case RISCV::RORIW:
case RISCV::CLZW:
case RISCV::CTZW:
case RISCV::CPOPW:
case RISCV::SLLI_UW:
case RISCV::FMV_W_X:
case RISCV::FCVT_H_W:
case RISCV::FCVT_H_WU:
case RISCV::FCVT_S_W:
case RISCV::FCVT_S_WU:
case RISCV::FCVT_D_W:
case RISCV::FCVT_D_WU:
if (Bits >= 32)
break;
return false;
case RISCV::SEXT_B:
case RISCV::PACKH:
if (Bits >= 8)
break;
return false;
case RISCV::SEXT_H:
case RISCV::FMV_H_X:
case RISCV::ZEXT_H_RV32:
case RISCV::ZEXT_H_RV64:
case RISCV::PACKW:
if (Bits >= 16)
break;
return false;

case RISCV::PACK:
if (Bits >= (STI.getXLen() / 2))
break;
return false;

case RISCV::SRLI: {
// If we are shifting right by less than Bits, and users don't demand
// any bits that were shifted into [Bits-1:0], then we can consider this
// as an N-Bit user.
unsigned ShAmt = UserMI->getOperand(2).getImm();
if (Bits > ShAmt) {
Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt));
break;
}
return false;
}

// these overwrite higher input bits, otherwise the lower word of output
// depends only on the lower word of input. So check their uses read W.
case RISCV::SLLI:
if (Bits >= (STI.getXLen() - UserMI->getOperand(2).getImm()))
break;
Worklist.push_back(std::make_pair(UserMI, Bits));
break;
case RISCV::ANDI: {
uint64_t Imm = UserMI->getOperand(2).getImm();
if (Bits >= (unsigned)llvm::bit_width(Imm))
break;
Worklist.push_back(std::make_pair(UserMI, Bits));
break;
}
case RISCV::ORI: {
uint64_t Imm = UserMI->getOperand(2).getImm();
if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
break;
Worklist.push_back(std::make_pair(UserMI, Bits));
break;
}

case RISCV::SLL:
case RISCV::BSET:
case RISCV::BCLR:
case RISCV::BINV:
// Operand 2 is the shift amount which uses log2(xlen) bits.
if (OpIdx == 2) {
if (Bits >= Log2_32(STI.getXLen()))
break;
return false;
}
Worklist.push_back(std::make_pair(UserMI, Bits));
break;

case RISCV::SRA:
case RISCV::SRL:
case RISCV::ROL:
case RISCV::ROR:
// Operand 2 is the shift amount which uses 6 bits.
if (OpIdx == 2 && Bits >= Log2_32(STI.getXLen()))
break;
return false;

case RISCV::ADD_UW:
case RISCV::SH1ADD_UW:
case RISCV::SH2ADD_UW:
case RISCV::SH3ADD_UW:
// Operand 1 is implicitly zero extended.
if (OpIdx == 1 && Bits >= 32)
break;
Worklist.push_back(std::make_pair(UserMI, Bits));
break;

case RISCV::BEXTI:
if (UserMI->getOperand(2).getImm() >= Bits)
return false;
break;

case RISCV::SB:
// The first argument is the value to store.
if (OpIdx == 0 && Bits >= 8)
break;
return false;
case RISCV::SH:
// The first argument is the value to store.
if (OpIdx == 0 && Bits >= 16)
break;
return false;
case RISCV::SW:
// The first argument is the value to store.
if (OpIdx == 0 && Bits >= 32)
break;
return false;

// For these, lower word of output in these operations, depends only on
// the lower word of input. So, we check all uses only read lower word.
case RISCV::COPY:
case RISCV::PHI:

case RISCV::ADD:
case RISCV::ADDI:
case RISCV::AND:
case RISCV::MUL:
case RISCV::OR:
case RISCV::SUB:
case RISCV::XOR:
case RISCV::XORI:

case RISCV::ANDN:
case RISCV::BREV8:
case RISCV::CLMUL:
case RISCV::ORC_B:
case RISCV::ORN:
case RISCV::SH1ADD:
case RISCV::SH2ADD:
case RISCV::SH3ADD:
case RISCV::XNOR:
case RISCV::BSETI:
case RISCV::BCLRI:
case RISCV::BINVI:
Worklist.push_back(std::make_pair(UserMI, Bits));
break;

case RISCV::PseudoCCMOVGPR:
// Either operand 4 or operand 5 is returned by this instruction. If
// only the lower word of the result is used, then only the lower word
// of operand 4 and 5 is used.
if (OpIdx != 4 && OpIdx != 5)
return false;
Worklist.push_back(std::make_pair(UserMI, Bits));
break;

case RISCV::VT_MASKC:
case RISCV::VT_MASKCN:
if (OpIdx != 1)
return false;
Worklist.push_back(std::make_pair(UserMI, Bits));
break;
}
}
}

return true;
}

// Returns true if this is the sext.w pattern, addiw rd, rs1, 0.
bool RISCV::isSEXT_W(const MachineInstr &MI) {
return MI.getOpcode() == RISCV::ADDIW && MI.getOperand(1).isReg() &&
Expand Down
11 changes: 0 additions & 11 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.h
Expand Up @@ -227,17 +227,6 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {

std::optional<unsigned> getInverseOpcode(unsigned Opcode) const override;

// Returns true if all uses of OrigMI only depend on the lower \p NBits bits
// of its output.
bool hasAllNBitUsers(const MachineInstr &MI, const MachineRegisterInfo &MRI,
unsigned NBits) const;
// Returns true if all uses of OrigMI only depend on the lower word of its
// output, so we can transform OrigMI to the corresponding W-version.
bool hasAllWUsers(const MachineInstr &MI,
const MachineRegisterInfo &MRI) const {
return hasAllNBitUsers(MI, MRI, 32);
}

protected:
const RISCVSubtarget &STI;
};
Expand Down

0 comments on commit 3d7fa6d

Please sign in to comment.