Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][ISel] Combine scalable vector fadd/fsub/fmul with fp extend. #88615

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sun-jacobi
Copy link
Member

@sun-jacobi sun-jacobi commented Apr 13, 2024

Extend D133739, #76785 and #81248 to support combining scalable vector fadd/fsub/fmul with fp extend.

Specifically, this patch works for the below optimization case:

Source code

define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <vscale x 2 x float> %a, <vscale x 2 x float> %b, <vscale x 2 x float> %b2) {
  %c = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
  %d = fpext <vscale x 2 x float> %b to <vscale x 2 x double>
  %d2 = fpext <vscale x 2 x float> %b2 to <vscale x 2 x double>
  %e = fmul <vscale x 2 x double> %c, %d
  %f = fadd <vscale x 2 x double> %c, %d2
  %g = fsub <vscale x 2 x double> %d, %d2
  store <vscale x 2 x double> %e, ptr %x
  store <vscale x 2 x double> %f, ptr %y
  store <vscale x 2 x double> %g, ptr %z
  ret void
}

Before this patch

Compiler Explorer

vfwmul_v2f32_multiple_users: 
        vsetvli a3, zero, e32, m1, ta, ma
        vfwcvt.f.f.v    v12, v8
        vfwcvt.f.f.v    v14, v9
        vfwcvt.f.f.v    v8, v10
        vsetvli zero, zero, e64, m2, ta, ma
        vfmul.vv        v10, v12, v14
        vfadd.vv        v12, v12, v8
        vfsub.vv        v8, v14, v8
        vs2r.v  v10, (a0)
        vs2r.v  v12, (a1)
        vs2r.v  v8, (a2)
        ret

After this patch

vfwmul_v2f32_multiple_users:
        vsetvli a3, zero, e32, m1, ta, ma
        vfwmul.vv v12, v8, v9
        vfwadd.vv v14, v8, v10
        vfwsub.vv v16, v9, v10
        vs2r.v v12, (a0)
        vs2r.v v14, (a1)
        vs2r.v v16, (a2)
        ret

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 13, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Chia (sun-jacobi)

Changes

Extend D133739, #76785 and ##81248 to support combining scalable vector fadd/fsub/fmul with fp extend.

Specifically, this patch works for the below optimization case:

Source code

define void @<!-- -->vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, &lt;vscale x 2 x float&gt; %a, &lt;vscale x 2 x float&gt; %b, &lt;vscale x 2 x float&gt; %b2) {
  %c = fpext &lt;vscale x 2 x float&gt; %a to &lt;vscale x 2 x double&gt;
  %d = fpext &lt;vscale x 2 x float&gt; %b to &lt;vscale x 2 x double&gt;
  %d2 = fpext &lt;vscale x 2 x float&gt; %b2 to &lt;vscale x 2 x double&gt;
  %e = fmul &lt;vscale x 2 x double&gt; %c, %d
  %f = fadd &lt;vscale x 2 x double&gt; %c, %d2
  %g = fsub &lt;vscale x 2 x double&gt; %d, %d2
  store &lt;vscale x 2 x double&gt; %e, ptr %x
  store &lt;vscale x 2 x double&gt; %f, ptr %y
  store &lt;vscale x 2 x double&gt; %g, ptr %z
  ret void
}

Before this patch

Compiler Explorer

vfwmul_v2f32_multiple_users: 
        vsetvli a3, zero, e32, m1, ta, ma
        vfwcvt.f.f.v    v12, v8
        vfwcvt.f.f.v    v14, v9
        vfwcvt.f.f.v    v8, v10
        vsetvli zero, zero, e64, m2, ta, ma
        vfmul.vv        v10, v12, v14
        vfadd.vv        v12, v12, v8
        vfsub.vv        v8, v14, v8
        vs2r.v  v10, (a0)
        vs2r.v  v12, (a1)
        vs2r.v  v8, (a2)
        ret

After this patch

vfwmul_v2f32_multiple_users:
        vsetvli a3, zero, e32, m1, ta, ma
        vfwmul.vv v12, v8, v9
        vfwadd.vv v14, v8, v10
        vfwsub.vv v16, v9, v10
        vs2r.v v12, (a0)
        vs2r.v v14, (a1)
        vs2r.v v16, (a2)
        ret

Full diff: https://github.com/llvm/llvm-project/pull/88615.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+60-10)
  • (added) llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll (+99)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5a572002091ff3..b8b926a54ea908 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1430,6 +1430,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                          ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
                          ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM,
                          ISD::INSERT_VECTOR_ELT, ISD::ABS});
+  if (Subtarget.hasVInstructionsAnyF())
+    setTargetDAGCombine({ISD::FADD, ISD::FSUB, ISD::FMUL});
   if (Subtarget.hasVendorXTHeadMemPair())
     setTargetDAGCombine({ISD::LOAD, ISD::STORE});
   if (Subtarget.useRVVForFixedLengthVectors())
@@ -13597,6 +13599,13 @@ struct NodeExtensionHelper {
     case RISCVISD::VZEXT_VL:
     case RISCVISD::FP_EXTEND_VL:
       return OrigOperand.getOperand(0);
+    case ISD::SPLAT_VECTOR: {
+      SDValue Op = OrigOperand.getOperand(0);
+      if (Op.getOpcode() == ISD::FP_EXTEND)
+        return Op;
+      return OrigOperand;
+    }
+
     default:
       return OrigOperand;
     }
@@ -13735,12 +13744,15 @@ struct NodeExtensionHelper {
   /// Opcode(fpext(a), fpext(b)) -> newOpcode(a, b)
   static unsigned getFPExtOpcode(unsigned Opcode) {
     switch (Opcode) {
+    case ISD::FADD:
     case RISCVISD::FADD_VL:
     case RISCVISD::VFWADD_W_VL:
       return RISCVISD::VFWADD_VL;
+    case ISD::FSUB:
     case RISCVISD::FSUB_VL:
     case RISCVISD::VFWSUB_W_VL:
       return RISCVISD::VFWSUB_VL;
+    case ISD::FMUL:
     case RISCVISD::FMUL_VL:
       return RISCVISD::VFWMUL_VL;
     default:
@@ -13769,8 +13781,10 @@ struct NodeExtensionHelper {
     case RISCVISD::SUB_VL:
       return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_W_VL
                                           : RISCVISD::VWSUBU_W_VL;
+    case ISD::FADD:
     case RISCVISD::FADD_VL:
       return RISCVISD::VFWADD_W_VL;
+    case ISD::FSUB:
     case RISCVISD::FSUB_VL:
       return RISCVISD::VFWSUB_W_VL;
     default:
@@ -13824,6 +13838,10 @@ struct NodeExtensionHelper {
                               APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
       SupportsZExt = true;
 
+    if (Op.getOpcode() == ISD::FP_EXTEND &&
+        NarrowSize >= (Subtarget.hasVInstructionsF16() ? 16 : 32))
+      SupportsFPExt = true;
+
     EnforceOneUse = false;
   }
 
@@ -13854,6 +13872,7 @@ struct NodeExtensionHelper {
 
       SupportsZExt = Opc == ISD::ZERO_EXTEND;
       SupportsSExt = Opc == ISD::SIGN_EXTEND;
+      SupportsFPExt = Opc == ISD::FP_EXTEND;
       break;
     }
     case RISCVISD::VZEXT_VL:
@@ -13862,9 +13881,18 @@ struct NodeExtensionHelper {
     case RISCVISD::VSEXT_VL:
       SupportsSExt = true;
       break;
-    case RISCVISD::FP_EXTEND_VL:
+    case RISCVISD::FP_EXTEND_VL: {
+      SDValue NarrowElt = OrigOperand.getOperand(0);
+      MVT NarrowVT = NarrowElt.getSimpleValueType();
+
+      if (!Subtarget.hasVInstructionsF16() &&
+          NarrowVT.getVectorElementType() == MVT::f16)
+        break;
+
       SupportsFPExt = true;
       break;
+    }
+
     case ISD::SPLAT_VECTOR:
     case RISCVISD::VMV_V_X_VL:
       fillUpExtensionSupportForSplat(Root, DAG, Subtarget);
@@ -13880,13 +13908,16 @@ struct NodeExtensionHelper {
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
-    case ISD::MUL: {
+    case ISD::MUL:
       return Root->getValueType(0).isScalableVector();
-    }
-    case ISD::OR: {
+    case ISD::OR:
       return Root->getValueType(0).isScalableVector() &&
              Root->getFlags().hasDisjoint();
-    }
+    case ISD::FADD:
+    case ISD::FSUB:
+    case ISD::FMUL:
+      return Root->getValueType(0).isScalableVector() &&
+             Subtarget.hasVInstructionsAnyF();
     // Vector Widening Integer Add/Sub/Mul Instructions
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
@@ -13963,7 +13994,10 @@ struct NodeExtensionHelper {
     case ISD::SUB:
     case ISD::MUL:
     case ISD::OR:
-    case ISD::SHL: {
+    case ISD::SHL:
+    case ISD::FADD:
+    case ISD::FSUB:
+    case ISD::FMUL: {
       SDLoc DL(Root);
       MVT VT = Root->getSimpleValueType(0);
       return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13980,6 +14014,8 @@ struct NodeExtensionHelper {
     case ISD::ADD:
     case ISD::MUL:
     case ISD::OR:
+    case ISD::FADD:
+    case ISD::FMUL:
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
@@ -13989,6 +14025,7 @@ struct NodeExtensionHelper {
     case RISCVISD::VFWADD_W_VL:
       return true;
     case ISD::SUB:
+    case ISD::FSUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
@@ -14050,6 +14087,9 @@ struct CombineResult {
     case ISD::MUL:
     case ISD::OR:
     case ISD::SHL:
+    case ISD::FADD:
+    case ISD::FSUB:
+    case ISD::FMUL:
       Merge = DAG.getUNDEF(Root->getValueType(0));
       break;
     }
@@ -14192,6 +14232,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   case ISD::ADD:
   case ISD::SUB:
   case ISD::OR:
+  case ISD::FADD:
+  case ISD::FSUB:
   case RISCVISD::ADD_VL:
   case RISCVISD::SUB_VL:
   case RISCVISD::FADD_VL:
@@ -14201,6 +14243,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
     // add|sub|fadd|fsub -> vwadd(u)_w|vwsub(u)_w}|vfwadd_w|vfwsub_w
     Strategies.push_back(canFoldToVW_W);
     break;
+  case ISD::FMUL:
   case RISCVISD::FMUL_VL:
     Strategies.push_back(canFoldToVWWithSameExtension);
     break;
@@ -14244,9 +14287,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
 /// sub | sub_vl -> vwsub(u) | vwsub(u)_w
 /// mul | mul_vl -> vwmul(u) | vwmul_su
 /// shl | shl_vl -> vwsll
-/// fadd_vl ->  vfwadd | vfwadd_w
-/// fsub_vl ->  vfwsub | vfwsub_w
-/// fmul_vl ->  vfwmul
+/// fadd | fadd_vl ->  vfwadd | vfwadd_w
+/// fsub | fsub_vl ->  vfwsub | vfwsub_w
+/// fmul | fmul_vl ->  vfwmul
 /// vwadd_w(u) -> vwadd(u)
 /// vwsub_w(u) -> vwsub(u)
 /// vfwadd_w -> vfwadd
@@ -15921,7 +15964,14 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     if (SDValue V = combineBinOpOfZExt(N, DAG))
       return V;
     break;
-  case ISD::FADD:
+  case ISD::FSUB:
+  case ISD::FMUL:
+    return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+  case ISD::FADD: {
+    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+      return V;
+    [[fallthrough]];
+  }
   case ISD::UMAX:
   case ISD::UMIN:
   case ISD::SMAX:
diff --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll
new file mode 100644
index 00000000000000..0d1713acfc0cd0
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFH
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfhmin,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFHMIN
+; Check that the default value enables the web folding and
+; that it is bigger than 3.
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+
+define void @vfwmul_v2f116_multiple_users(ptr %x, ptr %y, ptr %z, <vscale x 2 x half> %a, <vscale x 2 x half> %b, <vscale x 2 x half> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f116_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetvli a3, zero, e16, mf2, ta, ma
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v11, v8
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; NO_FOLDING-NEXT:    vfmul.vv v10, v11, v8
+; NO_FOLDING-NEXT:    vfadd.vv v11, v11, v9
+; NO_FOLDING-NEXT:    vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT:    vs1r.v v10, (a0)
+; NO_FOLDING-NEXT:    vs1r.v v11, (a1)
+; NO_FOLDING-NEXT:    vs1r.v v8, (a2)
+; NO_FOLDING-NEXT:    ret
+;
+; ZVFH-LABEL: vfwmul_v2f116_multiple_users:
+; ZVFH:       # %bb.0:
+; ZVFH-NEXT:    vsetvli a3, zero, e16, mf2, ta, ma
+; ZVFH-NEXT:    vfwmul.vv v11, v8, v9
+; ZVFH-NEXT:    vfwadd.vv v12, v8, v10
+; ZVFH-NEXT:    vfwsub.vv v8, v9, v10
+; ZVFH-NEXT:    vs1r.v v11, (a0)
+; ZVFH-NEXT:    vs1r.v v12, (a1)
+; ZVFH-NEXT:    vs1r.v v8, (a2)
+; ZVFH-NEXT:    ret
+;
+; ZVFHMIN-LABEL: vfwmul_v2f116_multiple_users:
+; ZVFHMIN:       # %bb.0:
+; ZVFHMIN-NEXT:    vsetvli a3, zero, e16, mf2, ta, ma
+; ZVFHMIN-NEXT:    vfwcvt.f.f.v v11, v8
+; ZVFHMIN-NEXT:    vfwcvt.f.f.v v8, v9
+; ZVFHMIN-NEXT:    vfwcvt.f.f.v v9, v10
+; ZVFHMIN-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; ZVFHMIN-NEXT:    vfmul.vv v10, v11, v8
+; ZVFHMIN-NEXT:    vfadd.vv v11, v11, v9
+; ZVFHMIN-NEXT:    vfsub.vv v8, v8, v9
+; ZVFHMIN-NEXT:    vs1r.v v10, (a0)
+; ZVFHMIN-NEXT:    vs1r.v v11, (a1)
+; ZVFHMIN-NEXT:    vs1r.v v8, (a2)
+; ZVFHMIN-NEXT:    ret
+  %c = fpext <vscale x 2 x half> %a to <vscale x 2 x float>
+  %d = fpext <vscale x 2 x half> %b to <vscale x 2 x float>
+  %d2 = fpext <vscale x 2 x half> %b2 to <vscale x 2 x float>
+  %e = fmul <vscale x 2 x float> %c, %d
+  %f = fadd <vscale x 2 x float> %c, %d2
+  %g = fsub <vscale x 2 x float> %d, %d2
+  store <vscale x 2 x float> %e, ptr %x
+  store <vscale x 2 x float> %f, ptr %y
+  store <vscale x 2 x float> %g, ptr %z
+  ret void
+}
+
+define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <vscale x 2 x float> %a, <vscale x 2 x float> %b, <vscale x 2 x float> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetvli a3, zero, e32, m1, ta, ma
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v12, v8
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v14, v9
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v8, v10
+; NO_FOLDING-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
+; NO_FOLDING-NEXT:    vfmul.vv v10, v12, v14
+; NO_FOLDING-NEXT:    vfadd.vv v12, v12, v8
+; NO_FOLDING-NEXT:    vfsub.vv v8, v14, v8
+; NO_FOLDING-NEXT:    vs2r.v v10, (a0)
+; NO_FOLDING-NEXT:    vs2r.v v12, (a1)
+; NO_FOLDING-NEXT:    vs2r.v v8, (a2)
+; NO_FOLDING-NEXT:    ret
+;
+; FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetvli a3, zero, e32, m1, ta, ma
+; FOLDING-NEXT:    vfwmul.vv v12, v8, v9
+; FOLDING-NEXT:    vfwadd.vv v14, v8, v10
+; FOLDING-NEXT:    vfwsub.vv v16, v9, v10
+; FOLDING-NEXT:    vs2r.v v12, (a0)
+; FOLDING-NEXT:    vs2r.v v14, (a1)
+; FOLDING-NEXT:    vs2r.v v16, (a2)
+; FOLDING-NEXT:    ret
+  %c = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
+  %d = fpext <vscale x 2 x float> %b to <vscale x 2 x double>
+  %d2 = fpext <vscale x 2 x float> %b2 to <vscale x 2 x double>
+  %e = fmul <vscale x 2 x double> %c, %d
+  %f = fadd <vscale x 2 x double> %c, %d2
+  %g = fsub <vscale x 2 x double> %d, %d2
+  store <vscale x 2 x double> %e, ptr %x
+  store <vscale x 2 x double> %f, ptr %y
+  store <vscale x 2 x double> %g, ptr %z
+  ret void
+}

case ISD::SPLAT_VECTOR: {
SDValue Op = OrigOperand.getOperand(0);
if (Op.getOpcode() == ISD::FP_EXTEND)
return Op;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work for fixed length float splats currently if we didn't already handle RISCVISD::VFMV_V_F_VL? Since in #81248 it looks like we emit vfwadd.vf

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to look through scalar FP_EXTEND here, but we don't need to look throug ZERO_EXTEND or SIGN_EXTEND for integer?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work for fixed length float splats currently if we didn't already handle RISCVISD::VFMV_V_F_VL? Since in #81248 it looks like we emit vfwadd.vf

The extend on those tests are happening in the vector domain not the scalar domain I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work for fixed length float splats currently if we didn't already handle RISCVISD::VFMV_V_F_VL? Since in #81248 it looks like we emit vfwadd.vf

The extend on those tests are happening in the vector domain not the scalar domain I think.

Yes, you are right. We also need to handle the extension in the scalar domain (i.e. similar to #87249) for RISCVISD::VFMV_V_F_VL. Thanks for pointing out this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to look through scalar FP_EXTEND here, but we don't need to look throug ZERO_EXTEND or SIGN_EXTEND for integer?

For integers, I think the DAG.getSplat would do an implicit truncation for us. But it seems that the float does not work.

@sun-jacobi sun-jacobi removed the request for review from luke957 April 17, 2024 15:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants