Skip to content

Commit

Permalink
[RISCV] Add RISCVISD nodes for vfwadd/vfwsub.
Browse files Browse the repository at this point in the history
Add a DAG combine to form these from FADD_VL/FSUB_VL and FP_EXTEND_VL.

This makes it similar to other widening ops and allows us to handle
using the same FP_EXTEND_VL for both operands.

Differential Revision: https://reviews.llvm.org/D151969
  • Loading branch information
topperc committed Jun 5, 2023
1 parent c422478 commit 4157bfb
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 61 deletions.
59 changes: 59 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11473,6 +11473,58 @@ static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG) {
Op1, Merge, Mask, VL);
}

static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG) {
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Merge = N->getOperand(2);
SDValue Mask = N->getOperand(3);
SDValue VL = N->getOperand(4);

bool IsAdd = N->getOpcode() == RISCVISD::FADD_VL;

// Look for foldable FP_EXTENDS.
bool Op0IsExtend =
Op0.getOpcode() == RISCVISD::FP_EXTEND_VL &&
(Op0.hasOneUse() || (Op0 == Op1 && Op0->hasNUsesOfValue(2, 0)));
bool Op1IsExtend =
(Op0 == Op1 && Op0IsExtend) ||
(Op1.getOpcode() == RISCVISD::FP_EXTEND_VL && Op1.hasOneUse());

// Check the mask and VL.
if (Op0IsExtend && (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL))
Op0IsExtend = false;
if (Op1IsExtend && (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL))
Op1IsExtend = false;

// Canonicalize.
if (!Op1IsExtend) {
// Sub requires at least operand 1 to be an extend.
if (!IsAdd)
return SDValue();

// Add is commutable, if the other operand is foldable, swap them.
if (!Op0IsExtend)
return SDValue();

std::swap(Op0, Op1);
std::swap(Op0IsExtend, Op1IsExtend);
}

// Op1 is a foldable extend. Op0 might be foldable.
Op1 = Op1.getOperand(0);
if (Op0IsExtend)
Op0 = Op0.getOperand(0);

unsigned Opc;
if (IsAdd)
Opc = Op0IsExtend ? RISCVISD::VFWADD_VL : RISCVISD::VFWADD_W_VL;
else
Opc = Op0IsExtend ? RISCVISD::VFWSUB_VL : RISCVISD::VFWSUB_W_VL;

return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op0, Op1, Merge, Mask,
VL);
}

static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
Expand Down Expand Up @@ -12349,6 +12401,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return performVFMADD_VLCombine(N, DAG);
case RISCVISD::FMUL_VL:
return performVFMUL_VLCombine(N, DAG);
case RISCVISD::FADD_VL:
case RISCVISD::FSUB_VL:
return performFADDSUB_VLCombine(N, DAG);
case ISD::LOAD:
case ISD::STORE: {
if (DCI.isAfterLegalizeDAG())
Expand Down Expand Up @@ -15460,6 +15515,10 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VWSUB_W_VL)
NODE_NAME_CASE(VWSUBU_W_VL)
NODE_NAME_CASE(VFWMUL_VL)
NODE_NAME_CASE(VFWADD_VL)
NODE_NAME_CASE(VFWSUB_VL)
NODE_NAME_CASE(VFWADD_W_VL)
NODE_NAME_CASE(VFWSUB_W_VL)
NODE_NAME_CASE(VNSRL_VL)
NODE_NAME_CASE(SETCC_VL)
NODE_NAME_CASE(VSELECT_VL)
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ enum NodeType : unsigned {
VWSUBU_W_VL,

VFWMUL_VL,
VFWADD_VL,
VFWSUB_VL,
VFWADD_W_VL,
VFWSUB_W_VL,

// Narrowing logical shift right.
// Operands are (source, shift, passthru, mask, vl)
Expand Down
91 changes: 30 additions & 61 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ def SDT_RISCVVWFPBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>,
SDTCVecEltisVT<4, i1>,
SDTCisVT<5, XLenVT>]>;
def riscv_vfwmul_vl : SDNode<"RISCVISD::VFWMUL_VL", SDT_RISCVVWFPBinOp_VL, [SDNPCommutative]>;
def riscv_vfwadd_vl : SDNode<"RISCVISD::VFWADD_VL", SDT_RISCVVWFPBinOp_VL, [SDNPCommutative]>;
def riscv_vfwsub_vl : SDNode<"RISCVISD::VFWSUB_VL", SDT_RISCVVWFPBinOp_VL, []>;

def SDT_RISCVVNIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
SDTCisInt<1>,
Expand All @@ -426,6 +428,19 @@ def riscv_vwaddu_w_vl : SDNode<"RISCVISD::VWADDU_W_VL", SDT_RISCVVWIntBinOpW_VL>
def riscv_vwsub_w_vl : SDNode<"RISCVISD::VWSUB_W_VL", SDT_RISCVVWIntBinOpW_VL>;
def riscv_vwsubu_w_vl : SDNode<"RISCVISD::VWSUBU_W_VL", SDT_RISCVVWIntBinOpW_VL>;

def SDT_RISCVVWFPBinOpW_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>,
SDTCisSameAs<0, 1>,
SDTCisFP<2>,
SDTCisSameNumEltsAs<1, 2>,
SDTCisOpSmallerThanOp<2, 1>,
SDTCisSameAs<0, 3>,
SDTCisSameNumEltsAs<1, 4>,
SDTCVecEltisVT<4, i1>,
SDTCisVT<5, XLenVT>]>;

def riscv_vfwadd_w_vl : SDNode<"RISCVISD::VFWADD_W_VL", SDT_RISCVVWFPBinOpW_VL>;
def riscv_vfwsub_w_vl : SDNode<"RISCVISD::VFWSUB_W_VL", SDT_RISCVVWFPBinOpW_VL>;

def SDTRVVVecReduce : SDTypeProfile<1, 6, [
SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisSameAs<0, 3>,
SDTCVecEltisVT<4, i1>, SDTCisSameNumEltsAs<2, 4>, SDTCisVT<5, XLenVT>,
Expand Down Expand Up @@ -1375,70 +1390,24 @@ multiclass VPatBinaryFPWVL_VV_VF<SDNode vop, string instruction_name> {
}
}

multiclass VPatWidenBinaryFPVL_VV_VF<SDNode op, PatFrags extop, string instruction_name> {
foreach fvtiToFWti = AllWidenableFloatVectors in {
defvar fvti = fvtiToFWti.Vti;
defvar fwti = fvtiToFWti.Wti;
let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates,
GetVTypePredicates<fwti>.Predicates) in {
def : Pat<(fwti.Vector (op (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs2),
(fvti.Mask true_mask), VLOpFrag)),
(fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs1),
(fvti.Mask true_mask), VLOpFrag)),
srcvalue, (fwti.Mask true_mask), VLOpFrag)),
(!cast<Instruction>(instruction_name#"_VV_"#fvti.LMul.MX)
fvti.RegClass:$rs2, fvti.RegClass:$rs1,
GPR:$vl, fvti.Log2SEW)>;
def : Pat<(fwti.Vector (op (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs2),
(fvti.Mask true_mask), VLOpFrag)),
(fwti.Vector (extop (fvti.Vector (SplatFPOp fvti.ScalarRegClass:$rs1)),
(fvti.Mask true_mask), VLOpFrag)),
srcvalue, (fwti.Mask true_mask), VLOpFrag)),
(!cast<Instruction>(instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
fvti.RegClass:$rs2, fvti.ScalarRegClass:$rs1,
GPR:$vl, fvti.Log2SEW)>;
}
}
}

multiclass VPatWidenBinaryFPVL_WV_WF<SDNode op, PatFrags extop, string instruction_name> {
multiclass VPatBinaryFPWVL_VV_VF_WV_WF<SDNode vop, SDNode vop_w, string instruction_name>
: VPatBinaryFPWVL_VV_VF<vop, instruction_name> {
foreach fvtiToFWti = AllWidenableFloatVectors in {
defvar fvti = fvtiToFWti.Vti;
defvar fwti = fvtiToFWti.Wti;
let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates,
GetVTypePredicates<fwti>.Predicates) in {
def : Pat<(fwti.Vector (op (fwti.Vector fwti.RegClass:$rs2),
(fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs1),
(fvti.Mask true_mask), VLOpFrag)),
srcvalue, (fwti.Mask true_mask), VLOpFrag)),
(!cast<Instruction>(instruction_name#"_WV_"#fvti.LMul.MX#"_TIED")
fwti.RegClass:$rs2, fvti.RegClass:$rs1,
GPR:$vl, fvti.Log2SEW, TAIL_AGNOSTIC)>;
// Tail undisturbed
def : Pat<(riscv_vp_merge_vl true_mask,
(fwti.Vector (op (fwti.Vector fwti.RegClass:$rs2),
(fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs1),
(fvti.Mask true_mask), VLOpFrag)),
srcvalue, (fwti.Mask true_mask), VLOpFrag)),
fwti.RegClass:$rs2, VLOpFrag),
(!cast<Instruction>(instruction_name#"_WV_"#fvti.LMul.MX#"_TIED")
fwti.RegClass:$rs2, fvti.RegClass:$rs1,
GPR:$vl, fvti.Log2SEW, TAIL_UNDISTURBED_MASK_UNDISTURBED)>;
def : Pat<(fwti.Vector (op (fwti.Vector fwti.RegClass:$rs2),
(fwti.Vector (extop (fvti.Vector (SplatFPOp fvti.ScalarRegClass:$rs1)),
(fvti.Mask true_mask), VLOpFrag)),
srcvalue, (fwti.Mask true_mask), VLOpFrag)),
(!cast<Instruction>(instruction_name#"_W"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
fwti.RegClass:$rs2, fvti.ScalarRegClass:$rs1,
GPR:$vl, fvti.Log2SEW)>;
defvar vti = fvtiToFWti.Vti;
defvar wti = fvtiToFWti.Wti;
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
defm : VPatTiedBinaryNoMaskVL_V<vop_w, instruction_name, "WV",
wti.Vector, vti.Vector, vti.Log2SEW,
vti.LMul, wti.RegClass, vti.RegClass>;
def : VPatBinaryVL_VF<vop_w, instruction_name#"_W"#vti.ScalarSuffix,
wti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, wti.RegClass, wti.RegClass,
vti.ScalarRegClass>;
}
}
}

multiclass VPatWidenBinaryFPVL_VV_VF_WV_WF<SDNode op, string instruction_name>
: VPatWidenBinaryFPVL_VV_VF<op, riscv_fpextend_vl_oneuse, instruction_name>,
VPatWidenBinaryFPVL_WV_WF<op, riscv_fpextend_vl_oneuse, instruction_name>;

multiclass VPatNarrowShiftSplatExt_WX<SDNode op, PatFrags extop, string instruction_name> {
foreach vtiToWti = AllWidenableIntVectors in {
defvar vti = vtiToWti.Vti;
Expand Down Expand Up @@ -1938,8 +1907,8 @@ defm : VPatBinaryFPVL_VV_VF<any_riscv_fsub_vl, "PseudoVFSUB">;
defm : VPatBinaryFPVL_R_VF<any_riscv_fsub_vl, "PseudoVFRSUB">;

// 13.3. Vector Widening Floating-Point Add/Subtract Instructions
defm : VPatWidenBinaryFPVL_VV_VF_WV_WF<riscv_fadd_vl, "PseudoVFWADD">;
defm : VPatWidenBinaryFPVL_VV_VF_WV_WF<riscv_fsub_vl, "PseudoVFWSUB">;
defm : VPatBinaryFPWVL_VV_VF_WV_WF<riscv_vfwadd_vl, riscv_vfwadd_w_vl, "PseudoVFWADD">;
defm : VPatBinaryFPWVL_VV_VF_WV_WF<riscv_vfwsub_vl, riscv_vfwsub_w_vl, "PseudoVFWSUB">;

// 13.4. Vector Single-Width Floating-Point Multiply/Divide Instructions
defm : VPatBinaryFPVL_VV_VF<any_riscv_fmul_vl, "PseudoVFMUL">;
Expand Down
15 changes: 15 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+experimental-zvfh | FileCheck %s

define <vscale x 2 x float> @vfwadd_same_operand(<vscale x 2 x half> %arg, i32 signext %vl) {
; CHECK-LABEL: vfwadd_same_operand:
; CHECK: # %bb.0: # %bb
; CHECK-NEXT: slli a0, a0, 32
; CHECK-NEXT: srli a0, a0, 32
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
; CHECK-NEXT: vfwadd.vv v9, v8, v8
; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
bb:
%tmp = call <vscale x 2 x float> @llvm.vp.fpext.nxv2f32.nxv2f16(<vscale x 2 x half> %arg, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i32 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer), i32 %vl)
%tmp2 = call <vscale x 2 x float> @llvm.vp.fadd.nxv2f32(<vscale x 2 x float> %tmp, <vscale x 2 x float> %tmp, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i32 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer), i32 %vl)
ret <vscale x 2 x float> %tmp2
}

define <vscale x 2 x float> @vfwadd_tu(<vscale x 2 x half> %arg, <vscale x 2 x float> %arg1, i32 signext %arg2) {
; CHECK-LABEL: vfwadd_tu:
; CHECK: # %bb.0: # %bb
Expand Down

0 comments on commit 4157bfb

Please sign in to comment.