-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Use vwadd.vx for splat vector with extension #87249
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Chia (sun-jacobi) ChangesThis patch allows Source code
Before this patch
After this patch
Full diff: https://github.com/llvm/llvm-project/pull/87249.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f693cbd3bea51e..f422ee53874e32 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13589,6 +13589,8 @@ struct NodeExtensionHelper {
case RISCVISD::VZEXT_VL:
case RISCVISD::FP_EXTEND_VL:
return OrigOperand.getOperand(0);
+ case ISD::SPLAT_VECTOR:
+ return OrigOperand.getOperand(0).getOperand(0);
default:
return OrigOperand;
}
@@ -13640,6 +13642,8 @@ struct NodeExtensionHelper {
case RISCVISD::VZEXT_VL:
case RISCVISD::FP_EXTEND_VL:
return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
+ case ISD::SPLAT_VECTOR:
+ return DAG.getNode(ISD::SPLAT_VECTOR, DL, NarrowVT, Source, Mask, VL);
case RISCVISD::VMV_V_X_VL:
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
@@ -13817,6 +13821,35 @@ struct NodeExtensionHelper {
Mask = OrigOperand.getOperand(1);
VL = OrigOperand.getOperand(2);
break;
+ case ISD::SPLAT_VECTOR: {
+ SDValue ScalarOp = OrigOperand.getOperand(0);
+ unsigned ScalarOpc = ScalarOp.getOpcode();
+
+ MVT ScalarVT = ScalarOp.getSimpleValueType();
+ unsigned ScalarSize = ScalarVT.getScalarSizeInBits();
+ unsigned NarrowSize = ScalarSize / 2;
+
+ // Ensuring the scalar element is legal.
+ if (NarrowSize < 8)
+ break;
+
+ SupportsSExt = ScalarOpc == ISD::SIGN_EXTEND_INREG;
+
+ if (ScalarOpc == ISD::AND) {
+ if (ConstantSDNode *MaskNode =
+ dyn_cast<ConstantSDNode>(ScalarOp.getOperand(1)))
+ SupportsZExt = MaskNode->getAPIntValue() ==
+ APInt::getBitsSet(ScalarSize, 0, NarrowSize);
+ }
+
+ EnforceOneUse = false;
+ CheckMask = false;
+
+ MVT VT = OrigOperand.getSimpleValueType();
+ SDLoc DL(Root);
+ std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+ break;
+ }
case RISCVISD::VMV_V_X_VL: {
// Historically, we didn't care about splat values not disappearing during
// combines.
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 66e6883dd1d3e3..eeb29285594477 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1466,3 +1466,156 @@ define <vscale x 2 x i32> @vwadd_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vsca
%or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
ret <vscale x 2 x i32> %or
}
+
+define <vscale x 8 x i64> @vwadd_vx_splat_zext(<vscale x 8 x i32> %va, i32 %b) {
+; RV32-LABEL: vwadd_vx_splat_zext:
+; RV32: # %bb.0:
+; RV32-NEXT: addi sp, sp, -16
+; RV32-NEXT: .cfi_def_cfa_offset 16
+; RV32-NEXT: sw zero, 12(sp)
+; RV32-NEXT: sw a0, 8(sp)
+; RV32-NEXT: addi a0, sp, 8
+; RV32-NEXT: vsetvli a1, zero, e32, m4, ta, ma
+; RV32-NEXT: vlse64.v v16, (a0), zero
+; RV32-NEXT: vwaddu.wv v16, v16, v8
+; RV32-NEXT: vmv8r.v v8, v16
+; RV32-NEXT: addi sp, sp, 16
+; RV32-NEXT: ret
+;
+; RV64-LABEL: vwadd_vx_splat_zext:
+; RV64: # %bb.0:
+; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT: vwaddu.vx v16, v8, a0
+; RV64-NEXT: vmv8r.v v8, v16
+; RV64-NEXT: ret
+ %sb = zext i32 %b to i64
+ %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
+ %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
+ %vc = zext <vscale x 8 x i32> %va to <vscale x 8 x i64>
+ %ve = add <vscale x 8 x i64> %vc, %splat
+ ret <vscale x 8 x i64> %ve
+}
+
+define <vscale x 8 x i32> @vwadd_vx_splat_zext_i1(<vscale x 8 x i1> %va, i16 %b) {
+; RV32-LABEL: vwadd_vx_splat_zext_i1:
+; RV32: # %bb.0:
+; RV32-NEXT: slli a0, a0, 16
+; RV32-NEXT: srli a0, a0, 16
+; RV32-NEXT: vsetvli a1, zero, e32, m4, ta, mu
+; RV32-NEXT: vmv.v.x v8, a0
+; RV32-NEXT: vadd.vi v8, v8, 1, v0.t
+; RV32-NEXT: ret
+;
+; RV64-LABEL: vwadd_vx_splat_zext_i1:
+; RV64: # %bb.0:
+; RV64-NEXT: slli a0, a0, 48
+; RV64-NEXT: srli a0, a0, 48
+; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT: vmv.v.i v8, 0
+; RV64-NEXT: vmerge.vim v8, v8, 1, v0
+; RV64-NEXT: vadd.vx v8, v8, a0
+; RV64-NEXT: ret
+ %sb = zext i16 %b to i32
+ %head = insertelement <vscale x 8 x i32> poison, i32 %sb, i32 0
+ %splat = shufflevector <vscale x 8 x i32> %head, <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
+ %vc = zext <vscale x 8 x i1> %va to <vscale x 8 x i32>
+ %ve = add <vscale x 8 x i32> %vc, %splat
+ ret <vscale x 8 x i32> %ve
+}
+
+define <vscale x 8 x i64> @vwadd_wx_splat_zext(<vscale x 8 x i64> %va, i32 %b) {
+; RV32-LABEL: vwadd_wx_splat_zext:
+; RV32: # %bb.0:
+; RV32-NEXT: addi sp, sp, -16
+; RV32-NEXT: .cfi_def_cfa_offset 16
+; RV32-NEXT: sw zero, 12(sp)
+; RV32-NEXT: sw a0, 8(sp)
+; RV32-NEXT: addi a0, sp, 8
+; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma
+; RV32-NEXT: vlse64.v v16, (a0), zero
+; RV32-NEXT: vadd.vv v8, v8, v16
+; RV32-NEXT: addi sp, sp, 16
+; RV32-NEXT: ret
+;
+; RV64-LABEL: vwadd_wx_splat_zext:
+; RV64: # %bb.0:
+; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT: vwaddu.wx v8, v8, a0
+; RV64-NEXT: ret
+ %sb = zext i32 %b to i64
+ %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
+ %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
+ %ve = add <vscale x 8 x i64> %va, %splat
+ ret <vscale x 8 x i64> %ve
+}
+
+define <vscale x 8 x i64> @vwadd_vx_splat_sext(<vscale x 8 x i32> %va, i32 %b) {
+; RV32-LABEL: vwadd_vx_splat_sext:
+; RV32: # %bb.0:
+; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma
+; RV32-NEXT: vmv.v.x v16, a0
+; RV32-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; RV32-NEXT: vwadd.wv v16, v16, v8
+; RV32-NEXT: vmv8r.v v8, v16
+; RV32-NEXT: ret
+;
+; RV64-LABEL: vwadd_vx_splat_sext:
+; RV64: # %bb.0:
+; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT: vwadd.vx v16, v8, a0
+; RV64-NEXT: vmv8r.v v8, v16
+; RV64-NEXT: ret
+ %sb = sext i32 %b to i64
+ %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
+ %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
+ %vc = sext <vscale x 8 x i32> %va to <vscale x 8 x i64>
+ %ve = add <vscale x 8 x i64> %vc, %splat
+ ret <vscale x 8 x i64> %ve
+}
+
+define <vscale x 8 x i32> @vwadd_vx_splat_sext_i1(<vscale x 8 x i1> %va, i16 %b) {
+; RV32-LABEL: vwadd_vx_splat_sext_i1:
+; RV32: # %bb.0:
+; RV32-NEXT: slli a0, a0, 16
+; RV32-NEXT: srai a0, a0, 16
+; RV32-NEXT: vsetvli a1, zero, e32, m4, ta, mu
+; RV32-NEXT: vmv.v.x v8, a0
+; RV32-NEXT: li a0, 1
+; RV32-NEXT: vsub.vx v8, v8, a0, v0.t
+; RV32-NEXT: ret
+;
+; RV64-LABEL: vwadd_vx_splat_sext_i1:
+; RV64: # %bb.0:
+; RV64-NEXT: slli a0, a0, 48
+; RV64-NEXT: srai a0, a0, 48
+; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT: vmv.v.i v8, 0
+; RV64-NEXT: vmerge.vim v8, v8, 1, v0
+; RV64-NEXT: vrsub.vx v8, v8, a0
+; RV64-NEXT: ret
+ %sb = sext i16 %b to i32
+ %head = insertelement <vscale x 8 x i32> poison, i32 %sb, i32 0
+ %splat = shufflevector <vscale x 8 x i32> %head, <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
+ %vc = sext <vscale x 8 x i1> %va to <vscale x 8 x i32>
+ %ve = add <vscale x 8 x i32> %vc, %splat
+ ret <vscale x 8 x i32> %ve
+}
+
+define <vscale x 8 x i64> @vwadd_wx_splat_sext(<vscale x 8 x i64> %va, i32 %b) {
+; RV32-LABEL: vwadd_wx_splat_sext:
+; RV32: # %bb.0:
+; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma
+; RV32-NEXT: vadd.vx v8, v8, a0
+; RV32-NEXT: ret
+;
+; RV64-LABEL: vwadd_wx_splat_sext:
+; RV64: # %bb.0:
+; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT: vwadd.wx v8, v8, a0
+; RV64-NEXT: ret
+ %sb = sext i32 %b to i64
+ %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
+ %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
+ %ve = add <vscale x 8 x i64> %va, %splat
+ ret <vscale x 8 x i64> %ve
+}
|
|
||
SupportsSExt = ScalarOpc == ISD::SIGN_EXTEND_INREG; | ||
|
||
if (ScalarOpc == ISD::AND) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use ComputeNumSignBits and MaskedValueIsZero like RISCVISD::VMV_V_X_VL?
Nice, this would be useful to have. I had done something similar last week but just by copying and pasting the code from RISCVISD:VMV_V_X: diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f693cbd3bea5..24ecd5d57b6c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13596,7 +13596,8 @@ struct NodeExtensionHelper {
/// Check if this instance represents a splat.
bool isSplat() const {
- return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
+ return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL ||
+ OrigOperand.getOpcode() == ISD::SPLAT_VECTOR;
}
/// Get the extended opcode.
@@ -13643,6 +13644,9 @@ struct NodeExtensionHelper {
case RISCVISD::VMV_V_X_VL:
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
+ case ISD::SPLAT_VECTOR:
+ // Operand is implicitly truncated
+ return DAG.getSplat(NarrowVT, DL, Source.getOperand(0));
default:
// Other opcodes can only come from the original LHS of VW(ADD|SUB)_W_VL
// and that operand should already have the right NarrowVT so no
@@ -13817,6 +13821,37 @@ struct NodeExtensionHelper {
Mask = OrigOperand.getOperand(1);
VL = OrigOperand.getOperand(2);
break;
+ case ISD::SPLAT_VECTOR: {
+ MVT VT = OrigOperand.getSimpleValueType();
+ if (!VT.isVector())
+ break;
+ EnforceOneUse = false;
+ std::tie(Mask, VL) =
+ getDefaultScalableVLOps(VT, SDLoc(Root), DAG, Subtarget);
+
+ // Get the scalar value.
+ SDValue Op = OrigOperand.getOperand(0);
+
+ unsigned EltBits = VT.getScalarSizeInBits();
+ unsigned ScalarBits = Op.getValueSizeInBits();
+ // Make sure we're getting all element bits from the scalar register.
+ // FIXME: Support implicit sign extension of vmv.v.x?
+ if (ScalarBits < EltBits)
+ break;
+
+ unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
+ // If the narrow type cannot be expressed with a legal VMV,
+ // this is not a valid candidate.
+ if (NarrowSize < 8)
+ break;
+
+ if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
+ SupportsSExt = true;
+ if (DAG.MaskedValueIsZero(Op,
+ APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
+ SupportsZExt = true;
+ break;
+ }
case RISCVISD::VMV_V_X_VL: {
// Historically, we didn't care about splat values not disappearing during
// combines.
The trick was that I needed to make sure to add ISD::SPLAT_VECTOR to define <vscale x 1 x i16> @vadd_vx_nxv1i16_0(<vscale x 1 x i16> %va) {
; CHECK-LABEL: vadd_vx_nxv1i16_0:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetvli a0, zero, e16, mf4, ta, ma
-; CHECK-NEXT: vadd.vi v8, v8, -1
+; CHECK-NEXT: li a0, -1
+; CHECK-NEXT: vsetvli a1, zero, e8, mf8, ta, ma
+; CHECK-NEXT: vwadd.wx v8, v8, a0
; CHECK-NEXT: ret
|
; RV32-LABEL: vwadd_wx_splat_sext: | ||
; RV32: # %bb.0: | ||
; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma | ||
; RV32-NEXT: vadd.vx v8, v8, a0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably don't need to handle it in this patch, but just want to point out on RV32 we end up with SPLAT_VECTOR_PARTS for i64 element vectors IIRC. But it isn't immediately obvious to me how we're supposed to check if it's zero/sign extending, since the splat value is split over multiple operands
If the subtarget has +zvbb then we can attempt folding shl and shl_vl to vwsll nodes. There are few test cases where we still don't pick up the vwsll: - For fixed vector vwsll.vi on RV32, see the FIXME for VMV_V_X_VL in fillUpExtensionSupport for support implicit sign extension - For scalable vector vwsll.vi we need to support ISD::SPLAT_VECTOR, see llvm#87249
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the code LGTM, just some questions about the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -1229,22 +1229,6 @@ define <vscale x 1 x i64> @ctlz_nxv1i64(<vscale x 1 x i64> %va) { | |||
; RV64I-NEXT: vsrl.vx v8, v8, a0 | |||
; RV64I-NEXT: ret | |||
; | |||
; CHECK-F-LABEL: ctlz_nxv1i64: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these lines deleted without any other lines being changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, I forgot to modify the prefix, thank you.
If the subtarget has +zvbb then we can attempt folding shl and shl_vl to vwsll nodes. There are few test cases where we still don't pick up the vwsll: - For fixed vector vwsll.vi on RV32, see the FIXME for VMV_V_X_VL in fillUpExtensionSupport for support implicit sign extension - For scalable vector vwsll.vi we need to support ISD::SPLAT_VECTOR, see llvm#87249
If the subtarget has +zvbb then we can attempt folding shl and shl_vl to vwsll nodes. There are few test cases where we still don't pick up the vwsll: - For fixed vector vwsll.vi on RV32, see the FIXME for VMV_V_X_VL in fillUpExtensionSupport for support implicit sign extension - For scalable vector vwsll.vi we need to support ISD::SPLAT_VECTOR, see #87249
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This patch allows
combineBinOp_VLToVWBinOp_VL
to handle patterns like(splat_vector (sext op))
or(splat_vector (zext op))
. Then we can usevwadd.vx
andvwadd.w
for such a case.Source code
Before this patch
Compiler Explorer
After this patch