Skip to content

Commit

Permalink
[RISCV][GISel] Select G_BRCOND and G_ICMP together when possible.
Browse files Browse the repository at this point in the history
This allows us to fold the G_ICMP operands into the conditional branch.

This reuses the helper function we have for folding a G_ICMP into
G_SELECT.
  • Loading branch information
topperc committed Nov 12, 2023
1 parent c2205ab commit e0e0891
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 99 deletions.
185 changes: 93 additions & 92 deletions llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,92 @@ RISCVInstructionSelector::selectAddrRegImm(MachineOperand &Root) const {
[=](MachineInstrBuilder &MIB) { MIB.addImm(0); }}};
}

/// Returns the RISCVCC::CondCode that corresponds to the CmpInst::Predicate CC.
/// CC Must be an ICMP Predicate.
static RISCVCC::CondCode getRISCVCCFromICmp(CmpInst::Predicate CC) {
switch (CC) {
default:
llvm_unreachable("Expected ICMP CmpInst::Predicate.");
case CmpInst::Predicate::ICMP_EQ:
return RISCVCC::COND_EQ;
case CmpInst::Predicate::ICMP_NE:
return RISCVCC::COND_NE;
case CmpInst::Predicate::ICMP_ULT:
return RISCVCC::COND_LTU;
case CmpInst::Predicate::ICMP_SLT:
return RISCVCC::COND_LT;
case CmpInst::Predicate::ICMP_UGE:
return RISCVCC::COND_GEU;
case CmpInst::Predicate::ICMP_SGE:
return RISCVCC::COND_GE;
}
}

static void getOperandsForBranch(Register CondReg, MachineRegisterInfo &MRI,
RISCVCC::CondCode &CC, Register &LHS,
Register &RHS) {
// Try to fold an ICmp. If that fails, use a NE compare with X0.
CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
if (!mi_match(CondReg, MRI, m_GICmp(m_Pred(Pred), m_Reg(LHS), m_Reg(RHS)))) {
LHS = CondReg;
RHS = RISCV::X0;
CC = RISCVCC::COND_NE;
return;
}

// We found an ICmp, do some canonicalizations.

// Adjust comparisons to use comparison with 0 if possible.
if (auto Constant = getIConstantVRegSExtVal(RHS, MRI)) {
switch (Pred) {
case CmpInst::Predicate::ICMP_SGT:
// Convert X > -1 to X >= 0
if (*Constant == -1) {
CC = RISCVCC::COND_GE;
RHS = RISCV::X0;
return;
}
break;
case CmpInst::Predicate::ICMP_SLT:
// Convert X < 1 to 0 >= X
if (*Constant == 1) {
CC = RISCVCC::COND_GE;
RHS = LHS;
LHS = RISCV::X0;
return;
}
break;
default:
break;
}
}

switch (Pred) {
default:
llvm_unreachable("Expected ICMP CmpInst::Predicate.");
case CmpInst::Predicate::ICMP_EQ:
case CmpInst::Predicate::ICMP_NE:
case CmpInst::Predicate::ICMP_ULT:
case CmpInst::Predicate::ICMP_SLT:
case CmpInst::Predicate::ICMP_UGE:
case CmpInst::Predicate::ICMP_SGE:
// These CCs are supported directly by RISC-V branches.
break;
case CmpInst::Predicate::ICMP_SGT:
case CmpInst::Predicate::ICMP_SLE:
case CmpInst::Predicate::ICMP_UGT:
case CmpInst::Predicate::ICMP_ULE:
// These CCs are not supported directly by RISC-V branches, but changing the
// direction of the CC and swapping LHS and RHS are.
Pred = CmpInst::getSwappedPredicate(Pred);
std::swap(LHS, RHS);
break;
}

CC = getRISCVCCFromICmp(Pred);
return;
}

