Skip to content

Commit

Permalink
[RISCV] Add RISCVISD::BR_CC similar to RISCVISD::SELECT_CC.
Browse files Browse the repository at this point in the history
This allows me to introduce similar combines for branches as
we have recently added for SELECT_CC. Some of them are less
useful for standalone setccs and only help branch instructions.
By having a BR_CC node its easier to only affect branches.

I'm using CondCodeSDNode to make isel patterns easier to
write so we can refer to the codes by name. SELECT_CC uses a
constant instead.

I've translated the condition code just like SELECT_CC so
we need less patterns for the swapped conditions. This
includes special cases for X < 1 and X > -1 that get translated
to blez and bgez by using a 0 constant.

computeKnownBitsForTargetNode support for SELECT_CC is added
to allow MaskedValueIsZero to work for cases where the true
and false values of the SELECT_CC are setccs and the
result of the SELECT_CC is used by a BR_CC. This was needed
to avoid regressions in some of the overflow tests.

Reviewed By: luismarques

Differential Revision: https://reviews.llvm.org/D98159
  • Loading branch information
topperc committed Mar 15, 2021
1 parent f675b3d commit 41759c3
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 57 deletions.
96 changes: 83 additions & 13 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,

setOperationAction(ISD::BR_JT, MVT::Other, Expand);
setOperationAction(ISD::BR_CC, XLenVT, Expand);
setOperationAction(ISD::BRCOND, MVT::Other, Custom);
setOperationAction(ISD::SELECT_CC, XLenVT, Expand);

setOperationAction(ISD::STACKSAVE, MVT::Other, Expand);
Expand Down Expand Up @@ -676,7 +677,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// We can use any register for comparisons
setHasMultipleConditionRegisters();

setTargetDAGCombine(ISD::SETCC);
if (Subtarget.hasStdExtZbp()) {
setTargetDAGCombine(ISD::OR);
}
Expand Down Expand Up @@ -1207,6 +1207,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerGlobalTLSAddress(Op, DAG);
case ISD::SELECT:
return lowerSELECT(Op, DAG);
case ISD::BRCOND:
return lowerBRCOND(Op, DAG);
case ISD::VASTART:
return lowerVASTART(Op, DAG);
case ISD::FRAMEADDR:
Expand Down Expand Up @@ -1906,6 +1908,29 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(RISCVISD::SELECT_CC, DL, Op.getValueType(), Ops);
}

SDValue RISCVTargetLowering::lowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
SDValue CondV = Op.getOperand(1);
SDLoc DL(Op);
MVT XLenVT = Subtarget.getXLenVT();

if (CondV.getOpcode() == ISD::SETCC &&
CondV.getOperand(0).getValueType() == XLenVT) {
SDValue LHS = CondV.getOperand(0);
SDValue RHS = CondV.getOperand(1);
ISD::CondCode CCVal = cast<CondCodeSDNode>(CondV.getOperand(2))->get();

translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);

SDValue TargetCC = DAG.getCondCode(CCVal);
return DAG.getNode(RISCVISD::BR_CC, DL, Op.getValueType(), Op.getOperand(0),
LHS, RHS, TargetCC, Op.getOperand(2));
}

return DAG.getNode(RISCVISD::BR_CC, DL, Op.getValueType(), Op.getOperand(0),
CondV, DAG.getConstant(0, DL, XLenVT),
DAG.getCondCode(ISD::SETNE), Op.getOperand(2));
}

SDValue RISCVTargetLowering::lowerVASTART(SDValue Op, SelectionDAG &DAG) const {
MachineFunction &MF = DAG.getMachineFunction();
RISCVMachineFunctionInfo *FuncInfo = MF.getInfo<RISCVMachineFunctionInfo>();
Expand Down Expand Up @@ -4234,21 +4259,54 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,

break;
}
case ISD::SETCC: {
// (setcc X, 1, setne) -> (setcc X, 0, seteq) if we can prove X is 0/1.
// Comparing with 0 may allow us to fold into bnez/beqz.
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
if (LHS.getValueType().isScalableVector())
case RISCVISD::BR_CC: {
SDValue LHS = N->getOperand(1);
SDValue RHS = N->getOperand(2);
ISD::CondCode CCVal = cast<CondCodeSDNode>(N->getOperand(3))->get();
if (!ISD::isIntEqualitySetCC(CCVal))
break;
auto CC = cast<CondCodeSDNode>(N->getOperand(2))->get();

// Fold (br_cc (setlt X, Y), 0, ne, dest) ->
// (br_cc X, Y, lt, dest)
// Sometimes the setcc is introduced after br_cc has been formed.
if (LHS.getOpcode() == ISD::SETCC && isNullConstant(RHS) &&
LHS.getOperand(0).getValueType() == Subtarget.getXLenVT()) {
// If we're looking for eq 0 instead of ne 0, we need to invert the
// condition.
bool Invert = CCVal == ISD::SETEQ;
CCVal = cast<CondCodeSDNode>(LHS.getOperand(2))->get();
if (Invert)
CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());

SDLoc DL(N);
RHS = LHS.getOperand(1);
LHS = LHS.getOperand(0);
translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);

return DAG.getNode(RISCVISD::BR_CC, DL, N->getValueType(0),
N->getOperand(0), LHS, RHS, DAG.getCondCode(CCVal),
N->getOperand(4));
}

// Fold (br_cc (xor X, Y), 0, eq/ne, dest) ->
// (br_cc X, Y, eq/ne, trueV, falseV)
if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS))
return DAG.getNode(RISCVISD::BR_CC, SDLoc(N), N->getValueType(0),
N->getOperand(0), LHS.getOperand(0), LHS.getOperand(1),
N->getOperand(3), N->getOperand(4));

