Skip to content

Commit

Permalink
[RISCV][VP] Add basic RVV codegen for vp.fcmp
Browse files Browse the repository at this point in the history
This patch adds the necessary infrastructure to lower vp.fcmp via
ISD::VP_SETCC to RVV instructions.

Most notably this patch adds cond-code legalization for VP_SETCC,
reusing the existing TargetLowering::LegalizeSetCCCondCode by passing in
additional SDValue parameters for the Mask and EVL. This method then
uses VP operations to legalize the condcode.

There is still a general lack of canonicalization on VP_SETCC as opposed
to SETCC which results in worse code than is theoretically possible.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D123051
  • Loading branch information
frasercrmck committed Apr 7, 2022
1 parent b8f50ab commit 8216255
Show file tree
Hide file tree
Showing 10 changed files with 3,484 additions and 41 deletions.
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Expand Up @@ -905,6 +905,11 @@ class SelectionDAG {
/// Create a logical NOT operation as (XOR Val, BooleanOne).
SDValue getLogicalNOT(const SDLoc &DL, SDValue Val, EVT VT);

/// Create a vector-predicated logical NOT operation as (VP_XOR Val,
/// BooleanOne, Mask, EVL).
SDValue getVPLogicalNOT(const SDLoc &DL, SDValue Val, SDValue Mask,
SDValue EVL, EVT VT);

/// Returns sum of the base pointer and offset.
/// Unlike getObjectPtrOffset this does not set NoUnsignedWrap by default.
SDValue getMemBasePlusOffset(SDValue Base, TypeSize Offset, const SDLoc &DL,
Expand Down
34 changes: 19 additions & 15 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Expand Up @@ -4745,28 +4745,32 @@ class TargetLowering : public TargetLoweringBase {
/// method accepts vectors as its arguments.
SDValue expandVectorSplice(SDNode *Node, SelectionDAG &DAG) const;

/// Legalize a SETCC with given LHS and RHS and condition code CC on the
/// current target.
/// Legalize a SETCC or VP_SETCC with given LHS and RHS and condition code CC
/// on the current target. A VP_SETCC will additionally be given a Mask
/// and/or EVL not equal to SDValue().
///
/// If the SETCC has been legalized using AND / OR, then the legalized node
/// will be stored in LHS. RHS and CC will be set to SDValue(). NeedInvert
/// will be set to false.
/// will be set to false. This will also hold if the VP_SETCC has been
/// legalized using VP_AND / VP_OR.
///
/// If the SETCC has been legalized by using getSetCCSwappedOperands(),
/// then the values of LHS and RHS will be swapped, CC will be set to the
/// new condition, and NeedInvert will be set to false.
/// If the SETCC / VP_SETCC has been legalized by using
/// getSetCCSwappedOperands(), then the values of LHS and RHS will be
/// swapped, CC will be set to the new condition, and NeedInvert will be set
/// to false.
///
/// If the SETCC has been legalized using the inverse condcode, then LHS and
/// RHS will be unchanged, CC will set to the inverted condcode, and
/// NeedInvert will be set to true. The caller must invert the result of the
/// SETCC with SelectionDAG::getLogicalNOT() or take equivalent action to swap
/// the effect of a true/false result.
/// If the SETCC / VP_SETCC has been legalized using the inverse condcode,
/// then LHS and RHS will be unchanged, CC will set to the inverted condcode,
/// and NeedInvert will be set to true. The caller must invert the result of
/// the SETCC with SelectionDAG::getLogicalNOT() or take equivalent action to
/// swap the effect of a true/false result.
///
/// \returns true if the SetCC has been legalized, false if it hasn't.
/// \returns true if the SETCC / VP_SETCC has been legalized, false if it
/// hasn't.
bool LegalizeSetCCCondCode(SelectionDAG &DAG, EVT VT, SDValue &LHS,
SDValue &RHS, SDValue &CC, bool &NeedInvert,
const SDLoc &dl, SDValue &Chain,
bool IsSignaling = false) const;
SDValue &RHS, SDValue &CC, SDValue Mask,
SDValue EVL, bool &NeedInvert, const SDLoc &dl,
SDValue &Chain, bool IsSignaling = false) const;

//===--------------------------------------------------------------------===//
// Instruction Emitting Hooks
Expand Down
6 changes: 3 additions & 3 deletions llvm/include/llvm/IR/VPIntrinsics.def
Expand Up @@ -302,11 +302,11 @@ HELPER_REGISTER_INT_CAST_VP(inttoptr, VP_INTTOPTR, IntToPtr)

///// Comparisons {
// llvm.vp.fcmp(x,y,cc,mask,vlen)
BEGIN_REGISTER_VP(vp_fcmp, 3, 4, VP_FCMP, -1)
BEGIN_REGISTER_VP_INTRINSIC(vp_fcmp, 3, 4)
VP_PROPERTY_FUNCTIONAL_OPC(FCmp)
VP_PROPERTY_CMP(2, true)
VP_PROPERTY_CONSTRAINEDFP(0, 1, experimental_constrained_fcmp)
END_REGISTER_VP(vp_fcmp, VP_FCMP)
END_REGISTER_VP_INTRINSIC(vp_fcmp)

// llvm.vp.icmp(x,y,cc,mask,vlen)
BEGIN_REGISTER_VP_INTRINSIC(vp_icmp, 3, 4)
Expand All @@ -315,7 +315,7 @@ VP_PROPERTY_CMP(2, false)
END_REGISTER_VP_INTRINSIC(vp_icmp)

// VP_SETCC (ISel only)
BEGIN_REGISTER_VP_SDNODE(VP_SETCC, -1, vp_setcc, 3, 4)
BEGIN_REGISTER_VP_SDNODE(VP_SETCC, 0, vp_setcc, 3, 4)
END_REGISTER_VP_SDNODE(VP_SETCC)

///// } Comparisons
Expand Down
37 changes: 27 additions & 10 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Expand Up @@ -3589,18 +3589,26 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
Results.push_back(Tmp1);
break;
case ISD::SETCC:
case ISD::VP_SETCC:
case ISD::STRICT_FSETCC:
case ISD::STRICT_FSETCCS: {
bool IsStrict = Node->getOpcode() != ISD::SETCC;
bool IsVP = Node->getOpcode() == ISD::VP_SETCC;
bool IsStrict = Node->getOpcode() == ISD::STRICT_FSETCC ||
Node->getOpcode() == ISD::STRICT_FSETCCS;
bool IsSignaling = Node->getOpcode() == ISD::STRICT_FSETCCS;
SDValue Chain = IsStrict ? Node->getOperand(0) : SDValue();
unsigned Offset = IsStrict ? 1 : 0;
Tmp1 = Node->getOperand(0 + Offset);
Tmp2 = Node->getOperand(1 + Offset);
Tmp3 = Node->getOperand(2 + Offset);
bool Legalized =
TLI.LegalizeSetCCCondCode(DAG, Node->getValueType(0), Tmp1, Tmp2, Tmp3,
NeedInvert, dl, Chain, IsSignaling);
SDValue Mask, EVL;
if (IsVP) {
Mask = Node->getOperand(3 + Offset);
EVL = Node->getOperand(4 + Offset);
}
bool Legalized = TLI.LegalizeSetCCCondCode(
DAG, Node->getValueType(0), Tmp1, Tmp2, Tmp3, Mask, EVL, NeedInvert, dl,
Chain, IsSignaling);

if (Legalized) {
// If we expanded the SETCC by swapping LHS and RHS, or by inverting the
Expand All @@ -3610,6 +3618,9 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
Tmp1 = DAG.getNode(Node->getOpcode(), dl, Node->getVTList(),
{Chain, Tmp1, Tmp2, Tmp3}, Node->getFlags());
Chain = Tmp1.getValue(1);
} else if (IsVP) {
Tmp1 = DAG.getNode(Node->getOpcode(), dl, Node->getValueType(0),
{Tmp1, Tmp2, Tmp3, Mask, EVL}, Node->getFlags());
} else {
Tmp1 = DAG.getNode(Node->getOpcode(), dl, Node->getValueType(0), Tmp1,
Tmp2, Tmp3, Node->getFlags());
Expand All @@ -3618,8 +3629,13 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {

// If we expanded the SETCC by inverting the condition code, then wrap
// the existing SETCC in a NOT to restore the intended condition.
if (NeedInvert)
Tmp1 = DAG.getLogicalNOT(dl, Tmp1, Tmp1->getValueType(0));
if (NeedInvert) {
if (!IsVP)
Tmp1 = DAG.getLogicalNOT(dl, Tmp1, Tmp1->getValueType(0));
else
Tmp1 =
DAG.getVPLogicalNOT(dl, Tmp1, Mask, EVL, Tmp1->getValueType(0));
}

Results.push_back(Tmp1);
if (IsStrict)
Expand All @@ -3634,6 +3650,7 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {

// Otherwise, SETCC for the given comparison type must be completely
// illegal; expand it into a SELECT_CC.
// FIXME: This drops the mask/evl for VP_SETCC.
EVT VT = Node->getValueType(0);
EVT Tmp1VT = Tmp1.getValueType();
Tmp1 = DAG.getNode(ISD::SELECT_CC, dl, VT, Tmp1, Tmp2,
Expand Down Expand Up @@ -3694,7 +3711,7 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
if (!Legalized) {
Legalized = TLI.LegalizeSetCCCondCode(
DAG, getSetCCResultType(Tmp1.getValueType()), Tmp1, Tmp2, CC,
NeedInvert, dl, Chain);
/*Mask*/ SDValue(), /*EVL*/ SDValue(), NeedInvert, dl, Chain);

assert(Legalized && "Can't legalize SELECT_CC with legal condition!");

Expand Down Expand Up @@ -3727,9 +3744,9 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
Tmp3 = Node->getOperand(3); // RHS
Tmp4 = Node->getOperand(1); // CC

bool Legalized =
TLI.LegalizeSetCCCondCode(DAG, getSetCCResultType(Tmp2.getValueType()),
Tmp2, Tmp3, Tmp4, NeedInvert, dl, Chain);
bool Legalized = TLI.LegalizeSetCCCondCode(
DAG, getSetCCResultType(Tmp2.getValueType()), Tmp2, Tmp3, Tmp4,
/*Mask*/ SDValue(), /*EVL*/ SDValue(), NeedInvert, dl, Chain);
(void)Legalized;
assert(Legalized && "Can't legalize BR_CC with legal condition!");

Expand Down
38 changes: 31 additions & 7 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Expand Up @@ -461,6 +461,12 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VPID: { \
EVT LegalizeVT = LEGALPOS < 0 ? Node->getValueType(-(1 + LEGALPOS)) \
: Node->getOperand(LEGALPOS).getValueType(); \
if (ISD::VPID == ISD::VP_SETCC) { \
ISD::CondCode CCCode = cast<CondCodeSDNode>(Node->getOperand(2))->get(); \
Action = TLI.getCondCodeAction(CCCode, LegalizeVT.getSimpleVT()); \
if (Action != TargetLowering::Legal) \
break; \
} \
Action = TLI.getOperationAction(Node->getOpcode(), LegalizeVT); \
} break;
#include "llvm/IR/VPIntrinsics.def"
Expand Down Expand Up @@ -744,6 +750,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
ExpandFSUB(Node, Results);
return;
case ISD::SETCC:
case ISD::VP_SETCC:
ExpandSETCC(Node, Results);
return;
case ISD::ABS:
Expand Down Expand Up @@ -1417,6 +1424,7 @@ void VectorLegalizer::ExpandFSUB(SDNode *Node,
void VectorLegalizer::ExpandSETCC(SDNode *Node,
SmallVectorImpl<SDValue> &Results) {
bool NeedInvert = false;
bool IsVP = Node->getOpcode() == ISD::VP_SETCC;
SDLoc dl(Node);
MVT OpVT = Node->getOperand(0).getSimpleValueType();
ISD::CondCode CCCode = cast<CondCodeSDNode>(Node->getOperand(2))->get();
Expand All @@ -1430,20 +1438,36 @@ void VectorLegalizer::ExpandSETCC(SDNode *Node,
SDValue LHS = Node->getOperand(0);
SDValue RHS = Node->getOperand(1);
SDValue CC = Node->getOperand(2);
bool Legalized = TLI.LegalizeSetCCCondCode(DAG, Node->getValueType(0), LHS,
RHS, CC, NeedInvert, dl, Chain);
SDValue Mask, EVL;
if (IsVP) {
Mask = Node->getOperand(3);
EVL = Node->getOperand(4);
}

bool Legalized =
TLI.LegalizeSetCCCondCode(DAG, Node->getValueType(0), LHS, RHS, CC, Mask,
EVL, NeedInvert, dl, Chain);

if (Legalized) {
// If we expanded the SETCC by swapping LHS and RHS, or by inverting the
// condition code, create a new SETCC node.
if (CC.getNode())
LHS = DAG.getNode(ISD::SETCC, dl, Node->getValueType(0), LHS, RHS, CC,
Node->getFlags());
if (CC.getNode()) {
if (!IsVP)
LHS = DAG.getNode(ISD::SETCC, dl, Node->getValueType(0), LHS, RHS, CC,
Node->getFlags());
else
LHS = DAG.getNode(ISD::VP_SETCC, dl, Node->getValueType(0),
{LHS, RHS, CC, Mask, EVL}, Node->getFlags());
}

// If we expanded the SETCC by inverting the condition code, then wrap
// the existing SETCC in a NOT to restore the intended condition.
if (NeedInvert)
LHS = DAG.getLogicalNOT(dl, LHS, LHS->getValueType(0));
if (NeedInvert) {
if (!IsVP)
LHS = DAG.getLogicalNOT(dl, LHS, LHS->getValueType(0));
else
LHS = DAG.getVPLogicalNOT(dl, LHS, Mask, EVL, LHS->getValueType(0));
}
} else {
// Otherwise, SETCC for the given comparison type must be completely
// illegal; expand it into a SELECT_CC.
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Expand Up @@ -1425,6 +1425,12 @@ SDValue SelectionDAG::getLogicalNOT(const SDLoc &DL, SDValue Val, EVT VT) {
return getNode(ISD::XOR, DL, VT, Val, TrueValue);
}

SDValue SelectionDAG::getVPLogicalNOT(const SDLoc &DL, SDValue Val,
SDValue Mask, SDValue EVL, EVT VT) {
SDValue TrueValue = getBoolConstant(true, DL, VT, VT);
return getNode(ISD::VP_XOR, DL, VT, Val, TrueValue, Mask, EVL);
}

SDValue SelectionDAG::getBoolConstant(bool V, const SDLoc &DL, EVT VT,
EVT OpVT) {
if (!V)
Expand Down
32 changes: 26 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Expand Up @@ -9199,13 +9199,16 @@ SDValue TargetLowering::expandVectorSplice(SDNode *Node,

bool TargetLowering::LegalizeSetCCCondCode(SelectionDAG &DAG, EVT VT,
SDValue &LHS, SDValue &RHS,
SDValue &CC, bool &NeedInvert,
SDValue &CC, SDValue Mask,
SDValue EVL, bool &NeedInvert,
const SDLoc &dl, SDValue &Chain,
bool IsSignaling) const {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
MVT OpVT = LHS.getSimpleValueType();
ISD::CondCode CCCode = cast<CondCodeSDNode>(CC)->get();
NeedInvert = false;
assert(!EVL == !Mask && "VP Mask and EVL must either both be set or unset");
bool IsNonVP = !EVL;
switch (TLI.getCondCodeAction(CCCode, OpVT)) {
default:
llvm_unreachable("Unknown condition code action!");
Expand Down Expand Up @@ -9312,17 +9315,34 @@ bool TargetLowering::LegalizeSetCCCondCode(SelectionDAG &DAG, EVT VT,
if (CCCode != ISD::SETO && CCCode != ISD::SETUO) {
// If we aren't the ordered or unorder operation,
// then the pattern is (LHS CC1 RHS) Opc (LHS CC2 RHS).
SetCC1 = DAG.getSetCC(dl, VT, LHS, RHS, CC1, Chain, IsSignaling);
SetCC2 = DAG.getSetCC(dl, VT, LHS, RHS, CC2, Chain, IsSignaling);
if (IsNonVP) {
SetCC1 = DAG.getSetCC(dl, VT, LHS, RHS, CC1, Chain, IsSignaling);
SetCC2 = DAG.getSetCC(dl, VT, LHS, RHS, CC2, Chain, IsSignaling);
} else {
SetCC1 = DAG.getSetCCVP(dl, VT, LHS, RHS, CC1, Mask, EVL);
SetCC2 = DAG.getSetCCVP(dl, VT, LHS, RHS, CC2, Mask, EVL);
}
} else {
// Otherwise, the pattern is (LHS CC1 LHS) Opc (RHS CC2 RHS)
SetCC1 = DAG.getSetCC(dl, VT, LHS, LHS, CC1, Chain, IsSignaling);
SetCC2 = DAG.getSetCC(dl, VT, RHS, RHS, CC2, Chain, IsSignaling);
if (IsNonVP) {
SetCC1 = DAG.getSetCC(dl, VT, LHS, LHS, CC1, Chain, IsSignaling);
SetCC2 = DAG.getSetCC(dl, VT, RHS, RHS, CC2, Chain, IsSignaling);
} else {
SetCC1 = DAG.getSetCCVP(dl, VT, LHS, LHS, CC1, Mask, EVL);
SetCC2 = DAG.getSetCCVP(dl, VT, RHS, RHS, CC2, Mask, EVL);
}
}
if (Chain)
Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, SetCC1.getValue(1),
SetCC2.getValue(1));
LHS = DAG.getNode(Opc, dl, VT, SetCC1, SetCC2);
if (IsNonVP)
LHS = DAG.getNode(Opc, dl, VT, SetCC1, SetCC2);
else {
// Transform the binary opcode to the VP equivalent.
assert((Opc == ISD::OR || Opc == ISD::AND) && "Unexpected opcode");
Opc = Opc == ISD::OR ? ISD::VP_OR : ISD::VP_AND;
LHS = DAG.getNode(Opc, dl, VT, SetCC1, SetCC2, Mask, EVL);
}
RHS = SDValue();
CC = SDValue();
return true;
Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Expand Up @@ -577,6 +577,15 @@ multiclass VPatFPSetCCVL_VV_VF_FV<CondCode cc,
VLOpFrag)),
(!cast<Instruction>(inst_name#"_VV_"#fvti.LMul.MX)
fvti.RegClass:$rs1, fvti.RegClass:$rs2, GPR:$vl, fvti.Log2SEW)>;
def : Pat<(fvti.Mask (riscv_setcc_vl (fvti.Vector fvti.RegClass:$rs1),
fvti.RegClass:$rs2,
cc,
(fvti.Mask V0),
VLOpFrag)),
(!cast<Instruction>(inst_name#"_VV_"#fvti.LMul.MX#"_MASK")
(fvti.Mask (IMPLICIT_DEF)), fvti.RegClass:$rs1,
fvti.RegClass:$rs2, (fvti.Mask V0),
GPR:$vl, fvti.Log2SEW)>;
def : Pat<(fvti.Mask (riscv_setcc_vl (fvti.Vector fvti.RegClass:$rs1),
(SplatFPOp fvti.ScalarRegClass:$rs2),
cc,
Expand All @@ -585,6 +594,15 @@ multiclass VPatFPSetCCVL_VV_VF_FV<CondCode cc,
(!cast<Instruction>(inst_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
fvti.RegClass:$rs1, fvti.ScalarRegClass:$rs2,
GPR:$vl, fvti.Log2SEW)>;
def : Pat<(fvti.Mask (riscv_setcc_vl (fvti.Vector fvti.RegClass:$rs1),
(SplatFPOp fvti.ScalarRegClass:$rs2),
cc,
(fvti.Mask V0),
VLOpFrag)),
(!cast<Instruction>(inst_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_MASK")
(fvti.Mask (IMPLICIT_DEF)), fvti.RegClass:$rs1,
fvti.ScalarRegClass:$rs2, (fvti.Mask V0),
GPR:$vl, fvti.Log2SEW)>;
def : Pat<(fvti.Mask (riscv_setcc_vl (SplatFPOp fvti.ScalarRegClass:$rs2),
(fvti.Vector fvti.RegClass:$rs1),
cc,
Expand All @@ -593,6 +611,15 @@ multiclass VPatFPSetCCVL_VV_VF_FV<CondCode cc,
(!cast<Instruction>(swapped_op_inst_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
fvti.RegClass:$rs1, fvti.ScalarRegClass:$rs2,
GPR:$vl, fvti.Log2SEW)>;
def : Pat<(fvti.Mask (riscv_setcc_vl (SplatFPOp fvti.ScalarRegClass:$rs2),
(fvti.Vector fvti.RegClass:$rs1),
cc,
(fvti.Mask V0),
VLOpFrag)),
(!cast<Instruction>(swapped_op_inst_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_MASK")
(fvti.Mask (IMPLICIT_DEF)), fvti.RegClass:$rs1,
fvti.ScalarRegClass:$rs2, (fvti.Mask V0),
GPR:$vl, fvti.Log2SEW)>;
}
}

Expand Down

0 comments on commit 8216255

Please sign in to comment.