-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Teach combineBinOpOfZExt to narrow based on known bits #86680
Conversation
This extends the existing narrowing transform for binop (zext, zext) to use known zero bits from the source of the zext if the zext is not at greater than 2x in size. This is essentially a generic narrowing for vector binops (currently add/sub) with operands known to be positive in the half-bitwidth w/a restriction to the case where we eliminate a source zext. This patch is currently slightly WIP. I want to add a few more tests, and will rebase. I went ahead and posted it now as it seems to expose the same basic widening op matching issue as llvm#86465.
@llvm/pr-subscribers-backend-risc-v Author: Philip Reames (preames) ChangesThis extends the existing narrowing transform for binop (zext, zext) to use known zero bits from the source of the zext if the zext is not at greater than 2x in size. This is essentially a generic narrowing for vector binops (currently add/sub) with operands known to be positive in the half-bitwidth w/a restriction to the case where we eliminate a source zext. This patch is currently slightly WIP. I want to add a few more tests, and will rebase. I went ahead and posted it now as it seems to expose the same basic widening op matching issue as #86465. Full diff: https://github.com/llvm/llvm-project/pull/86680.diff 4 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e6814c5f71a09b..507f5a600f51ab 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12944,32 +12944,50 @@ static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG,
static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
- if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT))
+ if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
+ VT.getScalarSizeInBits() <= 8)
return SDValue();
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
if (N0.getOpcode() != ISD::ZERO_EXTEND || N1.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
+ // TODO: Can relax these checks when we're not needing to insert a new extend
+ // on one side or the other..
if (!N0.hasOneUse() || !N1.hasOneUse())
return SDValue();
SDValue Src0 = N0.getOperand(0);
SDValue Src1 = N1.getOperand(0);
- EVT SrcVT = Src0.getValueType();
- if (!DAG.getTargetLoweringInfo().isTypeLegal(SrcVT) ||
- SrcVT != Src1.getValueType() || SrcVT.getScalarSizeInBits() < 8 ||
- SrcVT.getScalarSizeInBits() >= VT.getScalarSizeInBits() / 2)
+ EVT Src0VT = Src0.getValueType();
+ EVT Src1VT = Src0.getValueType();
+
+ if (!DAG.getTargetLoweringInfo().isTypeLegal(Src0VT) ||
+ !DAG.getTargetLoweringInfo().isTypeLegal(Src1VT))
return SDValue();
+ unsigned HalfBitWidth = VT.getScalarSizeInBits() / 2;
+ if (Src0VT.getScalarSizeInBits() >= HalfBitWidth) {
+ KnownBits Known = DAG.computeKnownBits(Src0);
+ if (Known.countMinLeadingZeros() <= HalfBitWidth)
+ return SDValue();
+ }
+ if (Src1VT.getScalarSizeInBits() >= HalfBitWidth) {
+ KnownBits Known = DAG.computeKnownBits(Src0);
+ if (Known.countMinLeadingZeros() <= HalfBitWidth)
+ return SDValue();
+ }
+
LLVMContext &C = *DAG.getContext();
- EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
+ EVT ElemVT = EVT::getIntegerVT(C, HalfBitWidth);
EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
- Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
- Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
+ if (Src0VT != NarrowVT)
+ Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
+ if (Src1VT != NarrowVT)
+ Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
- // Src0 and Src1 are zero extended, so they're always positive if signed.
+ // Src0 and Src1 are always positive if signed.
//
// sub can produce a negative from two positive operands, so it needs sign
// extended. Other nodes produce a positive from two positive operands, so
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
index a4ab67f41595d4..19ade65db59f43 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
@@ -98,41 +98,38 @@ define signext i32 @sad_2block_16xi8_as_i32(ptr %a, ptr %b, i32 signext %stridea
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle8.v v9, (a1)
+; CHECK-NEXT: vminu.vv v10, v8, v9
+; CHECK-NEXT: vmaxu.vv v8, v8, v9
+; CHECK-NEXT: vsub.vv v8, v8, v10
; CHECK-NEXT: add a0, a0, a2
; CHECK-NEXT: add a1, a1, a3
-; CHECK-NEXT: vle8.v v10, (a0)
-; CHECK-NEXT: vle8.v v11, (a1)
-; CHECK-NEXT: vminu.vv v12, v8, v9
-; CHECK-NEXT: vmaxu.vv v8, v8, v9
-; CHECK-NEXT: vsub.vv v8, v8, v12
-; CHECK-NEXT: vminu.vv v9, v10, v11
+; CHECK-NEXT: vle8.v v9, (a0)
+; CHECK-NEXT: vle8.v v10, (a1)
+; CHECK-NEXT: add a0, a0, a2
+; CHECK-NEXT: add a1, a1, a3
+; CHECK-NEXT: vle8.v v11, (a0)
+; CHECK-NEXT: vle8.v v12, (a1)
+; CHECK-NEXT: vminu.vv v13, v9, v10
+; CHECK-NEXT: vmaxu.vv v9, v9, v10
+; CHECK-NEXT: vsub.vv v9, v9, v13
+; CHECK-NEXT: vminu.vv v10, v11, v12
+; CHECK-NEXT: vmaxu.vv v11, v11, v12
; CHECK-NEXT: add a0, a0, a2
; CHECK-NEXT: add a1, a1, a3
; CHECK-NEXT: vle8.v v12, (a0)
; CHECK-NEXT: vle8.v v13, (a1)
-; CHECK-NEXT: vmaxu.vv v10, v10, v11
-; CHECK-NEXT: vsub.vv v9, v10, v9
-; CHECK-NEXT: vwaddu.vv v10, v9, v8
+; CHECK-NEXT: vsub.vv v10, v11, v10
+; CHECK-NEXT: vwaddu.vv v14, v9, v8
+; CHECK-NEXT: vwaddu.wv v14, v14, v10
; CHECK-NEXT: vminu.vv v8, v12, v13
; CHECK-NEXT: vmaxu.vv v9, v12, v13
; CHECK-NEXT: vsub.vv v8, v9, v8
-; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT: add a0, a0, a2
-; CHECK-NEXT: add a1, a1, a3
-; CHECK-NEXT: vle8.v v9, (a0)
-; CHECK-NEXT: vle8.v v12, (a1)
-; CHECK-NEXT: vzext.vf2 v14, v8
-; CHECK-NEXT: vwaddu.vv v16, v14, v10
-; CHECK-NEXT: vsetvli zero, zero, e8, m1, ta, ma
-; CHECK-NEXT: vminu.vv v8, v9, v12
-; CHECK-NEXT: vmaxu.vv v9, v9, v12
-; CHECK-NEXT: vsub.vv v8, v9, v8
-; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT: vzext.vf2 v10, v8
-; CHECK-NEXT: vwaddu.wv v16, v16, v10
+; CHECK-NEXT: vwaddu.wv v14, v14, v8
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v16, v8
+; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT: vwredsumu.vs v8, v14, v8
+; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: ret
entry:
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
index bc0bf5dd76ad45..ccf76f97ac8b6b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
@@ -403,11 +403,12 @@ define <2 x i32> @vwaddu_v2i32_v2i8(ptr %x, ptr %y) {
define <4 x i32> @vwaddu_v4i32_v4i8_v4i16(ptr %x, ptr %y) {
; CHECK-LABEL: vwaddu_v4i32_v4i8_v4i16:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle16.v v9, (a1)
-; CHECK-NEXT: vzext.vf2 v10, v8
-; CHECK-NEXT: vwaddu.vv v8, v10, v9
+; CHECK-NEXT: vwaddu.wv v9, v9, v8
+; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT: vzext.vf2 v8, v9
; CHECK-NEXT: ret
%a = load <4 x i8>, ptr %x
%b = load <4 x i16>, ptr %y
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
index a084b5383b4030..7c53577309b576 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
@@ -407,7 +407,9 @@ define <4 x i32> @vwsubu_v4i32_v4i8_v4i16(ptr %x, ptr %y) {
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle16.v v9, (a1)
; CHECK-NEXT: vzext.vf2 v10, v8
-; CHECK-NEXT: vwsubu.vv v8, v10, v9
+; CHECK-NEXT: vsub.vv v9, v10, v9
+; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT: vsext.vf2 v8, v9
; CHECK-NEXT: ret
%a = load <4 x i8>, ptr %x
%b = load <4 x i16>, ptr %y
|
I was going to drop a comment on #86680 warning that whichever patch landed first would need rebased carefully as the required zero bits are different for mul than for add/sub, but Luke beat me to landing between when I decided to post this and actually doing so. As such, I will need to rebase this to account for the newly added mul case. :) |
You can test this locally with the following command:git-clang-format --diff a6b870db091830844431f77eb47aa30fc1d70bed 232dbc9e2becc7bcaff67753168704dccbc82378 -- llvm/lib/Target/RISCV/RISCVISelLowering.cpp View the diff from clang-format here.diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 507f5a600f..d3cdb9826f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12966,7 +12966,7 @@ static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG) {
!DAG.getTargetLoweringInfo().isTypeLegal(Src1VT))
return SDValue();
- unsigned HalfBitWidth = VT.getScalarSizeInBits() / 2;
+ unsigned HalfBitWidth = VT.getScalarSizeInBits() / 2;
if (Src0VT.getScalarSizeInBits() >= HalfBitWidth) {
KnownBits Known = DAG.computeKnownBits(Src0);
if (Known.countMinLeadingZeros() <= HalfBitWidth)
|
I just landed #86465, this might needed rebased to pick up any changes in the mul tests? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably tangential to this PR, but just want to mention that I had tried using computeKnownBits in combineBinOpOfZext last week. But it was to generalise whether we need a sign extend or zero extend for the result. I.e. to remove the hardcoded opcode == ISD::SUB.
The difficulty I ended up running into is that KnownBits doesn't store enough information about the potential range of the result. I.e. an i64 sub of two zero extended i8 integers will be between 0x00000000000000FF and 0xFFFFFFFFFFFFFE02. So the range can fit into i16, but from KnownBits perspective this is just 0x??????????????. So we can't actually tell what the minimum bit width needed is.
Presumably would need to use something based on ConstantRange instead, but I didn't find anything already in SelectionDAG.
Shouldn't that be between 0x00000000000000FF and 0xFFFFFFFFFFFFFE01(-255) Does ComputeNumSignBits know? |
I knew this was somewhat WIP when I posted it, but I hadn't realized I had several nasty bugs which invalidated all the test changes. Going to start this one from scratch with a new review once I actually have something which works. |
Yes, I had originally written 0xFFFFFFFFFFFFFF01 (0x00 - 0xFF ?) but then confused myself thinking the i8s could be signed and edited my comment
That seems to be what I was looking for, thanks |
This extends the existing narrowing transform for binop (zext, zext) to use known zero bits from the source of the zext if the zext is not at greater than 2x in size. This is essentially a generic narrowing for vector binops (currently add/sub) with operands known to be positive in the half-bitwidth w/a restriction to the case where we eliminate a source zext.
This patch is currently slightly WIP. I want to add a few more tests, and will rebase. I went ahead and posted it now as it seems to expose the same basic widening op matching issue as #86465.