From 760f3e4a42504bc866c8cfea73c7e9a8a55331fc Mon Sep 17 00:00:00 2001 From: Brandon Wu Date: Fri, 21 Nov 2025 04:53:24 -0800 Subject: [PATCH 1/4] [RISCV][llvm] Support PSLL codegen for P extension There's no instruciton for vector shift amount, so we have to scalarize it and rebuild the vector. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 32 ++++++ llvm/lib/Target/RISCV/RISCVInstrInfoP.td | 13 +++ llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll | 109 ++++++++++++++++++++ llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll | 36 +++++++ 4 files changed, 190 insertions(+) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index a6212f5cc84be..62d659574892c 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -529,6 +529,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal); setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal); setOperationAction(ISD::SPLAT_VECTOR, VTs, Legal); + setOperationAction(ISD::SHL, VTs, Custom); setOperationAction(ISD::BITCAST, VTs, Custom); setOperationAction(ISD::EXTRACT_VECTOR_ELT, VTs, Custom); } @@ -8592,6 +8593,37 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::VSELECT: return lowerToScalableOp(Op, DAG); case ISD::SHL: + if (Subtarget.enablePExtCodeGen() && + Op.getSimpleValueType().isFixedLengthVector()) { + // There's no vector-vector version of shift instruction in P extension so + // we need to fallback to scalar computation and pack them back. + MVT VecVT = Op.getSimpleValueType(); + unsigned NumElts = VecVT.getVectorNumElements(); + MVT VecEltTy = VecVT.getVectorElementType(); + SDValue Src0 = Op.getOperand(0); + SDValue Src1 = Op.getOperand(1); + SDLoc DL(Op); + SmallVector Results; + + if (Src1.getNode()->getOpcode() == ISD::SPLAT_VECTOR) + return Op; + + for (unsigned I = 0; I < NumElts; ++I) { + // extract scalar value + SDValue SrcElt0 = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltTy, + {Src0, DAG.getConstant(I, DL, Subtarget.getXLenVT())}); + SDValue SrcElt1 = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltTy, + {Src1, DAG.getConstant(I, DL, Subtarget.getXLenVT())}); + // perform computation + Results.push_back( + DAG.getNode(ISD::SHL, DL, VecEltTy, SrcElt0, SrcElt1)); + } + // pack the results + return DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, Results); + } + [[fallthrough]]; case ISD::SRA: case ISD::SRL: if (Op.getSimpleValueType().isFixedLengthVector()) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td index 599358368594f..fd0d26657a88a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td @@ -1532,6 +1532,15 @@ let Predicates = [HasStdExtP] in { def: Pat<(XLenVecI16VT (sshlsat GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))), (PSSLAI_H GPR:$rs1, uimm4:$shamt)>; + // 8-bit logical shift left + def: Pat<(XLenVecI8VT (shl GPR:$rs1, + (XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))), + (PSLL_BS GPR:$rs1, GPR:$rs2)>; + // 16-bit logical shift left + def: Pat<(XLenVecI16VT (shl GPR:$rs1, + (XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))), + (PSLL_HS GPR:$rs1, GPR:$rs2)>; + // 8-bit PLI SD node pattern def: Pat<(XLenVecI8VT (splat_vector simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>; // 16-bit PLI SD node pattern @@ -1578,6 +1587,10 @@ let Predicates = [HasStdExtP, IsRV64] in { def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>; def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>; + // 32-bit logical shift left + def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))), + (PSLL_WS GPR:$rs1, GPR:$rs2)>; + // splat pattern def: Pat<(v2i32 (splat_vector (XLenVT GPR:$rs2))), (PADD_WS (XLenVT X0), GPR:$rs2)>; diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll index f803f6aa09652..3f4366f0dc758 100644 --- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll +++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll @@ -637,3 +637,112 @@ define void @test_psslai_h(ptr %ret_ptr, ptr %a_ptr) { store <2 x i16> %res, ptr %ret_ptr ret void } + +; Test logical shift left(scalar shamt) +define void @test_psll_hs(ptr %ret_ptr, ptr %a_ptr, i16 %shamt) { +; CHECK-LABEL: test_psll_hs: +; CHECK: # %bb.0: +; CHECK-NEXT: lw a1, 0(a1) +; CHECK-NEXT: psll.hs a1, a1, a2 +; CHECK-NEXT: sw a1, 0(a0) +; CHECK-NEXT: ret + %a = load <2 x i16>, ptr %a_ptr + %insert = insertelement <2 x i16> undef, i16 %shamt, i32 0 + %b = shufflevector <2 x i16> %insert, <2 x i16> undef, <2 x i32> zeroinitializer + %res = shl <2 x i16> %a, %b + store <2 x i16> %res, ptr %ret_ptr + ret void +} + +define void @test_psll_bs(ptr %ret_ptr, ptr %a_ptr, i8 %shamt) { +; CHECK-LABEL: test_psll_bs: +; CHECK: # %bb.0: +; CHECK-NEXT: lw a1, 0(a1) +; CHECK-NEXT: psll.bs a1, a1, a2 +; CHECK-NEXT: sw a1, 0(a0) +; CHECK-NEXT: ret + %a = load <4 x i8>, ptr %a_ptr + %insert = insertelement <4 x i8> undef, i8 %shamt, i32 0 + %b = shufflevector <4 x i8> %insert, <4 x i8> undef, <4 x i32> zeroinitializer + %res = shl <4 x i8> %a, %b + store <4 x i8> %res, ptr %ret_ptr + ret void +} + +; Test logical shift left(vector shamt) +define void @test_psll_hs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) { +; CHECK-RV32-LABEL: test_psll_hs_vec_shamt: +; CHECK-RV32: # %bb.0: +; CHECK-RV32-NEXT: lw a1, 0(a1) +; CHECK-RV32-NEXT: lw a2, 0(a2) +; CHECK-RV32-NEXT: sll a3, a1, a2 +; CHECK-RV32-NEXT: srli a2, a2, 16 +; CHECK-RV32-NEXT: srli a1, a1, 16 +; CHECK-RV32-NEXT: sll a1, a1, a2 +; CHECK-RV32-NEXT: pack a1, a3, a1 +; CHECK-RV32-NEXT: sw a1, 0(a0) +; CHECK-RV32-NEXT: ret +; +; CHECK-RV64-LABEL: test_psll_hs_vec_shamt: +; CHECK-RV64: # %bb.0: +; CHECK-RV64-NEXT: lw a1, 0(a1) +; CHECK-RV64-NEXT: lw a2, 0(a2) +; CHECK-RV64-NEXT: sll a3, a1, a2 +; CHECK-RV64-NEXT: srli a2, a2, 16 +; CHECK-RV64-NEXT: srli a1, a1, 16 +; CHECK-RV64-NEXT: sll a1, a1, a2 +; CHECK-RV64-NEXT: ppack.w a1, a3, a1 +; CHECK-RV64-NEXT: sw a1, 0(a0) +; CHECK-RV64-NEXT: ret + %a = load <2 x i16>, ptr %a_ptr + %b = load <2 x i16>, ptr %shamt_ptr + %res = shl <2 x i16> %a, %b + store <2 x i16> %res, ptr %ret_ptr + ret void +} + +define void @test_psll_bs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) { +; CHECK-RV32-LABEL: test_psll_bs_vec_shamt: +; CHECK-RV32: # %bb.0: +; CHECK-RV32-NEXT: lw a2, 0(a2) +; CHECK-RV32-NEXT: lw a1, 0(a1) +; CHECK-RV32-NEXT: srli a3, a2, 24 +; CHECK-RV32-NEXT: srli a4, a1, 24 +; CHECK-RV32-NEXT: srli a5, a2, 8 +; CHECK-RV32-NEXT: srli a6, a1, 8 +; CHECK-RV32-NEXT: sll a7, a4, a3 +; CHECK-RV32-NEXT: sll a6, a6, a5 +; CHECK-RV32-NEXT: sll a4, a1, a2 +; CHECK-RV32-NEXT: srli a2, a2, 16 +; CHECK-RV32-NEXT: srli a1, a1, 16 +; CHECK-RV32-NEXT: sll a5, a1, a2 +; CHECK-RV32-NEXT: ppack.dh a2, a4, a6 +; CHECK-RV32-NEXT: pack a1, a2, a3 +; CHECK-RV32-NEXT: sw a1, 0(a0) +; CHECK-RV32-NEXT: ret +; +; CHECK-RV64-LABEL: test_psll_bs_vec_shamt: +; CHECK-RV64: # %bb.0: +; CHECK-RV64-NEXT: lw a2, 0(a2) +; CHECK-RV64-NEXT: lw a1, 0(a1) +; CHECK-RV64-NEXT: srli a3, a2, 24 +; CHECK-RV64-NEXT: srli a4, a1, 24 +; CHECK-RV64-NEXT: srli a5, a2, 16 +; CHECK-RV64-NEXT: sll a3, a4, a3 +; CHECK-RV64-NEXT: srli a4, a1, 16 +; CHECK-RV64-NEXT: sll a4, a4, a5 +; CHECK-RV64-NEXT: sll a5, a1, a2 +; CHECK-RV64-NEXT: srli a2, a2, 8 +; CHECK-RV64-NEXT: srli a1, a1, 8 +; CHECK-RV64-NEXT: sll a1, a1, a2 +; CHECK-RV64-NEXT: ppack.h a2, a4, a3 +; CHECK-RV64-NEXT: ppack.h a1, a5, a1 +; CHECK-RV64-NEXT: ppack.w a1, a1, a2 +; CHECK-RV64-NEXT: sw a1, 0(a0) +; CHECK-RV64-NEXT: ret + %a = load <4 x i8>, ptr %a_ptr + %b = load <4 x i8>, ptr %shamt_ptr + %res = shl <4 x i8> %a, %b + store <4 x i8> %res, ptr %ret_ptr + ret void +} diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll index 9b021df8dd452..721762fffdc9f 100644 --- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll +++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll @@ -805,3 +805,39 @@ define void @test_psslai_w(ptr %ret_ptr, ptr %a_ptr) { store <2 x i32> %res, ptr %ret_ptr ret void } + +; Test logical shift left(scalar shamt) +define void @test_psll_ws(ptr %ret_ptr, ptr %a_ptr, i32 %shamt) { +; CHECK-LABEL: test_psll_ws: +; CHECK: # %bb.0: +; CHECK-NEXT: ld a1, 0(a1) +; CHECK-NEXT: psll.ws a1, a1, a2 +; CHECK-NEXT: sd a1, 0(a0) +; CHECK-NEXT: ret + %a = load <2 x i32>, ptr %a_ptr + %insert = insertelement <2 x i32> undef, i32 %shamt, i32 0 + %b = shufflevector <2 x i32> %insert, <2 x i32> undef, <2 x i32> zeroinitializer + %res = shl <2 x i32> %a, %b + store <2 x i32> %res, ptr %ret_ptr + ret void +} + +; Test logical shift left(vector shamt) +define void @test_psll_ws_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) { +; CHECK-LABEL: test_psll_ws_vec_shamt: +; CHECK: # %bb.0: +; CHECK-NEXT: ld a1, 0(a1) +; CHECK-NEXT: ld a2, 0(a2) +; CHECK-NEXT: sllw a3, a1, a2 +; CHECK-NEXT: srli a2, a2, 32 +; CHECK-NEXT: srli a1, a1, 32 +; CHECK-NEXT: sllw a1, a1, a2 +; CHECK-NEXT: pack a1, a3, a1 +; CHECK-NEXT: sd a1, 0(a0) +; CHECK-NEXT: ret + %a = load <2 x i32>, ptr %a_ptr + %b = load <2 x i32>, ptr %shamt_ptr + %res = shl <2 x i32> %a, %b + store <2 x i32> %res, ptr %ret_ptr + ret void +} From 06b77024534a74b88f3c2c0e6b80966676355ba2 Mon Sep 17 00:00:00 2001 From: Brandon Wu Date: Sun, 30 Nov 2025 20:28:11 -0800 Subject: [PATCH 2/4] fixup! more comments --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 62d659574892c..ca67d741b035d 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -8605,6 +8605,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, SDLoc DL(Op); SmallVector Results; + // We have patterns for scalar/immediate shift amount, so no lowering + // needed. if (Src1.getNode()->getOpcode() == ISD::SPLAT_VECTOR) return Op; From db414485ccfb2b11c01dda504586430fc0b794c4 Mon Sep 17 00:00:00 2001 From: Brandon Wu Date: Sun, 30 Nov 2025 20:30:34 -0800 Subject: [PATCH 3/4] fixup! replace undef --- llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll | 8 ++++---- llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll index 3f4366f0dc758..bc79d441f6407 100644 --- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll +++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll @@ -647,8 +647,8 @@ define void @test_psll_hs(ptr %ret_ptr, ptr %a_ptr, i16 %shamt) { ; CHECK-NEXT: sw a1, 0(a0) ; CHECK-NEXT: ret %a = load <2 x i16>, ptr %a_ptr - %insert = insertelement <2 x i16> undef, i16 %shamt, i32 0 - %b = shufflevector <2 x i16> %insert, <2 x i16> undef, <2 x i32> zeroinitializer + %insert = insertelement <2 x i16> poison, i16 %shamt, i32 0 + %b = shufflevector <2 x i16> %insert, <2 x i16> poison, <2 x i32> zeroinitializer %res = shl <2 x i16> %a, %b store <2 x i16> %res, ptr %ret_ptr ret void @@ -662,8 +662,8 @@ define void @test_psll_bs(ptr %ret_ptr, ptr %a_ptr, i8 %shamt) { ; CHECK-NEXT: sw a1, 0(a0) ; CHECK-NEXT: ret %a = load <4 x i8>, ptr %a_ptr - %insert = insertelement <4 x i8> undef, i8 %shamt, i32 0 - %b = shufflevector <4 x i8> %insert, <4 x i8> undef, <4 x i32> zeroinitializer + %insert = insertelement <4 x i8> poison, i8 %shamt, i32 0 + %b = shufflevector <4 x i8> %insert, <4 x i8> poison, <4 x i32> zeroinitializer %res = shl <4 x i8> %a, %b store <4 x i8> %res, ptr %ret_ptr ret void diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll index 721762fffdc9f..197a1869963c7 100644 --- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll +++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll @@ -815,8 +815,8 @@ define void @test_psll_ws(ptr %ret_ptr, ptr %a_ptr, i32 %shamt) { ; CHECK-NEXT: sd a1, 0(a0) ; CHECK-NEXT: ret %a = load <2 x i32>, ptr %a_ptr - %insert = insertelement <2 x i32> undef, i32 %shamt, i32 0 - %b = shufflevector <2 x i32> %insert, <2 x i32> undef, <2 x i32> zeroinitializer + %insert = insertelement <2 x i32> poison, i32 %shamt, i32 0 + %b = shufflevector <2 x i32> %insert, <2 x i32> poison, <2 x i32> zeroinitializer %res = shl <2 x i32> %a, %b store <2 x i32> %res, ptr %ret_ptr ret void From 59f5bcb50cc9141ef9ea8c2f63ac6a395d39bbc9 Mon Sep 17 00:00:00 2001 From: Brandon Wu Date: Tue, 2 Dec 2025 01:20:21 -0800 Subject: [PATCH 4/4] fixup! use unrollVectorOp --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 29 +++------------------ 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index ca67d741b035d..1be04443856c8 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -8595,35 +8595,14 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::SHL: if (Subtarget.enablePExtCodeGen() && Op.getSimpleValueType().isFixedLengthVector()) { - // There's no vector-vector version of shift instruction in P extension so - // we need to fallback to scalar computation and pack them back. - MVT VecVT = Op.getSimpleValueType(); - unsigned NumElts = VecVT.getVectorNumElements(); - MVT VecEltTy = VecVT.getVectorElementType(); - SDValue Src0 = Op.getOperand(0); - SDValue Src1 = Op.getOperand(1); - SDLoc DL(Op); - SmallVector Results; - // We have patterns for scalar/immediate shift amount, so no lowering // needed. - if (Src1.getNode()->getOpcode() == ISD::SPLAT_VECTOR) + if (Op.getOperand(1)->getOpcode() == ISD::SPLAT_VECTOR) return Op; - for (unsigned I = 0; I < NumElts; ++I) { - // extract scalar value - SDValue SrcElt0 = - DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltTy, - {Src0, DAG.getConstant(I, DL, Subtarget.getXLenVT())}); - SDValue SrcElt1 = - DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltTy, - {Src1, DAG.getConstant(I, DL, Subtarget.getXLenVT())}); - // perform computation - Results.push_back( - DAG.getNode(ISD::SHL, DL, VecEltTy, SrcElt0, SrcElt1)); - } - // pack the results - return DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, Results); + // There's no vector-vector version of shift instruction in P extension so + // we need to unroll to scalar computation and pack them back. + return DAG.UnrollVectorOp(Op.getNode()); } [[fallthrough]]; case ISD::SRA: