Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Teach combineBinOpOfZExt to narrow based on known bits #86680

Closed
wants to merge 1 commit into from

Conversation

preames
Copy link
Collaborator

@preames preames commented Mar 26, 2024

This extends the existing narrowing transform for binop (zext, zext) to use known zero bits from the source of the zext if the zext is not at greater than 2x in size. This is essentially a generic narrowing for vector binops (currently add/sub) with operands known to be positive in the half-bitwidth w/a restriction to the case where we eliminate a source zext.

This patch is currently slightly WIP. I want to add a few more tests, and will rebase. I went ahead and posted it now as it seems to expose the same basic widening op matching issue as #86465.

This extends the existing narrowing transform for binop (zext, zext)
to use known zero bits from the source of the zext if the zext is not
at greater than 2x in size.  This is essentially a generic narrowing
for vector binops (currently add/sub) with operands known to be
positive in the half-bitwidth w/a restriction to the case where we
eliminate a source zext.

This patch is currently slightly WIP.  I want to add a few more tests,
and will rebase.  I went ahead and posted it now as it seems to expose
the same basic widening op matching issue as llvm#86465.
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 26, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

Changes

This extends the existing narrowing transform for binop (zext, zext) to use known zero bits from the source of the zext if the zext is not at greater than 2x in size. This is essentially a generic narrowing for vector binops (currently add/sub) with operands known to be positive in the half-bitwidth w/a restriction to the case where we eliminate a source zext.

This patch is currently slightly WIP. I want to add a few more tests, and will rebase. I went ahead and posted it now as it seems to expose the same basic widening op matching issue as #86465.


Full diff: https://github.com/llvm/llvm-project/pull/86680.diff

4 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+27-9)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll (+21-24)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll (+4-3)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll (+3-1)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e6814c5f71a09b..507f5a600f51ab 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12944,32 +12944,50 @@ static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG,
 static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG) {
 
   EVT VT = N->getValueType(0);
-  if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT))
+  if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
+      VT.getScalarSizeInBits() <= 8)
     return SDValue();
 
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   if (N0.getOpcode() != ISD::ZERO_EXTEND || N1.getOpcode() != ISD::ZERO_EXTEND)
     return SDValue();
+  // TODO: Can relax these checks when we're not needing to insert a new extend
+  // on one side or the other..
   if (!N0.hasOneUse() || !N1.hasOneUse())
     return SDValue();
 
   SDValue Src0 = N0.getOperand(0);
   SDValue Src1 = N1.getOperand(0);
-  EVT SrcVT = Src0.getValueType();
-  if (!DAG.getTargetLoweringInfo().isTypeLegal(SrcVT) ||
-      SrcVT != Src1.getValueType() || SrcVT.getScalarSizeInBits() < 8 ||
-      SrcVT.getScalarSizeInBits() >= VT.getScalarSizeInBits() / 2)
+  EVT Src0VT = Src0.getValueType();
+  EVT Src1VT = Src0.getValueType();
+
+  if (!DAG.getTargetLoweringInfo().isTypeLegal(Src0VT) ||
+      !DAG.getTargetLoweringInfo().isTypeLegal(Src1VT))
     return SDValue();
 
+  unsigned HalfBitWidth =  VT.getScalarSizeInBits() / 2;
+  if (Src0VT.getScalarSizeInBits() >= HalfBitWidth) {
+    KnownBits Known = DAG.computeKnownBits(Src0);
+    if (Known.countMinLeadingZeros() <= HalfBitWidth)
+      return SDValue();
+  }
+  if (Src1VT.getScalarSizeInBits() >= HalfBitWidth) {
+    KnownBits Known = DAG.computeKnownBits(Src0);
+    if (Known.countMinLeadingZeros() <= HalfBitWidth)
+      return SDValue();
+  }
+
   LLVMContext &C = *DAG.getContext();
-  EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
+  EVT ElemVT = EVT::getIntegerVT(C, HalfBitWidth);
   EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
 
-  Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
-  Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
+  if (Src0VT != NarrowVT)
+    Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
+  if (Src1VT != NarrowVT)
+    Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
 
