Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Handle scalable ops with < EEW / 2 narrow types in combineBinOp_VLToVWBinOp_VL #84158

Merged

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Mar 6, 2024

We can remove the restriction that the narrow type needs to be exactly EEW / 2 for scalable ISD::{ADD,SUB,MUL} nodes. This allows us to perform the combine even if we can't fully fold the extend into the widening op.

VP intrinsics already do this, since they are lowered to _VL nodes which don't have this restriction.

The "exactly EEW / 2" narrow type restriction prevented us from emitting V{S,Z}EXT_VL nodes with i1 element types which crash when we try to select them, since no other legal type is double the size of i1, see the test case added in this PR i1_zext. So to preserve this, this adds a check for i1 narrow types instead.

Stacked on #84125

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 6, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes

We can remove the restriction that the narrow type needs to be exactly EEW / 2 for scalable ISD::{ADD,SUB,MUL} nodes. This allows us to perform the combine even if we can't fully fold the extend into the widening op.

VP intrinsics already do this, since they are lowered to _VL nodes which don't have this restriction.

The "exactly EEW / 2" narrow type restriction prevented us from emitting V{S,Z}EXT_VL nodes with i1 element types which crash when we try to select them, since no other legal type is double the size of i1.

So to preserve this, this also restricts the combine to only run after the legalize vector ops phase, at which point all unselectable i1 vectors should be custom lowered away.

Stacked on #84125


Patch is 118.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84158.diff

5 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+12-24)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll (+20-18)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll (+238-224)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vwmul-sdnode.ll (+192-192)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vwsub-sdnode.ll (+160-160)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 4c3dc63afd878d..f9bfaf01b235db 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13646,20 +13646,6 @@ struct NodeExtensionHelper {
       if (!VT.isVector())
         break;
 
-      SDValue NarrowElt = OrigOperand.getOperand(0);
-      MVT NarrowVT = NarrowElt.getSimpleValueType();
-
-      unsigned ScalarBits = VT.getScalarSizeInBits();
-      unsigned NarrowScalarBits = NarrowVT.getScalarSizeInBits();
-
-      // Ensure the narrowing element type is legal
-      if (!Subtarget.getTargetLowering()->isTypeLegal(NarrowElt.getValueType()))
-        break;
-
-      // Ensure the extension's semantic is equivalent to rvv vzext or vsext.
-      if (ScalarBits != NarrowScalarBits * 2)
-        break;
-
       SupportsZExt = Opc == ISD::ZERO_EXTEND;
       SupportsSExt = Opc == ISD::SIGN_EXTEND;
 
@@ -13727,14 +13713,11 @@ struct NodeExtensionHelper {
   }
 
   /// Check if \p Root supports any extension folding combines.
-  static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
-    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  static bool isSupportedRoot(const SDNode *Root) {
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
     case ISD::MUL: {
-      if (!TLI.isTypeLegal(Root->getValueType(0)))
-        return false;
       return Root->getValueType(0).isScalableVector();
     }
     // Vector Widening Integer Add/Sub/Mul Instructions
@@ -13751,7 +13734,7 @@ struct NodeExtensionHelper {
     case RISCVISD::FMUL_VL:
     case RISCVISD::VFWADD_W_VL:
     case RISCVISD::VFWSUB_W_VL:
-      return TLI.isTypeLegal(Root->getValueType(0));
+      return true;
     default:
       return false;
     }
@@ -13760,9 +13743,10 @@ struct NodeExtensionHelper {
   /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx).
   NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG,
                       const RISCVSubtarget &Subtarget) {
-    assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an "
-                                         "unsupported root");
+    assert(isSupportedRoot(Root) && "Trying to build an helper with an "
+                                    "unsupported root");
     assert(OperandIdx < 2 && "Requesting something else than LHS or RHS");
+    assert(DAG.getTargetLoweringInfo().isTypeLegal(Root->getValueType(0)));
     OrigOperand = Root->getOperand(OperandIdx);
 
     unsigned Opc = Root->getOpcode();
@@ -13812,7 +13796,7 @@ struct NodeExtensionHelper {
   static std::pair<SDValue, SDValue>
   getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
                const RISCVSubtarget &Subtarget) {
-    assert(isSupportedRoot(Root, DAG) && "Unexpected root");
+    assert(isSupportedRoot(Root) && "Unexpected root");
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
@@ -14112,8 +14096,12 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
                                            TargetLowering::DAGCombinerInfo &DCI,
                                            const RISCVSubtarget &Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
+  // Don't perform this until types are legalized and any legal i1 types are
+  // custom lowered to avoid introducing unselectable V{S,Z}EXT_VLs.
+  if (DCI.isBeforeLegalizeOps())
+    return SDValue();
 
-  if (!NodeExtensionHelper::isSupportedRoot(N, DAG))
+  if (!NodeExtensionHelper::isSupportedRoot(N))
     return SDValue();
 
   SmallVector<SDNode *> Worklist;
@@ -14124,7 +14112,7 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
 
   while (!Worklist.empty()) {
     SDNode *Root = Worklist.pop_back_val();
-    if (!NodeExtensionHelper::isSupportedRoot(Root, DAG))
+    if (!NodeExtensionHelper::isSupportedRoot(Root))
       return SDValue();
 
     NodeExtensionHelper LHS(N, 0, DAG, Subtarget);
diff --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
index 972fa66917a568..e56dca0732bb4c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
@@ -283,18 +283,19 @@ define <vscale x 2 x i32> @vwop_vscale_sext_i8i32_multiple_users(ptr %x, ptr %y,
 ;
 ; FOLDING-LABEL: vwop_vscale_sext_i8i32_multiple_users:
 ; FOLDING:       # %bb.0:
-; FOLDING-NEXT:    vsetvli a3, zero, e32, m1, ta, ma
+; FOLDING-NEXT:    vsetvli a3, zero, e16, mf2, ta, ma
 ; FOLDING-NEXT:    vle8.v v8, (a0)
 ; FOLDING-NEXT:    vle8.v v9, (a1)
 ; FOLDING-NEXT:    vle8.v v10, (a2)
-; FOLDING-NEXT:    vsext.vf4 v11, v8
-; FOLDING-NEXT:    vsext.vf4 v8, v9
-; FOLDING-NEXT:    vsext.vf4 v9, v10
-; FOLDING-NEXT:    vmul.vv v8, v11, v8
-; FOLDING-NEXT:    vadd.vv v10, v11, v9
-; FOLDING-NEXT:    vsub.vv v9, v11, v9
-; FOLDING-NEXT:    vor.vv v8, v8, v10
-; FOLDING-NEXT:    vor.vv v8, v8, v9
+; FOLDING-NEXT:    vsext.vf2 v11, v8
+; FOLDING-NEXT:    vsext.vf2 v8, v9
+; FOLDING-NEXT:    vsext.vf2 v9, v10
+; FOLDING-NEXT:    vwmul.vv v10, v11, v8
+; FOLDING-NEXT:    vwadd.vv v8, v11, v9
+; FOLDING-NEXT:    vwsub.vv v12, v11, v9
+; FOLDING-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; FOLDING-NEXT:    vor.vv v8, v10, v8
+; FOLDING-NEXT:    vor.vv v8, v8, v12
 ; FOLDING-NEXT:    ret
   %a = load <vscale x 2 x i8>, ptr %x
   %b = load <vscale x 2 x i8>, ptr %y
@@ -563,18 +564,19 @@ define <vscale x 2 x i32> @vwop_vscale_zext_i8i32_multiple_users(ptr %x, ptr %y,
 ;
 ; FOLDING-LABEL: vwop_vscale_zext_i8i32_multiple_users:
 ; FOLDING:       # %bb.0:
-; FOLDING-NEXT:    vsetvli a3, zero, e32, m1, ta, ma
+; FOLDING-NEXT:    vsetvli a3, zero, e16, mf2, ta, ma
 ; FOLDING-NEXT:    vle8.v v8, (a0)
 ; FOLDING-NEXT:    vle8.v v9, (a1)
 ; FOLDING-NEXT:    vle8.v v10, (a2)
-; FOLDING-NEXT:    vzext.vf4 v11, v8
-; FOLDING-NEXT:    vzext.vf4 v8, v9
-; FOLDING-NEXT:    vzext.vf4 v9, v10
-; FOLDING-NEXT:    vmul.vv v8, v11, v8
-; FOLDING-NEXT:    vadd.vv v10, v11, v9
-; FOLDING-NEXT:    vsub.vv v9, v11, v9
-; FOLDING-NEXT:    vor.vv v8, v8, v10
-; FOLDING-NEXT:    vor.vv v8, v8, v9
+; FOLDING-NEXT:    vzext.vf2 v11, v8
+; FOLDING-NEXT:    vzext.vf2 v8, v9
+; FOLDING-NEXT:    vzext.vf2 v9, v10
+; FOLDING-NEXT:    vwmulu.vv v10, v11, v8
+; FOLDING-NEXT:    vwaddu.vv v8, v11, v9
+; FOLDING-NEXT:    vwsubu.vv v12, v11, v9
+; FOLDING-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; FOLDING-NEXT:    vor.vv v8, v10, v8
+; FOLDING-NEXT:    vor.vv v8, v8, v12
 ; FOLDING-NEXT:    ret
   %a = load <vscale x 2 x i8>, ptr %x
   %b = load <vscale x 2 x i8>, ptr %y
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index a559fbf2bc8a7a..4152e61c0541ae 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -421,10 +421,10 @@ define <vscale x 8 x i64> @vwaddu_wx_nxv8i64_nxv8i32(<vscale x 8 x i64> %va, i32
 define <vscale x 1 x i64> @vwadd_vv_nxv1i64_nxv1i16(<vscale x 1 x i16> %va, <vscale x 1 x i16> %vb) {
 ; CHECK-LABEL: vwadd_vv_nxv1i64_nxv1i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
-; CHECK-NEXT:    vsext.vf4 v10, v8
-; CHECK-NEXT:    vsext.vf4 v8, v9
-; CHECK-NEXT:    vadd.vv v8, v10, v8
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v10, v8
+; CHECK-NEXT:    vsext.vf2 v11, v9
+; CHECK-NEXT:    vwadd.vv v8, v10, v11
 ; CHECK-NEXT:    ret
   %vc = sext <vscale x 1 x i16> %va to <vscale x 1 x i64>
   %vd = sext <vscale x 1 x i16> %vb to <vscale x 1 x i64>
@@ -435,10 +435,10 @@ define <vscale x 1 x i64> @vwadd_vv_nxv1i64_nxv1i16(<vscale x 1 x i16> %va, <vsc
 define <vscale x 1 x i64> @vwaddu_vv_nxv1i64_nxv1i16(<vscale x 1 x i16> %va, <vscale x 1 x i16> %vb) {
 ; CHECK-LABEL: vwaddu_vv_nxv1i64_nxv1i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v8, v9
-; CHECK-NEXT:    vadd.vv v8, v10, v8
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vzext.vf2 v10, v8
+; CHECK-NEXT:    vzext.vf2 v11, v9
+; CHECK-NEXT:    vwaddu.vv v8, v10, v11
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 1 x i16> %va to <vscale x 1 x i64>
   %vd = zext <vscale x 1 x i16> %vb to <vscale x 1 x i64>
@@ -451,10 +451,10 @@ define <vscale x 1 x i64> @vwadd_vx_nxv1i64_nxv1i16(<vscale x 1 x i16> %va, i16
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, mf4, ta, ma
 ; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
-; CHECK-NEXT:    vsext.vf4 v10, v8
-; CHECK-NEXT:    vsext.vf4 v8, v9
-; CHECK-NEXT:    vadd.vv v8, v10, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v10, v8
+; CHECK-NEXT:    vsext.vf2 v11, v9
+; CHECK-NEXT:    vwadd.vv v8, v10, v11
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 1 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 1 x i16> %head, <vscale x 1 x i16> poison, <vscale x 1 x i32> zeroinitializer
@@ -469,10 +469,10 @@ define <vscale x 1 x i64> @vwaddu_vx_nxv1i64_nxv1i16(<vscale x 1 x i16> %va, i16
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, mf4, ta, ma
 ; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v8, v9
-; CHECK-NEXT:    vadd.vv v8, v10, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vzext.vf2 v10, v8
+; CHECK-NEXT:    vzext.vf2 v11, v9
+; CHECK-NEXT:    vwaddu.vv v8, v10, v11
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 1 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 1 x i16> %head, <vscale x 1 x i16> poison, <vscale x 1 x i32> zeroinitializer
@@ -485,9 +485,9 @@ define <vscale x 1 x i64> @vwaddu_vx_nxv1i64_nxv1i16(<vscale x 1 x i16> %va, i16
 define <vscale x 1 x i64> @vwadd_wv_nxv1i64_nxv1i16(<vscale x 1 x i64> %va, <vscale x 1 x i16> %vb) {
 ; CHECK-LABEL: vwadd_wv_nxv1i64_nxv1i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
-; CHECK-NEXT:    vsext.vf4 v10, v9
-; CHECK-NEXT:    vadd.vv v8, v8, v10
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v10, v9
+; CHECK-NEXT:    vwadd.wv v8, v8, v10
 ; CHECK-NEXT:    ret
   %vc = sext <vscale x 1 x i16> %vb to <vscale x 1 x i64>
   %vd = add <vscale x 1 x i64> %va, %vc
@@ -497,9 +497,9 @@ define <vscale x 1 x i64> @vwadd_wv_nxv1i64_nxv1i16(<vscale x 1 x i64> %va, <vsc
 define <vscale x 1 x i64> @vwaddu_wv_nxv1i64_nxv1i16(<vscale x 1 x i64> %va, <vscale x 1 x i16> %vb) {
 ; CHECK-LABEL: vwaddu_wv_nxv1i64_nxv1i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
-; CHECK-NEXT:    vzext.vf4 v10, v9
-; CHECK-NEXT:    vadd.vv v8, v8, v10
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vzext.vf2 v10, v9
+; CHECK-NEXT:    vwaddu.wv v8, v8, v10
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 1 x i16> %vb to <vscale x 1 x i64>
   %vd = add <vscale x 1 x i64> %va, %vc
@@ -511,9 +511,9 @@ define <vscale x 1 x i64> @vwadd_wx_nxv1i64_nxv1i16(<vscale x 1 x i64> %va, i16
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, mf4, ta, ma
 ; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
-; CHECK-NEXT:    vsext.vf4 v10, v9
-; CHECK-NEXT:    vadd.vv v8, v8, v10
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v10, v9
+; CHECK-NEXT:    vwadd.wv v8, v8, v10
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 1 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 1 x i16> %head, <vscale x 1 x i16> poison, <vscale x 1 x i32> zeroinitializer
@@ -527,9 +527,9 @@ define <vscale x 1 x i64> @vwaddu_wx_nxv1i64_nxv1i16(<vscale x 1 x i64> %va, i16
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, mf4, ta, ma
 ; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
-; CHECK-NEXT:    vzext.vf4 v10, v9
-; CHECK-NEXT:    vadd.vv v8, v8, v10
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vzext.vf2 v10, v9
+; CHECK-NEXT:    vwaddu.wv v8, v8, v10
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 1 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 1 x i16> %head, <vscale x 1 x i16> poison, <vscale x 1 x i32> zeroinitializer
@@ -541,10 +541,10 @@ define <vscale x 1 x i64> @vwaddu_wx_nxv1i64_nxv1i16(<vscale x 1 x i64> %va, i16
 define <vscale x 2 x i64> @vwadd_vv_nxv2i64_nxv2i16(<vscale x 2 x i16> %va, <vscale x 2 x i16> %vb) {
 ; CHECK-LABEL: vwadd_vv_nxv2i64_nxv2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
-; CHECK-NEXT:    vsext.vf4 v10, v8
-; CHECK-NEXT:    vsext.vf4 v12, v9
-; CHECK-NEXT:    vadd.vv v8, v10, v12
+; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v10, v8
+; CHECK-NEXT:    vsext.vf2 v11, v9
+; CHECK-NEXT:    vwadd.vv v8, v10, v11
 ; CHECK-NEXT:    ret
   %vc = sext <vscale x 2 x i16> %va to <vscale x 2 x i64>
   %vd = sext <vscale x 2 x i16> %vb to <vscale x 2 x i64>
@@ -555,10 +555,10 @@ define <vscale x 2 x i64> @vwadd_vv_nxv2i64_nxv2i16(<vscale x 2 x i16> %va, <vsc
 define <vscale x 2 x i64> @vwaddu_vv_nxv2i64_nxv2i16(<vscale x 2 x i16> %va, <vscale x 2 x i16> %vb) {
 ; CHECK-LABEL: vwaddu_vv_nxv2i64_nxv2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v12, v9
-; CHECK-NEXT:    vadd.vv v8, v10, v12
+; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v10, v8
+; CHECK-NEXT:    vzext.vf2 v11, v9
+; CHECK-NEXT:    vwaddu.vv v8, v10, v11
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 2 x i16> %va to <vscale x 2 x i64>
   %vd = zext <vscale x 2 x i16> %vb to <vscale x 2 x i64>
@@ -571,10 +571,10 @@ define <vscale x 2 x i64> @vwadd_vx_nxv2i64_nxv2i16(<vscale x 2 x i16> %va, i16
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, ma
 ; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
-; CHECK-NEXT:    vsext.vf4 v10, v8
-; CHECK-NEXT:    vsext.vf4 v12, v9
-; CHECK-NEXT:    vadd.vv v8, v10, v12
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v10, v8
+; CHECK-NEXT:    vsext.vf2 v11, v9
+; CHECK-NEXT:    vwadd.vv v8, v10, v11
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 2 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 2 x i16> %head, <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer
@@ -589,10 +589,10 @@ define <vscale x 2 x i64> @vwaddu_vx_nxv2i64_nxv2i16(<vscale x 2 x i16> %va, i16
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, ma
 ; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v12, v9
-; CHECK-NEXT:    vadd.vv v8, v10, v12
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v10, v8
+; CHECK-NEXT:    vzext.vf2 v11, v9
+; CHECK-NEXT:    vwaddu.vv v8, v10, v11
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 2 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 2 x i16> %head, <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer
@@ -605,9 +605,9 @@ define <vscale x 2 x i64> @vwaddu_vx_nxv2i64_nxv2i16(<vscale x 2 x i16> %va, i16
 define <vscale x 2 x i64> @vwadd_wv_nxv2i64_nxv2i16(<vscale x 2 x i64> %va, <vscale x 2 x i16> %vb) {
 ; CHECK-LABEL: vwadd_wv_nxv2i64_nxv2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
-; CHECK-NEXT:    vsext.vf4 v12, v10
-; CHECK-NEXT:    vadd.vv v8, v8, v12
+; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v11, v10
+; CHECK-NEXT:    vwadd.wv v8, v8, v11
 ; CHECK-NEXT:    ret
   %vc = sext <vscale x 2 x i16> %vb to <vscale x 2 x i64>
   %vd = add <vscale x 2 x i64> %va, %vc
@@ -617,9 +617,9 @@ define <vscale x 2 x i64> @vwadd_wv_nxv2i64_nxv2i16(<vscale x 2 x i64> %va, <vsc
 define <vscale x 2 x i64> @vwaddu_wv_nxv2i64_nxv2i16(<vscale x 2 x i64> %va, <vscale x 2 x i16> %vb) {
 ; CHECK-LABEL: vwaddu_wv_nxv2i64_nxv2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
-; CHECK-NEXT:    vzext.vf4 v12, v10
-; CHECK-NEXT:    vadd.vv v8, v8, v12
+; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v11, v10
+; CHECK-NEXT:    vwaddu.wv v8, v8, v11
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 2 x i16> %vb to <vscale x 2 x i64>
   %vd = add <vscale x 2 x i64> %va, %vc
@@ -631,9 +631,9 @@ define <vscale x 2 x i64> @vwadd_wx_nxv2i64_nxv2i16(<vscale x 2 x i64> %va, i16
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, ma
 ; CHECK-NEXT:    vmv.v.x v10, a0
-; CHECK-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
-; CHECK-NEXT:    vsext.vf4 v12, v10
-; CHECK-NEXT:    vadd.vv v8, v8, v12
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v11, v10
+; CHECK-NEXT:    vwadd.wv v8, v8, v11
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 2 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 2 x i16> %head, <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer
@@ -647,9 +647,9 @@ define <vscale x 2 x i64> @vwaddu_wx_nxv2i64_nxv2i16(<vscale x 2 x i64> %va, i16
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, ma
 ; CHECK-NEXT:    vmv.v.x v10, a0
-; CHECK-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
-; CHECK-NEXT:    vzext.vf4 v12, v10
-; CHECK-NEXT:    vadd.vv v8, v8, v12
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v11, v10
+; CHECK-NEXT:    vwaddu.wv v8, v8, v11
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 2 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 2 x i16> %head, <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer
@@ -661,10 +661,10 @@ define <vscale x 2 x i64> @vwaddu_wx_nxv2i64_nxv2i16(<vscale x 2 x i64> %va, i16
 define <vscale x 4 x i64> @vwadd_vv_nxv4i64_nxv4i16(<vscale x 4 x i16> %va, <vscale x 4 x i16> %vb) {
 ; CHECK-LABEL: vwadd_vv_nxv4i64_nxv4i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e64, m4, ta, ma
-; CHECK-NEXT:    vsext.vf4 v12, v8
-; CHECK-NEXT:    vsext.vf4 v16, v9
-; CHECK-NEXT:    vadd.vv v8, v12, v16
+; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v12, v8
+; CHECK-NEXT:    vsext.vf2 v14, v9
+; CHECK-NEXT:    vwadd.vv v8, v12, v14
 ; CHECK-NEXT:    ret
   %vc = sext <vscale x 4 x i16> %va to <vscale x 4 x i64>
   %vd = sext <vscale x 4 x i16> %vb to <vscale x 4 x i64>
@@ -675,10 +675,10 @@ define <vscale x 4 x i64> @vwadd_vv_nxv4i64_nxv4i16(<vscale x 4 x i16> %va, <vsc
 define <vscale x 4 x i64> @vwaddu_vv_nxv4i64_nxv4i16(<vscale x 4 x i16> %va, <vscale x 4 x i16> %vb) {
 ; CHECK-LABEL: vwaddu_vv_nxv4i64_nxv4i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e64, m4, ta, ma
-; CHECK-NEXT:    vzext.vf4 v12, v8
-; CHECK-NEXT:    vzext.vf4 v16, v9
-; CHECK-NEXT:    vadd.vv v8, v12, v16
+; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; CHECK-NEXT:    vzext.vf2 v12, v8
+; CHECK-NEXT:    vzext.vf2 v14, v9
+; CHECK-NEXT:    vwaddu.vv v8, v12, v14
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 4 x i16> %va to <vscale x 4 x i64>
   %vd = zext <vscale x 4 x i16> %vb to <vscale x 4 x i64>
@@ -691,10 +691,10 @@ define <vscale x 4 x i64> @vwadd_vx_nxv4i64_nxv4i16(<vscale x 4 x i16> %va, i16
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, m1, ta, ma
 ; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e64, m4, ta, ma
-; CHECK-NEXT:    vsext.vf4 v12, v8
-; CHECK-NEXT:    vsext.vf4 v16, v9
-; CHECK-NEXT:    vadd.vv v8, v12, v16
+; CHECK-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v12, v8
+; CHECK-NEXT:    vsext.vf2 v14, v9
+; CHECK-NEXT:    vwadd.vv v8, v12, v14
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 4 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 4 x i16> %head, <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
@@ -709,10 +709,10 @@ define <vscale x 4 x i64> @vwaddu_vx_...
[truncated]

break;

// Ensure the extension's semantic is equivalent to rvv vzext or vsext.
if (ScalarBits != NarrowScalarBits * 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If after the prior change which moves this transform after legalize types, the only case which needs this restriction to keep the transform between legalize types and legalize ops is the i1 vector case, why not simply check if the narrow vt is a i1 vector here? Wouldn't that be less disruptive than moving the combine after legalize ops?

Note that you should also be asserting that both narrow and wide are legal types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If after the prior change which moves this transform after legalize types, the only case which needs this restriction to keep the transform between legalize types and legalize ops is the i1 vector case, why not simply check if the narrow vt is a i1 vector here?

I moved it to after the legalize vector ops phase since we weren't checking for i1 vectors in any of the other _VL nodes. So I think there was already an implicit invariant here that the combine would only run after legalize ops, and it seemed safer to just be explicit about it.

Wouldn't that be less disruptive than moving the combine after legalize ops?

Since the combine was already happening after legalize ops for the _VL nodes, this should only affect the ISD::ADD/SUB/MUL nodes that were added in #76785

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that you should also be asserting that both narrow and wide are legal types.

I've moved the narrow type assert in 0ef61ed so that we now check the narrow type for all extend node types, and we have an assert that the wide type is legal here:

assert(DAG.getTargetLoweringInfo().isTypeLegal(Root->getValueType(0)));

…BinOp_VL

We can remove the restriction that the narrow type needs to be exactly EEW
/ 2 for scalable ISD::{ADD,SUB,MUL} nodes. This allows us to perform the
combine even if we can't fully fold the extend into the widening op.

VP intrinsics already do this, since they are lowered to _VL nodes which
don't have this restriction.

The "exactly EEW / 2" narrow type restriction prevented us from emitting
V{S,Z}EXT_VL nodes with i1 element types which crash when we try to select
them, since no other legal type is double the size of i1.

So to preserve this, this also restricts the combine to only run after the
legalize vector ops phase, at which point all unselectable i1 vectors
should be custom lowered away.
@lukel97 lukel97 force-pushed the combineBinOp_VLToVWBinOp_VL-removeRestriction branch from 51e9007 to 74ea8fe Compare March 11, 2024 10:04
Copy link
Member

@sun-jacobi sun-jacobi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for providing this elegant way. LGTM

llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll Outdated Show resolved Hide resolved
if (DCI.isBeforeLegalize())
// Don't perform this until types are legalized and any legal i1 types are
// custom lowered to avoid introducing unselectable V{S,Z}EXT_VLs.
if (DCI.isBeforeLegalizeOps())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is 100% reliable. Its theoretically possible for an i1 vector to be created by the DAG combiner after legalize ops. The last DAG combine stage also runs the legalizer on every node as part of its worklist. So its not illegal for an i1 zext to created as it would get legalized before isel.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it is better for us to still check whether the op is legal.
But as @lukel97 said, the original EEW / 2 check, which the VP intrinsics already does, could be removed, AFAIU.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the PR to instead check that the narrow element type isn't i1 across the different possible extend ops

unsigned Opc = OrigOperand.getOpcode();
switch (Opc) {
case ISD::ZERO_EXTEND:
case ISD::SIGN_EXTEND: {
NarrowOp = OrigOperand.getOperand(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't se just check for i1 in this case? None of the other cases can have an i1 narrow op past type legalization.

Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

To be particularly clear since this patch has evolved quite a bit - my LGTM applies specifically to the version of code currently on review. Any change beyond a simple rebase invalidates this LGTM and you should explicitly ask for new review.

@lukel97 lukel97 merged commit 51d5b65 into llvm:main Mar 21, 2024
4 checks passed
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
…Op_VLToVWBinOp_VL (llvm#84158)

We can remove the restriction that the narrow type needs to be exactly
EEW / 2 for scalable ISD::{ADD,SUB,MUL} nodes. This allows us to perform
the combine even if we can't fully fold the extend into the widening op.

VP intrinsics already do this, since they are lowered to _VL nodes which
don't have this restriction.

The "exactly EEW / 2" narrow type restriction prevented us from emitting
V{S,Z}EXT_VL nodes with i1 element types which crash when we try to
select them, since no other legal type is double the size of i1, see the
test case added in this PR `i1_zext`. So to preserve this, this adds a
check for i1 narrow types instead.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants