Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Vector sub (zext, zext) -> sext (sub (zext, zext)) #82455

Merged
merged 3 commits into from
Feb 23, 2024

Conversation

preames
Copy link
Collaborator

@preames preames commented Feb 21, 2024

This is legal as long as the inner zext retains at least one bit of increase so that the sub overflow case (0 - UINT_MAX) can be represented. Alive2 proof: https://alive2.llvm.org/ce/z/BKeV3W

For RVV, restrict this to power of two sizes with the operation type being at least e8 to stick to legal extends. We could arguably handle i1 source types with some care if we wanted to.

This is likely profitable because it may allow us to perform the sub instruction in a narrow LMUL (equivalently, in fewer DLEN-sized pieces) before widening for the user. We could arguably avoid narrowing below DLEN, but the transform should at worst introduce one extra extend and one extra vsetvli toggle if the source could previously be handled via loads explicit w/EEW.

This is legal as long as the inner zext retains at least one bit
of increase so that the sub overflow case (0 - UINT_MAX) can be
represented.  Alive2 proof: https://alive2.llvm.org/ce/z/BKeV3W

For RVV, restrict this to power of two sizes with the operation
type being at least e8 to stick to legal extends.  We could
arguably handle i1 source types with some care if we wanted to.

This is likely profitable because it may allow us to perform the sub
instruction in a narrow LMUL (equivalently, in fewer DLEN-sized
pieces)  before widening for the user.  We could arguably avoid
narrowing below DLEN, but the transform should at worst introduce
one extra extend and one extra vsetvli toggle if the source
could previously be handled via loads explicit w/EEW.
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 21, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

Changes

This is legal as long as the inner zext retains at least one bit of increase so that the sub overflow case (0 - UINT_MAX) can be represented. Alive2 proof: https://alive2.llvm.org/ce/z/BKeV3W

For RVV, restrict this to power of two sizes with the operation type being at least e8 to stick to legal extends. We could arguably handle i1 source types with some care if we wanted to.

This is likely profitable because it may allow us to perform the sub instruction in a narrow LMUL (equivalently, in fewer DLEN-sized pieces) before widening for the user. We could arguably avoid narrowing below DLEN, but the transform should at worst introduce one extra extend and one extra vsetvli toggle if the source could previously be handled via loads explicit w/EEW.


Full diff: https://github.com/llvm/llvm-project/pull/82455.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+23-1)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll (+16-16)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 874c851cd9147a..64e06e2648dc23 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12846,6 +12846,7 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
   if (SDValue V = combineSubOfBoolean(N, DAG))
     return V;
 