-  // Src0 and Src1 are zero extended, so they're always positive if signed.
+  // Src0 and Src1 are always positive if signed.
   //
   // sub can produce a negative from two positive operands, so it needs sign
   // extended. Other nodes produce a positive from two positive operands, so
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
index a4ab67f41595d4..19ade65db59f43 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
@@ -98,41 +98,38 @@ define signext i32 @sad_2block_16xi8_as_i32(ptr %a, ptr %b, i32 signext %stridea
 ; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
+; CHECK-NEXT:    vminu.vv v10, v8, v9
+; CHECK-NEXT:    vmaxu.vv v8, v8, v9
+; CHECK-NEXT:    vsub.vv v8, v8, v10
 ; CHECK-NEXT:    add a0, a0, a2
 ; CHECK-NEXT:    add a1, a1, a3
-; CHECK-NEXT:    vle8.v v10, (a0)
-; CHECK-NEXT:    vle8.v v11, (a1)
-; CHECK-NEXT:    vminu.vv v12, v8, v9
-; CHECK-NEXT:    vmaxu.vv v8, v8, v9
-; CHECK-NEXT:    vsub.vv v8, v8, v12
-; CHECK-NEXT:    vminu.vv v9, v10, v11
+; CHECK-NEXT:    vle8.v v9, (a0)
+; CHECK-NEXT:    vle8.v v10, (a1)
+; CHECK-NEXT:    add a0, a0, a2
+; CHECK-NEXT:    add a1, a1, a3
+; CHECK-NEXT:    vle8.v v11, (a0)
+; CHECK-NEXT:    vle8.v v12, (a1)
+; CHECK-NEXT:    vminu.vv v13, v9, v10
+; CHECK-NEXT:    vmaxu.vv v9, v9, v10
+; CHECK-NEXT:    vsub.vv v9, v9, v13
+; CHECK-NEXT:    vminu.vv v10, v11, v12
+; CHECK-NEXT:    vmaxu.vv v11, v11, v12
 ; CHECK-NEXT:    add a0, a0, a2
 ; CHECK-NEXT:    add a1, a1, a3
 ; CHECK-NEXT:    vle8.v v12, (a0)
 ; CHECK-NEXT:    vle8.v v13, (a1)
-; CHECK-NEXT:    vmaxu.vv v10, v10, v11
-; CHECK-NEXT:    vsub.vv v9, v10, v9
-; CHECK-NEXT:    vwaddu.vv v10, v9, v8
+; CHECK-NEXT:    vsub.vv v10, v11, v10
+; CHECK-NEXT:    vwaddu.vv v14, v9, v8
+; CHECK-NEXT:    vwaddu.wv v14, v14, v10
 ; CHECK-NEXT:    vminu.vv v8, v12, v13
 ; CHECK-NEXT:    vmaxu.vv v9, v12, v13
 ; CHECK-NEXT:    vsub.vv v8, v9, v8
-; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT:    add a0, a0, a2
-; CHECK-NEXT:    add a1, a1, a3
-; CHECK-NEXT:    vle8.v v9, (a0)
-; CHECK-NEXT:    vle8.v v12, (a1)
-; CHECK-NEXT:    vzext.vf2 v14, v8
-; CHECK-NEXT:    vwaddu.vv v16, v14, v10
-; CHECK-NEXT:    vsetvli zero, zero, e8, m1, ta, ma
-; CHECK-NEXT:    vminu.vv v8, v9, v12
-; CHECK-NEXT:    vmaxu.vv v9, v9, v12
-; CHECK-NEXT:    vsub.vv v8, v9, v8
-; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vwaddu.wv v16, v16, v10
+; CHECK-NEXT:    vwaddu.wv v14, v14, v8
 ; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
 ; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v16, v8
+; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vwredsumu.vs v8, v14, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
 ; CHECK-NEXT:    vmv.x.s a0, v8
 ; CHECK-NEXT:    ret
 entry:
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
index bc0bf5dd76ad45..ccf76f97ac8b6b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
@@ -403,11 +403,12 @@ define <2 x i32> @vwaddu_v2i32_v2i8(ptr %x, ptr %y) {
 define <4 x i32> @vwaddu_v4i32_v4i8_v4i16(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwaddu_v4i32_v4i8_v4i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle16.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vwaddu.vv v8, v10, v9
+; CHECK-NEXT:    vwaddu.wv v9, v9, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v9
 ; CHECK-NEXT:    ret
   %a = load <4 x i8>, ptr %x
   %b = load <4 x i16>, ptr %y
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
index a084b5383b4030..7c53577309b576 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
@@ -407,7 +407,9 @@ define <4 x i32> @vwsubu_v4i32_v4i8_v4i16(ptr %x, ptr %y) {
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle16.v v9, (a1)
 ; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vwsubu.vv v8, v10, v9
+; CHECK-NEXT:    vsub.vv v9, v10, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v9
 ; CHECK-NEXT:    ret
   %a = load <4 x i8>, ptr %x
   %b = load <4 x i16>, ptr %y

@preames
Copy link
Collaborator Author

preames commented Mar 26, 2024

I was going to drop a comment on #86680 warning that whichever patch landed first would need rebased carefully as the required zero bits are different for mul than for add/sub, but Luke beat me to landing between when I decided to post this and actually doing so. As such, I will need to rebase this to account for the newly added mul case. :)

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff a6b870db091830844431f77eb47aa30fc1d70bed 232dbc9e2becc7bcaff67753168704dccbc82378 -- llvm/lib/Target/RISCV/RISCVISelLowering.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 507f5a600f..d3cdb9826f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12966,7 +12966,7 @@ static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG) {
       !DAG.getTargetLoweringInfo().isTypeLegal(Src1VT))
     return SDValue();
 
-  unsigned HalfBitWidth =  VT.getScalarSizeInBits() / 2;
+  unsigned HalfBitWidth = VT.getScalarSizeInBits() / 2;
   if (Src0VT.getScalarSizeInBits() >= HalfBitWidth) {
     KnownBits Known = DAG.computeKnownBits(Src0);
     if (Known.countMinLeadingZeros() <= HalfBitWidth)

@lukel97
Copy link
Contributor

lukel97 commented Mar 26, 2024

I just landed #86465, this might needed rebased to pick up any changes in the mul tests?
Edit: Perfect timing :)

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably tangential to this PR, but just want to mention that I had tried using computeKnownBits in combineBinOpOfZext last week. But it was to generalise whether we need a sign extend or zero extend for the result. I.e. to remove the hardcoded opcode == ISD::SUB.

The difficulty I ended up running into is that KnownBits doesn't store enough information about the potential range of the result. I.e. an i64 sub of two zero extended i8 integers will be between 0x00000000000000FF and 0xFFFFFFFFFFFFFE02. So the range can fit into i16, but from KnownBits perspective this is just 0x??????????????. So we can't actually tell what the minimum bit width needed is.

Presumably would need to use something based on ConstantRange instead, but I didn't find anything already in SelectionDAG.

@topperc
Copy link
Collaborator

topperc commented Mar 26, 2024

The difficulty I ended up running into is that KnownBits doesn't store enough information about the potential range of the result. I.e. an i64 sub of two zero extended i8 integers will be between 0x00000000000000FF and 0xFFFFFFFFFFFFFE02. So the range can fit into i16, but from KnownBits perspective this is just 0x??????????????. So we can't actually tell what the minimum bit width needed is.

Shouldn't that be between 0x00000000000000FF and 0xFFFFFFFFFFFFFE01(-255)

Does ComputeNumSignBits know?

@preames
Copy link
Collaborator Author

preames commented Mar 26, 2024

I knew this was somewhat WIP when I posted it, but I hadn't realized I had several nasty bugs which invalidated all the test changes. Going to start this one from scratch with a new review once I actually have something which works.

@preames preames closed this Mar 26, 2024
@lukel97
Copy link
Contributor

lukel97 commented Mar 27, 2024

The difficulty I ended up running into is that KnownBits doesn't store enough information about the potential range of the result. I.e. an i64 sub of two zero extended i8 integers will be between 0x00000000000000FF and 0xFFFFFFFFFFFFFE02. So the range can fit into i16, but from KnownBits perspective this is just 0x??????????????. So we can't actually tell what the minimum bit width needed is.

Shouldn't that be between 0x00000000000000FF and 0xFFFFFFFFFFFFFE01(-255)

Yes, I had originally written 0xFFFFFFFFFFFFFF01 (0x00 - 0xFF ?) but then confused myself thinking the i8s could be signed and edited my comment

Does ComputeNumSignBits know?

That seems to be what I was looking for, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants