diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp index d97776265ce64..25f4a217c0703 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp @@ -296,6 +296,92 @@ RISCVInstructionSelector::selectAddrRegImm(MachineOperand &Root) const { [=](MachineInstrBuilder &MIB) { MIB.addImm(0); }}}; } +/// Returns the RISCVCC::CondCode that corresponds to the CmpInst::Predicate CC. +/// CC Must be an ICMP Predicate. +static RISCVCC::CondCode getRISCVCCFromICmp(CmpInst::Predicate CC) { + switch (CC) { + default: + llvm_unreachable("Expected ICMP CmpInst::Predicate."); + case CmpInst::Predicate::ICMP_EQ: + return RISCVCC::COND_EQ; + case CmpInst::Predicate::ICMP_NE: + return RISCVCC::COND_NE; + case CmpInst::Predicate::ICMP_ULT: + return RISCVCC::COND_LTU; + case CmpInst::Predicate::ICMP_SLT: + return RISCVCC::COND_LT; + case CmpInst::Predicate::ICMP_UGE: + return RISCVCC::COND_GEU; + case CmpInst::Predicate::ICMP_SGE: + return RISCVCC::COND_GE; + } +} + +static void getOperandsForBranch(Register CondReg, MachineRegisterInfo &MRI, + RISCVCC::CondCode &CC, Register &LHS, + Register &RHS) { + // Try to fold an ICmp. If that fails, use a NE compare with X0. + CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; + if (!mi_match(CondReg, MRI, m_GICmp(m_Pred(Pred), m_Reg(LHS), m_Reg(RHS)))) { + LHS = CondReg; + RHS = RISCV::X0; + CC = RISCVCC::COND_NE; + return; + } + + // We found an ICmp, do some canonicalizations. + + // Adjust comparisons to use comparison with 0 if possible. + if (auto Constant = getIConstantVRegSExtVal(RHS, MRI)) { + switch (Pred) { + case CmpInst::Predicate::ICMP_SGT: + // Convert X > -1 to X >= 0 + if (*Constant == -1) { + CC = RISCVCC::COND_GE; + RHS = RISCV::X0; + return; + } + break; + case CmpInst::Predicate::ICMP_SLT: + // Convert X < 1 to 0 >= X + if (*Constant == 1) { + CC = RISCVCC::COND_GE; + RHS = LHS; + LHS = RISCV::X0; + return; + } + break; + default: + break; + } + } + + switch (Pred) { + default: + llvm_unreachable("Expected ICMP CmpInst::Predicate."); + case CmpInst::Predicate::ICMP_EQ: + case CmpInst::Predicate::ICMP_NE: + case CmpInst::Predicate::ICMP_ULT: + case CmpInst::Predicate::ICMP_SLT: + case CmpInst::Predicate::ICMP_UGE: + case CmpInst::Predicate::ICMP_SGE: + // These CCs are supported directly by RISC-V branches. + break; + case CmpInst::Predicate::ICMP_SGT: + case CmpInst::Predicate::ICMP_SLE: + case CmpInst::Predicate::ICMP_UGT: + case CmpInst::Predicate::ICMP_ULE: + // These CCs are not supported directly by RISC-V branches, but changing the + // direction of the CC and swapping LHS and RHS are. + Pred = CmpInst::getSwappedPredicate(Pred); + std::swap(LHS, RHS); + break; + } + + CC = getRISCVCCFromICmp(Pred); + return; +} + bool RISCVInstructionSelector::select(MachineInstr &MI) { MachineBasicBlock &MBB = *MI.getParent(); MachineFunction &MF = *MBB.getParent(); @@ -398,10 +484,12 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) { case TargetOpcode::G_GLOBAL_VALUE: return selectGlobalValue(MI, MIB, MRI); case TargetOpcode::G_BRCOND: { - // TODO: Fold with G_ICMP. - auto Bcc = - MIB.buildInstr(RISCV::BNE, {}, {MI.getOperand(0), Register(RISCV::X0)}) - .addMBB(MI.getOperand(1).getMBB()); + Register LHS, RHS; + RISCVCC::CondCode CC; + getOperandsForBranch(MI.getOperand(0).getReg(), MRI, CC, LHS, RHS); + + auto Bcc = MIB.buildInstr(RISCVCC::getBrCond(CC), {}, {LHS, RHS}) + .addMBB(MI.getOperand(1).getMBB()); MI.eraseFromParent(); return constrainSelectedInstRegOperands(*Bcc, TII, TRI, RBI); } @@ -719,93 +807,6 @@ bool RISCVInstructionSelector::selectSExtInreg(MachineInstr &MI, return true; } -/// Returns the RISCVCC::CondCode that corresponds to the CmpInst::Predicate CC. -/// CC Must be an ICMP Predicate. -static RISCVCC::CondCode getRISCVCCFromICMP(CmpInst::Predicate CC) { - switch (CC) { - default: - llvm_unreachable("Expected ICMP CmpInst::Predicate."); - case CmpInst::Predicate::ICMP_EQ: - return RISCVCC::COND_EQ; - case CmpInst::Predicate::ICMP_NE: - return RISCVCC::COND_NE; - case CmpInst::Predicate::ICMP_ULT: - return RISCVCC::COND_LTU; - case CmpInst::Predicate::ICMP_SLT: - return RISCVCC::COND_LT; - case CmpInst::Predicate::ICMP_UGE: - return RISCVCC::COND_GEU; - case CmpInst::Predicate::ICMP_SGE: - return RISCVCC::COND_GE; - } -} - -static void getOperandsForBranch(Register CondReg, MachineIRBuilder &MIB, - MachineRegisterInfo &MRI, - RISCVCC::CondCode &CC, Register &LHS, - Register &RHS) { - // Try to fold an ICmp. If that fails, use a NE compare with X0. - CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; - if (!mi_match(CondReg, MRI, m_GICmp(m_Pred(Pred), m_Reg(LHS), m_Reg(RHS)))) { - LHS = CondReg; - RHS = RISCV::X0; - CC = RISCVCC::COND_NE; - return; - } - - // We found an ICmp, do some canonicalizations. - - // Adjust comparisons to use comparison with 0 if possible. - if (auto Constant = getIConstantVRegSExtVal(RHS, MRI)) { - switch (Pred) { - case CmpInst::Predicate::ICMP_SGT: - // Convert X > -1 to X >= 0 - if (*Constant == -1) { - CC = RISCVCC::COND_GE; - RHS = RISCV::X0; - return; - } - break; - case CmpInst::Predicate::ICMP_SLT: - // Convert X < 1 to 0 >= X - if (*Constant == 1) { - CC = RISCVCC::COND_GE; - RHS = LHS; - LHS = RISCV::X0; - return; - } - break; - default: - break; - } - } - - switch (Pred) { - default: - llvm_unreachable("Expected ICMP CmpInst::Predicate."); - case CmpInst::Predicate::ICMP_EQ: - case CmpInst::Predicate::ICMP_NE: - case CmpInst::Predicate::ICMP_ULT: - case CmpInst::Predicate::ICMP_SLT: - case CmpInst::Predicate::ICMP_UGE: - case CmpInst::Predicate::ICMP_SGE: - // These CCs are supported directly by RISC-V branches. - break; - case CmpInst::Predicate::ICMP_SGT: - case CmpInst::Predicate::ICMP_SLE: - case CmpInst::Predicate::ICMP_UGT: - case CmpInst::Predicate::ICMP_ULE: - // These CCs are not supported directly by RISC-V branches, but changing the - // direction of the CC and swapping LHS and RHS are. - Pred = CmpInst::getSwappedPredicate(Pred); - std::swap(LHS, RHS); - break; - } - - CC = getRISCVCCFromICMP(Pred); - return; -} - bool RISCVInstructionSelector::selectSelect(MachineInstr &MI, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) const { @@ -813,7 +814,7 @@ bool RISCVInstructionSelector::selectSelect(MachineInstr &MI, Register LHS, RHS; RISCVCC::CondCode CC; - getOperandsForBranch(SelectMI.getCondReg(), MIB, MRI, CC, LHS, RHS); + getOperandsForBranch(SelectMI.getCondReg(), MRI, CC, LHS, RHS); MachineInstr *Result = MIB.buildInstr(RISCV::Select_GPR_Using_CC_GPR) .addDef(SelectMI.getReg(0)) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index 01f2bb9d73037..9271f807a8483 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -904,25 +904,29 @@ static void parseCondBranch(MachineInstr &LastInst, MachineBasicBlock *&Target, Cond.push_back(LastInst.getOperand(1)); } -const MCInstrDesc &RISCVInstrInfo::getBrCond(RISCVCC::CondCode CC) const { +unsigned RISCVCC::getBrCond(RISCVCC::CondCode CC) { switch (CC) { default: llvm_unreachable("Unknown condition code!"); case RISCVCC::COND_EQ: - return get(RISCV::BEQ); + return RISCV::BEQ; case RISCVCC::COND_NE: - return get(RISCV::BNE); + return RISCV::BNE; case RISCVCC::COND_LT: - return get(RISCV::BLT); + return RISCV::BLT; case RISCVCC::COND_GE: - return get(RISCV::BGE); + return RISCV::BGE; case RISCVCC::COND_LTU: - return get(RISCV::BLTU); + return RISCV::BLTU; case RISCVCC::COND_GEU: - return get(RISCV::BGEU); + return RISCV::BGEU; } } +const MCInstrDesc &RISCVInstrInfo::getBrCond(RISCVCC::CondCode CC) const { + return get(RISCVCC::getBrCond(CC)); +} + RISCVCC::CondCode RISCVCC::getOppositeBranchCondition(RISCVCC::CondCode CC) { switch (CC) { default: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index 491278c2e017e..b33d8c2856159 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -43,6 +43,7 @@ enum CondCode { }; CondCode getOppositeBranchCondition(CondCode); +unsigned getBrCond(CondCode CC); } // end of namespace RISCVCC