// (br_cc X, 1, setne, br_cc) ->
// (br_cc X, 0, seteq, br_cc) if we can prove X is 0/1.
// This can occur when legalizing some floating point comparisons.
APInt Mask = APInt::getBitsSetFrom(LHS.getValueSizeInBits(), 1);
if (isOneConstant(RHS) && ISD::isIntEqualitySetCC(CC) &&
DAG.MaskedValueIsZero(LHS, Mask)) {
if (isOneConstant(RHS) && DAG.MaskedValueIsZero(LHS, Mask)) {
SDLoc DL(N);
SDValue Zero = DAG.getConstant(0, DL, LHS.getValueType());
CC = ISD::getSetCCInverse(CC, LHS.getValueType());
return DAG.getSetCC(DL, N->getValueType(0), LHS, Zero, CC);
CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
SDValue TargetCC = DAG.getCondCode(CCVal);
RHS = DAG.getConstant(0, DL, LHS.getValueType());
return DAG.getNode(RISCVISD::BR_CC, DL, N->getValueType(0),
N->getOperand(0), LHS, RHS, TargetCC,
N->getOperand(4));
}
break;
}
Expand Down Expand Up @@ -4409,6 +4467,17 @@ void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
Known.resetAll();
switch (Opc) {
default: break;
case RISCVISD::SELECT_CC: {
Known = DAG.computeKnownBits(Op.getOperand(4), Depth + 1);
// If we don't know any bits, early out.
if (Known.isUnknown())
break;
KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(3), Depth + 1);

// Only known if known in both the LHS and RHS.
Known = KnownBits::commonBits(Known, Known2);
break;
}
case RISCVISD::REMUW: {
KnownBits Known2;
Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Expand Down Expand Up @@ -6155,6 +6224,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(MRET_FLAG)
NODE_NAME_CASE(CALL)
NODE_NAME_CASE(SELECT_CC)
NODE_NAME_CASE(BR_CC)
NODE_NAME_CASE(BuildPairF64)
NODE_NAME_CASE(SplitF64)
NODE_NAME_CASE(TAIL)
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ enum NodeType : unsigned {
/// The lhs and rhs are XLenVT integers. The true and false values can be
/// integer or floating point.
SELECT_CC,
BR_CC,
BuildPairF64,
SplitF64,
TAIL,
Expand Down Expand Up @@ -452,6 +453,7 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerBRCOND(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVASTART(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFRAMEADDR(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const;
Expand Down
50 changes: 14 additions & 36 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def SDT_RISCVCall : SDTypeProfile<0, -1, [SDTCisVT<0, XLenVT>]>;
def SDT_RISCVSelectCC : SDTypeProfile<1, 5, [SDTCisSameAs<1, 2>,
SDTCisSameAs<0, 4>,
SDTCisSameAs<4, 5>]>;
def SDT_RISCVBrCC : SDTypeProfile<0, 4, [SDTCisSameAs<0, 1>,
SDTCisVT<2, OtherVT>,
SDTCisVT<3, OtherVT>]>;
def SDT_RISCVReadCycleWide : SDTypeProfile<2, 0, [SDTCisVT<0, i32>,
SDTCisVT<1, i32>]>;
def SDT_RISCVIntBinOpW : SDTypeProfile<1, 2, [
Expand All @@ -50,6 +53,8 @@ def riscv_sret_flag : SDNode<"RISCVISD::SRET_FLAG", SDTNone,
def riscv_mret_flag : SDNode<"RISCVISD::MRET_FLAG", SDTNone,
[SDNPHasChain, SDNPOptInGlue]>;
def riscv_selectcc : SDNode<"RISCVISD::SELECT_CC", SDT_RISCVSelectCC>;
def riscv_brcc : SDNode<"RISCVISD::BR_CC", SDT_RISCVBrCC,
[SDNPHasChain]>;
def riscv_tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;
Expand Down Expand Up @@ -961,44 +966,17 @@ def Select_GPR_Using_CC_GPR : SelectCC_rrirr<GPR, GPR>;

/// Branches and jumps

// Match `(brcond (CondOp ..), ..)` and lower to the appropriate RISC-V branch
// instruction.
class BccPat<PatFrag CondOp, RVInstB Inst>
: Pat<(brcond (XLenVT (CondOp GPR:$rs1, GPR:$rs2)), bb:$imm12),
// Match `riscv_brcc` and lower to the appropriate RISC-V branch instruction.
class BccPat<CondCode Cond, RVInstB Inst>
: Pat<(riscv_brcc GPR:$rs1, GPR:$rs2, Cond, bb:$imm12),
(Inst GPR:$rs1, GPR:$rs2, simm13_lsb0:$imm12)>;

def : BccPat<seteq, BEQ>;
def : BccPat<setne, BNE>;
def : BccPat<setlt, BLT>;
def : BccPat<setge, BGE>;
def : BccPat<setult, BLTU>;
def : BccPat<setuge, BGEU>;

class BccSwapPat<PatFrag CondOp, RVInst InstBcc>
: Pat<(brcond (XLenVT (CondOp GPR:$rs1, GPR:$rs2)), bb:$imm12),
(InstBcc GPR:$rs2, GPR:$rs1, bb:$imm12)>;

// Condition codes that don't have matching RISC-V branch instructions, but
// are trivially supported by swapping the two input operands
def : BccSwapPat<setgt, BLT>;
def : BccSwapPat<setle, BGE>;
def : BccSwapPat<setugt, BLTU>;
def : BccSwapPat<setule, BGEU>;

// Extra patterns are needed for a brcond without a setcc (i.e. where the
// condition was calculated elsewhere).
def : Pat<(brcond GPR:$cond, bb:$imm12), (BNE GPR:$cond, X0, bb:$imm12)>;
// In this pattern, the `(xor $cond, 1)` functions like (boolean) `not`, as the
// `brcond` only uses the lowest bit.
def : Pat<(brcond (XLenVT (xor GPR:$cond, 1)), bb:$imm12),
(BEQ GPR:$cond, X0, bb:$imm12)>;

// Match X > -1, the canonical form of X >= 0, to the bgez pattern.
def : Pat<(brcond (XLenVT (setgt GPR:$rs1, -1)), bb:$imm12),
(BGE GPR:$rs1, X0, bb:$imm12)>;
// Lower (a < 1) as (0 >= a) into the blez pattern.
def : Pat<(brcond (XLenVT (setlt GPR:$lhs, 1)), bb:$imm12),
(BGE X0, GPR:$lhs, bb:$imm12)>;
def : BccPat<SETEQ, BEQ>;
def : BccPat<SETNE, BNE>;
def : BccPat<SETLT, BLT>;
def : BccPat<SETGE, BGE>;
def : BccPat<SETULT, BLTU>;
def : BccPat<SETUGE, BGEU>;

let isBarrier = 1, isBranch = 1, isTerminator = 1 in
def PseudoBR : Pseudo<(outs), (ins simm21_lsb0_jal:$imm20), [(br bb:$imm20)]>,
Expand Down
12 changes: 4 additions & 8 deletions llvm/test/CodeGen/RISCV/xaluo.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1423,8 +1423,7 @@ define zeroext i1 @saddo.br.i32(i32 %v1, i32 %v2) {
; RV32-NEXT: add a2, a0, a1
; RV32-NEXT: slt a0, a2, a0
; RV32-NEXT: slti a1, a1, 0
; RV32-NEXT: xor a0, a1, a0
; RV32-NEXT: beqz a0, .LBB46_2
; RV32-NEXT: beq a1, a0, .LBB46_2
; RV32-NEXT: # %bb.1: # %overflow
; RV32-NEXT: mv a0, zero
; RV32-NEXT: ret
Expand Down Expand Up @@ -1482,8 +1481,7 @@ define zeroext i1 @saddo.br.i64(i64 %v1, i64 %v2) {
; RV64-NEXT: add a2, a0, a1
; RV64-NEXT: slt a0, a2, a0
; RV64-NEXT: slti a1, a1, 0
; RV64-NEXT: xor a0, a1, a0
; RV64-NEXT: beqz a0, .LBB47_2
; RV64-NEXT: beq a1, a0, .LBB47_2
; RV64-NEXT: # %bb.1: # %overflow
; RV64-NEXT: mv a0, zero
; RV64-NEXT: ret
Expand Down Expand Up @@ -1587,8 +1585,7 @@ define zeroext i1 @ssubo.br.i32(i32 %v1, i32 %v2) {
; RV32-NEXT: sgtz a2, a1
; RV32-NEXT: sub a1, a0, a1
; RV32-NEXT: slt a0, a1, a0
; RV32-NEXT: xor a0, a2, a0
; RV32-NEXT: beqz a0, .LBB50_2
; RV32-NEXT: beq a2, a0, .LBB50_2
; RV32-NEXT: # %bb.1: # %overflow
; RV32-NEXT: mv a0, zero
; RV32-NEXT: ret
Expand Down Expand Up @@ -1644,8 +1641,7 @@ define zeroext i1 @ssubo.br.i64(i64 %v1, i64 %v2) {
; RV64-NEXT: sgtz a2, a1
; RV64-NEXT: sub a1, a0, a1
; RV64-NEXT: slt a0, a1, a0
; RV64-NEXT: xor a0, a2, a0
; RV64-NEXT: beqz a0, .LBB51_2
; RV64-NEXT: beq a2, a0, .LBB51_2
; RV64-NEXT: # %bb.1: # %overflow
; RV64-NEXT: mv a0, zero
; RV64-NEXT: ret
Expand Down

0 comments on commit 41759c3

Please sign in to comment.