Skip to content

Commit

Permalink
[RISCV] Introduce a RISCV CondCode enum instead of using ISD:SET* in …
Browse files Browse the repository at this point in the history
…MIR. NFC

Previously we converted ISD condition codes to integers and stored
them directly in our MIR instructions. The ISD enum kind of belongs
to SelectionDAG so that seems like incorrect layering.

This patch instead uses a CondCode node on RISCV::SELECT_CC until
isel and then converts it from ISD encoding to a RISCV specific value.
This value can be converted to/from the RISCV branch opcodes in the
RISCV namespace.

My larger motivation is to possibly support a microarchitectural
feature of some CPUs where a short forward branch over a single
instruction can be predicated internally. This will require a new
pseudo instruction for select that needs to carry a branch condition
and live probably until RISCVExpandPseudos. At that point it can be
expanded to control flow without other instructions ending up in the
predicated basic block. Using an ISD encoding in RISCVExpandPseudos
doesn't seem like correct layering.

Reviewed By: luismarques

Differential Revision: https://reviews.llvm.org/D107400
  • Loading branch information
topperc committed Aug 9, 2021
1 parent 20dfe05 commit 88bc29f
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 60 deletions.
22 changes: 22 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
Expand Up @@ -83,6 +83,28 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
void selectVSSEG(SDNode *Node, bool IsMasked, bool IsStrided);
void selectVSXSEG(SDNode *Node, bool IsMasked, bool IsOrdered);

// Return the RISC-V condition code that matches the given DAG integer
// condition code. The CondCode must be one of those supported by the RISC-V
// ISA (see translateSetCCForBranch).
static RISCVCC::CondCode getRISCVCCForIntCC(ISD::CondCode CC) {
switch (CC) {
default:
llvm_unreachable("Unsupported CondCode");
case ISD::SETEQ:
return RISCVCC::COND_EQ;
case ISD::SETNE:
return RISCVCC::COND_NE;
case ISD::SETLT:
return RISCVCC::COND_LT;
case ISD::SETGE:
return RISCVCC::COND_GE;
case ISD::SETULT:
return RISCVCC::COND_LTU;
case ISD::SETUGE:
return RISCVCC::COND_GEU;
}
}

// Include the pieces autogenerated from the target description.
#include "RISCVGenDAGISel.inc"

Expand Down
47 changes: 11 additions & 36 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -1071,28 +1071,6 @@ static void translateSetCCForBranch(const SDLoc &DL, SDValue &LHS, SDValue &RHS,
}
}

// Return the RISC-V branch opcode that matches the given DAG integer
// condition code. The CondCode must be one of those supported by the RISC-V
// ISA (see translateSetCCForBranch).
static unsigned getBranchOpcodeForIntCondCode(ISD::CondCode CC) {
switch (CC) {
default:
llvm_unreachable("Unsupported CondCode");
case ISD::SETEQ:
return RISCV::BEQ;
case ISD::SETNE:
return RISCV::BNE;
case ISD::SETLT:
return RISCV::BLT;
case ISD::SETGE:
return RISCV::BGE;
case ISD::SETULT:
return RISCV::BLTU;
case ISD::SETUGE:
return RISCV::BGEU;
}
}

RISCVII::VLMUL RISCVTargetLowering::getLMUL(MVT VT) {
assert(VT.isScalableVector() && "Expecting a scalable vector type");
unsigned KnownSize = VT.getSizeInBits().getKnownMinValue();
Expand Down Expand Up @@ -3002,7 +2980,7 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const {

translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);

SDValue TargetCC = DAG.getTargetConstant(CCVal, DL, XLenVT);
SDValue TargetCC = DAG.getCondCode(CCVal);
SDValue Ops[] = {LHS, RHS, TargetCC, TrueV, FalseV};
return DAG.getNode(RISCVISD::SELECT_CC, DL, Op.getValueType(), Ops);
}
Expand All @@ -3011,7 +2989,7 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const {
// (select condv, truev, falsev)
// -> (riscvisd::select_cc condv, zero, setne, truev, falsev)
SDValue Zero = DAG.getConstant(0, DL, XLenVT);
SDValue SetNE = DAG.getTargetConstant(ISD::SETNE, DL, XLenVT);
SDValue SetNE = DAG.getCondCode(ISD::SETNE);

SDValue Ops[] = {CondV, Zero, SetNE, TrueV, FalseV};

Expand Down Expand Up @@ -6177,7 +6155,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
// Transform
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
auto CCVal = static_cast<ISD::CondCode>(N->getConstantOperandVal(2));
ISD::CondCode CCVal = cast<CondCodeSDNode>(N->getOperand(2))->get();
if (!ISD::isIntEqualitySetCC(CCVal))
break;

Expand All @@ -6198,8 +6176,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
LHS = LHS.getOperand(0);
translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);

SDValue TargetCC =
DAG.getTargetConstant(CCVal, DL, Subtarget.getXLenVT());
SDValue TargetCC = DAG.getCondCode(CCVal);
return DAG.getNode(
RISCVISD::SELECT_CC, DL, N->getValueType(0),
{LHS, RHS, TargetCC, N->getOperand(3), N->getOperand(4)});
Expand All @@ -6219,8 +6196,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (isOneConstant(RHS) && DAG.MaskedValueIsZero(LHS, Mask)) {
SDLoc DL(N);
CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
SDValue TargetCC =
DAG.getTargetConstant(CCVal, DL, Subtarget.getXLenVT());
SDValue TargetCC = DAG.getCondCode(CCVal);
RHS = DAG.getConstant(0, DL, LHS.getValueType());
return DAG.getNode(
RISCVISD::SELECT_CC, DL, N->getValueType(0),
Expand Down Expand Up @@ -6881,7 +6857,8 @@ static bool isSelectPseudo(MachineInstr &MI) {
}

static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
MachineBasicBlock *BB) {
MachineBasicBlock *BB,
const RISCVSubtarget &Subtarget) {
// To "insert" Select_* instructions, we actually have to insert the triangle
// control-flow pattern. The incoming instructions know the destination vreg
// to set, the condition code register to branch on, the true/false values to
Expand All @@ -6908,7 +6885,7 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
// related approach and more information.
Register LHS = MI.getOperand(1).getReg();
Register RHS = MI.getOperand(2).getReg();
auto CC = static_cast<ISD::CondCode>(MI.getOperand(3).getImm());
auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm());

SmallVector<MachineInstr *, 4> SelectDebugValues;
SmallSet<Register, 4> SelectDests;
Expand Down Expand Up @@ -6941,7 +6918,7 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
}
}

const TargetInstrInfo &TII = *BB->getParent()->getSubtarget().getInstrInfo();
const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
const BasicBlock *LLVM_BB = BB->getBasicBlock();
DebugLoc DL = MI.getDebugLoc();
MachineFunction::iterator I = ++BB->getIterator();
Expand Down Expand Up @@ -6970,9 +6947,7 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
HeadMBB->addSuccessor(TailMBB);

// Insert appropriate branch.
unsigned Opcode = getBranchOpcodeForIntCondCode(CC);

BuildMI(HeadMBB, DL, TII.get(Opcode))
BuildMI(HeadMBB, DL, TII.getBrCond(CC))
.addReg(LHS)
.addReg(RHS)
.addMBB(TailMBB);
Expand Down Expand Up @@ -7017,7 +6992,7 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
case RISCV::Select_FPR16_Using_CC_GPR:
case RISCV::Select_FPR32_Using_CC_GPR:
case RISCV::Select_FPR64_Using_CC_GPR:
return emitSelectPseudo(MI, BB);
return emitSelectPseudo(MI, BB, Subtarget);
case RISCV::BuildPairF64Pseudo:
return emitBuildPairF64Pseudo(MI, BB);
case RISCV::SplitF64Pseudo:
Expand Down
76 changes: 58 additions & 18 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Expand Up @@ -469,6 +469,25 @@ void RISCVInstrInfo::movImm(MachineBasicBlock &MBB,
}
}

static RISCVCC::CondCode getCondFromBranchOpc(unsigned Opc) {
switch (Opc) {
default:
return RISCVCC::COND_INVALID;
case RISCV::BEQ:
return RISCVCC::COND_EQ;
case RISCV::BNE:
return RISCVCC::COND_NE;
case RISCV::BLT:
return RISCVCC::COND_LT;
case RISCV::BGE:
return RISCVCC::COND_GE;
case RISCV::BLTU:
return RISCVCC::COND_LTU;
case RISCV::BGEU:
return RISCVCC::COND_GEU;
}
}

// The contents of values added to Cond are not examined outside of
// RISCVInstrInfo, giving us flexibility in what to push to it. For RISCV, we
// push BranchOpcode, Reg1, Reg2.
Expand All @@ -478,27 +497,47 @@ static void parseCondBranch(MachineInstr &LastInst, MachineBasicBlock *&Target,
assert(LastInst.getDesc().isConditionalBranch() &&
"Unknown conditional branch");
Target = LastInst.getOperand(2).getMBB();
Cond.push_back(MachineOperand::CreateImm(LastInst.getOpcode()));
unsigned CC = getCondFromBranchOpc(LastInst.getOpcode());
Cond.push_back(MachineOperand::CreateImm(CC));
Cond.push_back(LastInst.getOperand(0));
Cond.push_back(LastInst.getOperand(1));
}

static unsigned getOppositeBranchOpcode(int Opc) {
switch (Opc) {
const MCInstrDesc &RISCVInstrInfo::getBrCond(RISCVCC::CondCode CC) const {
switch (CC) {
default:
llvm_unreachable("Unknown condition code!");
case RISCVCC::COND_EQ:
return get(RISCV::BEQ);
case RISCVCC::COND_NE:
return get(RISCV::BNE);
case RISCVCC::COND_LT:
return get(RISCV::BLT);
case RISCVCC::COND_GE:
return get(RISCV::BGE);
case RISCVCC::COND_LTU:
return get(RISCV::BLTU);
case RISCVCC::COND_GEU:
return get(RISCV::BGEU);
}
}

RISCVCC::CondCode RISCVCC::getOppositeBranchCondition(RISCVCC::CondCode CC) {
switch (CC) {
default:
llvm_unreachable("Unrecognized conditional branch");
case RISCV::BEQ:
return RISCV::BNE;
case RISCV::BNE:
return RISCV::BEQ;
case RISCV::BLT:
return RISCV::BGE;
case RISCV::BGE:
return RISCV::BLT;
case RISCV::BLTU:
return RISCV::BGEU;
case RISCV::BGEU:
return RISCV::BLTU;
case RISCVCC::COND_EQ:
return RISCVCC::COND_NE;
case RISCVCC::COND_NE:
return RISCVCC::COND_EQ;
case RISCVCC::COND_LT:
return RISCVCC::COND_GE;
case RISCVCC::COND_GE:
return RISCVCC::COND_LT;
case RISCVCC::COND_LTU:
return RISCVCC::COND_GEU;
case RISCVCC::COND_GEU:
return RISCVCC::COND_LTU;
}
}

Expand Down Expand Up @@ -624,9 +663,9 @@ unsigned RISCVInstrInfo::insertBranch(
}

// Either a one or two-way conditional branch.
unsigned Opc = Cond[0].getImm();
auto CC = static_cast<RISCVCC::CondCode>(Cond[0].getImm());
MachineInstr &CondMI =
*BuildMI(&MBB, DL, get(Opc)).add(Cond[1]).add(Cond[2]).addMBB(TBB);
*BuildMI(&MBB, DL, getBrCond(CC)).add(Cond[1]).add(Cond[2]).addMBB(TBB);
if (BytesAdded)
*BytesAdded += getInstSizeInBytes(CondMI);

Expand Down Expand Up @@ -680,7 +719,8 @@ unsigned RISCVInstrInfo::insertIndirectBranch(MachineBasicBlock &MBB,
bool RISCVInstrInfo::reverseBranchCondition(
SmallVectorImpl<MachineOperand> &Cond) const {
assert((Cond.size() == 3) && "Invalid branch condition!");
Cond[0].setImm(getOppositeBranchOpcode(Cond[0].getImm()));
auto CC = static_cast<RISCVCC::CondCode>(Cond[0].getImm());
Cond[0].setImm(getOppositeBranchCondition(CC));
return false;
}

Expand Down
17 changes: 17 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.h
Expand Up @@ -24,12 +24,29 @@ namespace llvm {

class RISCVSubtarget;

namespace RISCVCC {

enum CondCode {
COND_EQ,
COND_NE,
COND_LT,
COND_GE,
COND_LTU,
COND_GEU,
COND_INVALID
};

CondCode getOppositeBranchCondition(CondCode);

} // end of namespace RISCVCC

class RISCVInstrInfo : public RISCVGenInstrInfo {

public:
explicit RISCVInstrInfo(RISCVSubtarget &STI);

MCInst getNop() const override;
const MCInstrDesc &getBrCond(RISCVCC::CondCode CC) const;

unsigned isLoadFromStackSlot(const MachineInstr &MI,
int &FrameIndex) const override;
Expand Down
19 changes: 17 additions & 2 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.td
Expand Up @@ -23,6 +23,7 @@ def SDT_CallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>,
// Target-dependent type requirements.
def SDT_RISCVCall : SDTypeProfile<0, -1, [SDTCisVT<0, XLenVT>]>;
def SDT_RISCVSelectCC : SDTypeProfile<1, 5, [SDTCisSameAs<1, 2>,
SDTCisVT<3, OtherVT>,
SDTCisSameAs<0, 4>,
SDTCisSameAs<4, 5>]>;
def SDT_RISCVBrCC : SDTypeProfile<0, 4, [SDTCisSameAs<0, 1>,
Expand Down Expand Up @@ -974,13 +975,27 @@ def : Pat<(setgt GPR:$rs1, GPR:$rs2), (SLT GPR:$rs2, GPR:$rs1)>;
def : Pat<(setge GPR:$rs1, GPR:$rs2), (XORI (SLT GPR:$rs1, GPR:$rs2), 1)>;
def : Pat<(setle GPR:$rs1, GPR:$rs2), (XORI (SLT GPR:$rs2, GPR:$rs1), 1)>;

def IntCCtoRISCVCC : SDNodeXForm<riscv_selectcc, [{
ISD::CondCode CC = cast<CondCodeSDNode>(N->getOperand(2))->get();
RISCVCC::CondCode BrCC = getRISCVCCForIntCC(CC);
return CurDAG->getTargetConstant(BrCC, SDLoc(N), Subtarget->getXLenVT());
}]>;

def riscv_selectcc_frag : PatFrag<(ops node:$lhs, node:$rhs, node:$cc,
node:$truev, node:$falsev),
(riscv_selectcc node:$lhs, node:$rhs,
node:$cc, node:$truev,
node:$falsev), [{}],
IntCCtoRISCVCC>;

let usesCustomInserter = 1 in
class SelectCC_rrirr<RegisterClass valty, RegisterClass cmpty>
: Pseudo<(outs valty:$dst),
(ins cmpty:$lhs, cmpty:$rhs, ixlenimm:$imm,
valty:$truev, valty:$falsev),
[(set valty:$dst, (riscv_selectcc cmpty:$lhs, cmpty:$rhs,
(XLenVT timm:$imm), valty:$truev, valty:$falsev))]>;
[(set valty:$dst,
(riscv_selectcc_frag:$imm cmpty:$lhs, cmpty:$rhs, cond,
valty:$truev, valty:$falsev))]>;

def Select_GPR_Using_CC_GPR : SelectCC_rrirr<GPR, GPR>;

Expand Down
8 changes: 4 additions & 4 deletions llvm/test/CodeGen/RISCV/select-optimize-multiple.mir
Expand Up @@ -140,9 +140,9 @@ body: |
%0:gpr = COPY $x10
%5:gpr = ANDI %0, 1
%6:gpr = COPY $x0
%7:gpr = Select_GPR_Using_CC_GPR %5, %6, 22, %1, %2
%7:gpr = Select_GPR_Using_CC_GPR %5, %6, 1, %1, %2
%8:gpr = ADDI %7, 1
%9:gpr = Select_GPR_Using_CC_GPR %5, %6, 22, %3, %2
%9:gpr = Select_GPR_Using_CC_GPR %5, %6, 1, %3, %2
%10:gpr = ADD %7, killed %9
$x10 = COPY %10
PseudoRET implicit $x10
Expand Down Expand Up @@ -266,11 +266,11 @@ body: |
%0:gpr = COPY $x10
%5:gpr = ANDI %0, 1
%6:gpr = COPY $x0
%7:gpr = Select_GPR_Using_CC_GPR %5, %6, 22, %1, %2
%7:gpr = Select_GPR_Using_CC_GPR %5, %6, 1, %1, %2
DBG_VALUE %7, $noreg
%8:gpr = ADDI %0, 1
DBG_VALUE %8, $noreg
%9:gpr = Select_GPR_Using_CC_GPR %5, %6, 22, %3, %2
%9:gpr = Select_GPR_Using_CC_GPR %5, %6, 1, %3, %2
DBG_VALUE %9, $noreg
%10:gpr = ADD %7, killed %9
$x10 = COPY %10
Expand Down

0 comments on commit 88bc29f

Please sign in to comment.