Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Combine (or disjoint ext, ext) -> vwadd #86929

Merged
merged 1 commit into from
Mar 29, 2024

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Mar 28, 2024

DAGCombiner (or InstCombine) will convert an add to an or if the bits are disjoint, which can prevent what was originally an (add {s,z}ext, {s,z}ext) from being selected as a vwadd.

This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as an add.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 28, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes

DAGCombiner (or InstCombine) will convert an add to an or if the bits are disjoint, which can prevent what was originally an (add {s,z}ext, {s,z}ext) from being selected as a vwadd.

This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as an add.


Full diff: https://github.com/llvm/llvm-project/pull/86929.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+20-6)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll (+4-7)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 564fda674317f4..e068e9e72a26b3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13527,7 +13527,7 @@ struct CombineResult;
 enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
-/// add | add_vl -> vwadd(u) | vwadd(u)_w
+/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
 /// sub | sub_vl -> vwsub(u) | vwsub(u)_w
 /// mul | mul_vl -> vwmul(u) | vwmul_su
 /// fadd -> vfwadd | vfwadd_w
@@ -13675,6 +13675,7 @@ struct NodeExtensionHelper {
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
+    case ISD::OR:
       return RISCVISD::VWADD_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
@@ -13697,6 +13698,7 @@ struct NodeExtensionHelper {
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
+    case ISD::OR:
       return RISCVISD::VWADDU_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
@@ -13742,6 +13744,7 @@ struct NodeExtensionHelper {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
+    case ISD::OR:
       return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL
                                           : RISCVISD::VWADDU_W_VL;
     case ISD::SUB:
@@ -13862,6 +13865,10 @@ struct NodeExtensionHelper {
     case ISD::MUL: {
       return Root->getValueType(0).isScalableVector();
     }
+    case ISD::OR: {
+      return Root->getValueType(0).isScalableVector() &&
+             Root->getFlags().hasDisjoint();
+    }
     // Vector Widening Integer Add/Sub/Mul Instructions
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
@@ -13942,7 +13949,8 @@ struct NodeExtensionHelper {
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
-    case ISD::MUL: {
+    case ISD::MUL:
+    case ISD::OR: {
       SDLoc DL(Root);
       MVT VT = Root->getSimpleValueType(0);
       return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13965,6 +13973,7 @@ struct NodeExtensionHelper {
     switch (N->getOpcode()) {
     case ISD::ADD:
     case ISD::MUL:
+    case ISD::OR:
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
@@ -14031,6 +14040,7 @@ struct CombineResult {
     case ISD::ADD:
     case ISD::SUB:
     case ISD::MUL:
+    case ISD::OR:
       Merge = DAG.getUNDEF(Root->getValueType(0));
       break;
     }
@@ -14181,6 +14191,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   switch (Root->getOpcode()) {
   case ISD::ADD:
   case ISD::SUB:
+  case ISD::OR:
   case RISCVISD::ADD_VL:
   case RISCVISD::SUB_VL:
   case RISCVISD::FADD_VL:
@@ -14224,9 +14235,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
 
 /// Combine a binary operation to its equivalent VW or VW_W form.
 /// The supported combines are:
-/// add_vl -> vwadd(u) | vwadd(u)_w
-/// sub_vl -> vwsub(u) | vwsub(u)_w
-/// mul_vl -> vwmul(u) | vwmul_su
+/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
+/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
+/// mul | mul_vl -> vwmul(u) | vwmul_su
 /// fadd_vl ->  vfwadd | vfwadd_w
 /// fsub_vl ->  vfwsub | vfwsub_w
 /// fmul_vl ->  vfwmul
@@ -15886,8 +15897,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   }
   case ISD::AND:
     return performANDCombine(N, DCI, Subtarget);
-  case ISD::OR:
+  case ISD::OR: {
+    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+      return V;
     return performORCombine(N, DCI, Subtarget);
+  }
   case ISD::XOR:
     return performXORCombine(N, DAG, Subtarget);
   case ISD::MUL:
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 36bc10f055b84b..569d1bbbfa5f2d 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1394,18 +1394,15 @@ define <vscale x 1 x i64> @i1_zext(<vscale x 1 x i1> %va, <vscale x 1 x i64> %vb
 }
 
 ; %x.i32 and %y.i32 are disjoint, so DAGCombiner will combine it into an or.
-; FIXME: We should be able to recover the or into vwaddu.vv if the disjoint
-; flag is set.
+; Check that we combine disjoint ors into vwaddu.
 define <vscale x 2 x i32> @disjoint_or(<vscale x 2 x i8> %x.i8, <vscale x 2 x i8> %y.i8) {
 ; CHECK-LABEL: disjoint_or:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
 ; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vsll.vi v8, v10, 8
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf4 v8, v9
-; CHECK-NEXT:    vor.vv v8, v10, v8
+; CHECK-NEXT:    vsll.vi v10, v10, 8
+; CHECK-NEXT:    vzext.vf2 v11, v9
+; CHECK-NEXT:    vwaddu.vv v8, v10, v11
 ; CHECK-NEXT:    ret
   %x.i16 = zext <vscale x 2 x i8> %x.i8 to <vscale x 2 x i16>
   %x.shl = shl <vscale x 2 x i16> %x.i16, shufflevector(<vscale x 2 x i16> insertelement(<vscale x 2 x i16> poison, i16 8, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer)

Copy link
Member

@sun-jacobi sun-jacobi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -13675,6 +13675,7 @@ struct NodeExtensionHelper {
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
case ISD::OR:
return RISCVISD::VWADD_VL;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for vwadd?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't think of a good way to add a test case that had the disjoint flag inferred, since in order for the bits to be disjoint the highest bit of the sexted operand needs to be cleared, which causes the sext to be combined to a zext. Using @llvm.assume didn't help either. I'll add a test case where we're just explicitly setting the disjoint flag though.

@@ -1394,18 +1394,15 @@ define <vscale x 1 x i64> @i1_zext(<vscale x 1 x i1> %va, <vscale x 1 x i64> %vb
}

; %x.i32 and %y.i32 are disjoint, so DAGCombiner will combine it into an or.
; FIXME: We should be able to recover the or into vwaddu.vv if the disjoint
; flag is set.
; Check that we combine disjoint ors into vwaddu.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a couple tests - one for vwadd.vv and one for vwadd.wv - in terms of disjoint or in the source IR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 131be5d

Interestingly enough, if the two arms of a bitwise operation are the same DAGCombiner will pull the op through the extend, so we get narrowing for free. But this prevents the current combine from converting it to a vwadd[u].vv. This might be worth adding a pattern for.

DAGCombiner (or InstCombine) will convert an add to an or if the bits are disjoint, which can prevent what was originally an (add {s,z}ext, {s,z}ext) from being selected as a vwadd.

This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as an add.
@lukel97 lukel97 force-pushed the combineBinOp_VLToVWBinOp_VL-disjoint-or branch from fd642dd to 5218340 Compare March 29, 2024 11:12
@lukel97 lukel97 merged commit 2a315d8 into llvm:main Mar 29, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants