diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 5a572002091ff..b8b926a54ea90 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1430,6 +1430,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL, ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM, ISD::INSERT_VECTOR_ELT, ISD::ABS}); + if (Subtarget.hasVInstructionsAnyF()) + setTargetDAGCombine({ISD::FADD, ISD::FSUB, ISD::FMUL}); if (Subtarget.hasVendorXTHeadMemPair()) setTargetDAGCombine({ISD::LOAD, ISD::STORE}); if (Subtarget.useRVVForFixedLengthVectors()) @@ -13597,6 +13599,13 @@ struct NodeExtensionHelper { case RISCVISD::VZEXT_VL: case RISCVISD::FP_EXTEND_VL: return OrigOperand.getOperand(0); + case ISD::SPLAT_VECTOR: { + SDValue Op = OrigOperand.getOperand(0); + if (Op.getOpcode() == ISD::FP_EXTEND) + return Op; + return OrigOperand; + } + default: return OrigOperand; } @@ -13735,12 +13744,15 @@ struct NodeExtensionHelper { /// Opcode(fpext(a), fpext(b)) -> newOpcode(a, b) static unsigned getFPExtOpcode(unsigned Opcode) { switch (Opcode) { + case ISD::FADD: case RISCVISD::FADD_VL: case RISCVISD::VFWADD_W_VL: return RISCVISD::VFWADD_VL; + case ISD::FSUB: case RISCVISD::FSUB_VL: case RISCVISD::VFWSUB_W_VL: return RISCVISD::VFWSUB_VL; + case ISD::FMUL: case RISCVISD::FMUL_VL: return RISCVISD::VFWMUL_VL; default: @@ -13769,8 +13781,10 @@ struct NodeExtensionHelper { case RISCVISD::SUB_VL: return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL; + case ISD::FADD: case RISCVISD::FADD_VL: return RISCVISD::VFWADD_W_VL; + case ISD::FSUB: case RISCVISD::FSUB_VL: return RISCVISD::VFWSUB_W_VL; default: @@ -13824,6 +13838,10 @@ struct NodeExtensionHelper { APInt::getBitsSetFrom(ScalarBits, NarrowSize))) SupportsZExt = true; + if (Op.getOpcode() == ISD::FP_EXTEND && + NarrowSize >= (Subtarget.hasVInstructionsF16() ? 16 : 32)) + SupportsFPExt = true; + EnforceOneUse = false; } @@ -13854,6 +13872,7 @@ struct NodeExtensionHelper { SupportsZExt = Opc == ISD::ZERO_EXTEND; SupportsSExt = Opc == ISD::SIGN_EXTEND; + SupportsFPExt = Opc == ISD::FP_EXTEND; break; } case RISCVISD::VZEXT_VL: @@ -13862,9 +13881,18 @@ struct NodeExtensionHelper { case RISCVISD::VSEXT_VL: SupportsSExt = true; break; - case RISCVISD::FP_EXTEND_VL: + case RISCVISD::FP_EXTEND_VL: { + SDValue NarrowElt = OrigOperand.getOperand(0); + MVT NarrowVT = NarrowElt.getSimpleValueType(); + + if (!Subtarget.hasVInstructionsF16() && + NarrowVT.getVectorElementType() == MVT::f16) + break; + SupportsFPExt = true; break; + } + case ISD::SPLAT_VECTOR: case RISCVISD::VMV_V_X_VL: fillUpExtensionSupportForSplat(Root, DAG, Subtarget); @@ -13880,13 +13908,16 @@ struct NodeExtensionHelper { switch (Root->getOpcode()) { case ISD::ADD: case ISD::SUB: - case ISD::MUL: { + case ISD::MUL: return Root->getValueType(0).isScalableVector(); - } - case ISD::OR: { + case ISD::OR: return Root->getValueType(0).isScalableVector() && Root->getFlags().hasDisjoint(); - } + case ISD::FADD: + case ISD::FSUB: + case ISD::FMUL: + return Root->getValueType(0).isScalableVector() && + Subtarget.hasVInstructionsAnyF(); // Vector Widening Integer Add/Sub/Mul Instructions case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: @@ -13963,7 +13994,10 @@ struct NodeExtensionHelper { case ISD::SUB: case ISD::MUL: case ISD::OR: - case ISD::SHL: { + case ISD::SHL: + case ISD::FADD: + case ISD::FSUB: + case ISD::FMUL: { SDLoc DL(Root); MVT VT = Root->getSimpleValueType(0); return getDefaultScalableVLOps(VT, DL, DAG, Subtarget); @@ -13980,6 +14014,8 @@ struct NodeExtensionHelper { case ISD::ADD: case ISD::MUL: case ISD::OR: + case ISD::FADD: + case ISD::FMUL: case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: case RISCVISD::VWADD_W_VL: @@ -13989,6 +14025,7 @@ struct NodeExtensionHelper { case RISCVISD::VFWADD_W_VL: return true; case ISD::SUB: + case ISD::FSUB: case RISCVISD::SUB_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: @@ -14050,6 +14087,9 @@ struct CombineResult { case ISD::MUL: case ISD::OR: case ISD::SHL: + case ISD::FADD: + case ISD::FSUB: + case ISD::FMUL: Merge = DAG.getUNDEF(Root->getValueType(0)); break; } @@ -14192,6 +14232,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { case ISD::ADD: case ISD::SUB: case ISD::OR: + case ISD::FADD: + case ISD::FSUB: case RISCVISD::ADD_VL: case RISCVISD::SUB_VL: case RISCVISD::FADD_VL: @@ -14201,6 +14243,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { // add|sub|fadd|fsub -> vwadd(u)_w|vwsub(u)_w}|vfwadd_w|vfwsub_w Strategies.push_back(canFoldToVW_W); break; + case ISD::FMUL: case RISCVISD::FMUL_VL: Strategies.push_back(canFoldToVWWithSameExtension); break; @@ -14244,9 +14287,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { /// sub | sub_vl -> vwsub(u) | vwsub(u)_w /// mul | mul_vl -> vwmul(u) | vwmul_su /// shl | shl_vl -> vwsll -/// fadd_vl -> vfwadd | vfwadd_w -/// fsub_vl -> vfwsub | vfwsub_w -/// fmul_vl -> vfwmul +/// fadd | fadd_vl -> vfwadd | vfwadd_w +/// fsub | fsub_vl -> vfwsub | vfwsub_w +/// fmul | fmul_vl -> vfwmul /// vwadd_w(u) -> vwadd(u) /// vwsub_w(u) -> vwsub(u) /// vfwadd_w -> vfwadd @@ -15921,7 +15964,14 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, if (SDValue V = combineBinOpOfZExt(N, DAG)) return V; break; - case ISD::FADD: + case ISD::FSUB: + case ISD::FMUL: + return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget); + case ISD::FADD: { + if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) + return V; + [[fallthrough]]; + } case ISD::UMAX: case ISD::UMIN: case ISD::SMAX: diff --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll new file mode 100644 index 0000000000000..0d1713acfc0cd --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll @@ -0,0 +1,99 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFH +; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfhmin,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFHMIN +; Check that the default value enables the web folding and +; that it is bigger than 3. +; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING + +define void @vfwmul_v2f116_multiple_users(ptr %x, ptr %y, ptr %z, %a, %b, %b2) { +; NO_FOLDING-LABEL: vfwmul_v2f116_multiple_users: +; NO_FOLDING: # %bb.0: +; NO_FOLDING-NEXT: vsetvli a3, zero, e16, mf2, ta, ma +; NO_FOLDING-NEXT: vfwcvt.f.f.v v11, v8 +; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9 +; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10 +; NO_FOLDING-NEXT: vsetvli zero, zero, e32, m1, ta, ma +; NO_FOLDING-NEXT: vfmul.vv v10, v11, v8 +; NO_FOLDING-NEXT: vfadd.vv v11, v11, v9 +; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9 +; NO_FOLDING-NEXT: vs1r.v v10, (a0) +; NO_FOLDING-NEXT: vs1r.v v11, (a1) +; NO_FOLDING-NEXT: vs1r.v v8, (a2) +; NO_FOLDING-NEXT: ret +; +; ZVFH-LABEL: vfwmul_v2f116_multiple_users: +; ZVFH: # %bb.0: +; ZVFH-NEXT: vsetvli a3, zero, e16, mf2, ta, ma +; ZVFH-NEXT: vfwmul.vv v11, v8, v9 +; ZVFH-NEXT: vfwadd.vv v12, v8, v10 +; ZVFH-NEXT: vfwsub.vv v8, v9, v10 +; ZVFH-NEXT: vs1r.v v11, (a0) +; ZVFH-NEXT: vs1r.v v12, (a1) +; ZVFH-NEXT: vs1r.v v8, (a2) +; ZVFH-NEXT: ret +; +; ZVFHMIN-LABEL: vfwmul_v2f116_multiple_users: +; ZVFHMIN: # %bb.0: +; ZVFHMIN-NEXT: vsetvli a3, zero, e16, mf2, ta, ma +; ZVFHMIN-NEXT: vfwcvt.f.f.v v11, v8 +; ZVFHMIN-NEXT: vfwcvt.f.f.v v8, v9 +; ZVFHMIN-NEXT: vfwcvt.f.f.v v9, v10 +; ZVFHMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma +; ZVFHMIN-NEXT: vfmul.vv v10, v11, v8 +; ZVFHMIN-NEXT: vfadd.vv v11, v11, v9 +; ZVFHMIN-NEXT: vfsub.vv v8, v8, v9 +; ZVFHMIN-NEXT: vs1r.v v10, (a0) +; ZVFHMIN-NEXT: vs1r.v v11, (a1) +; ZVFHMIN-NEXT: vs1r.v v8, (a2) +; ZVFHMIN-NEXT: ret + %c = fpext %a to + %d = fpext %b to + %d2 = fpext %b2 to + %e = fmul %c, %d + %f = fadd %c, %d2 + %g = fsub %d, %d2 + store %e, ptr %x + store %f, ptr %y + store %g, ptr %z + ret void +} + +define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, %a, %b, %b2) { +; NO_FOLDING-LABEL: vfwmul_v2f32_multiple_users: +; NO_FOLDING: # %bb.0: +; NO_FOLDING-NEXT: vsetvli a3, zero, e32, m1, ta, ma +; NO_FOLDING-NEXT: vfwcvt.f.f.v v12, v8 +; NO_FOLDING-NEXT: vfwcvt.f.f.v v14, v9 +; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v10 +; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m2, ta, ma +; NO_FOLDING-NEXT: vfmul.vv v10, v12, v14 +; NO_FOLDING-NEXT: vfadd.vv v12, v12, v8 +; NO_FOLDING-NEXT: vfsub.vv v8, v14, v8 +; NO_FOLDING-NEXT: vs2r.v v10, (a0) +; NO_FOLDING-NEXT: vs2r.v v12, (a1) +; NO_FOLDING-NEXT: vs2r.v v8, (a2) +; NO_FOLDING-NEXT: ret +; +; FOLDING-LABEL: vfwmul_v2f32_multiple_users: +; FOLDING: # %bb.0: +; FOLDING-NEXT: vsetvli a3, zero, e32, m1, ta, ma +; FOLDING-NEXT: vfwmul.vv v12, v8, v9 +; FOLDING-NEXT: vfwadd.vv v14, v8, v10 +; FOLDING-NEXT: vfwsub.vv v16, v9, v10 +; FOLDING-NEXT: vs2r.v v12, (a0) +; FOLDING-NEXT: vs2r.v v14, (a1) +; FOLDING-NEXT: vs2r.v v16, (a2) +; FOLDING-NEXT: ret + %c = fpext %a to + %d = fpext %b to + %d2 = fpext %b2 to + %e = fmul %c, %d + %f = fadd %c, %d2 + %g = fsub %d, %d2 + store %e, ptr %x + store %f, ptr %y + store %g, ptr %z + ret void +}