bool RISCVInstructionSelector::select(MachineInstr &MI) {
MachineBasicBlock &MBB = *MI.getParent();
MachineFunction &MF = *MBB.getParent();
Expand Down Expand Up @@ -398,10 +484,12 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
case TargetOpcode::G_GLOBAL_VALUE:
return selectGlobalValue(MI, MIB, MRI);
case TargetOpcode::G_BRCOND: {
// TODO: Fold with G_ICMP.
auto Bcc =
MIB.buildInstr(RISCV::BNE, {}, {MI.getOperand(0), Register(RISCV::X0)})
.addMBB(MI.getOperand(1).getMBB());
Register LHS, RHS;
RISCVCC::CondCode CC;
getOperandsForBranch(MI.getOperand(0).getReg(), MRI, CC, LHS, RHS);

auto Bcc = MIB.buildInstr(RISCVCC::getBrCond(CC), {}, {LHS, RHS})
.addMBB(MI.getOperand(1).getMBB());
MI.eraseFromParent();
return constrainSelectedInstRegOperands(*Bcc, TII, TRI, RBI);
}
Expand Down Expand Up @@ -719,101 +807,14 @@ bool RISCVInstructionSelector::selectSExtInreg(MachineInstr &MI,
return true;
}

/// Returns the RISCVCC::CondCode that corresponds to the CmpInst::Predicate CC.
/// CC Must be an ICMP Predicate.
static RISCVCC::CondCode getRISCVCCFromICMP(CmpInst::Predicate CC) {
switch (CC) {
default:
llvm_unreachable("Expected ICMP CmpInst::Predicate.");
case CmpInst::Predicate::ICMP_EQ:
return RISCVCC::COND_EQ;
case CmpInst::Predicate::ICMP_NE:
return RISCVCC::COND_NE;
case CmpInst::Predicate::ICMP_ULT:
return RISCVCC::COND_LTU;
case CmpInst::Predicate::ICMP_SLT:
return RISCVCC::COND_LT;
case CmpInst::Predicate::ICMP_UGE:
return RISCVCC::COND_GEU;
case CmpInst::Predicate::ICMP_SGE:
return RISCVCC::COND_GE;
}
}

static void getOperandsForBranch(Register CondReg, MachineIRBuilder &MIB,
MachineRegisterInfo &MRI,
RISCVCC::CondCode &CC, Register &LHS,
Register &RHS) {
// Try to fold an ICmp. If that fails, use a NE compare with X0.
CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
if (!mi_match(CondReg, MRI, m_GICmp(m_Pred(Pred), m_Reg(LHS), m_Reg(RHS)))) {
LHS = CondReg;
RHS = RISCV::X0;
CC = RISCVCC::COND_NE;
return;
}

// We found an ICmp, do some canonicalizations.

// Adjust comparisons to use comparison with 0 if possible.
if (auto Constant = getIConstantVRegSExtVal(RHS, MRI)) {
switch (Pred) {
case CmpInst::Predicate::ICMP_SGT:
// Convert X > -1 to X >= 0
if (*Constant == -1) {
CC = RISCVCC::COND_GE;
RHS = RISCV::X0;
return;
}
break;
case CmpInst::Predicate::ICMP_SLT:
// Convert X < 1 to 0 >= X
if (*Constant == 1) {
CC = RISCVCC::COND_GE;
RHS = LHS;
LHS = RISCV::X0;
return;
}
break;
default:
break;
}
}

switch (Pred) {
default:
llvm_unreachable("Expected ICMP CmpInst::Predicate.");
case CmpInst::Predicate::ICMP_EQ:
case CmpInst::Predicate::ICMP_NE:
case CmpInst::Predicate::ICMP_ULT:
case CmpInst::Predicate::ICMP_SLT:
case CmpInst::Predicate::ICMP_UGE:
case CmpInst::Predicate::ICMP_SGE:
// These CCs are supported directly by RISC-V branches.
break;
case CmpInst::Predicate::ICMP_SGT:
case CmpInst::Predicate::ICMP_SLE:
case CmpInst::Predicate::ICMP_UGT:
case CmpInst::Predicate::ICMP_ULE:
// These CCs are not supported directly by RISC-V branches, but changing the
// direction of the CC and swapping LHS and RHS are.
Pred = CmpInst::getSwappedPredicate(Pred);
std::swap(LHS, RHS);
break;
}

CC = getRISCVCCFromICMP(Pred);
return;
}

bool RISCVInstructionSelector::selectSelect(MachineInstr &MI,
MachineIRBuilder &MIB,
MachineRegisterInfo &MRI) const {
auto &SelectMI = cast<GSelect>(MI);

Register LHS, RHS;
RISCVCC::CondCode CC;
getOperandsForBranch(SelectMI.getCondReg(), MIB, MRI, CC, LHS, RHS);
getOperandsForBranch(SelectMI.getCondReg(), MRI, CC, LHS, RHS);

MachineInstr *Result = MIB.buildInstr(RISCV::Select_GPR_Using_CC_GPR)
.addDef(SelectMI.getReg(0))
Expand Down
18 changes: 11 additions & 7 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,25 +904,29 @@ static void parseCondBranch(MachineInstr &LastInst, MachineBasicBlock *&Target,
Cond.push_back(LastInst.getOperand(1));
}

const MCInstrDesc &RISCVInstrInfo::getBrCond(RISCVCC::CondCode CC) const {
unsigned RISCVCC::getBrCond(RISCVCC::CondCode CC) {
switch (CC) {
default:
llvm_unreachable("Unknown condition code!");
case RISCVCC::COND_EQ:
return get(RISCV::BEQ);
return RISCV::BEQ;
case RISCVCC::COND_NE:
return get(RISCV::BNE);
return RISCV::BNE;
case RISCVCC::COND_LT:
return get(RISCV::BLT);
return RISCV::BLT;
case RISCVCC::COND_GE:
return get(RISCV::BGE);
return RISCV::BGE;
case RISCVCC::COND_LTU:
return get(RISCV::BLTU);
return RISCV::BLTU;
case RISCVCC::COND_GEU:
return get(RISCV::BGEU);
return RISCV::BGEU;
}
}

const MCInstrDesc &RISCVInstrInfo::getBrCond(RISCVCC::CondCode CC) const {
return get(RISCVCC::getBrCond(CC));
}

RISCVCC::CondCode RISCVCC::getOppositeBranchCondition(RISCVCC::CondCode CC) {
switch (CC) {
default:
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ enum CondCode {
};

CondCode getOppositeBranchCondition(CondCode);
unsigned getBrCond(CondCode CC);

} // end of namespace RISCVCC

Expand Down

0 comments on commit e0e0891

Please sign in to comment.