Skip to content

Commit

Permalink
[RISCV] Add support for fixed vector floating point setcc.
Browse files Browse the repository at this point in the history
This is annoying because the condition code legalization belongs
to LegalizeDAG, but our custom handler runs in Legalize vector ops
which occurs earlier.

This adds some of the mask binary operations so that we can combine
multiple compares that we need for expansion.

I've also fixed up RISCVISelDAGToDAG.cpp to handle copies of masks.

This patch contains a subset of the integer setcc patch as well.
That patch is dependent on the integer binary ops patch. I'll rebase
based on what order the patches go in.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D96567
  • Loading branch information
topperc committed Feb 15, 2021
1 parent eb75f25 commit 7ba2e1c
Show file tree
Hide file tree
Showing 6 changed files with 1,652 additions and 38 deletions.
18 changes: 14 additions & 4 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Expand Up @@ -849,13 +849,18 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
break;

// Bail when normal isel should do the job.
EVT InVT = Node->getOperand(1).getValueType();
MVT InVT = Node->getOperand(1).getSimpleValueType();
if (VT.isFixedLengthVector() || InVT.isScalableVector())
break;

unsigned RegClassID;
if (VT.getVectorElementType() == MVT::i1)
RegClassID = RISCV::VRRegClassID;
else
RegClassID = getRegClassIDForLMUL(getLMUL(VT));

SDValue V = Node->getOperand(1);
SDLoc DL(V);
unsigned RegClassID = getRegClassIDForLMUL(getLMUL(VT));
SDValue RC =
CurDAG->getTargetConstant(RegClassID, DL, Subtarget->getXLenVT());
SDNode *NewNode =
Expand All @@ -869,13 +874,18 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
break;

// Bail when normal isel can do the job.
EVT InVT = Node->getOperand(0).getValueType();
MVT InVT = Node->getOperand(0).getSimpleValueType();
if (VT.isScalableVector() || InVT.isFixedLengthVector())
break;

unsigned RegClassID;
if (InVT.getVectorElementType() == MVT::i1)
RegClassID = RISCV::VRRegClassID;
else
RegClassID = getRegClassIDForLMUL(getLMUL(InVT));

SDValue V = Node->getOperand(0);
SDLoc DL(V);
unsigned RegClassID = getRegClassIDForLMUL(getLMUL(InVT.getSimpleVT()));
SDValue RC =
CurDAG->getTargetConstant(RegClassID, DL, Subtarget->getXLenVT());
SDNode *NewNode =
Expand Down
89 changes: 87 additions & 2 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -580,6 +580,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FABS, VT, Custom);
setOperationAction(ISD::FSQRT, VT, Custom);
setOperationAction(ISD::FMA, VT, Custom);

for (auto CC : VFPCCToExpand)
setCondCodeAction(CC, VT, Expand);
}
}
}
Expand Down Expand Up @@ -2137,10 +2140,89 @@ RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
SDValue VL =
DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT());

ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();

bool Invert = false;
Optional<unsigned> LogicOpc;
if (ContainerVT.isFloatingPoint()) {
bool Swap = false;
switch (CC) {
default:
break;
case ISD::SETULE:
case ISD::SETULT:
Swap = true;
LLVM_FALLTHROUGH;
case ISD::SETUGE:
case ISD::SETUGT:
CC = getSetCCInverse(CC, ContainerVT);
Invert = true;
break;
case ISD::SETOGE:
case ISD::SETOGT:
case ISD::SETGE:
case ISD::SETGT:
Swap = true;
break;
case ISD::SETUEQ:
// Use !((OLT Op1, Op2) || (OLT Op2, Op1))
Invert = true;
LogicOpc = RISCVISD::VMOR_VL;
CC = ISD::SETOLT;
break;
case ISD::SETONE:
// Use ((OLT Op1, Op2) || (OLT Op2, Op1))
LogicOpc = RISCVISD::VMOR_VL;
CC = ISD::SETOLT;
break;
case ISD::SETO:
// Use (OEQ Op1, Op1) && (OEQ Op2, Op2)
LogicOpc = RISCVISD::VMAND_VL;
CC = ISD::SETOEQ;
break;
case ISD::SETUO:
// Use (UNE Op1, Op1) || (UNE Op2, Op2)
LogicOpc = RISCVISD::VMOR_VL;
CC = ISD::SETUNE;
break;
}

if (Swap) {
CC = getSetCCSwappedOperands(CC);
std::swap(Op1, Op2);
}
}

MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
SDValue Cmp = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT, Op1, Op2,
Op.getOperand(2), Mask, VL);

// There are 3 cases we need to emit.
// 1. For (OEQ Op1, Op1) && (OEQ Op2, Op2) or (UNE Op1, Op1) || (UNE Op2, Op2)
// we need to compare each operand with itself.
// 2. For (OLT Op1, Op2) || (OLT Op2, Op1) we need to compare Op1 and Op2 in
// both orders.
// 3. For any other case we just need one compare with Op1 and Op2.
SDValue Cmp;
if (LogicOpc && (CC == ISD::SETOEQ || CC == ISD::SETUNE)) {
Cmp = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT, Op1, Op1,
DAG.getCondCode(CC), Mask, VL);
SDValue Cmp2 = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT, Op2, Op2,
DAG.getCondCode(CC), Mask, VL);
Cmp = DAG.getNode(*LogicOpc, DL, MaskVT, Cmp, Cmp2, VL);
} else {
Cmp = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT, Op1, Op2,
DAG.getCondCode(CC), Mask, VL);
if (LogicOpc) {
SDValue Cmp2 = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT, Op2, Op1,
DAG.getCondCode(CC), Mask, VL);
Cmp = DAG.getNode(*LogicOpc, DL, MaskVT, Cmp, Cmp2, VL);
}
}

if (Invert) {
SDValue AllOnes = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
Cmp = DAG.getNode(RISCVISD::VMXOR_VL, DL, MaskVT, Cmp, AllOnes, VL);
}

return convertFromScalableVector(VT, Cmp, DAG, Subtarget);
}
Expand Down Expand Up @@ -4778,6 +4860,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(UMIN_VL)
NODE_NAME_CASE(UMAX_VL)
NODE_NAME_CASE(SETCC_VL)
NODE_NAME_CASE(VMAND_VL)
NODE_NAME_CASE(VMOR_VL)
NODE_NAME_CASE(VMXOR_VL)
NODE_NAME_CASE(VMCLR_VL)
NODE_NAME_CASE(VMSET_VL)
NODE_NAME_CASE(VRGATHER_VX_VL)
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Expand Up @@ -175,6 +175,11 @@ enum NodeType : unsigned {
// operand is VL.
SETCC_VL,

// Mask binary operators.
VMAND_VL,
VMOR_VL,
VMXOR_VL,

// Set mask vector to all zeros or ones.
VMCLR_VL,
VMSET_VL,
Expand Down
39 changes: 12 additions & 27 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
Expand Up @@ -273,43 +273,28 @@ multiclass VPatIntegerSetCCSDNode_VX_VI<CondCode cc,
SplatPat_simm5, simm5, swap>;
}

multiclass VPatFPSetCCSDNode_VV<CondCode cc, string instruction_name> {
foreach fvti = AllFloatVectors in
multiclass VPatFPSetCCSDNode_VV_VF_FV<CondCode cc,
string inst_name,
string swapped_op_inst_name> {
foreach fvti = AllFloatVectors in {
def : Pat<(fvti.Mask (setcc (fvti.Vector fvti.RegClass:$rs1),
(fvti.Vector fvti.RegClass:$rs2),
cc)),
(!cast<Instruction>(instruction_name#"_VV_"#fvti.LMul.MX)
(!cast<Instruction>(inst_name#"_VV_"#fvti.LMul.MX)
fvti.RegClass:$rs1, fvti.RegClass:$rs2, fvti.AVL, fvti.SEW)>;
}

multiclass VPatFPSetCCSDNode_VF<CondCode cc, string instruction_name> {
foreach fvti = AllFloatVectors in
def : Pat<(fvti.Mask (setcc (fvti.Vector fvti.RegClass:$rs1),
(fvti.Vector (splat_vector fvti.ScalarRegClass:$rs2)),
(splat_vector fvti.ScalarRegClass:$rs2),
cc)),
(!cast<Instruction>(instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
fvti.RegClass:$rs1,
(fvti.Scalar fvti.ScalarRegClass:$rs2),
(!cast<Instruction>(inst_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
fvti.RegClass:$rs1, fvti.ScalarRegClass:$rs2,
fvti.AVL, fvti.SEW)>;
}

multiclass VPatFPSetCCSDNode_FV<CondCode cc, string swapped_op_instruction_name> {
foreach fvti = AllFloatVectors in
def : Pat<(fvti.Mask (setcc (fvti.Vector (splat_vector fvti.ScalarRegClass:$rs2)),
def : Pat<(fvti.Mask (setcc (splat_vector fvti.ScalarRegClass:$rs2),
(fvti.Vector fvti.RegClass:$rs1),
cc)),
(!cast<Instruction>(swapped_op_instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
fvti.RegClass:$rs1,
(fvti.Scalar fvti.ScalarRegClass:$rs2),
(!cast<Instruction>(swapped_op_inst_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
fvti.RegClass:$rs1, fvti.ScalarRegClass:$rs2,
fvti.AVL, fvti.SEW)>;
}

multiclass VPatFPSetCCSDNode_VV_VF_FV<CondCode cc,
string inst_name,
string swapped_op_inst_name> {
defm : VPatFPSetCCSDNode_VV<cc, inst_name>;
defm : VPatFPSetCCSDNode_VF<cc, inst_name>;
defm : VPatFPSetCCSDNode_FV<cc, swapped_op_inst_name>;
}
}

multiclass VPatExtendSDNode_V<list<SDNode> ops, string inst_name, string suffix,
Expand Down
83 changes: 78 additions & 5 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Expand Up @@ -95,8 +95,7 @@ def SDT_RISCVVecFMA_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
def riscv_fma_vl : SDNode<"RISCVISD::FMA_VL", SDT_RISCVVecFMA_VL>;

def riscv_setcc_vl : SDNode<"RISCVISD::SETCC_VL",
SDTypeProfile<1, 5, [SDTCisVec<0>,
SDTCVecEltisVT<0, i1>,
SDTypeProfile<1, 5, [SDTCVecEltisVT<0, i1>,
SDTCisVec<1>,
SDTCisSameNumEltsAs<0, 1>,
SDTCisSameAs<1, 2>,
Expand All @@ -112,8 +111,15 @@ def riscv_vrgather_vx_vl : SDNode<"RISCVISD::VRGATHER_VX_VL",
SDTCisSameNumEltsAs<0, 3>,
SDTCisVT<4, XLenVT>]>>;

def SDT_RISCVVMSETCLR_VL : SDTypeProfile<1, 1, [SDTCisVec<0>,
SDTCVecEltisVT<0, i1>,
def SDT_RISCVMaskBinOp_VL : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>,
SDTCisSameAs<0, 2>,
SDTCVecEltisVT<0, i1>,
SDTCisVT<3, XLenVT>]>;
def riscv_vmand_vl : SDNode<"RISCVISD::VMAND_VL", SDT_RISCVMaskBinOp_VL, [SDNPCommutative]>;
def riscv_vmor_vl : SDNode<"RISCVISD::VMOR_VL", SDT_RISCVMaskBinOp_VL, [SDNPCommutative]>;
def riscv_vmxor_vl : SDNode<"RISCVISD::VMXOR_VL", SDT_RISCVMaskBinOp_VL, [SDNPCommutative]>;

def SDT_RISCVVMSETCLR_VL : SDTypeProfile<1, 1, [SDTCVecEltisVT<0, i1>,
SDTCisVT<1, XLenVT>]>;
def riscv_vmclr_vl : SDNode<"RISCVISD::VMCLR_VL", SDT_RISCVVMSETCLR_VL>;
def riscv_vmset_vl : SDNode<"RISCVISD::VMSET_VL", SDT_RISCVVMSETCLR_VL>;
Expand Down Expand Up @@ -302,6 +308,36 @@ multiclass VPatIntegerSetCCVL_VI_Swappable<VTypeInfo vti, string instruction_nam
(instruction vti.RegClass:$rs1, simm5:$rs2, GPR:$vl, vti.SEW)>;
}

multiclass VPatFPSetCCVL_VV_VF_FV<CondCode cc,
string inst_name,
string swapped_op_inst_name> {
foreach fvti = AllFloatVectors in {
def : Pat<(fvti.Mask (riscv_setcc_vl (fvti.Vector fvti.RegClass:$rs1),
fvti.RegClass:$rs2,
cc,
(fvti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>(inst_name#"_VV_"#fvti.LMul.MX)
fvti.RegClass:$rs1, fvti.RegClass:$rs2, GPR:$vl, fvti.SEW)>;
def : Pat<(fvti.Mask (riscv_setcc_vl (fvti.Vector fvti.RegClass:$rs1),
(SplatFPOp fvti.ScalarRegClass:$rs2),
cc,
(fvti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>(inst_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
fvti.RegClass:$rs1, fvti.ScalarRegClass:$rs2,
GPR:$vl, fvti.SEW)>;
def : Pat<(fvti.Mask (riscv_setcc_vl (SplatFPOp fvti.ScalarRegClass:$rs2),
(fvti.Vector fvti.RegClass:$rs1),
cc,
(fvti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>(swapped_op_inst_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
fvti.RegClass:$rs1, fvti.ScalarRegClass:$rs2,
GPR:$vl, fvti.SEW)>;
}
}

//===----------------------------------------------------------------------===//
// Patterns.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -451,6 +487,21 @@ foreach vti = AllFloatVectors in {
GPR:$vl, vti.SEW)>;
}

// 14.11. Vector Floating-Point Compare Instructions
defm "" : VPatFPSetCCVL_VV_VF_FV<SETEQ, "PseudoVMFEQ", "PseudoVMFEQ">;
defm "" : VPatFPSetCCVL_VV_VF_FV<SETOEQ, "PseudoVMFEQ", "PseudoVMFEQ">;

defm "" : VPatFPSetCCVL_VV_VF_FV<SETNE, "PseudoVMFNE", "PseudoVMFNE">;
defm "" : VPatFPSetCCVL_VV_VF_FV<SETUNE, "PseudoVMFNE", "PseudoVMFNE">;

defm "" : VPatFPSetCCVL_VV_VF_FV<SETLT, "PseudoVMFLT", "PseudoVMFGT">;
defm "" : VPatFPSetCCVL_VV_VF_FV<SETOLT, "PseudoVMFLT", "PseudoVMFGT">;

defm "" : VPatFPSetCCVL_VV_VF_FV<SETLE, "PseudoVMFLE", "PseudoVMFGE">;
defm "" : VPatFPSetCCVL_VV_VF_FV<SETOLE, "PseudoVMFLE", "PseudoVMFGE">;

// 14.12. Vector Floating-Point Sign-Injection Instructions
// Handle fneg with VFSGNJN using the same input for both operands.
foreach vti = AllFloatVectors in {
// 14.8. Vector Floating-Point Square-Root Instruction
def : Pat<(riscv_fsqrt_vl (vti.Vector vti.RegClass:$rs2), (vti.Mask true_mask),
Expand Down Expand Up @@ -496,6 +547,28 @@ foreach mti = AllMasks in {
(!cast<Instruction>("PseudoVMSET_M_" # mti.BX) GPR:$vl, mti.SEW)>;
def : Pat<(mti.Mask (riscv_vmclr_vl (XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVMCLR_M_" # mti.BX) GPR:$vl, mti.SEW)>;

def : Pat<(mti.Mask (riscv_vmand_vl VR:$rs1, VR:$rs2, (XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVMAND_MM_" # mti.LMul.MX)
VR:$rs1, VR:$rs2, GPR:$vl, mti.SEW)>;
def : Pat<(mti.Mask (riscv_vmor_vl VR:$rs1, VR:$rs2, (XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVMOR_MM_" # mti.LMul.MX)
VR:$rs1, VR:$rs2, GPR:$vl, mti.SEW)>;
def : Pat<(mti.Mask (riscv_vmxor_vl VR:$rs1, VR:$rs2, (XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVMXOR_MM_" # mti.LMul.MX)
VR:$rs1, VR:$rs2, GPR:$vl, mti.SEW)>;

// FIXME: Add remaining mask instructions.
def : Pat<(mti.Mask (riscv_vmxor_vl (riscv_vmor_vl VR:$rs1, VR:$rs2,
(XLenVT (VLOp GPR:$vl))),
true_mask, (XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVMNOR_MM_" # mti.LMul.MX)
VR:$rs1, VR:$rs2, GPR:$vl, mti.SEW)>;

// Match the not idiom to the vnot.mm pseudo.
def : Pat<(mti.Mask (riscv_vmxor_vl VR:$rs, true_mask, (XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVMNAND_MM_" # mti.LMul.MX)
VR:$rs, VR:$rs, GPR:$vl, mti.SEW)>;
}

} // Predicates = [HasStdExtV]
Expand Down Expand Up @@ -540,7 +613,7 @@ foreach vti = AllFloatVectors in {
//===----------------------------------------------------------------------===//

def riscv_vid_vl : SDNode<"RISCVISD::VID_VL", SDTypeProfile<1, 2,
[SDTCisVec<0>, SDTCisVec<1>, SDTCVecEltisVT<1, i1>,
[SDTCisVec<0>, SDTCVecEltisVT<1, i1>,
SDTCisSameNumEltsAs<0, 1>, SDTCisVT<2, XLenVT>]>, []>;

def SDTRVVSlide : SDTypeProfile<1, 5, [
Expand Down

0 comments on commit 7ba2e1c

Please sign in to comment.