diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index d33c48397536a..e8b6560036f08 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -11473,6 +11473,58 @@ static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG) { Op1, Merge, Mask, VL); } +static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG) { + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + SDValue Merge = N->getOperand(2); + SDValue Mask = N->getOperand(3); + SDValue VL = N->getOperand(4); + + bool IsAdd = N->getOpcode() == RISCVISD::FADD_VL; + + // Look for foldable FP_EXTENDS. + bool Op0IsExtend = + Op0.getOpcode() == RISCVISD::FP_EXTEND_VL && + (Op0.hasOneUse() || (Op0 == Op1 && Op0->hasNUsesOfValue(2, 0))); + bool Op1IsExtend = + (Op0 == Op1 && Op0IsExtend) || + (Op1.getOpcode() == RISCVISD::FP_EXTEND_VL && Op1.hasOneUse()); + + // Check the mask and VL. + if (Op0IsExtend && (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL)) + Op0IsExtend = false; + if (Op1IsExtend && (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)) + Op1IsExtend = false; + + // Canonicalize. + if (!Op1IsExtend) { + // Sub requires at least operand 1 to be an extend. + if (!IsAdd) + return SDValue(); + + // Add is commutable, if the other operand is foldable, swap them. + if (!Op0IsExtend) + return SDValue(); + + std::swap(Op0, Op1); + std::swap(Op0IsExtend, Op1IsExtend); + } + + // Op1 is a foldable extend. Op0 might be foldable. + Op1 = Op1.getOperand(0); + if (Op0IsExtend) + Op0 = Op0.getOperand(0); + + unsigned Opc; + if (IsAdd) + Opc = Op0IsExtend ? RISCVISD::VFWADD_VL : RISCVISD::VFWADD_W_VL; + else + Opc = Op0IsExtend ? RISCVISD::VFWSUB_VL : RISCVISD::VFWSUB_W_VL; + + return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op0, Op1, Merge, Mask, + VL); +} + static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { assert(N->getOpcode() == ISD::SRA && "Unexpected opcode"); @@ -12349,6 +12401,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return performVFMADD_VLCombine(N, DAG); case RISCVISD::FMUL_VL: return performVFMUL_VLCombine(N, DAG); + case RISCVISD::FADD_VL: + case RISCVISD::FSUB_VL: + return performFADDSUB_VLCombine(N, DAG); case ISD::LOAD: case ISD::STORE: { if (DCI.isAfterLegalizeDAG()) @@ -15460,6 +15515,10 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(VWSUB_W_VL) NODE_NAME_CASE(VWSUBU_W_VL) NODE_NAME_CASE(VFWMUL_VL) + NODE_NAME_CASE(VFWADD_VL) + NODE_NAME_CASE(VFWSUB_VL) + NODE_NAME_CASE(VFWADD_W_VL) + NODE_NAME_CASE(VFWSUB_W_VL) NODE_NAME_CASE(VNSRL_VL) NODE_NAME_CASE(SETCC_VL) NODE_NAME_CASE(VSELECT_VL) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index fb67ed5445068..69d5dffa15d98 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -285,6 +285,10 @@ enum NodeType : unsigned { VWSUBU_W_VL, VFWMUL_VL, + VFWADD_VL, + VFWSUB_VL, + VFWADD_W_VL, + VFWSUB_W_VL, // Narrowing logical shift right. // Operands are (source, shift, passthru, mask, vl) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index 056c5ce61bbd7..71df6e4a6fce2 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -400,6 +400,8 @@ def SDT_RISCVVWFPBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>, SDTCVecEltisVT<4, i1>, SDTCisVT<5, XLenVT>]>; def riscv_vfwmul_vl : SDNode<"RISCVISD::VFWMUL_VL", SDT_RISCVVWFPBinOp_VL, [SDNPCommutative]>; +def riscv_vfwadd_vl : SDNode<"RISCVISD::VFWADD_VL", SDT_RISCVVWFPBinOp_VL, [SDNPCommutative]>; +def riscv_vfwsub_vl : SDNode<"RISCVISD::VFWSUB_VL", SDT_RISCVVWFPBinOp_VL, []>; def SDT_RISCVVNIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>, SDTCisInt<1>, @@ -426,6 +428,19 @@ def riscv_vwaddu_w_vl : SDNode<"RISCVISD::VWADDU_W_VL", SDT_RISCVVWIntBinOpW_VL> def riscv_vwsub_w_vl : SDNode<"RISCVISD::VWSUB_W_VL", SDT_RISCVVWIntBinOpW_VL>; def riscv_vwsubu_w_vl : SDNode<"RISCVISD::VWSUBU_W_VL", SDT_RISCVVWIntBinOpW_VL>; +def SDT_RISCVVWFPBinOpW_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>, + SDTCisSameAs<0, 1>, + SDTCisFP<2>, + SDTCisSameNumEltsAs<1, 2>, + SDTCisOpSmallerThanOp<2, 1>, + SDTCisSameAs<0, 3>, + SDTCisSameNumEltsAs<1, 4>, + SDTCVecEltisVT<4, i1>, + SDTCisVT<5, XLenVT>]>; + +def riscv_vfwadd_w_vl : SDNode<"RISCVISD::VFWADD_W_VL", SDT_RISCVVWFPBinOpW_VL>; +def riscv_vfwsub_w_vl : SDNode<"RISCVISD::VFWSUB_W_VL", SDT_RISCVVWFPBinOpW_VL>; + def SDTRVVVecReduce : SDTypeProfile<1, 6, [ SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisSameAs<0, 3>, SDTCVecEltisVT<4, i1>, SDTCisSameNumEltsAs<2, 4>, SDTCisVT<5, XLenVT>, @@ -1375,70 +1390,24 @@ multiclass VPatBinaryFPWVL_VV_VF { } } -multiclass VPatWidenBinaryFPVL_VV_VF { - foreach fvtiToFWti = AllWidenableFloatVectors in { - defvar fvti = fvtiToFWti.Vti; - defvar fwti = fvtiToFWti.Wti; - let Predicates = !listconcat(GetVTypePredicates.Predicates, - GetVTypePredicates.Predicates) in { - def : Pat<(fwti.Vector (op (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs2), - (fvti.Mask true_mask), VLOpFrag)), - (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs1), - (fvti.Mask true_mask), VLOpFrag)), - srcvalue, (fwti.Mask true_mask), VLOpFrag)), - (!cast(instruction_name#"_VV_"#fvti.LMul.MX) - fvti.RegClass:$rs2, fvti.RegClass:$rs1, - GPR:$vl, fvti.Log2SEW)>; - def : Pat<(fwti.Vector (op (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs2), - (fvti.Mask true_mask), VLOpFrag)), - (fwti.Vector (extop (fvti.Vector (SplatFPOp fvti.ScalarRegClass:$rs1)), - (fvti.Mask true_mask), VLOpFrag)), - srcvalue, (fwti.Mask true_mask), VLOpFrag)), - (!cast(instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX) - fvti.RegClass:$rs2, fvti.ScalarRegClass:$rs1, - GPR:$vl, fvti.Log2SEW)>; - } - } -} - -multiclass VPatWidenBinaryFPVL_WV_WF { +multiclass VPatBinaryFPWVL_VV_VF_WV_WF + : VPatBinaryFPWVL_VV_VF { foreach fvtiToFWti = AllWidenableFloatVectors in { - defvar fvti = fvtiToFWti.Vti; - defvar fwti = fvtiToFWti.Wti; - let Predicates = !listconcat(GetVTypePredicates.Predicates, - GetVTypePredicates.Predicates) in { - def : Pat<(fwti.Vector (op (fwti.Vector fwti.RegClass:$rs2), - (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs1), - (fvti.Mask true_mask), VLOpFrag)), - srcvalue, (fwti.Mask true_mask), VLOpFrag)), - (!cast(instruction_name#"_WV_"#fvti.LMul.MX#"_TIED") - fwti.RegClass:$rs2, fvti.RegClass:$rs1, - GPR:$vl, fvti.Log2SEW, TAIL_AGNOSTIC)>; - // Tail undisturbed - def : Pat<(riscv_vp_merge_vl true_mask, - (fwti.Vector (op (fwti.Vector fwti.RegClass:$rs2), - (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs1), - (fvti.Mask true_mask), VLOpFrag)), - srcvalue, (fwti.Mask true_mask), VLOpFrag)), - fwti.RegClass:$rs2, VLOpFrag), - (!cast(instruction_name#"_WV_"#fvti.LMul.MX#"_TIED") - fwti.RegClass:$rs2, fvti.RegClass:$rs1, - GPR:$vl, fvti.Log2SEW, TAIL_UNDISTURBED_MASK_UNDISTURBED)>; - def : Pat<(fwti.Vector (op (fwti.Vector fwti.RegClass:$rs2), - (fwti.Vector (extop (fvti.Vector (SplatFPOp fvti.ScalarRegClass:$rs1)), - (fvti.Mask true_mask), VLOpFrag)), - srcvalue, (fwti.Mask true_mask), VLOpFrag)), - (!cast(instruction_name#"_W"#fvti.ScalarSuffix#"_"#fvti.LMul.MX) - fwti.RegClass:$rs2, fvti.ScalarRegClass:$rs1, - GPR:$vl, fvti.Log2SEW)>; + defvar vti = fvtiToFWti.Vti; + defvar wti = fvtiToFWti.Wti; + let Predicates = !listconcat(GetVTypePredicates.Predicates, + GetVTypePredicates.Predicates) in { + defm : VPatTiedBinaryNoMaskVL_V; + def : VPatBinaryVL_VF; } } } -multiclass VPatWidenBinaryFPVL_VV_VF_WV_WF - : VPatWidenBinaryFPVL_VV_VF, - VPatWidenBinaryFPVL_WV_WF; - multiclass VPatNarrowShiftSplatExt_WX { foreach vtiToWti = AllWidenableIntVectors in { defvar vti = vtiToWti.Vti; @@ -1938,8 +1907,8 @@ defm : VPatBinaryFPVL_VV_VF; defm : VPatBinaryFPVL_R_VF; // 13.3. Vector Widening Floating-Point Add/Subtract Instructions -defm : VPatWidenBinaryFPVL_VV_VF_WV_WF; -defm : VPatWidenBinaryFPVL_VV_VF_WV_WF; +defm : VPatBinaryFPWVL_VV_VF_WV_WF; +defm : VPatBinaryFPWVL_VV_VF_WV_WF; // 13.4. Vector Single-Width Floating-Point Multiply/Divide Instructions defm : VPatBinaryFPVL_VV_VF; diff --git a/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll index 1c2ba683cd876..661d8cc5a468d 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll @@ -1,6 +1,21 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+experimental-zvfh | FileCheck %s +define @vfwadd_same_operand( %arg, i32 signext %vl) { +; CHECK-LABEL: vfwadd_same_operand: +; CHECK: # %bb.0: # %bb +; CHECK-NEXT: slli a0, a0, 32 +; CHECK-NEXT: srli a0, a0, 32 +; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma +; CHECK-NEXT: vfwadd.vv v9, v8, v8 +; CHECK-NEXT: vmv1r.v v8, v9 +; CHECK-NEXT: ret +bb: + %tmp = call @llvm.vp.fpext.nxv2f32.nxv2f16( %arg, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer), i32 %vl) + %tmp2 = call @llvm.vp.fadd.nxv2f32( %tmp, %tmp, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer), i32 %vl) + ret %tmp2 +} + define @vfwadd_tu( %arg, %arg1, i32 signext %arg2) { ; CHECK-LABEL: vfwadd_tu: ; CHECK: # %bb.0: # %bb