-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[RISCV][llvm] Support PSLL codegen for P extension #170074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There's no instruciton for vector shift amount, so we have to scalarize it and rebuild the vector.
|
@llvm/pr-subscribers-backend-risc-v Author: Brandon Wu (4vtomat) ChangesThere's no instruciton for vector shift amount, so we have to scalarize Full diff: https://github.com/llvm/llvm-project/pull/170074.diff 4 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index a6212f5cc84be..62d659574892c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -529,6 +529,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal);
setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal);
setOperationAction(ISD::SPLAT_VECTOR, VTs, Legal);
+ setOperationAction(ISD::SHL, VTs, Custom);
setOperationAction(ISD::BITCAST, VTs, Custom);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VTs, Custom);
}
@@ -8592,6 +8593,37 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::VSELECT:
return lowerToScalableOp(Op, DAG);
case ISD::SHL:
+ if (Subtarget.enablePExtCodeGen() &&
+ Op.getSimpleValueType().isFixedLengthVector()) {
+ // There's no vector-vector version of shift instruction in P extension so
+ // we need to fallback to scalar computation and pack them back.
+ MVT VecVT = Op.getSimpleValueType();
+ unsigned NumElts = VecVT.getVectorNumElements();
+ MVT VecEltTy = VecVT.getVectorElementType();
+ SDValue Src0 = Op.getOperand(0);
+ SDValue Src1 = Op.getOperand(1);
+ SDLoc DL(Op);
+ SmallVector<SDValue, 2> Results;
+
+ if (Src1.getNode()->getOpcode() == ISD::SPLAT_VECTOR)
+ return Op;
+
+ for (unsigned I = 0; I < NumElts; ++I) {
+ // extract scalar value
+ SDValue SrcElt0 =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltTy,
+ {Src0, DAG.getConstant(I, DL, Subtarget.getXLenVT())});
+ SDValue SrcElt1 =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltTy,
+ {Src1, DAG.getConstant(I, DL, Subtarget.getXLenVT())});
+ // perform computation
+ Results.push_back(
+ DAG.getNode(ISD::SHL, DL, VecEltTy, SrcElt0, SrcElt1));
+ }
+ // pack the results
+ return DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, Results);
+ }
+ [[fallthrough]];
case ISD::SRA:
case ISD::SRL:
if (Op.getSimpleValueType().isFixedLengthVector())
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 599358368594f..fd0d26657a88a 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1532,6 +1532,15 @@ let Predicates = [HasStdExtP] in {
def: Pat<(XLenVecI16VT (sshlsat GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))),
(PSSLAI_H GPR:$rs1, uimm4:$shamt)>;
+ // 8-bit logical shift left
+ def: Pat<(XLenVecI8VT (shl GPR:$rs1,
+ (XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))),
+ (PSLL_BS GPR:$rs1, GPR:$rs2)>;
+ // 16-bit logical shift left
+ def: Pat<(XLenVecI16VT (shl GPR:$rs1,
+ (XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))),
+ (PSLL_HS GPR:$rs1, GPR:$rs2)>;
+
// 8-bit PLI SD node pattern
def: Pat<(XLenVecI8VT (splat_vector simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>;
// 16-bit PLI SD node pattern
@@ -1578,6 +1587,10 @@ let Predicates = [HasStdExtP, IsRV64] in {
def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
+ // 32-bit logical shift left
+ def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
+ (PSLL_WS GPR:$rs1, GPR:$rs2)>;
+
// splat pattern
def: Pat<(v2i32 (splat_vector (XLenVT GPR:$rs2))), (PADD_WS (XLenVT X0), GPR:$rs2)>;
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index f803f6aa09652..3f4366f0dc758 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -637,3 +637,112 @@ define void @test_psslai_h(ptr %ret_ptr, ptr %a_ptr) {
store <2 x i16> %res, ptr %ret_ptr
ret void
}
+
+; Test logical shift left(scalar shamt)
+define void @test_psll_hs(ptr %ret_ptr, ptr %a_ptr, i16 %shamt) {
+; CHECK-LABEL: test_psll_hs:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: psll.hs a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %insert = insertelement <2 x i16> undef, i16 %shamt, i32 0
+ %b = shufflevector <2 x i16> %insert, <2 x i16> undef, <2 x i32> zeroinitializer
+ %res = shl <2 x i16> %a, %b
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_psll_bs(ptr %ret_ptr, ptr %a_ptr, i8 %shamt) {
+; CHECK-LABEL: test_psll_bs:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: psll.bs a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i8>, ptr %a_ptr
+ %insert = insertelement <4 x i8> undef, i8 %shamt, i32 0
+ %b = shufflevector <4 x i8> %insert, <4 x i8> undef, <4 x i32> zeroinitializer
+ %res = shl <4 x i8> %a, %b
+ store <4 x i8> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test logical shift left(vector shamt)
+define void @test_psll_hs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
+; CHECK-RV32-LABEL: test_psll_hs_vec_shamt:
+; CHECK-RV32: # %bb.0:
+; CHECK-RV32-NEXT: lw a1, 0(a1)
+; CHECK-RV32-NEXT: lw a2, 0(a2)
+; CHECK-RV32-NEXT: sll a3, a1, a2
+; CHECK-RV32-NEXT: srli a2, a2, 16
+; CHECK-RV32-NEXT: srli a1, a1, 16
+; CHECK-RV32-NEXT: sll a1, a1, a2
+; CHECK-RV32-NEXT: pack a1, a3, a1
+; CHECK-RV32-NEXT: sw a1, 0(a0)
+; CHECK-RV32-NEXT: ret
+;
+; CHECK-RV64-LABEL: test_psll_hs_vec_shamt:
+; CHECK-RV64: # %bb.0:
+; CHECK-RV64-NEXT: lw a1, 0(a1)
+; CHECK-RV64-NEXT: lw a2, 0(a2)
+; CHECK-RV64-NEXT: sll a3, a1, a2
+; CHECK-RV64-NEXT: srli a2, a2, 16
+; CHECK-RV64-NEXT: srli a1, a1, 16
+; CHECK-RV64-NEXT: sll a1, a1, a2
+; CHECK-RV64-NEXT: ppack.w a1, a3, a1
+; CHECK-RV64-NEXT: sw a1, 0(a0)
+; CHECK-RV64-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %shamt_ptr
+ %res = shl <2 x i16> %a, %b
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_psll_bs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
+; CHECK-RV32-LABEL: test_psll_bs_vec_shamt:
+; CHECK-RV32: # %bb.0:
+; CHECK-RV32-NEXT: lw a2, 0(a2)
+; CHECK-RV32-NEXT: lw a1, 0(a1)
+; CHECK-RV32-NEXT: srli a3, a2, 24
+; CHECK-RV32-NEXT: srli a4, a1, 24
+; CHECK-RV32-NEXT: srli a5, a2, 8
+; CHECK-RV32-NEXT: srli a6, a1, 8
+; CHECK-RV32-NEXT: sll a7, a4, a3
+; CHECK-RV32-NEXT: sll a6, a6, a5
+; CHECK-RV32-NEXT: sll a4, a1, a2
+; CHECK-RV32-NEXT: srli a2, a2, 16
+; CHECK-RV32-NEXT: srli a1, a1, 16
+; CHECK-RV32-NEXT: sll a5, a1, a2
+; CHECK-RV32-NEXT: ppack.dh a2, a4, a6
+; CHECK-RV32-NEXT: pack a1, a2, a3
+; CHECK-RV32-NEXT: sw a1, 0(a0)
+; CHECK-RV32-NEXT: ret
+;
+; CHECK-RV64-LABEL: test_psll_bs_vec_shamt:
+; CHECK-RV64: # %bb.0:
+; CHECK-RV64-NEXT: lw a2, 0(a2)
+; CHECK-RV64-NEXT: lw a1, 0(a1)
+; CHECK-RV64-NEXT: srli a3, a2, 24
+; CHECK-RV64-NEXT: srli a4, a1, 24
+; CHECK-RV64-NEXT: srli a5, a2, 16
+; CHECK-RV64-NEXT: sll a3, a4, a3
+; CHECK-RV64-NEXT: srli a4, a1, 16
+; CHECK-RV64-NEXT: sll a4, a4, a5
+; CHECK-RV64-NEXT: sll a5, a1, a2
+; CHECK-RV64-NEXT: srli a2, a2, 8
+; CHECK-RV64-NEXT: srli a1, a1, 8
+; CHECK-RV64-NEXT: sll a1, a1, a2
+; CHECK-RV64-NEXT: ppack.h a2, a4, a3
+; CHECK-RV64-NEXT: ppack.h a1, a5, a1
+; CHECK-RV64-NEXT: ppack.w a1, a1, a2
+; CHECK-RV64-NEXT: sw a1, 0(a0)
+; CHECK-RV64-NEXT: ret
+ %a = load <4 x i8>, ptr %a_ptr
+ %b = load <4 x i8>, ptr %shamt_ptr
+ %res = shl <4 x i8> %a, %b
+ store <4 x i8> %res, ptr %ret_ptr
+ ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index 9b021df8dd452..721762fffdc9f 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -805,3 +805,39 @@ define void @test_psslai_w(ptr %ret_ptr, ptr %a_ptr) {
store <2 x i32> %res, ptr %ret_ptr
ret void
}
+
+; Test logical shift left(scalar shamt)
+define void @test_psll_ws(ptr %ret_ptr, ptr %a_ptr, i32 %shamt) {
+; CHECK-LABEL: test_psll_ws:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: psll.ws a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %insert = insertelement <2 x i32> undef, i32 %shamt, i32 0
+ %b = shufflevector <2 x i32> %insert, <2 x i32> undef, <2 x i32> zeroinitializer
+ %res = shl <2 x i32> %a, %b
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test logical shift left(vector shamt)
+define void @test_psll_ws_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
+; CHECK-LABEL: test_psll_ws_vec_shamt:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: sllw a3, a1, a2
+; CHECK-NEXT: srli a2, a2, 32
+; CHECK-NEXT: srli a1, a1, 32
+; CHECK-NEXT: sllw a1, a1, a2
+; CHECK-NEXT: pack a1, a3, a1
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %shamt_ptr
+ %res = shl <2 x i32> %a, %b
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
|
|
✅ With the latest revision this PR passed the undef deprecator. |
| if (Src1.getNode()->getOpcode() == ISD::SPLAT_VECTOR) | ||
| return Op; | ||
|
|
||
| for (unsigned I = 0; I < NumElts; ++I) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use DAG.UnrollVectorOp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that's much more concise, I should have done this way!
topperc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There's no instruciton for vector shift amount, so we have to scalarize
it and rebuild the vector.