-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Combine (or disjoint ext, ext) -> vwadd #86929
[RISCV] Combine (or disjoint ext, ext) -> vwadd #86929
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Luke Lau (lukel97) ChangesDAGCombiner (or InstCombine) will convert an add to an or if the bits are disjoint, which can prevent what was originally an (add {s,z}ext, {s,z}ext) from being selected as a vwadd. This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as an add. Full diff: https://github.com/llvm/llvm-project/pull/86929.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 564fda674317f4..e068e9e72a26b3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13527,7 +13527,7 @@ struct CombineResult;
enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
-/// add | add_vl -> vwadd(u) | vwadd(u)_w
+/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
/// mul | mul_vl -> vwmul(u) | vwmul_su
/// fadd -> vfwadd | vfwadd_w
@@ -13675,6 +13675,7 @@ struct NodeExtensionHelper {
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ case ISD::OR:
return RISCVISD::VWADD_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
@@ -13697,6 +13698,7 @@ struct NodeExtensionHelper {
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ case ISD::OR:
return RISCVISD::VWADDU_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
@@ -13742,6 +13744,7 @@ struct NodeExtensionHelper {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
+ case ISD::OR:
return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL
: RISCVISD::VWADDU_W_VL;
case ISD::SUB:
@@ -13862,6 +13865,10 @@ struct NodeExtensionHelper {
case ISD::MUL: {
return Root->getValueType(0).isScalableVector();
}
+ case ISD::OR: {
+ return Root->getValueType(0).isScalableVector() &&
+ Root->getFlags().hasDisjoint();
+ }
// Vector Widening Integer Add/Sub/Mul Instructions
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
@@ -13942,7 +13949,8 @@ struct NodeExtensionHelper {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
- case ISD::MUL: {
+ case ISD::MUL:
+ case ISD::OR: {
SDLoc DL(Root);
MVT VT = Root->getSimpleValueType(0);
return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13965,6 +13973,7 @@ struct NodeExtensionHelper {
switch (N->getOpcode()) {
case ISD::ADD:
case ISD::MUL:
+ case ISD::OR:
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
@@ -14031,6 +14040,7 @@ struct CombineResult {
case ISD::ADD:
case ISD::SUB:
case ISD::MUL:
+ case ISD::OR:
Merge = DAG.getUNDEF(Root->getValueType(0));
break;
}
@@ -14181,6 +14191,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
+ case ISD::OR:
case RISCVISD::ADD_VL:
case RISCVISD::SUB_VL:
case RISCVISD::FADD_VL:
@@ -14224,9 +14235,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
/// Combine a binary operation to its equivalent VW or VW_W form.
/// The supported combines are:
-/// add_vl -> vwadd(u) | vwadd(u)_w
-/// sub_vl -> vwsub(u) | vwsub(u)_w
-/// mul_vl -> vwmul(u) | vwmul_su
+/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
+/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
+/// mul | mul_vl -> vwmul(u) | vwmul_su
/// fadd_vl -> vfwadd | vfwadd_w
/// fsub_vl -> vfwsub | vfwsub_w
/// fmul_vl -> vfwmul
@@ -15886,8 +15897,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
case ISD::AND:
return performANDCombine(N, DCI, Subtarget);
- case ISD::OR:
+ case ISD::OR: {
+ if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ return V;
return performORCombine(N, DCI, Subtarget);
+ }
case ISD::XOR:
return performXORCombine(N, DAG, Subtarget);
case ISD::MUL:
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 36bc10f055b84b..569d1bbbfa5f2d 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1394,18 +1394,15 @@ define <vscale x 1 x i64> @i1_zext(<vscale x 1 x i1> %va, <vscale x 1 x i64> %vb
}
; %x.i32 and %y.i32 are disjoint, so DAGCombiner will combine it into an or.
-; FIXME: We should be able to recover the or into vwaddu.vv if the disjoint
-; flag is set.
+; Check that we combine disjoint ors into vwaddu.
define <vscale x 2 x i32> @disjoint_or(<vscale x 2 x i8> %x.i8, <vscale x 2 x i8> %y.i8) {
; CHECK-LABEL: disjoint_or:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
; CHECK-NEXT: vzext.vf2 v10, v8
-; CHECK-NEXT: vsll.vi v8, v10, 8
-; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vzext.vf2 v10, v8
-; CHECK-NEXT: vzext.vf4 v8, v9
-; CHECK-NEXT: vor.vv v8, v10, v8
+; CHECK-NEXT: vsll.vi v10, v10, 8
+; CHECK-NEXT: vzext.vf2 v11, v9
+; CHECK-NEXT: vwaddu.vv v8, v10, v11
; CHECK-NEXT: ret
%x.i16 = zext <vscale x 2 x i8> %x.i8 to <vscale x 2 x i16>
%x.shl = shl <vscale x 2 x i16> %x.i16, shufflevector(<vscale x 2 x i16> insertelement(<vscale x 2 x i16> poison, i16 8, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -13675,6 +13675,7 @@ struct NodeExtensionHelper { | |||
case RISCVISD::ADD_VL: | |||
case RISCVISD::VWADD_W_VL: | |||
case RISCVISD::VWADDU_W_VL: | |||
case ISD::OR: | |||
return RISCVISD::VWADD_VL; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test for vwadd
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't think of a good way to add a test case that had the disjoint flag inferred, since in order for the bits to be disjoint the highest bit of the sexted operand needs to be cleared, which causes the sext to be combined to a zext. Using @llvm.assume didn't help either. I'll add a test case where we're just explicitly setting the disjoint flag though.
@@ -1394,18 +1394,15 @@ define <vscale x 1 x i64> @i1_zext(<vscale x 1 x i1> %va, <vscale x 1 x i64> %vb | |||
} | |||
|
|||
; %x.i32 and %y.i32 are disjoint, so DAGCombiner will combine it into an or. | |||
; FIXME: We should be able to recover the or into vwaddu.vv if the disjoint | |||
; flag is set. | |||
; Check that we combine disjoint ors into vwaddu. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add a couple tests - one for vwadd.vv and one for vwadd.wv - in terms of disjoint or in the source IR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 131be5d
Interestingly enough, if the two arms of a bitwise operation are the same DAGCombiner will pull the op through the extend, so we get narrowing for free. But this prevents the current combine from converting it to a vwadd[u].vv. This might be worth adding a pattern for.
DAGCombiner (or InstCombine) will convert an add to an or if the bits are disjoint, which can prevent what was originally an (add {s,z}ext, {s,z}ext) from being selected as a vwadd. This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as an add.
fd642dd
to
5218340
Compare
DAGCombiner (or InstCombine) will convert an add to an or if the bits are disjoint, which can prevent what was originally an (add {s,z}ext, {s,z}ext) from being selected as a vwadd.
This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as an add.