Skip to content

Conversation

@4vtomat
Copy link
Member

@4vtomat 4vtomat commented Dec 1, 2025

There's no instruciton for vector shift amount, so we have to scalarize
it and rebuild the vector.

There's no instruciton for vector shift amount, so we have to scalarize
it and rebuild the vector.
@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Brandon Wu (4vtomat)

Changes

There's no instruciton for vector shift amount, so we have to scalarize
it and rebuild the vector.


Full diff: https://github.com/llvm/llvm-project/pull/170074.diff

4 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+32)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoP.td (+13)
  • (modified) llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll (+109)
  • (modified) llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll (+36)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index a6212f5cc84be..62d659574892c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -529,6 +529,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal);
     setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal);
     setOperationAction(ISD::SPLAT_VECTOR, VTs, Legal);
+    setOperationAction(ISD::SHL, VTs, Custom);
     setOperationAction(ISD::BITCAST, VTs, Custom);
     setOperationAction(ISD::EXTRACT_VECTOR_ELT, VTs, Custom);
   }
@@ -8592,6 +8593,37 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   case ISD::VSELECT:
     return lowerToScalableOp(Op, DAG);
   case ISD::SHL:
+    if (Subtarget.enablePExtCodeGen() &&
+        Op.getSimpleValueType().isFixedLengthVector()) {
+      // There's no vector-vector version of shift instruction in P extension so
+      // we need to fallback to scalar computation and pack them back.
+      MVT VecVT = Op.getSimpleValueType();
+      unsigned NumElts = VecVT.getVectorNumElements();
+      MVT VecEltTy = VecVT.getVectorElementType();
+      SDValue Src0 = Op.getOperand(0);
+      SDValue Src1 = Op.getOperand(1);
+      SDLoc DL(Op);
+      SmallVector<SDValue, 2> Results;
+
+      if (Src1.getNode()->getOpcode() == ISD::SPLAT_VECTOR)
+        return Op;
+
+      for (unsigned I = 0; I < NumElts; ++I) {
+        // extract scalar value
+        SDValue SrcElt0 =
+            DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltTy,
+                        {Src0, DAG.getConstant(I, DL, Subtarget.getXLenVT())});
+        SDValue SrcElt1 =
+            DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltTy,
+                        {Src1, DAG.getConstant(I, DL, Subtarget.getXLenVT())});
+        // perform computation
+        Results.push_back(
+            DAG.getNode(ISD::SHL, DL, VecEltTy, SrcElt0, SrcElt1));
+      }
+      // pack the results
+      return DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, Results);
+    }
+    [[fallthrough]];
   case ISD::SRA:
   case ISD::SRL:
     if (Op.getSimpleValueType().isFixedLengthVector())
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 599358368594f..fd0d26657a88a 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1532,6 +1532,15 @@ let Predicates = [HasStdExtP] in {
   def: Pat<(XLenVecI16VT (sshlsat GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))),
            (PSSLAI_H GPR:$rs1, uimm4:$shamt)>;
 
+  // 8-bit logical shift left
+  def: Pat<(XLenVecI8VT (shl GPR:$rs1,
+                             (XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))),
+           (PSLL_BS GPR:$rs1, GPR:$rs2)>;
+  // 16-bit logical shift left
+  def: Pat<(XLenVecI16VT (shl GPR:$rs1,
+                              (XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))),
+           (PSLL_HS GPR:$rs1, GPR:$rs2)>;
+
   // 8-bit PLI SD node pattern
   def: Pat<(XLenVecI8VT (splat_vector simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>;
   // 16-bit PLI SD node pattern
@@ -1578,6 +1587,10 @@ let Predicates = [HasStdExtP, IsRV64] in {
   def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
   def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
 
+  // 32-bit logical shift left
+  def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
+           (PSLL_WS GPR:$rs1, GPR:$rs2)>;
+
   // splat pattern
   def: Pat<(v2i32 (splat_vector (XLenVT GPR:$rs2))), (PADD_WS (XLenVT X0), GPR:$rs2)>;
 
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index f803f6aa09652..3f4366f0dc758 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -637,3 +637,112 @@ define void @test_psslai_h(ptr %ret_ptr, ptr %a_ptr) {
   store <2 x i16> %res, ptr %ret_ptr
   ret void
 }
+
+; Test logical shift left(scalar shamt)
+define void @test_psll_hs(ptr %ret_ptr, ptr %a_ptr, i16 %shamt) {
+; CHECK-LABEL: test_psll_hs:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    psll.hs a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %insert = insertelement <2 x i16> undef, i16 %shamt, i32 0
+  %b = shufflevector <2 x i16> %insert, <2 x i16> undef, <2 x i32> zeroinitializer
+  %res = shl <2 x i16> %a, %b
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_psll_bs(ptr %ret_ptr, ptr %a_ptr, i8 %shamt) {
+; CHECK-LABEL: test_psll_bs:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    psll.bs a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i8>, ptr %a_ptr
+  %insert = insertelement <4 x i8> undef, i8 %shamt, i32 0
+  %b = shufflevector <4 x i8> %insert, <4 x i8> undef, <4 x i32> zeroinitializer
+  %res = shl <4 x i8> %a, %b
+  store <4 x i8> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test logical shift left(vector shamt)
+define void @test_psll_hs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
+; CHECK-RV32-LABEL: test_psll_hs_vec_shamt:
+; CHECK-RV32:       # %bb.0:
+; CHECK-RV32-NEXT:    lw a1, 0(a1)
+; CHECK-RV32-NEXT:    lw a2, 0(a2)
+; CHECK-RV32-NEXT:    sll a3, a1, a2
+; CHECK-RV32-NEXT:    srli a2, a2, 16
+; CHECK-RV32-NEXT:    srli a1, a1, 16
+; CHECK-RV32-NEXT:    sll a1, a1, a2
+; CHECK-RV32-NEXT:    pack a1, a3, a1
+; CHECK-RV32-NEXT:    sw a1, 0(a0)
+; CHECK-RV32-NEXT:    ret
+;
+; CHECK-RV64-LABEL: test_psll_hs_vec_shamt:
+; CHECK-RV64:       # %bb.0:
+; CHECK-RV64-NEXT:    lw a1, 0(a1)
+; CHECK-RV64-NEXT:    lw a2, 0(a2)
+; CHECK-RV64-NEXT:    sll a3, a1, a2
+; CHECK-RV64-NEXT:    srli a2, a2, 16
+; CHECK-RV64-NEXT:    srli a1, a1, 16
+; CHECK-RV64-NEXT:    sll a1, a1, a2
+; CHECK-RV64-NEXT:    ppack.w a1, a3, a1
+; CHECK-RV64-NEXT:    sw a1, 0(a0)
+; CHECK-RV64-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %shamt_ptr
+  %res = shl <2 x i16> %a, %b
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_psll_bs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
+; CHECK-RV32-LABEL: test_psll_bs_vec_shamt:
+; CHECK-RV32:       # %bb.0:
+; CHECK-RV32-NEXT:    lw a2, 0(a2)
+; CHECK-RV32-NEXT:    lw a1, 0(a1)
+; CHECK-RV32-NEXT:    srli a3, a2, 24
+; CHECK-RV32-NEXT:    srli a4, a1, 24
+; CHECK-RV32-NEXT:    srli a5, a2, 8
+; CHECK-RV32-NEXT:    srli a6, a1, 8
+; CHECK-RV32-NEXT:    sll a7, a4, a3
+; CHECK-RV32-NEXT:    sll a6, a6, a5
+; CHECK-RV32-NEXT:    sll a4, a1, a2
+; CHECK-RV32-NEXT:    srli a2, a2, 16
+; CHECK-RV32-NEXT:    srli a1, a1, 16
+; CHECK-RV32-NEXT:    sll a5, a1, a2
+; CHECK-RV32-NEXT:    ppack.dh a2, a4, a6
+; CHECK-RV32-NEXT:    pack a1, a2, a3
+; CHECK-RV32-NEXT:    sw a1, 0(a0)
+; CHECK-RV32-NEXT:    ret
+;
+; CHECK-RV64-LABEL: test_psll_bs_vec_shamt:
+; CHECK-RV64:       # %bb.0:
+; CHECK-RV64-NEXT:    lw a2, 0(a2)
+; CHECK-RV64-NEXT:    lw a1, 0(a1)
+; CHECK-RV64-NEXT:    srli a3, a2, 24
+; CHECK-RV64-NEXT:    srli a4, a1, 24
+; CHECK-RV64-NEXT:    srli a5, a2, 16
+; CHECK-RV64-NEXT:    sll a3, a4, a3
+; CHECK-RV64-NEXT:    srli a4, a1, 16
+; CHECK-RV64-NEXT:    sll a4, a4, a5
+; CHECK-RV64-NEXT:    sll a5, a1, a2
+; CHECK-RV64-NEXT:    srli a2, a2, 8
+; CHECK-RV64-NEXT:    srli a1, a1, 8
+; CHECK-RV64-NEXT:    sll a1, a1, a2
+; CHECK-RV64-NEXT:    ppack.h a2, a4, a3
+; CHECK-RV64-NEXT:    ppack.h a1, a5, a1
+; CHECK-RV64-NEXT:    ppack.w a1, a1, a2
+; CHECK-RV64-NEXT:    sw a1, 0(a0)
+; CHECK-RV64-NEXT:    ret
+  %a = load <4 x i8>, ptr %a_ptr
+  %b = load <4 x i8>, ptr %shamt_ptr
+  %res = shl <4 x i8> %a, %b
+  store <4 x i8> %res, ptr %ret_ptr
+  ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index 9b021df8dd452..721762fffdc9f 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -805,3 +805,39 @@ define void @test_psslai_w(ptr %ret_ptr, ptr %a_ptr) {
   store <2 x i32> %res, ptr %ret_ptr
   ret void
 }
+
+; Test logical shift left(scalar shamt)
+define void @test_psll_ws(ptr %ret_ptr, ptr %a_ptr, i32 %shamt) {
+; CHECK-LABEL: test_psll_ws:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    psll.ws a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %insert = insertelement <2 x i32> undef, i32 %shamt, i32 0
+  %b = shufflevector <2 x i32> %insert, <2 x i32> undef, <2 x i32> zeroinitializer
+  %res = shl <2 x i32> %a, %b
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test logical shift left(vector shamt)
+define void @test_psll_ws_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
+; CHECK-LABEL: test_psll_ws_vec_shamt:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    sllw a3, a1, a2
+; CHECK-NEXT:    srli a2, a2, 32
+; CHECK-NEXT:    srli a1, a1, 32
+; CHECK-NEXT:    sllw a1, a1, a2
+; CHECK-NEXT:    pack a1, a3, a1
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %shamt_ptr
+  %res = shl <2 x i32> %a, %b
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}

@github-actions
Copy link

github-actions bot commented Dec 1, 2025

✅ With the latest revision this PR passed the undef deprecator.

@4vtomat 4vtomat requested a review from topperc December 1, 2025 04:50
if (Src1.getNode()->getOpcode() == ISD::SPLAT_VECTOR)
return Op;

for (unsigned I = 0; I < NumElts; ++I) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use DAG.UnrollVectorOp?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's much more concise, I should have done this way!

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@4vtomat 4vtomat merged commit af4098b into llvm:main Dec 5, 2025
10 checks passed
@4vtomat 4vtomat deleted the p_ext_codegen4 branch December 5, 2025 03:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants