-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV][ISel] Combine scalable vector fadd/fsub/fmul with fp extend. #88615
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Chia (sun-jacobi) ChangesExtend D133739, #76785 and ##81248 to support combining scalable vector fadd/fsub/fmul with fp extend. Specifically, this patch works for the below optimization case: Source code
Before this patch
After this patch
Full diff: https://github.com/llvm/llvm-project/pull/88615.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5a572002091ff3..b8b926a54ea908 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1430,6 +1430,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM,
ISD::INSERT_VECTOR_ELT, ISD::ABS});
+ if (Subtarget.hasVInstructionsAnyF())
+ setTargetDAGCombine({ISD::FADD, ISD::FSUB, ISD::FMUL});
if (Subtarget.hasVendorXTHeadMemPair())
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
if (Subtarget.useRVVForFixedLengthVectors())
@@ -13597,6 +13599,13 @@ struct NodeExtensionHelper {
case RISCVISD::VZEXT_VL:
case RISCVISD::FP_EXTEND_VL:
return OrigOperand.getOperand(0);
+ case ISD::SPLAT_VECTOR: {
+ SDValue Op = OrigOperand.getOperand(0);
+ if (Op.getOpcode() == ISD::FP_EXTEND)
+ return Op;
+ return OrigOperand;
+ }
+
default:
return OrigOperand;
}
@@ -13735,12 +13744,15 @@ struct NodeExtensionHelper {
/// Opcode(fpext(a), fpext(b)) -> newOpcode(a, b)
static unsigned getFPExtOpcode(unsigned Opcode) {
switch (Opcode) {
+ case ISD::FADD:
case RISCVISD::FADD_VL:
case RISCVISD::VFWADD_W_VL:
return RISCVISD::VFWADD_VL;
+ case ISD::FSUB:
case RISCVISD::FSUB_VL:
case RISCVISD::VFWSUB_W_VL:
return RISCVISD::VFWSUB_VL;
+ case ISD::FMUL:
case RISCVISD::FMUL_VL:
return RISCVISD::VFWMUL_VL;
default:
@@ -13769,8 +13781,10 @@ struct NodeExtensionHelper {
case RISCVISD::SUB_VL:
return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_W_VL
: RISCVISD::VWSUBU_W_VL;
+ case ISD::FADD:
case RISCVISD::FADD_VL:
return RISCVISD::VFWADD_W_VL;
+ case ISD::FSUB:
case RISCVISD::FSUB_VL:
return RISCVISD::VFWSUB_W_VL;
default:
@@ -13824,6 +13838,10 @@ struct NodeExtensionHelper {
APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
SupportsZExt = true;
+ if (Op.getOpcode() == ISD::FP_EXTEND &&
+ NarrowSize >= (Subtarget.hasVInstructionsF16() ? 16 : 32))
+ SupportsFPExt = true;
+
EnforceOneUse = false;
}
@@ -13854,6 +13872,7 @@ struct NodeExtensionHelper {
SupportsZExt = Opc == ISD::ZERO_EXTEND;
SupportsSExt = Opc == ISD::SIGN_EXTEND;
+ SupportsFPExt = Opc == ISD::FP_EXTEND;
break;
}
case RISCVISD::VZEXT_VL:
@@ -13862,9 +13881,18 @@ struct NodeExtensionHelper {
case RISCVISD::VSEXT_VL:
SupportsSExt = true;
break;
- case RISCVISD::FP_EXTEND_VL:
+ case RISCVISD::FP_EXTEND_VL: {
+ SDValue NarrowElt = OrigOperand.getOperand(0);
+ MVT NarrowVT = NarrowElt.getSimpleValueType();
+
+ if (!Subtarget.hasVInstructionsF16() &&
+ NarrowVT.getVectorElementType() == MVT::f16)
+ break;
+
SupportsFPExt = true;
break;
+ }
+
case ISD::SPLAT_VECTOR:
case RISCVISD::VMV_V_X_VL:
fillUpExtensionSupportForSplat(Root, DAG, Subtarget);
@@ -13880,13 +13908,16 @@ struct NodeExtensionHelper {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
- case ISD::MUL: {
+ case ISD::MUL:
return Root->getValueType(0).isScalableVector();
- }
- case ISD::OR: {
+ case ISD::OR:
return Root->getValueType(0).isScalableVector() &&
Root->getFlags().hasDisjoint();
- }
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL:
+ return Root->getValueType(0).isScalableVector() &&
+ Subtarget.hasVInstructionsAnyF();
// Vector Widening Integer Add/Sub/Mul Instructions
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
@@ -13963,7 +13994,10 @@ struct NodeExtensionHelper {
case ISD::SUB:
case ISD::MUL:
case ISD::OR:
- case ISD::SHL: {
+ case ISD::SHL:
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL: {
SDLoc DL(Root);
MVT VT = Root->getSimpleValueType(0);
return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13980,6 +14014,8 @@ struct NodeExtensionHelper {
case ISD::ADD:
case ISD::MUL:
case ISD::OR:
+ case ISD::FADD:
+ case ISD::FMUL:
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
@@ -13989,6 +14025,7 @@ struct NodeExtensionHelper {
case RISCVISD::VFWADD_W_VL:
return true;
case ISD::SUB:
+ case ISD::FSUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
@@ -14050,6 +14087,9 @@ struct CombineResult {
case ISD::MUL:
case ISD::OR:
case ISD::SHL:
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL:
Merge = DAG.getUNDEF(Root->getValueType(0));
break;
}
@@ -14192,6 +14232,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case ISD::ADD:
case ISD::SUB:
case ISD::OR:
+ case ISD::FADD:
+ case ISD::FSUB:
case RISCVISD::ADD_VL:
case RISCVISD::SUB_VL:
case RISCVISD::FADD_VL:
@@ -14201,6 +14243,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
// add|sub|fadd|fsub -> vwadd(u)_w|vwsub(u)_w}|vfwadd_w|vfwsub_w
Strategies.push_back(canFoldToVW_W);
break;
+ case ISD::FMUL:
case RISCVISD::FMUL_VL:
Strategies.push_back(canFoldToVWWithSameExtension);
break;
@@ -14244,9 +14287,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
/// mul | mul_vl -> vwmul(u) | vwmul_su
/// shl | shl_vl -> vwsll
-/// fadd_vl -> vfwadd | vfwadd_w
-/// fsub_vl -> vfwsub | vfwsub_w
-/// fmul_vl -> vfwmul
+/// fadd | fadd_vl -> vfwadd | vfwadd_w
+/// fsub | fsub_vl -> vfwsub | vfwsub_w
+/// fmul | fmul_vl -> vfwmul
/// vwadd_w(u) -> vwadd(u)
/// vwsub_w(u) -> vwsub(u)
/// vfwadd_w -> vfwadd
@@ -15921,7 +15964,14 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (SDValue V = combineBinOpOfZExt(N, DAG))
return V;
break;
- case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL:
+ return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+ case ISD::FADD: {
+ if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ return V;
+ [[fallthrough]];
+ }
case ISD::UMAX:
case ISD::UMIN:
case ISD::SMAX:
diff --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll
new file mode 100644
index 00000000000000..0d1713acfc0cd0
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFH
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfhmin,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFHMIN
+; Check that the default value enables the web folding and
+; that it is bigger than 3.
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+
+define void @vfwmul_v2f116_multiple_users(ptr %x, ptr %y, ptr %z, <vscale x 2 x half> %a, <vscale x 2 x half> %b, <vscale x 2 x half> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f116_multiple_users:
+; NO_FOLDING: # %bb.0:
+; NO_FOLDING-NEXT: vsetvli a3, zero, e16, mf2, ta, ma
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v11, v8
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; NO_FOLDING-NEXT: vfmul.vv v10, v11, v8
+; NO_FOLDING-NEXT: vfadd.vv v11, v11, v9
+; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT: vs1r.v v10, (a0)
+; NO_FOLDING-NEXT: vs1r.v v11, (a1)
+; NO_FOLDING-NEXT: vs1r.v v8, (a2)
+; NO_FOLDING-NEXT: ret
+;
+; ZVFH-LABEL: vfwmul_v2f116_multiple_users:
+; ZVFH: # %bb.0:
+; ZVFH-NEXT: vsetvli a3, zero, e16, mf2, ta, ma
+; ZVFH-NEXT: vfwmul.vv v11, v8, v9
+; ZVFH-NEXT: vfwadd.vv v12, v8, v10
+; ZVFH-NEXT: vfwsub.vv v8, v9, v10
+; ZVFH-NEXT: vs1r.v v11, (a0)
+; ZVFH-NEXT: vs1r.v v12, (a1)
+; ZVFH-NEXT: vs1r.v v8, (a2)
+; ZVFH-NEXT: ret
+;
+; ZVFHMIN-LABEL: vfwmul_v2f116_multiple_users:
+; ZVFHMIN: # %bb.0:
+; ZVFHMIN-NEXT: vsetvli a3, zero, e16, mf2, ta, ma
+; ZVFHMIN-NEXT: vfwcvt.f.f.v v11, v8
+; ZVFHMIN-NEXT: vfwcvt.f.f.v v8, v9
+; ZVFHMIN-NEXT: vfwcvt.f.f.v v9, v10
+; ZVFHMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; ZVFHMIN-NEXT: vfmul.vv v10, v11, v8
+; ZVFHMIN-NEXT: vfadd.vv v11, v11, v9
+; ZVFHMIN-NEXT: vfsub.vv v8, v8, v9
+; ZVFHMIN-NEXT: vs1r.v v10, (a0)
+; ZVFHMIN-NEXT: vs1r.v v11, (a1)
+; ZVFHMIN-NEXT: vs1r.v v8, (a2)
+; ZVFHMIN-NEXT: ret
+ %c = fpext <vscale x 2 x half> %a to <vscale x 2 x float>
+ %d = fpext <vscale x 2 x half> %b to <vscale x 2 x float>
+ %d2 = fpext <vscale x 2 x half> %b2 to <vscale x 2 x float>
+ %e = fmul <vscale x 2 x float> %c, %d
+ %f = fadd <vscale x 2 x float> %c, %d2
+ %g = fsub <vscale x 2 x float> %d, %d2
+ store <vscale x 2 x float> %e, ptr %x
+ store <vscale x 2 x float> %f, ptr %y
+ store <vscale x 2 x float> %g, ptr %z
+ ret void
+}
+
+define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <vscale x 2 x float> %a, <vscale x 2 x float> %b, <vscale x 2 x float> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; NO_FOLDING: # %bb.0:
+; NO_FOLDING-NEXT: vsetvli a3, zero, e32, m1, ta, ma
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v12, v8
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v14, v9
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v10
+; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m2, ta, ma
+; NO_FOLDING-NEXT: vfmul.vv v10, v12, v14
+; NO_FOLDING-NEXT: vfadd.vv v12, v12, v8
+; NO_FOLDING-NEXT: vfsub.vv v8, v14, v8
+; NO_FOLDING-NEXT: vs2r.v v10, (a0)
+; NO_FOLDING-NEXT: vs2r.v v12, (a1)
+; NO_FOLDING-NEXT: vs2r.v v8, (a2)
+; NO_FOLDING-NEXT: ret
+;
+; FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; FOLDING: # %bb.0:
+; FOLDING-NEXT: vsetvli a3, zero, e32, m1, ta, ma
+; FOLDING-NEXT: vfwmul.vv v12, v8, v9
+; FOLDING-NEXT: vfwadd.vv v14, v8, v10
+; FOLDING-NEXT: vfwsub.vv v16, v9, v10
+; FOLDING-NEXT: vs2r.v v12, (a0)
+; FOLDING-NEXT: vs2r.v v14, (a1)
+; FOLDING-NEXT: vs2r.v v16, (a2)
+; FOLDING-NEXT: ret
+ %c = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
+ %d = fpext <vscale x 2 x float> %b to <vscale x 2 x double>
+ %d2 = fpext <vscale x 2 x float> %b2 to <vscale x 2 x double>
+ %e = fmul <vscale x 2 x double> %c, %d
+ %f = fadd <vscale x 2 x double> %c, %d2
+ %g = fsub <vscale x 2 x double> %d, %d2
+ store <vscale x 2 x double> %e, ptr %x
+ store <vscale x 2 x double> %f, ptr %y
+ store <vscale x 2 x double> %g, ptr %z
+ ret void
+}
|
case ISD::SPLAT_VECTOR: { | ||
SDValue Op = OrigOperand.getOperand(0); | ||
if (Op.getOpcode() == ISD::FP_EXTEND) | ||
return Op; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this work for fixed length float splats currently if we didn't already handle RISCVISD::VFMV_V_F_VL
? Since in #81248 it looks like we emit vfwadd.vf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to look through scalar FP_EXTEND here, but we don't need to look throug ZERO_EXTEND or SIGN_EXTEND for integer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this work for fixed length float splats currently if we didn't already handle RISCVISD::VFMV_V_F_VL? Since in #81248 it looks like we emit vfwadd.vf
The extend on those tests are happening in the vector domain not the scalar domain I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this work for fixed length float splats currently if we didn't already handle RISCVISD::VFMV_V_F_VL? Since in #81248 it looks like we emit vfwadd.vf
The extend on those tests are happening in the vector domain not the scalar domain I think.
Yes, you are right. We also need to handle the extension in the scalar domain (i.e. similar to #87249) for RISCVISD::VFMV_V_F_VL
. Thanks for pointing out this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to look through scalar FP_EXTEND here, but we don't need to look throug ZERO_EXTEND or SIGN_EXTEND for integer?
For integers, I think the DAG.getSplat
would do an implicit truncation for us. But it seems that the float does not work.
Extend D133739, #76785 and #81248 to support combining scalable vector fadd/fsub/fmul with fp extend.
Specifically, this patch works for the below optimization case:
Source code
Before this patch
Compiler Explorer
After this patch