+  EVT VT = N->getValueType(0);
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   // fold (sub 0, (setcc x, 0, setlt)) -> (sra x, xlen - 1)
@@ -12853,7 +12854,6 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
       isNullConstant(N1.getOperand(1))) {
     ISD::CondCode CCVal = cast<CondCodeSDNode>(N1.getOperand(2))->get();
     if (CCVal == ISD::SETLT) {
-      EVT VT = N->getValueType(0);
       SDLoc DL(N);
       unsigned ShAmt = N0.getValueSizeInBits() - 1;
       return DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0),
@@ -12861,6 +12861,28 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  // sub (zext, zext) -> sext (sub (zext, zext))
+  //   where the sum of the extend widths match, and the inner zexts
+  //   add at least one bit.  (For profitability on rvv, we use a
+  //   power of two for both inner and outer extend.)
+  if (VT.isVector() && N0.getOpcode() == N1.getOpcode() && N0.hasOneUse() &&
+      N1.hasOneUse() && N0.getOpcode() == ISD::ZERO_EXTEND) {
+    SDValue Src0 = N0.getOperand(0);
+    SDValue Src1 = N1.getOperand(0);
+    EVT SrcVT = Src0.getValueType();
+    if (SrcVT == Src1.getValueType() &&
+        SrcVT.getScalarSizeInBits() < VT.getScalarSizeInBits() / 2 &&
+        SrcVT.getScalarSizeInBits() >= 8) {
+      LLVMContext &C = *DAG.getContext();
+      EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
+      EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
+      Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
+      Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
+      return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT,
+                         DAG.getNode(ISD::SUB, SDLoc(N), NarrowVT, Src0, Src1));
+    }
+  }
+
   // fold (sub x, (select lhs, rhs, cc, 0, y)) ->
   //      (select lhs, rhs, cc, x, (sub x, y))
   return combineSelectAndUse(N, N1, N0, DAG, /*AllOnes*/ false, Subtarget);
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
index 574c2652ccfacd..a084b5383b4030 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
@@ -385,12 +385,12 @@ define <32 x i64> @vwsubu_v32i64(ptr %x, ptr %y) nounwind {
 define <2 x i32> @vwsubu_v2i32_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i32_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -899,12 +899,12 @@ define <2 x i64> @vwsubu_vx_v2i64_i64(ptr %x, ptr %y) nounwind {
 define <2 x i32> @vwsubu_v2i32_of_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i32_of_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -917,12 +917,12 @@ define <2 x i32> @vwsubu_v2i32_of_v2i8(ptr %x, ptr %y) {
 define <2 x i64> @vwsubu_v2i64_of_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i64_of_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vsext.vf4 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -935,12 +935,12 @@ define <2 x i64> @vwsubu_v2i64_of_v2i8(ptr %x, ptr %y) {
 define <2 x i64> @vwsubu_v2i64_of_v2i16(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i64_of_v2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
 ; CHECK-NEXT:    vle16.v v8, (a0)
 ; CHECK-NEXT:    vle16.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i16>, ptr %x
   %b = load <2 x i16>, ptr %y

// sub (zext, zext) -> sext (sub (zext, zext))
// where the sum of the extend widths match, and the inner zexts
// add at least one bit. (For profitability on rvv, we use a
// power of two for both inner and outer extend.)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how are we guaranteeing power of 2 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a case of a comment being out of sync with my mental model. Will fix, but let me explain what I'm thinking and you can tell me if this is sane or not.

I was originally intending to have a isTypeLegal check on both VT and SrcVT. That combined with the srcvt > 8 check should ensure that all of the types are e8, e16, e32, or e64.

Then I started thinking about illegal types. I think they fall into two camps - reasonable ones such as i128, and odd ones such as i34. For the former, narrowing before legalization (splitting, I think?) seems likely profitable. For the later, we might end up with an e.g. e17 intermediate type, but that'll get promoted to i32 and i64 respectively. So, reasonable overall result? (Though, I now notice there's an edge case here with e.g. i33 not having a half sized type.)

What do you think, should I fix the edge case and allow illegal types? Require legal types? Something else?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getHalfSizedIntegerVT rounds up odd types so the new type will cover at least half so it will return i17 for i33.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the result type is i128, any operation with the i128 element type as either source or dest will get split repeated until it can be scalarized, then the resulting scalar ops with illegal scalar types will get further legalized to XLen. CodeGen will be so bad I'm not sure its worth optimizing.

For other illegal types, they should get promoted to the next power of 2. After that your combine would run again have another chance at it. So it might be fine to check for legal types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the legality checks.

For my understanding, doesn't DAG combine run before type legalize (as well as after)? Given that, wouldn't narrowing a i128 add to i64 mean that only the sext would be legalized?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, doesn't DAG combine run before type legalize (as well as after)? Given that, wouldn't narrowing a i128 add to i64 mean that only the sext would be legalized?

Yes. We'd only scalarize the sext. My thinking was that if have to generate slidedowns and extracts to scalarlize some part of the calculation, it didn't make much sense to use vectors in the first place. But I guess it depends on how many elements and how much computation is before the sub(zext, zext).

@wangpc-pp wangpc-pp requested review from wangpc-pp and removed request for pcwang-thead February 22, 2024 02:54
if (VT.isVector() && N0.getOpcode() == N1.getOpcode() && N0.hasOneUse() &&
N1.hasOneUse() && N0.getOpcode() == ISD::ZERO_EXTEND) {
if (VT.isVector() && Subtarget.getTargetLowering()->isTypeLegal(VT) &&
N0.getOpcode() == N1.getOpcode() && N0.hasOneUse() && N1.hasOneUse() &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the opcode before checking the use count. hasOneUse is expensive for nodes that produce multiple results. It can't stop at the second use. It has to check all uses of all results. This is can get especially bad with the chain output of loads. So we've found that is best to make sure the opcode is one that only produces one result first.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the opcode before checking the use count. hasOneUse is expensive for nodes that produce multiple results. It can't stop at the second use. It has to check all uses of all results. This is can get especially bad with the chain output of loads. So we've found that is best to make sure the opcode is one that only produces one result first.

Your reasoning for why this is expensive doesn't make sense to me, but sure, will revise check order. I'll go read the code in hasOneUse to understand the issue you're mentioning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list of uses for all results is stored in one linked list that is not sorted. SDValue::hasOneUse calls SDNode::hasNUsesOfValue. hasNUseOfValue will walk the list until it finds more than N uses of the result or it reaches the end of the list. If there is more than 1 result there may be many uses of the other results in the list. So the list length can be more than 1 even when the result we're looking for only has a single use. We would have to scan through all those other uses before the loop can terminate.

If the node only has 1 result, the list will only contain 1 use when that result has one use. If a second use exists it would be the next one in the linked list and the loop would terminate there.

Checking the opcode first ensures we are in the single result case.

SrcVT == Src1.getValueType() && SrcVT.getScalarSizeInBits() >= 8 &&
SrcVT.getScalarSizeInBits() < VT.getScalarSizeInBits() / 2) {
LLVMContext &C = *DAG.getContext();
EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the types are (i64 (sub (zext i8), (zext i8))) this will produce (i64 (sext (i32 (sub (zext i8), (zext i8))))). Would it be better to do the sub at i16?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The combiner will iterate won't it? So first we'd produce (i64 (sext (i32 (sub (zext i8), (zext i8)))) and then the transform would run again producing (i64 (sext (i16 (sub (zext i8), (zext i8)))). (There'd be one extra step to fold the two sext together.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right I didn't think about that.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@preames preames merged commit ac518c7 into llvm:main Feb 23, 2024
3 of 4 checks passed
@preames preames deleted the pr-riscv-narrow-sub-of-zext branch February 23, 2024 00:17
lukel97 added a commit to lukel97/llvm-project that referenced this pull request Mar 22, 2024
This generalizes the combine added in llvm#82455 to other binary ops, beginning
with adds in this patch.

Because the two zext operands are always +ve when treated as signed, and we don't get any overflow since the add is carried out in at least N * 2 bits of the narrow type, the result of the add will always be +ve. So we can use a zext for the outer extend, unlike sub which may produce a -ve result from two +ve operands.

Although we could still use sext for add, I plan to add support for other binary ops like mul in a later patch, but mul requires zext to be correct (because the maximum value will take up the full N * 2 bits). So I've opted to use zext here too for consistency.

Alive2 proof: https://alive2.llvm.org/ce/z/PRNsUM
lukel97 added a commit that referenced this pull request Mar 25, 2024
…#86248)

This generalizes the combine added in #82455 to other binary ops,
beginning with adds in this patch.

Because the two zext operands are always +ve when treated as signed, and
we don't get any overflow since the add is carried out in at least N * 2
bits of the narrow type, the result of the add will always be +ve. So we
can use a zext for the outer extend, unlike sub which may produce a -ve
result from two +ve operands.

Although we could still use sext for add, I plan to add support for
other binary ops like mul in a later patch, but mul requires zext to be
correct (because the maximum value will take up the full N * 2 bits). So
I've opted to use zext here too for consistency.

Alive2 proof: https://alive2.llvm.org/ce/z/PRNsUM
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants